Merge pull request #153 from whn09/main

combine dtype and device to save CPU memory
This commit is contained in:
Hongji Zhu
2024-05-31 11:23:13 +08:00
committed by GitHub

View File

@@ -31,8 +31,7 @@ if 'int4' in model_path:
exit() exit()
model = AutoModel.from_pretrained(model_path, trust_remote_code=True) model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
else: else:
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.float16) model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map=device)
model = model.to(device=device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval() model.eval()