mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add train cfg in flow matching
This commit is contained in:
@@ -126,6 +126,13 @@ class ConditionalCFM(BASECFM):
|
||||
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
||||
u = x1 - (1 - self.sigma_min) * z
|
||||
|
||||
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
||||
if self.training_cfg_rate > 0:
|
||||
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
||||
mu = mu * cfg_mask.view(-1, 1, 1)
|
||||
spks = spks * cfg_mask.view(-1, 1)
|
||||
cond = cond * cfg_mask.view(-1, 1, 1)
|
||||
|
||||
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
||||
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
||||
return loss, y
|
||||
|
||||
Reference in New Issue
Block a user