update flow cache

This commit is contained in:
lyuxiang.lx
2024-10-16 15:24:47 +08:00
parent ace734def8
commit a4db3db8ed
3 changed files with 26 additions and 31 deletions

View File

@@ -110,8 +110,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
prompt_feat,
prompt_feat_len,
embedding,
required_cache_size=0,
flow_cache=None):
flow_cache):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
@@ -142,7 +141,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
cond=conds,
n_timesteps=10,
prompt_len=mel_len1,
required_cache_size=required_cache_size,
flow_cache=flow_cache
)
feat = feat[:, :, mel_len1:]