mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
[debug] support flow cache, for sharper tts_mel output
This commit is contained in:
@@ -109,7 +109,9 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
prompt_token_len,
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding):
|
||||
embedding,
|
||||
required_cache_size=0,
|
||||
flow_cache=None):
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
@@ -133,13 +135,15 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat = self.decoder(
|
||||
feat, flow_cache = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10
|
||||
n_timesteps=10,
|
||||
required_cache_size=required_cache_size,
|
||||
flow_cache=flow_cache
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat
|
||||
return feat, flow_cache
|
||||
|
||||
Reference in New Issue
Block a user