diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 50d96f7..e430b83 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -141,6 +141,7 @@ class MaskedDiffWithXvec(torch.nn.Module): spks=embedding, cond=conds, n_timesteps=10, + prompt_len=mel_len1, required_cache_size=required_cache_size, flow_cache=flow_cache ) diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 4b1503b..83dc971 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM): self.estimator = estimator @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, required_cache_size=0, flow_cache=None): """Forward diffusion Args: @@ -62,8 +62,8 @@ class ConditionalCFM(BASECFM): next_cache_start = max(z.size(2) - required_cache_size, 0) flow_cache = [ - z[..., next_cache_start:], - mu[..., next_cache_start:] + torch.cat((z[..., :prompt_len], z[..., next_cache_start:]), dim=2), + torch.cat((mu[..., :prompt_len], mu[..., next_cache_start:]), dim=2) ] t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)