mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
ein notation:
|
ein notation:
|
||||||
b - batch
|
b - batch
|
||||||
@@ -14,9 +15,8 @@ from torch import nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
from x_transformers.x_transformers import RotaryEmbedding
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
from funasr.models.transformer.utils.mask import causal_block_mask
|
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||||
|
from cosyvoice.flow.DiT.modules import (
|
||||||
from cosyvoice.flow.DiT.dit_modules import (
|
|
||||||
TimestepEmbedding,
|
TimestepEmbedding,
|
||||||
ConvNeXtV2Block,
|
ConvNeXtV2Block,
|
||||||
CausalConvPositionEmbedding,
|
CausalConvPositionEmbedding,
|
||||||
@@ -115,7 +115,8 @@ class DiT(nn.Module):
|
|||||||
mu_dim=None,
|
mu_dim=None,
|
||||||
long_skip_connection=False,
|
long_skip_connection=False,
|
||||||
spk_dim=None,
|
spk_dim=None,
|
||||||
**kwargs
|
static_chunk_size=50,
|
||||||
|
num_decoding_left_chunks=2
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -136,50 +137,20 @@ class DiT(nn.Module):
|
|||||||
|
|
||||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||||
self.proj_out = nn.Linear(dim, mel_dim)
|
self.proj_out = nn.Linear(dim, mel_dim)
|
||||||
self.causal_mask_type = kwargs.get("causal_mask_type", None)
|
self.static_chunk_size = static_chunk_size
|
||||||
|
self.num_decoding_left_chunks = num_decoding_left_chunks
|
||||||
|
|
||||||
def build_mix_causal_mask(self, attn_mask, rand=None, ratio=None):
|
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||||
b, _, _, t = attn_mask.shape
|
x = x.transpose(1, 2)
|
||||||
if rand is None:
|
mu = mu.transpose(1, 2)
|
||||||
rand = torch.rand((b, 1, 1, 1), device=attn_mask.device, dtype=torch.float32)
|
cond = cond.transpose(1, 2)
|
||||||
mixed_mask = attn_mask.clone()
|
spks = spks.unsqueeze(dim=1)
|
||||||
for item in self.causal_mask_type:
|
|
||||||
prob_min, prob_max = item["prob_min"], item["prob_max"]
|
|
||||||
_ratio = 1
|
|
||||||
if "ratio" in item:
|
|
||||||
_ratio = item["ratio"]
|
|
||||||
if ratio is not None:
|
|
||||||
_ratio = ratio
|
|
||||||
block_size = item["block_size"] * _ratio
|
|
||||||
if block_size <= 0:
|
|
||||||
causal_mask = attn_mask
|
|
||||||
else:
|
|
||||||
causal_mask = causal_block_mask(
|
|
||||||
t, block_size, attn_mask.device, torch.float32
|
|
||||||
).unsqueeze(0).unsqueeze(1) # 1,1,T,T
|
|
||||||
flag = (prob_min <= rand) & (rand < prob_max)
|
|
||||||
mixed_mask = mixed_mask * (~flag) + (causal_mask * attn_mask) * flag
|
|
||||||
|
|
||||||
return mixed_mask
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: float["b n d"], # nosied input audio
|
|
||||||
cond: float["b n d"], # masked cond audio
|
|
||||||
mu: int["b nt d"], # mu
|
|
||||||
spks: float["b 1 d"], # spk xvec
|
|
||||||
time: float["b"] | float[""], # time step
|
|
||||||
return_hidden: bool = False,
|
|
||||||
mask: bool["b 1 n"] | None = None,
|
|
||||||
mask_rand: float["b 1 1"] = None, # for mask flag type
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
batch, seq_len = x.shape[0], x.shape[1]
|
batch, seq_len = x.shape[0], x.shape[1]
|
||||||
if time.ndim == 0:
|
if t.ndim == 0:
|
||||||
time = time.repeat(batch)
|
t = t.repeat(batch)
|
||||||
|
|
||||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||||
t = self.time_embed(time)
|
t = self.time_embed(t)
|
||||||
x = self.input_embed(x, cond, mu, spks.squeeze(1))
|
x = self.input_embed(x, cond, mu, spks.squeeze(1))
|
||||||
|
|
||||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||||
@@ -187,22 +158,17 @@ class DiT(nn.Module):
|
|||||||
if self.long_skip_connection is not None:
|
if self.long_skip_connection is not None:
|
||||||
residual = x
|
residual = x
|
||||||
|
|
||||||
mask = mask.unsqueeze(1) # B,1,1,T
|
if streaming is True:
|
||||||
if self.causal_mask_type is not None:
|
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
|
||||||
mask = self.build_mix_causal_mask(mask, rand=mask_rand.unsqueeze(-1))
|
else:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)
|
||||||
|
|
||||||
for block in self.transformer_blocks:
|
for block in self.transformer_blocks:
|
||||||
# mask-out padded values for amp training
|
x = block(x, t, mask=attn_mask.bool(), rope=rope)
|
||||||
x = x * mask[:, 0, -1, :].unsqueeze(-1)
|
|
||||||
x = block(x, t, mask=mask.bool(), rope=rope)
|
|
||||||
|
|
||||||
if self.long_skip_connection is not None:
|
if self.long_skip_connection is not None:
|
||||||
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
||||||
|
|
||||||
x = self.norm_out(x, t)
|
x = self.norm_out(x, t)
|
||||||
output = self.proj_out(x)
|
output = self.proj_out(x).transpose(1, 2)
|
||||||
|
|
||||||
if return_hidden:
|
|
||||||
return output, None
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
ein notation:
|
ein notation:
|
||||||
b - batch
|
b - batch
|
||||||
Reference in New Issue
Block a user