Merge remote-tracking branch 'origin/inference_streaming' into inference_streaming

This commit is contained in:
禾息
2024-09-03 11:13:25 +08:00
15 changed files with 754 additions and 6 deletions

View File

@@ -159,7 +159,6 @@ class CosyVoiceModel:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
p.join()
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
@@ -180,7 +179,7 @@ class CosyVoiceModel:
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
# p.join()
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
@@ -193,7 +192,7 @@ class CosyVoiceModel:
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
# p.join()
p.join()
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
this_tts_speech = self.token2wav(token=this_tts_speech_token,