add flow unified training

This commit is contained in:
lyuxiang.lx
2025-01-26 16:56:06 +08:00
parent aea75207dd
commit fd1a951a6c
4 changed files with 38 additions and 26 deletions

View File

@@ -202,6 +202,9 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
streaming = True if random.random() < 0.5 else False
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
@@ -211,7 +214,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
h = self.encoder_proj(h)
# get conditions
@@ -230,7 +233,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds
cond=conds,
streaming=streaming,
)
return {'loss': loss}