mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add flow unified training
This commit is contained in:
@@ -210,6 +210,7 @@ class CausalAttention(Attention):
|
|||||||
upcast_softmax: bool = False,
|
upcast_softmax: bool = False,
|
||||||
cross_attention_norm: Optional[str] = None,
|
cross_attention_norm: Optional[str] = None,
|
||||||
cross_attention_norm_num_groups: int = 32,
|
cross_attention_norm_num_groups: int = 32,
|
||||||
|
qk_norm: Optional[str] = None,
|
||||||
added_kv_proj_dim: Optional[int] = None,
|
added_kv_proj_dim: Optional[int] = None,
|
||||||
norm_num_groups: Optional[int] = None,
|
norm_num_groups: Optional[int] = None,
|
||||||
spatial_norm_dim: Optional[int] = None,
|
spatial_norm_dim: Optional[int] = None,
|
||||||
@@ -223,7 +224,7 @@ class CausalAttention(Attention):
|
|||||||
processor: Optional["AttnProcessor2_0"] = None,
|
processor: Optional["AttnProcessor2_0"] = None,
|
||||||
out_dim: int = None,
|
out_dim: int = None,
|
||||||
):
|
):
|
||||||
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups,
|
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups, qk_norm,
|
||||||
added_kv_proj_dim, norm_num_groups, spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, _from_deprecated_attn_block, processor, out_dim)
|
added_kv_proj_dim, norm_num_groups, spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, _from_deprecated_attn_block, processor, out_dim)
|
||||||
processor = CausalAttnProcessor2_0()
|
processor = CausalAttnProcessor2_0()
|
||||||
self.set_processor(processor)
|
self.set_processor(processor)
|
||||||
@@ -505,7 +506,7 @@ class ConditionalDecoder(nn.Module):
|
|||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||||
"""Forward pass of the UNet1DConditional model.
|
"""Forward pass of the UNet1DConditional model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -540,7 +541,7 @@ class ConditionalDecoder(nn.Module):
|
|||||||
mask_down = masks[-1]
|
mask_down = masks[-1]
|
||||||
x = resnet(x, mask_down, t)
|
x = resnet(x, mask_down, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = (torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) == 1)
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x = transformer_block(
|
x = transformer_block(
|
||||||
@@ -558,7 +559,7 @@ class ConditionalDecoder(nn.Module):
|
|||||||
for resnet, transformer_blocks in self.mid_blocks:
|
for resnet, transformer_blocks in self.mid_blocks:
|
||||||
x = resnet(x, mask_mid, t)
|
x = resnet(x, mask_mid, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = (torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) == 1)
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x = transformer_block(
|
x = transformer_block(
|
||||||
@@ -574,7 +575,7 @@ class ConditionalDecoder(nn.Module):
|
|||||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||||
x = resnet(x, mask_up, t)
|
x = resnet(x, mask_up, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = (torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) == 1)
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x = transformer_block(
|
x = transformer_block(
|
||||||
@@ -700,7 +701,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|||||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||||
self.initialize_weights()
|
self.initialize_weights()
|
||||||
|
|
||||||
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||||
"""Forward pass of the UNet1DConditional model.
|
"""Forward pass of the UNet1DConditional model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -735,7 +736,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|||||||
mask_down = masks[-1]
|
mask_down = masks[-1]
|
||||||
x, _, _ = resnet(x, mask_down, t)
|
x, _, _ = resnet(x, mask_down, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
if streaming is True:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
||||||
|
else:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x, _ = transformer_block(
|
x, _ = transformer_block(
|
||||||
@@ -753,7 +757,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|||||||
for resnet, transformer_blocks in self.mid_blocks:
|
for resnet, transformer_blocks in self.mid_blocks:
|
||||||
x, _, _ = resnet(x, mask_mid, t)
|
x, _, _ = resnet(x, mask_mid, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
if streaming is True:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
||||||
|
else:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x, _ = transformer_block(
|
x, _ = transformer_block(
|
||||||
@@ -769,7 +776,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|||||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||||
x, _, _ = resnet(x, mask_up, t)
|
x, _, _ = resnet(x, mask_up, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
if streaming is True:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
||||||
|
else:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x, _ = transformer_block(
|
x, _ = transformer_block(
|
||||||
|
|||||||
@@ -202,6 +202,9 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
feat_len = batch['speech_feat_len'].to(device)
|
feat_len = batch['speech_feat_len'].to(device)
|
||||||
embedding = batch['embedding'].to(device)
|
embedding = batch['embedding'].to(device)
|
||||||
|
|
||||||
|
# NOTE unified training, static_chunk_size > 0 or = 0
|
||||||
|
streaming = True if random.random() < 0.5 else False
|
||||||
|
|
||||||
# xvec projection
|
# xvec projection
|
||||||
embedding = F.normalize(embedding, dim=1)
|
embedding = F.normalize(embedding, dim=1)
|
||||||
embedding = self.spk_embed_affine_layer(embedding)
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
@@ -211,7 +214,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||||
|
|
||||||
# text encode
|
# text encode
|
||||||
h, h_lengths = self.encoder(token, token_len)
|
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
||||||
h = self.encoder_proj(h)
|
h = self.encoder_proj(h)
|
||||||
|
|
||||||
# get conditions
|
# get conditions
|
||||||
@@ -230,7 +233,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
mask.unsqueeze(1),
|
mask.unsqueeze(1),
|
||||||
h.transpose(1, 2).contiguous(),
|
h.transpose(1, 2).contiguous(),
|
||||||
embedding,
|
embedding,
|
||||||
cond=conds
|
cond=conds,
|
||||||
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class ConditionalCFM(BASECFM):
|
|||||||
x.data_ptr()])
|
x.data_ptr()])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
||||||
"""Computes diffusion loss
|
"""Computes diffusion loss
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -179,11 +179,8 @@ class ConditionalCFM(BASECFM):
|
|||||||
spks = spks * cfg_mask.view(-1, 1)
|
spks = spks * cfg_mask.view(-1, 1)
|
||||||
cond = cond * cfg_mask.view(-1, 1, 1)
|
cond = cond * cfg_mask.view(-1, 1, 1)
|
||||||
|
|
||||||
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
||||||
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
||||||
if loss.isnan():
|
|
||||||
print(123)
|
|
||||||
pred_new = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
|
||||||
return loss, y
|
return loss, y
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -255,6 +255,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|||||||
xs_lens: torch.Tensor,
|
xs_lens: torch.Tensor,
|
||||||
decoding_chunk_size: int = 0,
|
decoding_chunk_size: int = 0,
|
||||||
num_decoding_left_chunks: int = -1,
|
num_decoding_left_chunks: int = -1,
|
||||||
|
streaming: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Embed positions in tensor.
|
"""Embed positions in tensor.
|
||||||
|
|
||||||
@@ -286,11 +287,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|||||||
xs, pos_emb, masks = self.embed(xs, masks)
|
xs, pos_emb, masks = self.embed(xs, masks)
|
||||||
mask_pad = masks # (B, 1, T/subsample_rate)
|
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||||
chunk_masks = add_optional_chunk_mask(xs, masks,
|
chunk_masks = add_optional_chunk_mask(xs, masks,
|
||||||
self.use_dynamic_chunk,
|
self.use_dynamic_chunk if streaming is True else False,
|
||||||
self.use_dynamic_left_chunk,
|
self.use_dynamic_left_chunk if streaming is True else False,
|
||||||
decoding_chunk_size,
|
decoding_chunk_size if streaming is True else 0,
|
||||||
self.static_chunk_size,
|
self.static_chunk_size if streaming is True else 0,
|
||||||
num_decoding_left_chunks)
|
num_decoding_left_chunks if streaming is True else -1)
|
||||||
# lookahead + conformer encoder
|
# lookahead + conformer encoder
|
||||||
xs, _ = self.pre_lookahead_layer(xs)
|
xs, _ = self.pre_lookahead_layer(xs)
|
||||||
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||||
@@ -304,11 +305,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|||||||
xs, pos_emb, masks = self.up_embed(xs, masks)
|
xs, pos_emb, masks = self.up_embed(xs, masks)
|
||||||
mask_pad = masks # (B, 1, T/subsample_rate)
|
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||||
chunk_masks = add_optional_chunk_mask(xs, masks,
|
chunk_masks = add_optional_chunk_mask(xs, masks,
|
||||||
self.use_dynamic_chunk,
|
self.use_dynamic_chunk if streaming is True else False,
|
||||||
self.use_dynamic_left_chunk,
|
self.use_dynamic_left_chunk if streaming is True else False,
|
||||||
decoding_chunk_size,
|
decoding_chunk_size if streaming is True else 0,
|
||||||
self.static_chunk_size * self.up_layer.stride,
|
self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
|
||||||
num_decoding_left_chunks)
|
num_decoding_left_chunks if streaming is True else -1)
|
||||||
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||||
|
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
|
|||||||
Reference in New Issue
Block a user