From 1c5fadf592a86de7361688b325e5a8f06da6a2b5 Mon Sep 17 00:00:00 2001 From: Hongji Zhu Date: Tue, 4 Jun 2024 11:26:46 +0800 Subject: [PATCH] Add warning when inference with mps and bf16 on Mac --- web_demo.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/web_demo.py b/web_demo.py index a9a0727..668dcf4 100644 --- a/web_demo.py +++ b/web_demo.py @@ -26,7 +26,11 @@ args = parser.parse_args() device = args.device assert device in ['cuda', 'mps'] if args.dtype == 'bf16': - dtype = torch.bfloat16 + if device == 'mps': + print('Warning: MPS does not support bf16, will use fp16 instead') + dtype = torch.float16 + else: + dtype = torch.bfloat16 else: dtype = torch.float16