From cbfed4a9ee71a2a9abc795b01bcd0757e6d17576 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Tue, 27 May 2025 13:47:12 +0800 Subject: [PATCH] send streaming as args --- cosyvoice/cli/model.py | 7 +++---- cosyvoice/flow/decoder.py | 4 ---- cosyvoice/flow/flow.py | 6 ++++-- cosyvoice/flow/flow_matching.py | 13 +++++++------ cosyvoice/transformer/upsample_encoder.py | 3 --- 5 files changed, 14 insertions(+), 19 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 811b2cb..c1e441f 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -258,9 +258,6 @@ class CosyVoice2Model(CosyVoiceModel): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow - # NOTE default setting for jit/onnx export, you can set to False when using pytorch inference - self.flow.encoder.streaming = True - self.flow.decoder.estimator.streaming = True self.hift = hift self.fp16 = fp16 self.trt_concurrent = trt_concurrent @@ -290,7 +287,7 @@ class CosyVoice2Model(CosyVoiceModel): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder - def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0): + def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]: tts_mel, _ = self.flow.inference(token=token.to(self.device), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), @@ -299,6 +296,7 @@ class CosyVoice2Model(CosyVoiceModel): prompt_feat=prompt_feat.to(self.device), prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), embedding=embedding.to(self.device), + streaming=stream, finalize=finalize) tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] # append hift cache @@ -356,6 +354,7 @@ class CosyVoice2Model(CosyVoiceModel): embedding=flow_embedding, token_offset=token_offset, uuid=this_uuid, + stream=stream, finalize=False) token_offset += this_token_hop_len yield {'tts_speech': this_tts_speech.cpu()} diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 9e28c3f..97768a4 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -419,10 +419,6 @@ class CausalConditionalDecoder(ConditionalDecoder): Returns: _type_: _description_ """ - if hasattr(self, 'streaming'): - assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode' - streaming = self.streaming - t = self.time_embeddings(t).to(t.dtype) t = self.time_mlp(t) diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index d9e832b..a068288 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -241,6 +241,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): prompt_feat, prompt_feat_len, embedding, + streaming, finalize): assert token.shape[0] == 1 # xvec projection @@ -254,10 +255,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): # text encode if finalize is True: - h, h_lengths = self.encoder(token, token_len) + h, h_lengths = self.encoder(token, token_len, streaming=streaming) else: token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:] - h, h_lengths = self.encoder(token, token_len, context=context) + h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming) mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] h = self.encoder_proj(h) @@ -273,6 +274,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): spks=embedding, cond=conds, n_timesteps=10, + streaming=streaming ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 735889f..704ced3 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -69,7 +69,7 @@ class ConditionalCFM(BASECFM): t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache - def solve_euler(self, x, t_span, mu, mask, spks, cond): + def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False): """ Fixed euler solver for ODEs. Args: @@ -110,7 +110,8 @@ class ConditionalCFM(BASECFM): x_in, mask_in, mu_in, t_in, spks_in, - cond_in + cond_in, + streaming ) dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) @@ -122,9 +123,9 @@ class ConditionalCFM(BASECFM): return sol[-1].float() - def forward_estimator(self, x, mask, mu, t, spks, cond): + def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False): if isinstance(self.estimator, torch.nn.Module): - return self.estimator(x, mask, mu, t, spks, cond) + return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) else: estimator, trt_engine = self.estimator.acquire_estimator() estimator.set_input_shape('x', (2, 80, x.size(2))) @@ -196,7 +197,7 @@ class CausalConditionalCFM(ConditionalCFM): self.rand_noise = torch.randn([1, 80, 50 * 300]) @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False): """Forward diffusion Args: @@ -220,4 +221,4 @@ class CausalConditionalCFM(ConditionalCFM): t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None diff --git a/cosyvoice/transformer/upsample_encoder.py b/cosyvoice/transformer/upsample_encoder.py index e17b188..6ffda6a 100644 --- a/cosyvoice/transformer/upsample_encoder.py +++ b/cosyvoice/transformer/upsample_encoder.py @@ -272,9 +272,6 @@ class UpsampleConformerEncoder(torch.nn.Module): checkpointing API because `__call__` attaches all the hooks of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 """ - if hasattr(self, 'streaming'): - assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode' - streaming = self.streaming T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) if self.global_cmvn is not None: