From dd2d9261478b97b197090c9bf69f108b382da5ef Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Wed, 20 Aug 2025 16:55:03 +0800 Subject: [PATCH] update --- cosyvoice/flow/DiT/{dit_model.py => dit.py} | 76 +++++-------------- .../flow/DiT/{dit_modules.py => modules.py} | 1 + 2 files changed, 22 insertions(+), 55 deletions(-) rename cosyvoice/flow/DiT/{dit_model.py => dit.py} (68%) rename cosyvoice/flow/DiT/{dit_modules.py => modules.py} (99%) diff --git a/cosyvoice/flow/DiT/dit_model.py b/cosyvoice/flow/DiT/dit.py similarity index 68% rename from cosyvoice/flow/DiT/dit_model.py rename to cosyvoice/flow/DiT/dit.py index ada9392..73a5423 100644 --- a/cosyvoice/flow/DiT/dit_model.py +++ b/cosyvoice/flow/DiT/dit.py @@ -1,3 +1,4 @@ + """ ein notation: b - batch @@ -14,9 +15,8 @@ from torch import nn import torch.nn.functional as F from einops import repeat from x_transformers.x_transformers import RotaryEmbedding -from funasr.models.transformer.utils.mask import causal_block_mask - -from cosyvoice.flow.DiT.dit_modules import ( +from cosyvoice.utils.mask import add_optional_chunk_mask +from cosyvoice.flow.DiT.modules import ( TimestepEmbedding, ConvNeXtV2Block, CausalConvPositionEmbedding, @@ -115,7 +115,8 @@ class DiT(nn.Module): mu_dim=None, long_skip_connection=False, spk_dim=None, - **kwargs + static_chunk_size=50, + num_decoding_left_chunks=2 ): super().__init__() @@ -136,50 +137,20 @@ class DiT(nn.Module): self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) - self.causal_mask_type = kwargs.get("causal_mask_type", None) + self.static_chunk_size = static_chunk_size + self.num_decoding_left_chunks = num_decoding_left_chunks - def build_mix_causal_mask(self, attn_mask, rand=None, ratio=None): - b, _, _, t = attn_mask.shape - if rand is None: - rand = torch.rand((b, 1, 1, 1), device=attn_mask.device, dtype=torch.float32) - mixed_mask = attn_mask.clone() - for item in self.causal_mask_type: - prob_min, prob_max = item["prob_min"], item["prob_max"] - _ratio = 1 - if "ratio" in item: - _ratio = item["ratio"] - if ratio is not None: - _ratio = ratio - block_size = item["block_size"] * _ratio - if block_size <= 0: - causal_mask = attn_mask - else: - causal_mask = causal_block_mask( - t, block_size, attn_mask.device, torch.float32 - ).unsqueeze(0).unsqueeze(1) # 1,1,T,T - flag = (prob_min <= rand) & (rand < prob_max) - mixed_mask = mixed_mask * (~flag) + (causal_mask * attn_mask) * flag - - return mixed_mask - - def forward( - self, - x: float["b n d"], # nosied input audio - cond: float["b n d"], # masked cond audio - mu: int["b nt d"], # mu - spks: float["b 1 d"], # spk xvec - time: float["b"] | float[""], # time step - return_hidden: bool = False, - mask: bool["b 1 n"] | None = None, - mask_rand: float["b 1 1"] = None, # for mask flag type - **kwargs, - ): + def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): + x = x.transpose(1, 2) + mu = mu.transpose(1, 2) + cond = cond.transpose(1, 2) + spks = spks.unsqueeze(dim=1) batch, seq_len = x.shape[0], x.shape[1] - if time.ndim == 0: - time = time.repeat(batch) + if t.ndim == 0: + t = t.repeat(batch) # t: conditioning time, c: context (text + masked cond audio), x: noised input audio - t = self.time_embed(time) + t = self.time_embed(t) x = self.input_embed(x, cond, mu, spks.squeeze(1)) rope = self.rotary_embed.forward_from_seq_len(seq_len) @@ -187,22 +158,17 @@ class DiT(nn.Module): if self.long_skip_connection is not None: residual = x - mask = mask.unsqueeze(1) # B,1,1,T - if self.causal_mask_type is not None: - mask = self.build_mix_causal_mask(mask, rand=mask_rand.unsqueeze(-1)) + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1) + else: + attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1) for block in self.transformer_blocks: - # mask-out padded values for amp training - x = x * mask[:, 0, -1, :].unsqueeze(-1) - x = block(x, t, mask=mask.bool(), rope=rope) + x = block(x, t, mask=attn_mask.bool(), rope=rope) if self.long_skip_connection is not None: x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) x = self.norm_out(x, t) - output = self.proj_out(x) - - if return_hidden: - return output, None - + output = self.proj_out(x).transpose(1, 2) return output diff --git a/cosyvoice/flow/DiT/dit_modules.py b/cosyvoice/flow/DiT/modules.py similarity index 99% rename from cosyvoice/flow/DiT/dit_modules.py rename to cosyvoice/flow/DiT/modules.py index 1c1ee01..542ba36 100644 --- a/cosyvoice/flow/DiT/dit_modules.py +++ b/cosyvoice/flow/DiT/modules.py @@ -1,3 +1,4 @@ + """ ein notation: b - batch