mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update stream code
This commit is contained in:
@@ -105,6 +105,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
@@ -112,17 +113,16 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h = self.encoder_proj(h)
|
||||
feat_len = (token_len / 50 * 22050 / 256).int()
|
||||
h, h_lengths = self.length_regulator(h, feat_len)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
|
||||
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
|
||||
if prompt_feat.shape[1] != 0:
|
||||
for i, j in enumerate(prompt_feat_len):
|
||||
conds[i, :j] = prompt_feat[i]
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(feat_len)).to(h)
|
||||
# mask = (~make_pad_mask(feat_len)).to(h)
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
@@ -130,6 +130,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
cond=conds,
|
||||
n_timesteps=10
|
||||
)
|
||||
if prompt_feat.shape[1] != 0:
|
||||
feat = feat[:, :, prompt_feat.shape[1]:]
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
from typing import Tuple
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from cosyvoice.utils.mask import make_pad_mask
|
||||
|
||||
@@ -47,3 +48,21 @@ class InterpolateRegulator(nn.Module):
|
||||
out = self.model(x).transpose(1, 2).contiguous()
|
||||
olens = ylens
|
||||
return out * mask, olens
|
||||
|
||||
def inference(self, x1, x2, mel_len1, mel_len2):
|
||||
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
||||
# x in (B, T, D)
|
||||
if x2.shape[1] > 40:
|
||||
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
|
||||
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
|
||||
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
|
||||
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
||||
else:
|
||||
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
||||
if x1.shape[1] != 0:
|
||||
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
||||
x = torch.concat([x1, x2], dim=2)
|
||||
else:
|
||||
x = x2
|
||||
out = self.model(x).transpose(1, 2).contiguous()
|
||||
return out, mel_len1 + mel_len2
|
||||
|
||||
Reference in New Issue
Block a user