fix cosyvoice3 training

This commit is contained in:
lyuxiang.lx
2025-12-29 10:02:59 +00:00
parent 8524c81acd
commit 4d7295a9a7
5 changed files with 25 additions and 23 deletions

View File

@@ -332,8 +332,9 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
h = self.encoder_proj(h)
h = self.pre_lookahead_layer(token)
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
@@ -344,7 +345,6 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),