update vc/tts code

This commit is contained in:
lyuxiang.lx
2024-09-26 11:53:10 +08:00
parent 49015f63e6
commit 72b89a52fb
5 changed files with 74 additions and 21 deletions

View File

@@ -125,7 +125,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)

View File

@@ -49,13 +49,14 @@ class InterpolateRegulator(nn.Module):
olens = ylens
return out * mask, olens
def inference(self, x1, x2, mel_len1, mel_len2):
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
# x in (B, T, D)
if x2.shape[1] > 40:
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
mode='linear')
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
else:
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')