[debug] support flow cache, for sharper tts_mel output

This commit is contained in:
boji123
2024-09-20 12:35:44 +08:00
parent d49259855b
commit c9acce1482
3 changed files with 34 additions and 9 deletions

View File

@@ -109,7 +109,9 @@ class MaskedDiffWithXvec(torch.nn.Module):
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding):
embedding,
required_cache_size=0,
flow_cache=None):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
@@ -133,13 +135,15 @@ class MaskedDiffWithXvec(torch.nn.Module):
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat = self.decoder(
feat, flow_cache = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10
n_timesteps=10,
required_cache_size=required_cache_size,
flow_cache=flow_cache
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat
return feat, flow_cache