mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
[debug] support flow cache, for sharper tts_mel output
This commit is contained in:
@@ -109,7 +109,9 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
prompt_token_len,
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding):
|
||||
embedding,
|
||||
required_cache_size=0,
|
||||
flow_cache=None):
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
@@ -133,13 +135,15 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat = self.decoder(
|
||||
feat, flow_cache = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10
|
||||
n_timesteps=10,
|
||||
required_cache_size=required_cache_size,
|
||||
flow_cache=flow_cache
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat
|
||||
return feat, flow_cache
|
||||
|
||||
@@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM):
|
||||
self.estimator = estimator
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
@@ -50,11 +50,26 @@ class ConditionalCFM(BASECFM):
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
z = torch.randn_like(mu) * temperature
|
||||
|
||||
if flow_cache is not None:
|
||||
z_cache = flow_cache[0]
|
||||
mu_cache = flow_cache[1]
|
||||
z = torch.randn((mu.size(0), mu.size(1), mu.size(2) - z_cache.size(2)), dtype=mu.dtype, device=mu.device) * temperature
|
||||
z = torch.cat((z_cache, z), dim=2) # [B, 80, T]
|
||||
mu = torch.cat((mu_cache, mu[..., mu_cache.size(2):]), dim=2) # [B, 80, T]
|
||||
else:
|
||||
z = torch.randn_like(mu) * temperature
|
||||
|
||||
next_cache_start = max(z.size(2) - required_cache_size, 0)
|
||||
flow_cache = [
|
||||
z[..., next_cache_start:],
|
||||
mu[..., next_cache_start:]
|
||||
]
|
||||
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user