diff --git a/web_demo_2.5.py b/web_demo_2.5.py index 2076da3..6f6b81a 100644 --- a/web_demo_2.5.py +++ b/web_demo_2.5.py @@ -31,8 +31,7 @@ if 'int4' in model_path: exit() model = AutoModel.from_pretrained(model_path, trust_remote_code=True) else: - model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.float16) - model = model.to(device=device) + model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map=device) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model.eval()