mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
[debug] handle cache with prompt
This commit is contained in:
@@ -141,6 +141,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
spks=embedding,
|
spks=embedding,
|
||||||
cond=conds,
|
cond=conds,
|
||||||
n_timesteps=10,
|
n_timesteps=10,
|
||||||
|
prompt_len=mel_len1,
|
||||||
required_cache_size=required_cache_size,
|
required_cache_size=required_cache_size,
|
||||||
flow_cache=flow_cache
|
flow_cache=flow_cache
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM):
|
|||||||
self.estimator = estimator
|
self.estimator = estimator
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None):
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, required_cache_size=0, flow_cache=None):
|
||||||
"""Forward diffusion
|
"""Forward diffusion
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -62,8 +62,8 @@ class ConditionalCFM(BASECFM):
|
|||||||
|
|
||||||
next_cache_start = max(z.size(2) - required_cache_size, 0)
|
next_cache_start = max(z.size(2) - required_cache_size, 0)
|
||||||
flow_cache = [
|
flow_cache = [
|
||||||
z[..., next_cache_start:],
|
torch.cat((z[..., :prompt_len], z[..., next_cache_start:]), dim=2),
|
||||||
mu[..., next_cache_start:]
|
torch.cat((mu[..., :prompt_len], mu[..., next_cache_start:]), dim=2)
|
||||||
]
|
]
|
||||||
|
|
||||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user