fix vocoder speech overlap

This commit is contained in:
lyuxiang.lx
2024-08-29 19:10:08 +08:00
parent f1e374a9bb
commit 1d881df8b2
3 changed files with 93 additions and 74 deletions

View File

@@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module):
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
return inverse_transform
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
f0 = self.f0_predictor(x)
s = self._f02source(f0)
# use cache_source to avoid glitch
if cache_source.shape[2] == 0:
s[:, :, :cache_source.shape[2]] = cache_source
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
@@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module):
x = self._istft(magnitude, phase)
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x
return x, s
def remove_weight_norm(self):
print('Removing weight norm...')
@@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module):
l.remove_weight_norm()
@torch.inference_mode()
def inference(self, mel: torch.Tensor) -> torch.Tensor:
return self.forward(x=mel)
def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
return self.forward(x=mel, cache_source=cache_source)