[debug] handle cache with prompt

This commit is contained in:
boji123
2024-09-29 19:12:30 +08:00
parent c9acce1482
commit 8130abb5ea
2 changed files with 4 additions and 3 deletions

View File

@@ -141,6 +141,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
spks=embedding,
cond=conds,
n_timesteps=10,
prompt_len=mel_len1,
required_cache_size=required_cache_size,
flow_cache=flow_cache
)

View File

@@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM):
self.estimator = estimator
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, required_cache_size=0, flow_cache=None):
"""Forward diffusion
Args:
@@ -62,8 +62,8 @@ class ConditionalCFM(BASECFM):
next_cache_start = max(z.size(2) - required_cache_size, 0)
flow_cache = [
z[..., next_cache_start:],
mu[..., next_cache_start:]
torch.cat((z[..., :prompt_len], z[..., next_cache_start:]), dim=2),
torch.cat((mu[..., :prompt_len], mu[..., next_cache_start:]), dim=2)
]
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)