diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index b75774c..ea0ec4a 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -63,11 +63,11 @@ class CosyVoiceModel: self.hift.to(self.device).eval() def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): - llm_text_encoder = torch.jit.load(llm_text_encoder_model) + llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device) self.llm.text_encoder = llm_text_encoder - llm_llm = torch.jit.load(llm_llm_model) + llm_llm = torch.jit.load(llm_llm_model, map_location=self.device) self.llm.llm = llm_llm - flow_encoder = torch.jit.load(flow_encoder_model) + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder def load_onnx(self, flow_decoder_estimator_model):