From fd1a951a6c7cb4670e3c301eb7bf6cad9ec77ebc Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Sun, 26 Jan 2025 16:56:06 +0800 Subject: [PATCH] add flow unified training --- cosyvoice/flow/decoder.py | 28 +++++++++++++++-------- cosyvoice/flow/flow.py | 8 +++++-- cosyvoice/flow/flow_matching.py | 7 ++---- cosyvoice/transformer/upsample_encoder.py | 21 +++++++++-------- 4 files changed, 38 insertions(+), 26 deletions(-) diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 865dedc..565e0e2 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -210,6 +210,7 @@ class CausalAttention(Attention): upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, @@ -223,7 +224,7 @@ class CausalAttention(Attention): processor: Optional["AttnProcessor2_0"] = None, out_dim: int = None, ): - super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups, + super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups, spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, _from_deprecated_attn_block, processor, out_dim) processor = CausalAttnProcessor2_0() self.set_processor(processor) @@ -505,7 +506,7 @@ class ConditionalDecoder(nn.Module): if m.bias is not None: nn.init.constant_(m.bias, 0) - def forward(self, x, mask, mu, t, spks=None, cond=None): + def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): """Forward pass of the UNet1DConditional model. Args: @@ -540,7 +541,7 @@ class ConditionalDecoder(nn.Module): mask_down = masks[-1] x = resnet(x, mask_down, t) x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = (torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) == 1) + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( @@ -558,7 +559,7 @@ class ConditionalDecoder(nn.Module): for resnet, transformer_blocks in self.mid_blocks: x = resnet(x, mask_mid, t) x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = (torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) == 1) + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( @@ -574,7 +575,7 @@ class ConditionalDecoder(nn.Module): x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] x = resnet(x, mask_up, t) x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = (torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) == 1) + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( @@ -700,7 +701,7 @@ class CausalConditionalDecoder(ConditionalDecoder): self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) self.initialize_weights() - def forward(self, x, mask, mu, t, spks=None, cond=None): + def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): """Forward pass of the UNet1DConditional model. Args: @@ -735,7 +736,10 @@ class CausalConditionalDecoder(ConditionalDecoder): mask_down = masks[-1] x, _, _ = resnet(x, mask_down, t) x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + else: + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x, _ = transformer_block( @@ -753,7 +757,10 @@ class CausalConditionalDecoder(ConditionalDecoder): for resnet, transformer_blocks in self.mid_blocks: x, _, _ = resnet(x, mask_mid, t) x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + else: + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x, _ = transformer_block( @@ -769,7 +776,10 @@ class CausalConditionalDecoder(ConditionalDecoder): x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] x, _, _ = resnet(x, mask_up, t) x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + else: + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x, _ = transformer_block( diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 71f5ae7..516c4f7 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -202,6 +202,9 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): feat_len = batch['speech_feat_len'].to(device) embedding = batch['embedding'].to(device) + # NOTE unified training, static_chunk_size > 0 or = 0 + streaming = True if random.random() < 0.5 else False + # xvec projection embedding = F.normalize(embedding, dim=1) embedding = self.spk_embed_affine_layer(embedding) @@ -211,7 +214,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): token = self.input_embedding(torch.clamp(token, min=0)) * mask # text encode - h, h_lengths = self.encoder(token, token_len) + h, h_lengths = self.encoder(token, token_len, streaming=streaming) h = self.encoder_proj(h) # get conditions @@ -230,7 +233,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): mask.unsqueeze(1), h.transpose(1, 2).contiguous(), embedding, - cond=conds + cond=conds, + streaming=streaming, ) return {'loss': loss} diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 3a7de2e..ca2f583 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -142,7 +142,7 @@ class ConditionalCFM(BASECFM): x.data_ptr()]) return x - def compute_loss(self, x1, mask, mu, spks=None, cond=None): + def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): """Computes diffusion loss Args: @@ -179,11 +179,8 @@ class ConditionalCFM(BASECFM): spks = spks * cfg_mask.view(-1, 1) cond = cond * cfg_mask.view(-1, 1, 1) - pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) + pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming) loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) - if loss.isnan(): - print(123) - pred_new = self.estimator(y, mask, mu, t.squeeze(), spks, cond) return loss, y diff --git a/cosyvoice/transformer/upsample_encoder.py b/cosyvoice/transformer/upsample_encoder.py index 92267a8..2276577 100644 --- a/cosyvoice/transformer/upsample_encoder.py +++ b/cosyvoice/transformer/upsample_encoder.py @@ -255,6 +255,7 @@ class UpsampleConformerEncoder(torch.nn.Module): xs_lens: torch.Tensor, decoding_chunk_size: int = 0, num_decoding_left_chunks: int = -1, + streaming: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Embed positions in tensor. @@ -286,11 +287,11 @@ class UpsampleConformerEncoder(torch.nn.Module): xs, pos_emb, masks = self.embed(xs, masks) mask_pad = masks # (B, 1, T/subsample_rate) chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size, - num_decoding_left_chunks) + self.use_dynamic_chunk if streaming is True else False, + self.use_dynamic_left_chunk if streaming is True else False, + decoding_chunk_size if streaming is True else 0, + self.static_chunk_size if streaming is True else 0, + num_decoding_left_chunks if streaming is True else -1) # lookahead + conformer encoder xs, _ = self.pre_lookahead_layer(xs) xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) @@ -304,11 +305,11 @@ class UpsampleConformerEncoder(torch.nn.Module): xs, pos_emb, masks = self.up_embed(xs, masks) mask_pad = masks # (B, 1, T/subsample_rate) chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size * self.up_layer.stride, - num_decoding_left_chunks) + self.use_dynamic_chunk if streaming is True else False, + self.use_dynamic_left_chunk if streaming is True else False, + decoding_chunk_size if streaming is True else 0, + self.static_chunk_size * self.up_layer.stride if streaming is True else 0, + num_decoding_left_chunks if streaming is True else -1) xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad) if self.normalize_before: