add cosyvoice2

This commit is contained in:
lyuxiang.lx
2024-12-11 16:14:19 +08:00
parent cde3cec6fa
commit 3e381002d7
9 changed files with 484 additions and 32 deletions

View File

@@ -13,16 +13,84 @@
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, repeat
from cosyvoice.utils.common import mask_to_bias
from cosyvoice.utils.mask import add_optional_chunk_mask
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
from matcha.models.components.transformer import BasicTransformerBlock
class Transpose(torch.nn.Module):
def __init__(self, dim0: int, dim1: int):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor):
x = torch.transpose(x, self.dim0, self.dim1)
return x
class CausalBlock1D(Block1D):
def __init__(self, dim: int, dim_out: int):
super(CausalBlock1D, self).__init__(dim, dim_out)
self.block = torch.nn.Sequential(
CausalConv1d(dim, dim_out, 3),
Transpose(1, 2),
nn.LayerNorm(dim_out),
Transpose(1, 2),
nn.Mish(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor):
output = self.block(x * mask)
return output * mask
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels,
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype
)
assert stride == 1
self.causal_padding = (kernel_size - 1, 0)
def forward(self, x: torch.Tensor):
x = F.pad(x, self.causal_padding)
x = super(CausalConv1d, self).forward(x)
return x
class ConditionalDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
causal=False,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
@@ -39,7 +107,7 @@ class ConditionalDecoder(nn.Module):
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.causal = causal
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
@@ -56,7 +124,7 @@ class ConditionalDecoder(nn.Module):
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
@@ -70,14 +138,14 @@ class ConditionalDecoder(nn.Module):
]
)
downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
@@ -99,7 +167,11 @@ class ConditionalDecoder(nn.Module):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = ResnetBlock1D(
resnet = CausalResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
) if self.causal else ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
@@ -119,10 +191,10 @@ class ConditionalDecoder(nn.Module):
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = Block1D(channels[-1], channels[-1])
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
@@ -175,7 +247,9 @@ class ConditionalDecoder(nn.Module):
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
@@ -192,7 +266,9 @@ class ConditionalDecoder(nn.Module):
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
@@ -207,7 +283,9 @@ class ConditionalDecoder(nn.Module):
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
@@ -218,4 +296,4 @@ class ConditionalDecoder(nn.Module):
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
return output * mask

View File

@@ -146,3 +146,83 @@ class MaskedDiffWithXvec(torch.nn.Module):
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat, flow_cache
class CausalMaskedDiffWithXvec(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
encoder: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.only_mask_loss = only_mask_loss
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
finalize):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
if finalize is False:
h = h[:, :-self.encoder.pre_lookahead_layer.pre_lookahead_len * self.encoder.up_layer.stride]
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat, None

View File

@@ -89,17 +89,25 @@ class ConditionalCFM(BASECFM):
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = self.forward_estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,
torch.zeros_like(cond)
)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
self.inference_cfg_rate * cfg_dphi_dt)
x_in = torch.concat([x, x], dim=0)
mask_in = torch.concat([mask, mask], dim=0)
mu_in = torch.concat([mu, torch.zeros_like(mu).to(x.device)], dim=0)
t_in = torch.concat([t, t], dim=0)
spks_in = torch.concat([spks, torch.zeros_like(spks).to(x.device)], dim=0) if spks is not None else None
cond_in = torch.concat([cond, torch.zeros_like(cond).to(x.device)], dim=0) if cond is not None else None
else:
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
dphi_dt = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in
)
if self.inference_cfg_rate > 0:
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
@@ -163,3 +171,37 @@ class ConditionalCFM(BASECFM):
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
return loss, y
class CausalConditionalCFM(ConditionalCFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
z[:] = 0
# fix prompt and overlap part mu and z
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None