add hifigan train code

This commit is contained in:
lyuxiang.lx
2024-10-09 17:36:42 +08:00
parent 67f298d94a
commit cb200b21c5
10 changed files with 768 additions and 40 deletions

View File

@@ -14,7 +14,7 @@
"""HIFI-GAN"""
import typing as tp
from typing import Dict, Optional, List
import numpy as np
from scipy.signal import get_window
import torch
@@ -46,7 +46,7 @@ class ResBlock(torch.nn.Module):
self,
channels: int = 512,
kernel_size: int = 3,
dilations: tp.List[int] = [1, 3, 5],
dilations: List[int] = [1, 3, 5],
):
super(ResBlock, self).__init__()
self.convs1 = nn.ModuleList()
@@ -234,13 +234,13 @@ class HiFTGenerator(nn.Module):
nsf_alpha: float = 0.1,
nsf_sigma: float = 0.003,
nsf_voiced_threshold: float = 10,
upsample_rates: tp.List[int] = [8, 8],
upsample_kernel_sizes: tp.List[int] = [16, 16],
istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
source_resblock_kernel_sizes: tp.List[int] = [7, 11],
source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
upsample_rates: List[int] = [8, 8],
upsample_kernel_sizes: List[int] = [16, 16],
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
resblock_kernel_sizes: List[int] = [3, 7, 11],
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
source_resblock_kernel_sizes: List[int] = [7, 11],
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
lrelu_slope: float = 0.1,
audio_limit: float = 0.99,
f0_predictor: torch.nn.Module = None,
@@ -316,11 +316,19 @@ class HiFTGenerator(nn.Module):
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
self.f0_predictor = f0_predictor
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
har_source, _, _ = self.m_source(f0)
return har_source.transpose(1, 2)
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
self.m_source.remove_weight_norm()
for l in self.source_downs:
remove_weight_norm(l)
for l in self.source_resblocks:
l.remove_weight_norm()
def _stft(self, x):
spec = torch.stft(
@@ -338,14 +346,7 @@ class HiFTGenerator(nn.Module):
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
return inverse_transform
def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
f0 = self.f0_predictor(x)
s = self._f02source(f0)
# use cache_source to avoid glitch
if cache_source.shape[2] != 0:
s[:, :, :cache_source.shape[2]] = cache_source
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
@@ -377,22 +378,34 @@ class HiFTGenerator(nn.Module):
x = self._istft(magnitude, phase)
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x, s
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
self.source_module.remove_weight_norm()
for l in self.source_downs:
remove_weight_norm(l)
for l in self.source_resblocks:
l.remove_weight_norm()
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
# mel->f0
f0 = self.f0_predictor(speech_feat)
# f0->source
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
s, _, _ = self.m_source(s)
s = s.transpose(1, 2)
# mel+source->speech
generated_speech = self.decode(x=speech_feat, s=s)
return generated_speech, f0
@torch.inference_mode()
def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
return self.forward(x=mel, cache_source=cache_source)
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
# mel->f0
f0 = self.f0_predictor(speech_feat)
# f0->source
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
s, _, _ = self.m_source(s)
s = s.transpose(1, 2)
# use cache_source to avoid glitch
if cache_source.shape[2] != 0:
s[:, :, :cache_source.shape[2]] = cache_source
generated_speech = self.decode(x=speech_feat, s=s)
return generated_speech, s