Merge pull request #495 from FunAudioLLM/dev/lyuxiang.lx

fix bug
This commit is contained in:
Xiang Lyu
2024-10-16 13:31:15 +08:00
committed by GitHub
2 changed files with 4 additions and 4 deletions

View File

@@ -99,7 +99,7 @@ def main():
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
tts_speeches = []
for model_output in model.inference(**model_input):
for model_output in model.tts(**model_input):
tts_speeches.append(model_output['tts_speech'])
tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0])

View File

@@ -56,14 +56,14 @@ class CosyVoiceModel:
self.hift_cache_dict = {}
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
self.llm.to(self.device).eval()
if self.fp16 is True:
self.llm.half()
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device)}
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=False)
self.hift.to(self.device).eval()