mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add hifigan train code
This commit is contained in:
139
cosyvoice/hifigan/discriminator.py
Normal file
139
cosyvoice/hifigan/discriminator.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from typing import List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm
|
||||
from typing import List, Optional, Tuple
|
||||
from einops import rearrange
|
||||
from torchaudio.transforms import Spectrogram
|
||||
|
||||
class MultipleDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self, mpd: nn.Module, mrd: nn.Module
|
||||
):
|
||||
super().__init__()
|
||||
self.mpd = mpd
|
||||
self.mrd = mrd
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
|
||||
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
||||
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
|
||||
y_d_rs += this_y_d_rs
|
||||
y_d_gs += this_y_d_gs
|
||||
fmap_rs += this_fmap_rs
|
||||
fmap_gs += this_fmap_gs
|
||||
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
|
||||
y_d_rs += this_y_d_rs
|
||||
y_d_gs += this_y_d_gs
|
||||
fmap_rs += this_fmap_rs
|
||||
fmap_gs += this_fmap_gs
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
||||
num_embeddings: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
||||
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
||||
|
||||
Args:
|
||||
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
||||
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for d in self.discriminators:
|
||||
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
||||
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
window_length: int,
|
||||
num_embeddings: Optional[int] = None,
|
||||
channels: int = 32,
|
||||
hop_factor: float = 0.25,
|
||||
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
||||
):
|
||||
super().__init__()
|
||||
self.window_length = window_length
|
||||
self.hop_factor = hop_factor
|
||||
self.spec_fn = Spectrogram(
|
||||
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
||||
)
|
||||
n_fft = window_length // 2 + 1
|
||||
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
||||
self.bands = bands
|
||||
convs = lambda: nn.ModuleList(
|
||||
[
|
||||
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||
|
||||
if num_embeddings is not None:
|
||||
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
||||
torch.nn.init.zeros_(self.emb.weight)
|
||||
|
||||
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
||||
|
||||
def spectrogram(self, x):
|
||||
# Remove DC offset
|
||||
x = x - x.mean(dim=-1, keepdims=True)
|
||||
# Peak normalize the volume of input audio
|
||||
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||
x = self.spec_fn(x)
|
||||
x = torch.view_as_real(x)
|
||||
x = rearrange(x, "b f t c -> b c t f")
|
||||
# Split into bands
|
||||
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
||||
return x_bands
|
||||
|
||||
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
||||
x_bands = self.spectrogram(x)
|
||||
fmap = []
|
||||
x = []
|
||||
for band, stack in zip(x_bands, self.band_convs):
|
||||
for i, layer in enumerate(stack):
|
||||
band = layer(band)
|
||||
band = torch.nn.functional.leaky_relu(band, 0.1)
|
||||
if i > 0:
|
||||
fmap.append(band)
|
||||
x.append(band)
|
||||
x = torch.cat(x, dim=-1)
|
||||
if cond_embedding_id is not None:
|
||||
emb = self.emb(cond_embedding_id)
|
||||
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
||||
else:
|
||||
h = 0
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x += h
|
||||
|
||||
return x, fmap
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
"""HIFI-GAN"""
|
||||
|
||||
import typing as tp
|
||||
from typing import Dict, Optional, List
|
||||
import numpy as np
|
||||
from scipy.signal import get_window
|
||||
import torch
|
||||
@@ -46,7 +46,7 @@ class ResBlock(torch.nn.Module):
|
||||
self,
|
||||
channels: int = 512,
|
||||
kernel_size: int = 3,
|
||||
dilations: tp.List[int] = [1, 3, 5],
|
||||
dilations: List[int] = [1, 3, 5],
|
||||
):
|
||||
super(ResBlock, self).__init__()
|
||||
self.convs1 = nn.ModuleList()
|
||||
@@ -234,13 +234,13 @@ class HiFTGenerator(nn.Module):
|
||||
nsf_alpha: float = 0.1,
|
||||
nsf_sigma: float = 0.003,
|
||||
nsf_voiced_threshold: float = 10,
|
||||
upsample_rates: tp.List[int] = [8, 8],
|
||||
upsample_kernel_sizes: tp.List[int] = [16, 16],
|
||||
istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
||||
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
||||
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
source_resblock_kernel_sizes: tp.List[int] = [7, 11],
|
||||
source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
|
||||
upsample_rates: List[int] = [8, 8],
|
||||
upsample_kernel_sizes: List[int] = [16, 16],
|
||||
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
||||
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
||||
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
source_resblock_kernel_sizes: List[int] = [7, 11],
|
||||
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
||||
lrelu_slope: float = 0.1,
|
||||
audio_limit: float = 0.99,
|
||||
f0_predictor: torch.nn.Module = None,
|
||||
@@ -316,11 +316,19 @@ class HiFTGenerator(nn.Module):
|
||||
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
||||
self.f0_predictor = f0_predictor
|
||||
|
||||
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
|
||||
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
|
||||
har_source, _, _ = self.m_source(f0)
|
||||
return har_source.transpose(1, 2)
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
self.m_source.remove_weight_norm()
|
||||
for l in self.source_downs:
|
||||
remove_weight_norm(l)
|
||||
for l in self.source_resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
def _stft(self, x):
|
||||
spec = torch.stft(
|
||||
@@ -338,14 +346,7 @@ class HiFTGenerator(nn.Module):
|
||||
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
||||
return inverse_transform
|
||||
|
||||
def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||
f0 = self.f0_predictor(x)
|
||||
s = self._f02source(f0)
|
||||
|
||||
# use cache_source to avoid glitch
|
||||
if cache_source.shape[2] != 0:
|
||||
s[:, :, :cache_source.shape[2]] = cache_source
|
||||
|
||||
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
||||
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
||||
|
||||
@@ -377,22 +378,34 @@ class HiFTGenerator(nn.Module):
|
||||
|
||||
x = self._istft(magnitude, phase)
|
||||
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
||||
return x, s
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
self.source_module.remove_weight_norm()
|
||||
for l in self.source_downs:
|
||||
remove_weight_norm(l)
|
||||
for l in self.source_resblocks:
|
||||
l.remove_weight_norm()
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
|
||||
# mel->f0
|
||||
f0 = self.f0_predictor(speech_feat)
|
||||
# f0->source
|
||||
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
s, _, _ = self.m_source(s)
|
||||
s = s.transpose(1, 2)
|
||||
# mel+source->speech
|
||||
generated_speech = self.decode(x=speech_feat, s=s)
|
||||
return generated_speech, f0
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||
return self.forward(x=mel, cache_source=cache_source)
|
||||
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||
# mel->f0
|
||||
f0 = self.f0_predictor(speech_feat)
|
||||
# f0->source
|
||||
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
s, _, _ = self.m_source(s)
|
||||
s = s.transpose(1, 2)
|
||||
# use cache_source to avoid glitch
|
||||
if cache_source.shape[2] != 0:
|
||||
s[:, :, :cache_source.shape[2]] = cache_source
|
||||
generated_speech = self.decode(x=speech_feat, s=s)
|
||||
return generated_speech, s
|
||||
|
||||
66
cosyvoice/hifigan/hifigan.py
Normal file
66
cosyvoice/hifigan/hifigan.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Dict, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
|
||||
from cosyvoice.utils.losses import tpr_loss, mel_loss
|
||||
|
||||
class HiFiGan(nn.Module):
|
||||
def __init__(self, generator, discriminator, mel_spec_transform,
|
||||
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
|
||||
tpr_loss_weight=1.0, tpr_loss_tau=0.04):
|
||||
super(HiFiGan, self).__init__()
|
||||
self.generator = generator
|
||||
self.discriminator = discriminator
|
||||
self.mel_spec_transform = mel_spec_transform
|
||||
self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
|
||||
self.feat_match_loss_weight = feat_match_loss_weight
|
||||
self.tpr_loss_weight = tpr_loss_weight
|
||||
self.tpr_loss_tau = tpr_loss_tau
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
if batch['turn'] == 'generator':
|
||||
return self.forward_generator(batch, device)
|
||||
else:
|
||||
return self.forward_discriminator(batch, device)
|
||||
|
||||
def forward_generator(self, batch, device):
|
||||
real_speech = batch['speech'].to(device)
|
||||
pitch_feat = batch['pitch_feat'].to(device)
|
||||
# 1. calculate generator outputs
|
||||
generated_speech, generated_f0 = self.generator(batch, device)
|
||||
# 2. calculate discriminator outputs
|
||||
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
||||
# 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
|
||||
loss_gen, _ = generator_loss(y_d_gs)
|
||||
loss_fm = feature_loss(fmap_rs, fmap_gs)
|
||||
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
|
||||
if self.tpr_loss_weight != 0:
|
||||
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
||||
else:
|
||||
loss_tpr = torch.zeros(1).to(device)
|
||||
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
||||
loss = loss_gen + self.feat_match_loss_weight * loss_fm + self.multi_mel_spectral_recon_loss_weight * loss_mel + self.tpr_loss_weight * loss_tpr + loss_f0
|
||||
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
||||
|
||||
def forward_discriminator(self, batch, device):
|
||||
real_speech = batch['speech'].to(device)
|
||||
pitch_feat = batch['pitch_feat'].to(device)
|
||||
# 1. calculate generator outputs
|
||||
with torch.no_grad():
|
||||
generated_speech, generated_f0 = self.generator(batch, device)
|
||||
# 2. calculate discriminator outputs
|
||||
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
||||
# 3. calculate discriminator losses, tpr losses [Optional]
|
||||
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
||||
if self.tpr_loss_weight != 0:
|
||||
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
||||
else:
|
||||
loss_tpr = torch.zeros(1).to(device)
|
||||
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
||||
loss = loss_disc + self.tpr_loss_weight * loss_tpr + loss_f0
|
||||
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
||||
Reference in New Issue
Block a user