remove flow_cache

This commit is contained in:
lyuxiang.lx
2025-05-23 12:50:47 +08:00
parent 88f467a8ac
commit 68100c267a
14 changed files with 365 additions and 955 deletions

View File

@@ -241,7 +241,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
prompt_feat,
prompt_feat_len,
embedding,
cache,
finalize):
assert token.shape[0] == 1
# xvec projection
@@ -255,16 +254,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
# text encode
if finalize is True:
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache'])
h, h_lengths = self.encoder(token, token_len)
else:
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache'])
cache['encoder_cache']['offset'] = encoder_cache[0]
cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
h, h_lengths = self.encoder(token, token_len, context=context)
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)
@@ -274,14 +267,13 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, cache['decoder_cache'] = self.decoder(
feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
cache=cache['decoder_cache']
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), cache
return feat.float(), None