diff --git a/cosyvoice/bin/inference.py b/cosyvoice/bin/inference.py index 32acf3a..2cb831a 100644 --- a/cosyvoice/bin/inference.py +++ b/cosyvoice/bin/inference.py @@ -99,7 +99,7 @@ def main(): 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding} tts_speeches = [] - for model_output in model.inference(**model_input): + for model_output in model.tts(**model_input): tts_speeches.append(model_output['tts_speech']) tts_speeches = torch.concat(tts_speeches, dim=1) tts_key = '{}_{}'.format(utts[0], tts_index[0]) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index cf6389d..7ac6cf9 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -56,14 +56,14 @@ class CosyVoiceModel: self.hift_cache_dict = {} def load(self, llm_model, flow_model, hift_model): - self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) + self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False) self.llm.to(self.device).eval() if self.fp16 is True: self.llm.half() - self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) + self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False) self.flow.to(self.device).eval() # in case hift_model is a hifigan model - hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device)} + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} self.hift.load_state_dict(hift_state_dict, strict=False) self.hift.to(self.device).eval()