send streaming as args

This commit is contained in:
lyuxiang.lx
2025-05-27 13:47:12 +08:00
parent 54d21b40f0
commit cbfed4a9ee
5 changed files with 14 additions and 19 deletions

View File

@@ -69,7 +69,7 @@ class ConditionalCFM(BASECFM):
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
def solve_euler(self, x, t_span, mu, mask, spks, cond):
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
"""
Fixed euler solver for ODEs.
Args:
@@ -110,7 +110,8 @@ class ConditionalCFM(BASECFM):
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in
cond_in,
streaming
)
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
@@ -122,9 +123,9 @@ class ConditionalCFM(BASECFM):
return sol[-1].float()
def forward_estimator(self, x, mask, mu, t, spks, cond):
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator(x, mask, mu, t, spks, cond)
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
else:
estimator, trt_engine = self.estimator.acquire_estimator()
estimator.set_input_shape('x', (2, 80, x.size(2)))
@@ -196,7 +197,7 @@ class CausalConditionalCFM(ConditionalCFM):
self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
"""Forward diffusion
Args:
@@ -220,4 +221,4 @@ class CausalConditionalCFM(ConditionalCFM):
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None