add flow unified training

This commit is contained in:
lyuxiang.lx
2025-01-26 16:56:06 +08:00
parent aea75207dd
commit fd1a951a6c
4 changed files with 38 additions and 26 deletions

View File

@@ -142,7 +142,7 @@ class ConditionalCFM(BASECFM):
x.data_ptr()])
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
"""Computes diffusion loss
Args:
@@ -179,11 +179,8 @@ class ConditionalCFM(BASECFM):
spks = spks * cfg_mask.view(-1, 1)
cond = cond * cfg_mask.view(-1, 1, 1)
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
if loss.isnan():
print(123)
pred_new = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
return loss, y