add flow unified training

This commit is contained in:
lyuxiang.lx
2025-01-26 16:56:06 +08:00
parent aea75207dd
commit fd1a951a6c
4 changed files with 38 additions and 26 deletions

View File

@@ -210,6 +210,7 @@ class CausalAttention(Attention):
upcast_softmax: bool = False, upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None, cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32, cross_attention_norm_num_groups: int = 32,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None, norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None, spatial_norm_dim: Optional[int] = None,
@@ -223,7 +224,7 @@ class CausalAttention(Attention):
processor: Optional["AttnProcessor2_0"] = None, processor: Optional["AttnProcessor2_0"] = None,
out_dim: int = None, out_dim: int = None,
): ):
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups, super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups, qk_norm,
added_kv_proj_dim, norm_num_groups, spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, _from_deprecated_attn_block, processor, out_dim) added_kv_proj_dim, norm_num_groups, spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, _from_deprecated_attn_block, processor, out_dim)
processor = CausalAttnProcessor2_0() processor = CausalAttnProcessor2_0()
self.set_processor(processor) self.set_processor(processor)
@@ -505,7 +506,7 @@ class ConditionalDecoder(nn.Module):
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None): def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
"""Forward pass of the UNet1DConditional model. """Forward pass of the UNet1DConditional model.
Args: Args:
@@ -540,7 +541,7 @@ class ConditionalDecoder(nn.Module):
mask_down = masks[-1] mask_down = masks[-1]
x = resnet(x, mask_down, t) x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = (torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) == 1) attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype) attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x = transformer_block( x = transformer_block(
@@ -558,7 +559,7 @@ class ConditionalDecoder(nn.Module):
for resnet, transformer_blocks in self.mid_blocks: for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t) x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = (torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) == 1) attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype) attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x = transformer_block( x = transformer_block(
@@ -574,7 +575,7 @@ class ConditionalDecoder(nn.Module):
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t) x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = (torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) == 1) attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype) attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x = transformer_block( x = transformer_block(
@@ -700,7 +701,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights() self.initialize_weights()
def forward(self, x, mask, mu, t, spks=None, cond=None): def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
"""Forward pass of the UNet1DConditional model. """Forward pass of the UNet1DConditional model.
Args: Args:
@@ -735,7 +736,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
mask_down = masks[-1] mask_down = masks[-1]
x, _, _ = resnet(x, mask_down, t) x, _, _ = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
else:
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype) attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x, _ = transformer_block( x, _ = transformer_block(
@@ -753,7 +757,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
for resnet, transformer_blocks in self.mid_blocks: for resnet, transformer_blocks in self.mid_blocks:
x, _, _ = resnet(x, mask_mid, t) x, _, _ = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
else:
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype) attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x, _ = transformer_block( x, _ = transformer_block(
@@ -769,7 +776,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x, _, _ = resnet(x, mask_up, t) x, _, _ = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
else:
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype) attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x, _ = transformer_block( x, _ = transformer_block(

View File

@@ -202,6 +202,9 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
feat_len = batch['speech_feat_len'].to(device) feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device) embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
streaming = True if random.random() < 0.5 else False
# xvec projection # xvec projection
embedding = F.normalize(embedding, dim=1) embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding) embedding = self.spk_embed_affine_layer(embedding)
@@ -211,7 +214,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
token = self.input_embedding(torch.clamp(token, min=0)) * mask token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode # text encode
h, h_lengths = self.encoder(token, token_len) h, h_lengths = self.encoder(token, token_len, streaming=streaming)
h = self.encoder_proj(h) h = self.encoder_proj(h)
# get conditions # get conditions
@@ -230,7 +233,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
mask.unsqueeze(1), mask.unsqueeze(1),
h.transpose(1, 2).contiguous(), h.transpose(1, 2).contiguous(),
embedding, embedding,
cond=conds cond=conds,
streaming=streaming,
) )
return {'loss': loss} return {'loss': loss}

View File

@@ -142,7 +142,7 @@ class ConditionalCFM(BASECFM):
x.data_ptr()]) x.data_ptr()])
return x return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None): def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
"""Computes diffusion loss """Computes diffusion loss
Args: Args:
@@ -179,11 +179,8 @@ class ConditionalCFM(BASECFM):
spks = spks * cfg_mask.view(-1, 1) spks = spks * cfg_mask.view(-1, 1)
cond = cond * cfg_mask.view(-1, 1, 1) cond = cond * cfg_mask.view(-1, 1, 1)
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
if loss.isnan():
print(123)
pred_new = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
return loss, y return loss, y

View File

@@ -255,6 +255,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
xs_lens: torch.Tensor, xs_lens: torch.Tensor,
decoding_chunk_size: int = 0, decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1, num_decoding_left_chunks: int = -1,
streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor. """Embed positions in tensor.
@@ -286,11 +287,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
xs, pos_emb, masks = self.embed(xs, masks) xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate) mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks, chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk, self.use_dynamic_chunk if streaming is True else False,
self.use_dynamic_left_chunk, self.use_dynamic_left_chunk if streaming is True else False,
decoding_chunk_size, decoding_chunk_size if streaming is True else 0,
self.static_chunk_size, self.static_chunk_size if streaming is True else 0,
num_decoding_left_chunks) num_decoding_left_chunks if streaming is True else -1)
# lookahead + conformer encoder # lookahead + conformer encoder
xs, _ = self.pre_lookahead_layer(xs) xs, _ = self.pre_lookahead_layer(xs)
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
@@ -304,11 +305,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
xs, pos_emb, masks = self.up_embed(xs, masks) xs, pos_emb, masks = self.up_embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate) mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks, chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk, self.use_dynamic_chunk if streaming is True else False,
self.use_dynamic_left_chunk, self.use_dynamic_left_chunk if streaming is True else False,
decoding_chunk_size, decoding_chunk_size if streaming is True else 0,
self.static_chunk_size * self.up_layer.stride, self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
num_decoding_left_chunks) num_decoding_left_chunks if streaming is True else -1)
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad) xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before: if self.normalize_before: