mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 09:59:23 +08:00
fix flow cache bug
This commit is contained in:
@@ -301,7 +301,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
self.flow.half()
|
self.flow.half()
|
||||||
# stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
|
# stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
|
||||||
self.token_hop_len = 25
|
self.token_hop_len = 25
|
||||||
self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len
|
self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
|
||||||
# hift cache
|
# hift cache
|
||||||
self.mel_cache_len = 8
|
self.mel_cache_len = 8
|
||||||
self.source_cache_len = int(self.mel_cache_len * 480)
|
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||||
@@ -340,7 +340,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
return cache
|
return cache
|
||||||
|
|
||||||
def trim_flow_cache(self, cache):
|
def trim_flow_cache(self, cache):
|
||||||
if self.flow_decoder_required_cache_size > 0:
|
if self.flow_decoder_required_cache_size > 0 and cache['decoder_cache']['down_blocks_kv_cache'].size(4) > self.flow_decoder_required_cache_size:
|
||||||
cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
||||||
cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
||||||
cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
||||||
@@ -410,8 +410,8 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
if stream is True:
|
if stream is True:
|
||||||
assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
|
assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
|
||||||
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
|
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
|
||||||
flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:]
|
flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):]
|
||||||
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:]
|
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:]
|
||||||
while True:
|
while True:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
|
if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
|
||||||
|
|||||||
Reference in New Issue
Block a user