This commit is contained in:
lyuxiang.lx
2024-10-16 13:30:13 +08:00
parent 29507bc77a
commit 7e6d60c24c
2 changed files with 4 additions and 4 deletions

View File

@@ -99,7 +99,7 @@ def main():
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding} 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
tts_speeches = [] tts_speeches = []
for model_output in model.inference(**model_input): for model_output in model.tts(**model_input):
tts_speeches.append(model_output['tts_speech']) tts_speeches.append(model_output['tts_speech'])
tts_speeches = torch.concat(tts_speeches, dim=1) tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0]) tts_key = '{}_{}'.format(utts[0], tts_index[0])

View File

@@ -56,14 +56,14 @@ class CosyVoiceModel:
self.hift_cache_dict = {} self.hift_cache_dict = {}
def load(self, llm_model, flow_model, hift_model): def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
self.llm.to(self.device).eval() self.llm.to(self.device).eval()
if self.fp16 is True: if self.fp16 is True:
self.llm.half() self.llm.half()
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
self.flow.to(self.device).eval() self.flow.to(self.device).eval()
# in case hift_model is a hifigan model # in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device)} hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=False) self.hift.load_state_dict(hift_state_dict, strict=False)
self.hift.to(self.device).eval() self.hift.to(self.device).eval()