Merge pull request #153 from whn09/main

combine dtype and device to save CPU memory
This commit is contained in:
Hongji Zhu
2024-05-31 11:23:13 +08:00
committed by GitHub

View File

@@ -31,8 +31,7 @@ if 'int4' in model_path:
exit()
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
else:
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.float16)
model = model.to(device=device)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval()