send streaming as args

This commit is contained in:
lyuxiang.lx
2025-05-27 13:47:12 +08:00
parent 54d21b40f0
commit cbfed4a9ee
5 changed files with 14 additions and 19 deletions

View File

@@ -241,6 +241,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
prompt_feat,
prompt_feat_len,
embedding,
streaming,
finalize):
assert token.shape[0] == 1
# xvec projection
@@ -254,10 +255,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
# text encode
if finalize is True:
h, h_lengths = self.encoder(token, token_len)
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
else:
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
h, h_lengths = self.encoder(token, token_len, context=context)
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)
@@ -273,6 +274,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
spks=embedding,
cond=conds,
n_timesteps=10,
streaming=streaming
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2