mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix vocoder speech overlap
This commit is contained in:
@@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module):
|
||||
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
||||
return inverse_transform
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||
f0 = self.f0_predictor(x)
|
||||
s = self._f02source(f0)
|
||||
|
||||
# use cache_source to avoid glitch
|
||||
if cache_source.shape[2] == 0:
|
||||
s[:, :, :cache_source.shape[2]] = cache_source
|
||||
|
||||
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
||||
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
||||
|
||||
@@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module):
|
||||
|
||||
x = self._istft(magnitude, phase)
|
||||
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
||||
return x
|
||||
return x, s
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
@@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module):
|
||||
l.remove_weight_norm()
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self, mel: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward(x=mel)
|
||||
def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||
return self.forward(x=mel, cache_source=cache_source)
|
||||
|
||||
Reference in New Issue
Block a user