add train cfg in flow matching

This commit is contained in:
lyuxiang.lx
2024-07-11 17:36:59 +08:00
parent c7d9754eee
commit 44aea805ea
8 changed files with 17 additions and 10 deletions

View File

@@ -126,6 +126,13 @@ class ConditionalCFM(BASECFM):
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
if self.training_cfg_rate > 0:
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
mu = mu * cfg_mask.view(-1, 1, 1)
spks = spks * cfg_mask.view(-1, 1)
cond = cond * cfg_mask.view(-1, 1, 1)
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
return loss, y