mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update flow cache
This commit is contained in:
@@ -110,8 +110,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
required_cache_size=0,
|
||||
flow_cache=None):
|
||||
flow_cache):
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
@@ -142,7 +141,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
prompt_len=mel_len1,
|
||||
required_cache_size=required_cache_size,
|
||||
flow_cache=flow_cache
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
|
||||
@@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM):
|
||||
self.estimator = estimator
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, required_cache_size=0, flow_cache=None):
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
@@ -51,20 +51,15 @@ class ConditionalCFM(BASECFM):
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
if flow_cache is not None:
|
||||
z_cache = flow_cache[0]
|
||||
mu_cache = flow_cache[1]
|
||||
z = torch.randn((mu.size(0), mu.size(1), mu.size(2) - z_cache.size(2)), dtype=mu.dtype, device=mu.device) * temperature
|
||||
z = torch.cat((z_cache, z), dim=2) # [B, 80, T]
|
||||
mu = torch.cat((mu_cache, mu[..., mu_cache.size(2):]), dim=2) # [B, 80, T]
|
||||
else:
|
||||
z = torch.randn_like(mu) * temperature
|
||||
|
||||
next_cache_start = max(z.size(2) - required_cache_size, 0)
|
||||
flow_cache = [
|
||||
torch.cat((z[..., :prompt_len], z[..., next_cache_start:]), dim=2),
|
||||
torch.cat((mu[..., :prompt_len], mu[..., next_cache_start:]), dim=2)
|
||||
]
|
||||
z = torch.randn_like(mu) * temperature
|
||||
cache_size = flow_cache.shape[2]
|
||||
# fix prompt and overlap part mu and z
|
||||
if cache_size != 0:
|
||||
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
|
||||
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
|
||||
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
||||
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
||||
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
|
||||
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
|
||||
Reference in New Issue
Block a user