From d71d790f5530c085b68f522c4bd59d5bdeec83c0 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Tue, 15 Apr 2025 13:12:25 +0800 Subject: [PATCH] fix flow cache bug --- cosyvoice/cli/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index a7157bc..0eff9b3 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -301,7 +301,7 @@ class CosyVoice2Model(CosyVoiceModel): self.flow.half() # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml self.token_hop_len = 25 - self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len + self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio # hift cache self.mel_cache_len = 8 self.source_cache_len = int(self.mel_cache_len * 480) @@ -340,7 +340,7 @@ class CosyVoice2Model(CosyVoiceModel): return cache def trim_flow_cache(self, cache): - if self.flow_decoder_required_cache_size > 0: + if self.flow_decoder_required_cache_size > 0 and cache['decoder_cache']['down_blocks_kv_cache'].size(4) > self.flow_decoder_required_cache_size: cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] @@ -410,8 +410,8 @@ class CosyVoice2Model(CosyVoiceModel): if stream is True: assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM" # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size - flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:] - prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:] + flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):] + prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:] while True: time.sleep(0.1) if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len: