mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add flow unified training
This commit is contained in:
@@ -142,7 +142,7 @@ class ConditionalCFM(BASECFM):
|
||||
x.data_ptr()])
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Args:
|
||||
@@ -179,11 +179,8 @@ class ConditionalCFM(BASECFM):
|
||||
spks = spks * cfg_mask.view(-1, 1)
|
||||
cond = cond * cfg_mask.view(-1, 1, 1)
|
||||
|
||||
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
||||
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
||||
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
||||
if loss.isnan():
|
||||
print(123)
|
||||
pred_new = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
||||
return loss, y
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user