diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 0eff9b3..677f486 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -301,7 +301,7 @@ class CosyVoice2Model(CosyVoiceModel): self.flow.half() # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml self.token_hop_len = 25 - self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio + self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio # hift cache self.mel_cache_len = 8 self.source_cache_len = int(self.mel_cache_len * 480) @@ -325,11 +325,11 @@ class CosyVoice2Model(CosyVoiceModel): 'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)} decoder_cache = {'offset': 0, 'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device), - 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device), + 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device), 'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device), - 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, 0, 512, 2).to(self.device), + 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device), 'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device), - 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device), + 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device), 'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)} if self.fp16 is True: for cache in [encoder_cache, decoder_cache]: @@ -339,13 +339,6 @@ class CosyVoice2Model(CosyVoiceModel): cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache} return cache - def trim_flow_cache(self, cache): - if self.flow_decoder_required_cache_size > 0 and cache['decoder_cache']['down_blocks_kv_cache'].size(4) > self.flow_decoder_required_cache_size: - cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] - cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] - cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] - return cache - def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder @@ -369,7 +362,6 @@ class CosyVoice2Model(CosyVoiceModel): embedding=embedding.to(self.device), cache=self.flow_cache_dict[uuid], finalize=finalize) - self.flow_cache_dict[uuid] = self.trim_flow_cache(self.flow_cache_dict[uuid]) # append hift cache if self.hift_cache_dict[uuid] is not None: hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 4039896..9b44ee3 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -243,11 +243,6 @@ class CausalConditionalCFM(ConditionalCFM): # Or in future might add like a return_all_steps flag sol = [] - # estimator cache for each step - down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x) - mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x) - up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x) - # Do not use concat, it may cause memory format changed and trt infer with wrong results! x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype) @@ -255,6 +250,7 @@ class CausalConditionalCFM(ConditionalCFM): t_in = torch.zeros([2], device=x.device, dtype=x.dtype) spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype) cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) + flow_cache_size = cache['down_blocks_kv_cache'].shape[4] for step in range(1, len(t_span)): # Classifier-Free Guidance inference introduced in VoiceBox x_in[:] = x @@ -271,13 +267,15 @@ class CausalConditionalCFM(ConditionalCFM): cond_in, cache_step ) - cache['down_blocks_conv_cache'][step - 1] = cache_step[0] - down_blocks_kv_cache_new[step - 1] = cache_step[1] - cache['mid_blocks_conv_cache'][step - 1] = cache_step[2] - mid_blocks_kv_cache_new[step - 1] = cache_step[3] - cache['up_blocks_conv_cache'][step - 1] = cache_step[4] - up_blocks_kv_cache_new[step - 1] = cache_step[5] - cache['final_blocks_conv_cache'][step - 1] = cache_step[6] + # NOTE if smaller than flow_cache_size, means last chunk, no need to cache + if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size: + cache['down_blocks_conv_cache'][step - 1] = cache_step[0] + cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:,:,:,-flow_cache_size:] + cache['mid_blocks_conv_cache'][step - 1] = cache_step[2] + cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:,:,:,-flow_cache_size:] + cache['up_blocks_conv_cache'][step - 1] = cache_step[4] + cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:,:,:,-flow_cache_size:] + cache['final_blocks_conv_cache'][step - 1] = cache_step[6] dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) x = x + dt * dphi_dt @@ -285,9 +283,6 @@ class CausalConditionalCFM(ConditionalCFM): sol.append(x) if step < len(t_span) - 1: dt = t_span[step + 1] - t - cache['down_blocks_kv_cache'] = torch.concat([cache['down_blocks_kv_cache'], down_blocks_kv_cache_new], dim=4) - cache['mid_blocks_kv_cache'] = torch.concat([cache['mid_blocks_kv_cache'], mid_blocks_kv_cache_new], dim=4) - cache['up_blocks_kv_cache'] = torch.concat([cache['up_blocks_kv_cache'], up_blocks_kv_cache_new], dim=4) return sol[-1].float(), cache def forward_estimator(self, x, mask, mu, t, spks, cond, cache):