diff --git a/cosyvoice/flow/DiT/dit_model.py b/cosyvoice/flow/DiT/dit_model.py new file mode 100644 index 0000000..ada9392 --- /dev/null +++ b/cosyvoice/flow/DiT/dit_model.py @@ -0,0 +1,208 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat +from x_transformers.x_transformers import RotaryEmbedding +from funasr.models.transformer.utils.mask import causal_block_mask + +from cosyvoice.flow.DiT.dit_modules import ( + TimestepEmbedding, + ConvNeXtV2Block, + CausalConvPositionEmbedding, + DiTBlock, + AdaLayerNormZero_Final, + precompute_freqs_cis, + get_pos_embed_indices, +) + + +# Text embedding + + +class TextEmbedding(nn.Module): + def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): + super().__init__() + self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + + if conv_layers > 0: + self.extra_modeling = True + self.precompute_max_pos = 4096 # ~44s of 24khz audio + self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) + self.text_blocks = nn.Sequential( + *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + ) + else: + self.extra_modeling = False + + def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 + batch, text_len = text.shape[0], text.shape[1] + text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + text = F.pad(text, (0, seq_len - text_len), value=0) + + if drop_text: # cfg for text + text = torch.zeros_like(text) + + text = self.text_embed(text) # b n -> b n d + + # possible extra modeling + if self.extra_modeling: + # sinus pos emb + batch_start = torch.zeros((batch,), dtype=torch.long) + pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) + text_pos_embed = self.freqs_cis[pos_idx] + text = text + text_pos_embed + + # convnextv2 blocks + text = self.text_blocks(text) + + return text + + +# noised input audio and context mixing embedding + + +class InputEmbedding(nn.Module): + def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None): + super().__init__() + spk_dim = 0 if spk_dim is None else spk_dim + self.spk_dim = spk_dim + self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim) + self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim) + + def forward( + self, + x: float["b n d"], + cond: float["b n d"], + text_embed: float["b n d"], + spks: float["b d"], + ): + to_cat = [x, cond, text_embed] + if self.spk_dim > 0: + spks = repeat(spks, "b c -> b t c", t=x.shape[1]) + to_cat.append(spks) + + x = self.proj(torch.cat(to_cat, dim=-1)) + x = self.conv_pos_embed(x) + x + return x + + +# Transformer backbone using DiT blocks + + +class DiT(nn.Module): + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=80, + mu_dim=None, + long_skip_connection=False, + spk_dim=None, + **kwargs + ): + super().__init__() + + self.time_embed = TimestepEmbedding(dim) + if mu_dim is None: + mu_dim = mel_dim + self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim) + + self.rotary_embed = RotaryEmbedding(dim_head) + + self.dim = dim + self.depth = depth + + self.transformer_blocks = nn.ModuleList( + [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)] + ) + self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + + self.norm_out = AdaLayerNormZero_Final(dim) # final modulation + self.proj_out = nn.Linear(dim, mel_dim) + self.causal_mask_type = kwargs.get("causal_mask_type", None) + + def build_mix_causal_mask(self, attn_mask, rand=None, ratio=None): + b, _, _, t = attn_mask.shape + if rand is None: + rand = torch.rand((b, 1, 1, 1), device=attn_mask.device, dtype=torch.float32) + mixed_mask = attn_mask.clone() + for item in self.causal_mask_type: + prob_min, prob_max = item["prob_min"], item["prob_max"] + _ratio = 1 + if "ratio" in item: + _ratio = item["ratio"] + if ratio is not None: + _ratio = ratio + block_size = item["block_size"] * _ratio + if block_size <= 0: + causal_mask = attn_mask + else: + causal_mask = causal_block_mask( + t, block_size, attn_mask.device, torch.float32 + ).unsqueeze(0).unsqueeze(1) # 1,1,T,T + flag = (prob_min <= rand) & (rand < prob_max) + mixed_mask = mixed_mask * (~flag) + (causal_mask * attn_mask) * flag + + return mixed_mask + + def forward( + self, + x: float["b n d"], # nosied input audio + cond: float["b n d"], # masked cond audio + mu: int["b nt d"], # mu + spks: float["b 1 d"], # spk xvec + time: float["b"] | float[""], # time step + return_hidden: bool = False, + mask: bool["b 1 n"] | None = None, + mask_rand: float["b 1 1"] = None, # for mask flag type + **kwargs, + ): + batch, seq_len = x.shape[0], x.shape[1] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(time) + x = self.input_embed(x, cond, mu, spks.squeeze(1)) + + rope = self.rotary_embed.forward_from_seq_len(seq_len) + + if self.long_skip_connection is not None: + residual = x + + mask = mask.unsqueeze(1) # B,1,1,T + if self.causal_mask_type is not None: + mask = self.build_mix_causal_mask(mask, rand=mask_rand.unsqueeze(-1)) + + for block in self.transformer_blocks: + # mask-out padded values for amp training + x = x * mask[:, 0, -1, :].unsqueeze(-1) + x = block(x, t, mask=mask.bool(), rope=rope) + + if self.long_skip_connection is not None: + x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) + + x = self.norm_out(x, t) + output = self.proj_out(x) + + if return_hidden: + return output, None + + return output diff --git a/cosyvoice/flow/DiT/dit_modules.py b/cosyvoice/flow/DiT/dit_modules.py new file mode 100644 index 0000000..1c1ee01 --- /dev/null +++ b/cosyvoice/flow/DiT/dit_modules.py @@ -0,0 +1,615 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations +from typing import Optional +import math + +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio + +from x_transformers.x_transformers import apply_rotary_pos_emb + + +# raw wav to mel spec +class MelSpec(nn.Module): + def __init__( + self, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24_000, + normalize=False, + power=1, + norm=None, + center=True, + ): + super().__init__() + self.n_mel_channels = n_mel_channels + + self.mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=filter_length, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=power, + center=center, + normalized=normalize, + norm=norm, + ) + + self.register_buffer("dummy", torch.tensor(0), persistent=False) + + def forward(self, inp): + if len(inp.shape) == 3: + inp = inp.squeeze(1) # 'b 1 nw -> b nw' + + assert len(inp.shape) == 2 + + if self.dummy.device != inp.device: + self.to(inp.device) + + mel = self.mel_stft(inp) + mel = mel.clamp(min=1e-5).log() + return mel + + +# sinusoidal position embedding + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +# convolutional position embedding + + +class ConvPositionEmbedding(nn.Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + ) + + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(~mask, 0.0) + + x = x.permute(0, 2, 1) + x = self.conv1d(x) + out = x.permute(0, 2, 1) + + if mask is not None: + out = out.masked_fill(~mask, 0.0) + + return out + + +class CausalConvPositionEmbedding(nn.Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.kernel_size = kernel_size + self.conv1 = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0), + nn.Mish(), + ) + self.conv2 = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0), + nn.Mish(), + ) + + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(~mask, 0.0) + + x = x.permute(0, 2, 1) + x = F.pad(x, (self.kernel_size - 1, 0, 0, 0)) + x = self.conv1(x) + x = F.pad(x, (self.kernel_size - 1, 0, 0, 0)) + x = self.conv2(x) + out = x.permute(0, 2, 1) + + if mask is not None: + out = out.masked_fill(~mask, 0.0) + + return out + + +# rotary positional embedding related + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return torch.cat([freqs_cos, freqs_sin], dim=-1) + + +def get_pos_embed_indices(start, length, max_pos, scale=1.0): + # length = length if isinstance(length, int) else length.max() + scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar + pos = ( + start.unsqueeze(1) + + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long() + ) + # avoid extra long error. + pos = torch.where(pos < max_pos, pos, max_pos - 1) + return pos + + +# Global Response Normalization layer (Instance Normalization ?) + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=1, keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py +# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108 + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = x.transpose(1, 2) # b n d -> b d n + x = self.dwconv(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation + + +class AdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation + + +class AdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +# FeedForward + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + activation = nn.GELU(approximate=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.ff(x) + + +# Attention with possible joint part +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class Attention(nn.Module): + def __init__( + self, + processor: JointAttnProcessor | AttnProcessor, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only=None, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.processor = processor + + self.dim = dim + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + + self.context_dim = context_dim + self.context_pre_only = context_pre_only + + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + + if self.context_dim is not None: + self.to_k_c = nn.Linear(context_dim, self.inner_dim) + self.to_v_c = nn.Linear(context_dim, self.inner_dim) + if self.context_pre_only is not None: + self.to_q_c = nn.Linear(context_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, dim)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_out_c = nn.Linear(self.inner_dim, dim) + + def forward( + self, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b n d"] = None, # context c # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.Tensor: + if c is not None: + return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) + else: + return self.processor(self, x, mask=mask, rope=rope) + + +# Attention processor + + +class AttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = mask + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + + x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + if mask.dim() == 2: + mask = mask.unsqueeze(-1) + else: + mask = mask[:, 0, -1].unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +# Joint Attention processor for MM-DiT +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class JointAttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b nt d"] = None, # context c, here text # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.FloatTensor: + residual = x + + batch_size = c.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # `context` projections. + c_query = attn.to_q_c(c) + c_key = attn.to_k_c(c) + c_value = attn.to_v_c(c) + + # apply rope for context and noised input independently + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + if c_rope is not None: + freqs, xpos_scale = c_rope + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) + c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) + + # attention + query = torch.cat([query, c_query], dim=1) + key = torch.cat([key, c_key], dim=1) + value = torch.cat([value, c_value], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text) + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + + x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # Split the attention outputs. + x, c = ( + x[:, : residual.shape[1]], + x[:, residual.shape[1] :], + ) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + if not attn.context_pre_only: + c = attn.to_out_c(c) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + # c = c.masked_fill(~mask, 0.) # no mask for c (text) + + return x, c + + +# DiT Block + + +class DiTBlock(nn.Module): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): + super().__init__() + + self.attn_norm = AdaLayerNormZero(dim) + self.attn = Attention( + processor=AttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) + + self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + + # attention + attn_output = self.attn(x=norm, mask=mask, rope=rope) + + # process attention output for input x + x = x + gate_msa.unsqueeze(1) * attn_output + + ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(ff_norm) + x = x + gate_mlp.unsqueeze(1) * ff_output + + return x + + +# MMDiT Block https://arxiv.org/abs/2403.03206 + + +class MMDiTBlock(nn.Module): + r""" + modified from diffusers/src/diffusers/models/attention.py + + notes. + _c: context related. text, cond, etc. (left part in sd3 fig2.b) + _x: noised input related. (right part) + context_pre_only: last layer only do prenorm + modulation cuz no more ffn + """ + + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False): + super().__init__() + + self.context_pre_only = context_pre_only + + self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) + self.attn_norm_x = AdaLayerNormZero(dim) + self.attn = Attention( + processor=JointAttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + context_dim=dim, + context_pre_only=context_pre_only, + ) + + if not context_pre_only: + self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + else: + self.ff_norm_c = None + self.ff_c = None + self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding + # pre-norm & modulation for attention input + if self.context_pre_only: + norm_c = self.attn_norm_c(c, t) + else: + norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t) + norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t) + + # attention + x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope) + + # process attention output for context c + if self.context_pre_only: + c = None + else: # if not last layer + c = c + c_gate_msa.unsqueeze(1) * c_attn_output + + norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + c_ff_output = self.ff_c(norm_c) + c = c + c_gate_mlp.unsqueeze(1) * c_ff_output + + # process attention output for input x + x = x + x_gate_msa.unsqueeze(1) * x_attn_output + + norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None] + x_ff_output = self.ff_x(norm_x) + x = x + x_gate_mlp.unsqueeze(1) * x_ff_output + + return c, x + + +# time step conditioning embedding + + +class TimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + + def forward(self, timestep: float["b"]): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) # b d + return time diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index a068288..415d44e 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -37,14 +37,11 @@ class MaskedDiffWithXvec(torch.nn.Module): 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, - 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, - mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, - 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): + 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}): super().__init__() self.input_size = input_size self.output_size = output_size self.decoder_conf = decoder_conf - self.mel_feat_conf = mel_feat_conf self.vocab_size = vocab_size self.output_type = output_type self.input_frame_rate = input_frame_rate @@ -165,14 +162,11 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, - 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, - mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, - 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): + 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}): super().__init__() self.input_size = input_size self.output_size = output_size self.decoder_conf = decoder_conf - self.mel_feat_conf = mel_feat_conf self.vocab_size = vocab_size self.output_type = output_type self.input_frame_rate = input_frame_rate @@ -279,3 +273,158 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 return feat.float(), None + + +class CausalMaskedDiffWithDiT(torch.nn.Module): + def __init__(self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 4096, + input_frame_rate: int = 50, + only_mask_loss: bool = True, + token_mel_ratio: int = 2, + pre_lookahead_len: int = 3, + pre_lookahead_layer: torch.nn.Module = None, + decoder: torch.nn.Module = None, + decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, + 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', + 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), + 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, + 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.decoder_conf = decoder_conf + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + logging.info(f"input frame rate={self.input_frame_rate}") + self.input_embedding = nn.Embedding(vocab_size, input_size) + self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) + self.pre_lookahead_len = pre_lookahead_len + self.pre_lookahead_layer = pre_lookahead_layer + self.decoder = decoder + self.only_mask_loss = only_mask_loss + self.token_mel_ratio = token_mel_ratio + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + token = batch['speech_token'].to(device) + token_len = batch['speech_token_len'].to(device) + feat = batch['speech_feat'].to(device) + feat_len = batch['speech_feat_len'].to(device) + embedding = batch['embedding'].to(device) + + # NOTE unified training, static_chunk_size > 0 or = 0 + streaming = True if random.random() < 0.5 else False + + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + h, h_lengths = self.encoder(token, token_len, streaming=streaming) + h = self.encoder_proj(h) + + # get conditions + conds = torch.zeros(feat.shape, device=token.device) + for i, j in enumerate(feat_len): + if random.random() < 0.5: + continue + index = random.randint(0, int(0.3 * j)) + conds[i, :index] = feat[i, :index] + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h) + loss, _ = self.decoder.compute_loss( + feat.transpose(1, 2).contiguous(), + mask.unsqueeze(1), + h.transpose(1, 2).contiguous(), + embedding, + cond=conds, + streaming=streaming, + ) + return {'loss': loss} + + @torch.inference_mode() + def inference(self, + token, + token_len, + prompt_token, + prompt_token_len, + prompt_feat, + prompt_feat_len, + embedding, + streaming, + finalize): + assert token.shape[0] == 1 + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len + mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + if finalize is True: + h = self.pre_lookahead_layer(token) + else: + h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:]) + h = h.repeat_interleave(self.token_mel_ratio, dim=1) + mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] + + # get conditions + conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype) + conds[:, :mel_len1] = prompt_feat + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) + feat, _ = self.decoder( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + spks=embedding, + cond=conds, + n_timesteps=10, + streaming=streaming + ) + feat = feat[:, :, mel_len1:] + assert feat.shape[2] == mel_len2 + return feat.float(), None + +if __name__ == '__main__': + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + from hyperpyyaml import load_hyperpyyaml + with open('./pretrained_models/CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f: + configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None}) + model = configs['flow'] + device = 'cuda' if torch.cuda.is_available() else 'cpu' + model.to(device) + model.eval() + max_len = 10 * model.decoder.estimator.static_chunk_size + chunk_size = model.decoder.estimator.static_chunk_size + context_size = model.pre_lookahead_layer.pre_lookahead_len + token = torch.randint(0, 6561, size=(1, max_len)).to(device) + token_len = torch.tensor([max_len]).to(device) + prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device) + prompt_token_len = torch.tensor([chunk_size]).to(device) + prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device) + prompt_feat_len = torch.tensor([chunk_size * 2]).to(device) + prompt_embedding = torch.rand(1, 192).to(device) + pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True) + for i in range(0, max_len, chunk_size): + finalize = True if i + chunk_size + context_size >= max_len else False + pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device), prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize) + pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:] + print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item()) \ No newline at end of file diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py index c3f8b1f..9647ef3 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -736,7 +736,7 @@ if __name__ == '__main__': model.to(device) model.eval() max_len, chunk_size, context_size = 300, 30, 8 - mel = torch.rand(1, 80, max_len) + mel = torch.rand(1, 80, max_len).to(device) pred_gt, _ = model.inference(mel) for i in range(0, max_len, chunk_size): finalize = True if i + chunk_size + context_size >= max_len else False diff --git a/cosyvoice/transformer/upsample_encoder.py b/cosyvoice/transformer/upsample_encoder.py index 6ffda6a..baf7481 100644 --- a/cosyvoice/transformer/upsample_encoder.py +++ b/cosyvoice/transformer/upsample_encoder.py @@ -64,17 +64,18 @@ class Upsample1D(nn.Module): class PreLookaheadLayer(nn.Module): - def __init__(self, channels: int, pre_lookahead_len: int = 1): + def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1): super().__init__() + self.in_channels = in_channels self.channels = channels self.pre_lookahead_len = pre_lookahead_len self.conv1 = nn.Conv1d( - channels, channels, + in_channels, channels, kernel_size=pre_lookahead_len + 1, stride=1, padding=0, ) self.conv2 = nn.Conv1d( - channels, channels, + channels, in_channels, kernel_size=3, stride=1, padding=0, ) @@ -199,7 +200,7 @@ class UpsampleConformerEncoder(torch.nn.Module): # convolution module definition convolution_layer_args = (output_size, cnn_module_kernel, activation, cnn_module_norm, causal) - self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3) + self.pre_lookahead_layer = PreLookaheadLayer(in_channels=512, channels=512, pre_lookahead_len=3) self.encoders = torch.nn.ModuleList([ ConformerEncoderLayer( output_size,