mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update stream code
This commit is contained in:
@@ -105,6 +105,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
@@ -112,17 +113,16 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h = self.encoder_proj(h)
|
||||
feat_len = (token_len / 50 * 22050 / 256).int()
|
||||
h, h_lengths = self.length_regulator(h, feat_len)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
|
||||
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
|
||||
if prompt_feat.shape[1] != 0:
|
||||
for i, j in enumerate(prompt_feat_len):
|
||||
conds[i, :j] = prompt_feat[i]
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(feat_len)).to(h)
|
||||
# mask = (~make_pad_mask(feat_len)).to(h)
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
@@ -130,6 +130,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
cond=conds,
|
||||
n_timesteps=10
|
||||
)
|
||||
if prompt_feat.shape[1] != 0:
|
||||
feat = feat[:, :, prompt_feat.shape[1]:]
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat
|
||||
|
||||
Reference in New Issue
Block a user