mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
Merge branch 'inference_streaming' into flow_tensorrt
This commit is contained in:
@@ -43,7 +43,6 @@ class CosyVoice:
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
||||
'{}/llm.llm.fp16.zip'.format(model_dir))
|
||||
|
||||
if load_trt:
|
||||
self.model.load_trt(model_dir, use_fp16)
|
||||
|
||||
|
||||
@@ -137,7 +137,6 @@ class CosyVoiceModel:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
p.start()
|
||||
p.join()
|
||||
if stream is True:
|
||||
token_hop_len = self.token_min_hop_len
|
||||
while True:
|
||||
@@ -158,7 +157,7 @@ class CosyVoiceModel:
|
||||
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
||||
break
|
||||
# p.join()
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
||||
with self.flow_hift_context:
|
||||
@@ -171,7 +170,7 @@ class CosyVoiceModel:
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
# p.join()
|
||||
p.join()
|
||||
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
||||
with self.flow_hift_context:
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
|
||||
Reference in New Issue
Block a user