mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
send streaming as args
This commit is contained in:
@@ -419,10 +419,6 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
if hasattr(self, 'streaming'):
|
||||
assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
|
||||
streaming = self.streaming
|
||||
|
||||
t = self.time_embeddings(t).to(t.dtype)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
|
||||
@@ -241,6 +241,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
streaming,
|
||||
finalize):
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
@@ -254,10 +255,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
|
||||
# text encode
|
||||
if finalize is True:
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
||||
else:
|
||||
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
||||
h, h_lengths = self.encoder(token, token_len, context=context)
|
||||
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
@@ -273,6 +274,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
streaming=streaming
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
|
||||
@@ -69,7 +69,7 @@ class ConditionalCFM(BASECFM):
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
@@ -110,7 +110,8 @@ class ConditionalCFM(BASECFM):
|
||||
x_in, mask_in,
|
||||
mu_in, t_in,
|
||||
spks_in,
|
||||
cond_in
|
||||
cond_in,
|
||||
streaming
|
||||
)
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||
@@ -122,9 +123,9 @@ class ConditionalCFM(BASECFM):
|
||||
|
||||
return sol[-1].float()
|
||||
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator(x, mask, mu, t, spks, cond)
|
||||
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
|
||||
else:
|
||||
estimator, trt_engine = self.estimator.acquire_estimator()
|
||||
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
@@ -196,7 +197,7 @@ class CausalConditionalCFM(ConditionalCFM):
|
||||
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
@@ -220,4 +221,4 @@ class CausalConditionalCFM(ConditionalCFM):
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
|
||||
|
||||
Reference in New Issue
Block a user