fix fm train bug

This commit is contained in:
lyuxiang.lx
2026-01-19 10:48:24 +08:00
parent 1dcc59676f
commit 1822c5c908

View File

@@ -174,8 +174,7 @@ class ConditionalCFM(BASECFM):
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t = 1 - torch.cos(t * 0.5 * torch.pi)
# sample noise p(x_0)
z = torch.randn_like(x1)