From ba3d9693da978e6bc3211f30e8ee5596f520b96d Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Thu, 26 Sep 2024 14:55:03 +0800 Subject: [PATCH] load jit to device --- cosyvoice/cli/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index b75774c..ea0ec4a 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -63,11 +63,11 @@ class CosyVoiceModel: self.hift.to(self.device).eval() def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): - llm_text_encoder = torch.jit.load(llm_text_encoder_model) + llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device) self.llm.text_encoder = llm_text_encoder - llm_llm = torch.jit.load(llm_llm_model) + llm_llm = torch.jit.load(llm_llm_model, map_location=self.device) self.llm.llm = llm_llm - flow_encoder = torch.jit.load(flow_encoder_model) + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder def load_onnx(self, flow_decoder_estimator_model):