mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 09:59:23 +08:00
send streaming as args
This commit is contained in:
@@ -241,6 +241,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
streaming,
|
||||
finalize):
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
@@ -254,10 +255,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
|
||||
# text encode
|
||||
if finalize is True:
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
||||
else:
|
||||
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
||||
h, h_lengths = self.encoder(token, token_len, context=context)
|
||||
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
@@ -273,6 +274,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
streaming=streaming
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
|
||||
Reference in New Issue
Block a user