From 9dc559fc2aff47aa4b18fb548bad762110ecaebb Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Tue, 8 Apr 2025 12:23:26 +0800 Subject: [PATCH] force set use_flow_cache --- cosyvoice/cli/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 9a50991..a7157bc 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -401,10 +401,6 @@ class CosyVoice2Model(CosyVoiceModel): prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) - # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size - if self.use_flow_cache is True: - flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:] - prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:] with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None @@ -412,6 +408,10 @@ class CosyVoice2Model(CosyVoiceModel): p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p.start() if stream is True: + assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM" + # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size + flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:] + prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:] while True: time.sleep(0.1) if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len: @@ -442,6 +442,7 @@ class CosyVoice2Model(CosyVoiceModel): yield {'tts_speech': this_tts_speech.cpu()} else: # deal with all tokens + assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference" p.join() this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token,