From 8130abb5ea67c77c1da7be2c3ab0e0f7707a8c52 Mon Sep 17 00:00:00 2001 From: boji123 Date: Sun, 29 Sep 2024 19:12:30 +0800 Subject: [PATCH] [debug] handle cache with prompt --- cosyvoice/flow/flow.py | 1 + cosyvoice/flow/flow_matching.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 50d96f7..e430b83 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -141,6 +141,7 @@ class MaskedDiffWithXvec(torch.nn.Module): spks=embedding, cond=conds, n_timesteps=10, + prompt_len=mel_len1, required_cache_size=required_cache_size, flow_cache=flow_cache ) diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 4b1503b..83dc971 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM): self.estimator = estimator @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, required_cache_size=0, flow_cache=None): """Forward diffusion Args: @@ -62,8 +62,8 @@ class ConditionalCFM(BASECFM): next_cache_start = max(z.size(2) - required_cache_size, 0) flow_cache = [ - z[..., next_cache_start:], - mu[..., next_cache_start:] + torch.cat((z[..., :prompt_len], z[..., next_cache_start:]), dim=2), + torch.cat((mu[..., :prompt_len], mu[..., next_cache_start:]), dim=2) ] t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)