Add warning when inference with mps and bf16 on Mac

This commit is contained in:
Hongji Zhu
2024-06-04 11:26:46 +08:00
parent dfbc3211ef
commit 1c5fadf592

View File

@@ -26,7 +26,11 @@ args = parser.parse_args()
device = args.device
assert device in ['cuda', 'mps']
if args.dtype == 'bf16':
dtype = torch.bfloat16
if device == 'mps':
print('Warning: MPS does not support bf16, will use fp16 instead')
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
dtype = torch.float16