From 39ffc50dec737c491c61a0f9ef235d3b0eb2e934 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Mon, 7 Apr 2025 21:23:09 +0800 Subject: [PATCH] add flow cache inference code --- README.md | 2 +- cosyvoice/cli/cosyvoice.py | 6 +++--- cosyvoice/cli/model.py | 21 ++++++++++--------- .../libritts/cosyvoice2/conf/cosyvoice2.yaml | 8 +++---- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 62bdf1d..bd18016 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ import torchaudio **CosyVoice2 Usage** ```python -cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False) +cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False) # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference # zero_shot usage diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index a511d78..bcff6ab 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -129,7 +129,7 @@ class CosyVoice: class CosyVoice2(CosyVoice): - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 @@ -151,9 +151,9 @@ class CosyVoice2(CosyVoice): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') - self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) + self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache) self.model.load('{}/llm.pt'.format(model_dir), - '{}/flow.pt'.format(model_dir), + '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) if load_jit: self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 68f9967..841e55d 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -288,19 +288,20 @@ class CosyVoice2Model(CosyVoiceModel): llm: torch.nn.Module, flow: torch.nn.Module, hift: torch.nn.Module, - fp16: bool): + fp16: bool, + use_flow_cache: bool): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift self.fp16 = fp16 + self.use_flow_cache = use_flow_cache if self.fp16 is True: self.llm.half() self.flow.half() - self.token_hop_len = self.flow.encoder.static_chunk_size - # flow decoder required_cache_size - # TODO 基模型训练时没有设置num_decoding_left_chunks,需要重新训一下才能指定flow_decoder_required_cache_size - self.flow_decoder_required_cache_size = 999 + # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml + self.token_hop_len = 25 + self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len # hift cache self.mel_cache_len = 8 self.source_cache_len = int(self.mel_cache_len * 480) @@ -339,7 +340,7 @@ class CosyVoice2Model(CosyVoiceModel): return cache def trim_flow_cache(self, cache): - if cache['decoder_cache']['down_blocks_kv_cache'].size(4) > self.flow_decoder_required_cache_size: + if self.flow_decoder_required_cache_size > 0: cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] @@ -399,10 +400,10 @@ class CosyVoice2Model(CosyVoiceModel): prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) - # NOTE flow model is only trained with static_chunk_size, so we need to trim flow prompt - n_chunk = int(flow_prompt_speech_token.size(1) / self.token_hop_len) - flow_prompt_speech_token = flow_prompt_speech_token[:, :n_chunk * self.token_hop_len] - prompt_speech_feat = prompt_speech_feat[:, :n_chunk * self.token_hop_len * 2] + # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size + if self.use_flow_cache is True: + flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:] + prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:] with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None diff --git a/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml b/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml index 72335ff..d6bdeb6 100644 --- a/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml +++ b/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml @@ -14,8 +14,8 @@ token_frame_rate: 25 token_mel_ratio: 2 # stream related params -chunk_size: 2 # streaming inference chunk size, in second -num_decoding_left_chunks: 1 # streaming inference flow decoder left chunk size +chunk_size: 25 # streaming inference chunk size, in token +num_decoding_left_chunks: 1 # streaming inference flow decoder left chunk size, <0 means use all left chunks # model params # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. @@ -60,7 +60,7 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec input_size: 512 use_cnn_module: False macaron_style: False - static_chunk_size: !ref * + static_chunk_size: !ref decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM in_channels: 240 n_spks: 1 @@ -83,7 +83,7 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec num_mid_blocks: 12 num_heads: 8 act_fn: 'gelu' - static_chunk_size: !ref * * # here we use static_chunk_size because we want to fix kv cache size during inference + static_chunk_size: !ref * num_decoding_left_chunks: !ref hift: !new:cosyvoice.hifigan.generator.HiFTGenerator