From cb200b21c5cf1e39216cc52491fd0c13e5b111cf Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Wed, 9 Oct 2024 17:36:42 +0800 Subject: [PATCH] add hifigan train code --- cosyvoice/bin/train.py | 2 +- cosyvoice/bin/train_gan.py | 137 +++++++++++++++++ cosyvoice/dataset/processor.py | 55 ++++++- cosyvoice/hifigan/discriminator.py | 139 +++++++++++++++++ cosyvoice/hifigan/generator.py | 89 ++++++----- cosyvoice/hifigan/hifigan.py | 66 ++++++++ cosyvoice/utils/executor_gan.py | 118 +++++++++++++++ cosyvoice/utils/losses.py | 18 +++ cosyvoice/utils/train_utils.py | 43 ++++++ .../cosyvoice/conf/cosyvoice.hifigan.yaml | 141 ++++++++++++++++++ 10 files changed, 768 insertions(+), 40 deletions(-) create mode 100644 cosyvoice/bin/train_gan.py create mode 100644 cosyvoice/hifigan/discriminator.py create mode 100644 cosyvoice/hifigan/hifigan.py create mode 100644 cosyvoice/utils/executor_gan.py create mode 100644 cosyvoice/utils/losses.py create mode 100644 examples/libritts/cosyvoice/conf/cosyvoice.hifigan.yaml diff --git a/cosyvoice/bin/train.py b/cosyvoice/bin/train.py index ae43fa7..016663f 100644 --- a/cosyvoice/bin/train.py +++ b/cosyvoice/bin/train.py @@ -87,7 +87,7 @@ def main(): logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') - override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model} + override_dict = {k: None for k in ['llm', 'flow', 'hifigan'] if k != args.model} with open(args.config, 'r') as f: configs = load_hyperpyyaml(f, overrides=override_dict) configs['train_conf'].update(vars(args)) diff --git a/cosyvoice/bin/train_gan.py b/cosyvoice/bin/train_gan.py new file mode 100644 index 0000000..96bf988 --- /dev/null +++ b/cosyvoice/bin/train_gan.py @@ -0,0 +1,137 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import argparse +import datetime +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +from copy import deepcopy +import torch +import torch.distributed as dist +import deepspeed + +from hyperpyyaml import load_hyperpyyaml + +from torch.distributed.elastic.multiprocessing.errors import record + +from cosyvoice.utils.executor_gan import Executor +from cosyvoice.utils.train_utils import ( + init_distributed, + init_dataset_and_dataloader, + init_optimizer_and_scheduler_gan, + init_summarywriter, save_model, + wrap_cuda_model, check_modify_and_save_config) + + +def get_args(): + parser = argparse.ArgumentParser(description='training your network') + parser.add_argument('--train_engine', + default='torch_ddp', + choices=['torch_ddp', 'deepspeed'], + help='Engine for paralleled training') + parser.add_argument('--model', required=True, help='model which will be trained') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--train_data', required=True, help='train data file') + parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--checkpoint', help='checkpoint model') + parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--tensorboard_dir', + default='tensorboard', + help='tensorboard log dir') + parser.add_argument('--ddp.dist_backend', + dest='dist_backend', + default='nccl', + choices=['nccl', 'gloo'], + help='distributed backend') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--deepspeed.save_states', + dest='save_states', + default='model_only', + choices=['model_only', 'model+optimizer'], + help='save model/optimizer states') + parser.add_argument('--timeout', + default=30, + type=int, + help='timeout (in seconds) of cosyvoice_join.') + parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args() + return args + + +@record +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + override_dict = {k: None for k in ['llm', 'flow', 'hifigan'] if k != args.model} + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides=override_dict, overrides_must_match=False) + configs['train_conf'].update(vars(args)) + + # Init env for ddp + init_distributed(args) + + # Get dataset & dataloader + train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ + init_dataset_and_dataloader(args, configs) + + # Do some sanity checks and save config to arsg.model_dir + configs = check_modify_and_save_config(args, configs) + + # Tensorboard summary + writer = init_summarywriter(args) + + # load checkpoint + model = configs[args.model] + if args.checkpoint is not None: + model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')) + + # Dispatch model from cpu to gpu + model = wrap_cuda_model(args, model) + + # Get optimizer & scheduler + model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler_gan(args, configs, model) + + # Save init checkpoints + info_dict = deepcopy(configs['train_conf']) + save_model(model, 'init', info_dict) + + # Get executor + executor = Executor() + + # Start training loop + for epoch in range(info_dict['max_epoch']): + executor.epoch = epoch + train_dataset.set_epoch(epoch) + dist.barrier() + group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) + executor.train_one_epoc(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, group_join) + dist.destroy_process_group(group_join) + + +if __name__ == '__main__': + main() diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py index 3a1486c..ba92911 100644 --- a/cosyvoice/dataset/processor.py +++ b/cosyvoice/dataset/processor.py @@ -85,6 +85,7 @@ def filter(data, """ for sample in data: sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) + sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) del sample['audio_data'] # sample['wav'] is torch.Tensor, we have 100 frames every second num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 @@ -134,6 +135,27 @@ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): yield sample +def truncate(data, truncate_length=24576, mode='train'): + """ Truncate data. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + truncate_length: truncate length + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + waveform = sample['speech'] + if waveform.shape[1] > truncate_length: + start = random.randint(0, waveform.shape[1] - truncate_length) + waveform = waveform[:, start: start + truncate_length] + else: + waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1) + sample['speech'] = waveform + yield sample + + def compute_fbank(data, feat_extractor, mode='train'): @@ -153,7 +175,26 @@ def compute_fbank(data, waveform = sample['speech'] mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) sample['speech_feat'] = mat - del sample['speech'] + yield sample + +def compute_f0(data, pitch_extractor, mode='train'): + """ Extract f0 + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + assert 'utt' in sample + assert 'text_token' in sample + waveform = sample['speech'] + mat = pitch_extractor(waveform).transpose(1, 2) + mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear') + sample['pitch_feat'] = mat[0, 0] yield sample @@ -325,6 +366,9 @@ def padding(data, use_spk_embedding, mode='train'): order = torch.argsort(speech_feat_len, descending=True) utts = [sample[i]['utt'] for i in order] + speech = [sample[i]['speech'].squeeze(dim=0) for i in order] + speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32) + speech = pad_sequence(speech, batch_first=True, padding_value=0) speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) speech_token = pad_sequence(speech_token, @@ -335,6 +379,11 @@ def padding(data, use_spk_embedding, mode='train'): speech_feat = pad_sequence(speech_feat, batch_first=True, padding_value=0) + pitch_feat = [sample[i]['pitch_feat'] for i in order] + pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32) + pitch_feat = pad_sequence(pitch_feat, + batch_first=True, + padding_value=0) text = [sample[i]['text'] for i in order] text_token = [torch.tensor(sample[i]['text_token']) for i in order] text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) @@ -343,10 +392,14 @@ def padding(data, use_spk_embedding, mode='train'): spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) batch = { "utts": utts, + "speech": speech, + "speech_len": speech_len, "speech_token": speech_token, "speech_token_len": speech_token_len, "speech_feat": speech_feat, "speech_feat_len": speech_feat_len, + "pitch_feat": pitch_feat, + "pitch_feat_len": pitch_feat_len, "text": text, "text_token": text_token, "text_token_len": text_token_len, diff --git a/cosyvoice/hifigan/discriminator.py b/cosyvoice/hifigan/discriminator.py new file mode 100644 index 0000000..d128652 --- /dev/null +++ b/cosyvoice/hifigan/discriminator.py @@ -0,0 +1,139 @@ +from typing import List +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm +from typing import List, Optional, Tuple +from einops import rearrange +from torchaudio.transforms import Spectrogram + +class MultipleDiscriminator(nn.Module): + def __init__( + self, mpd: nn.Module, mrd: nn.Module + ): + super().__init__() + self.mpd = mpd + self.mrd = mrd + + def forward(self, y: torch.Tensor, y_hat: torch.Tensor): + y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] + this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1)) + y_d_rs += this_y_d_rs + y_d_gs += this_y_d_gs + fmap_rs += this_fmap_rs + fmap_gs += this_fmap_gs + this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat) + y_d_rs += this_y_d_rs + y_d_gs += this_y_d_gs + fmap_rs += this_fmap_rs + fmap_gs += this_fmap_gs + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + +class MultiResolutionDiscriminator(nn.Module): + def __init__( + self, + fft_sizes: Tuple[int, ...] = (2048, 1024, 512), + num_embeddings: Optional[int] = None, + ): + """ + Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + super().__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__( + self, + window_length: int, + num_embeddings: Optional[int] = None, + channels: int = 32, + hop_factor: float = 0.25, + bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)), + ): + super().__init__() + self.window_length = window_length + self.hop_factor = hop_factor + self.spec_fn = Spectrogram( + n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None + ) + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + convs = lambda: nn.ModuleList( + [ + weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) + + def spectrogram(self, x): + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + x = self.spec_fn(x) + x = torch.view_as_real(x) + x = rearrange(x, "b f t c -> b c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): + x_bands = self.spectrogram(x) + fmap = [] + x = [] + for band, stack in zip(x_bands, self.band_convs): + for i, layer in enumerate(stack): + band = layer(band) + band = torch.nn.functional.leaky_relu(band, 0.1) + if i > 0: + fmap.append(band) + x.append(band) + x = torch.cat(x, dim=-1) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + + return x, fmap \ No newline at end of file diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py index 3d7e459..0098b90 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -14,7 +14,7 @@ """HIFI-GAN""" -import typing as tp +from typing import Dict, Optional, List import numpy as np from scipy.signal import get_window import torch @@ -46,7 +46,7 @@ class ResBlock(torch.nn.Module): self, channels: int = 512, kernel_size: int = 3, - dilations: tp.List[int] = [1, 3, 5], + dilations: List[int] = [1, 3, 5], ): super(ResBlock, self).__init__() self.convs1 = nn.ModuleList() @@ -234,13 +234,13 @@ class HiFTGenerator(nn.Module): nsf_alpha: float = 0.1, nsf_sigma: float = 0.003, nsf_voiced_threshold: float = 10, - upsample_rates: tp.List[int] = [8, 8], - upsample_kernel_sizes: tp.List[int] = [16, 16], - istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4}, - resblock_kernel_sizes: tp.List[int] = [3, 7, 11], - resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - source_resblock_kernel_sizes: tp.List[int] = [7, 11], - source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]], + upsample_rates: List[int] = [8, 8], + upsample_kernel_sizes: List[int] = [16, 16], + istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4}, + resblock_kernel_sizes: List[int] = [3, 7, 11], + resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + source_resblock_kernel_sizes: List[int] = [7, 11], + source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]], lrelu_slope: float = 0.1, audio_limit: float = 0.99, f0_predictor: torch.nn.Module = None, @@ -316,11 +316,19 @@ class HiFTGenerator(nn.Module): self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)) self.f0_predictor = f0_predictor - def _f02source(self, f0: torch.Tensor) -> torch.Tensor: - f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t - - har_source, _, _ = self.m_source(f0) - return har_source.transpose(1, 2) + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + self.m_source.remove_weight_norm() + for l in self.source_downs: + remove_weight_norm(l) + for l in self.source_resblocks: + l.remove_weight_norm() def _stft(self, x): spec = torch.stft( @@ -338,14 +346,7 @@ class HiFTGenerator(nn.Module): self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) return inverse_transform - def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: - f0 = self.f0_predictor(x) - s = self._f02source(f0) - - # use cache_source to avoid glitch - if cache_source.shape[2] != 0: - s[:, :, :cache_source.shape[2]] = cache_source - + def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) @@ -377,22 +378,34 @@ class HiFTGenerator(nn.Module): x = self._istft(magnitude, phase) x = torch.clamp(x, -self.audio_limit, self.audio_limit) - return x, s + return x - def remove_weight_norm(self): - print('Removing weight norm...') - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) - self.source_module.remove_weight_norm() - for l in self.source_downs: - remove_weight_norm(l) - for l in self.source_resblocks: - l.remove_weight_norm() + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + speech_feat = batch['speech_feat'].transpose(1, 2).to(device) + # mel->f0 + f0 = self.f0_predictor(speech_feat) + # f0->source + s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + s, _, _ = self.m_source(s) + s = s.transpose(1, 2) + # mel+source->speech + generated_speech = self.decode(x=speech_feat, s=s) + return generated_speech, f0 @torch.inference_mode() - def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: - return self.forward(x=mel, cache_source=cache_source) + def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + # mel->f0 + f0 = self.f0_predictor(speech_feat) + # f0->source + s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + s, _, _ = self.m_source(s) + s = s.transpose(1, 2) + # use cache_source to avoid glitch + if cache_source.shape[2] != 0: + s[:, :, :cache_source.shape[2]] = cache_source + generated_speech = self.decode(x=speech_feat, s=s) + return generated_speech, s diff --git a/cosyvoice/hifigan/hifigan.py b/cosyvoice/hifigan/hifigan.py new file mode 100644 index 0000000..ed18e8a --- /dev/null +++ b/cosyvoice/hifigan/hifigan.py @@ -0,0 +1,66 @@ +from typing import Dict, Optional +import torch +import torch.nn as nn +import torch.nn.functional as F +from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss +from cosyvoice.utils.losses import tpr_loss, mel_loss + +class HiFiGan(nn.Module): + def __init__(self, generator, discriminator, mel_spec_transform, + multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0, + tpr_loss_weight=1.0, tpr_loss_tau=0.04): + super(HiFiGan, self).__init__() + self.generator = generator + self.discriminator = discriminator + self.mel_spec_transform = mel_spec_transform + self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight + self.feat_match_loss_weight = feat_match_loss_weight + self.tpr_loss_weight = tpr_loss_weight + self.tpr_loss_tau = tpr_loss_tau + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + if batch['turn'] == 'generator': + return self.forward_generator(batch, device) + else: + return self.forward_discriminator(batch, device) + + def forward_generator(self, batch, device): + real_speech = batch['speech'].to(device) + pitch_feat = batch['pitch_feat'].to(device) + # 1. calculate generator outputs + generated_speech, generated_f0 = self.generator(batch, device) + # 2. calculate discriminator outputs + y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) + # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional] + loss_gen, _ = generator_loss(y_d_gs) + loss_fm = feature_loss(fmap_rs, fmap_gs) + loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform) + if self.tpr_loss_weight != 0: + loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) + else: + loss_tpr = torch.zeros(1).to(device) + loss_f0 = F.l1_loss(generated_f0, pitch_feat) + loss = loss_gen + self.feat_match_loss_weight * loss_fm + self.multi_mel_spectral_recon_loss_weight * loss_mel + self.tpr_loss_weight * loss_tpr + loss_f0 + return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0} + + def forward_discriminator(self, batch, device): + real_speech = batch['speech'].to(device) + pitch_feat = batch['pitch_feat'].to(device) + # 1. calculate generator outputs + with torch.no_grad(): + generated_speech, generated_f0 = self.generator(batch, device) + # 2. calculate discriminator outputs + y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) + # 3. calculate discriminator losses, tpr losses [Optional] + loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs) + if self.tpr_loss_weight != 0: + loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) + else: + loss_tpr = torch.zeros(1).to(device) + loss_f0 = F.l1_loss(generated_f0, pitch_feat) + loss = loss_disc + self.tpr_loss_weight * loss_tpr + loss_f0 + return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0} \ No newline at end of file diff --git a/cosyvoice/utils/executor_gan.py b/cosyvoice/utils/executor_gan.py new file mode 100644 index 0000000..9fb1b51 --- /dev/null +++ b/cosyvoice/utils/executor_gan.py @@ -0,0 +1,118 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from contextlib import nullcontext +import os + +import torch +import torch.distributed as dist + +from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join + + +class Executor: + + def __init__(self): + self.step = 0 + self.epoch = 0 + self.rank = int(os.environ.get('RANK', 0)) + self.device = torch.device('cuda:{}'.format(self.rank)) + + def train_one_epoc(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, group_join): + ''' Train one epoch + ''' + + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) + logging.info('using accumulate grad, new batch size is {} times' + ' larger than before'.format(info_dict['accum_grad'])) + # A context manager to be used in conjunction with an instance of + # torch.nn.parallel.DistributedDataParallel to be able to train + # with uneven inputs across participating processes. + model.train() + model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext + with model_context(): + for batch_idx, batch_dict in enumerate(train_data_loader): + info_dict["tag"] = "TRAIN" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + if cosyvoice_join(group_join, info_dict): + break + + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: + context = model.no_sync + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + + with context(): + batch_dict['turn'] = 'discriminator' + info_dict = batch_forward(model, batch_dict, info_dict) + info_dict = batch_backward(model, info_dict) + info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, info_dict) + log_per_step(writer, info_dict) + with context(): + batch_dict['turn'] = 'generator' + info_dict = batch_forward(model, batch_dict, info_dict) + info_dict = batch_backward(model, info_dict) + info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict) + log_per_step(writer, info_dict) + # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save + if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ + (batch_idx + 1) % info_dict["accum_grad"] == 0: + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) + model.train() + if (batch_idx + 1) % info_dict["accum_grad"] == 0: + self.step += 1 + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) + + @torch.inference_mode() + def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True): + ''' Cross validation on + ''' + logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank)) + model.eval() + total_num_utts, total_loss_dict = 0, {} # avoid division by 0 + for batch_idx, batch_dict in enumerate(cv_data_loader): + info_dict["tag"] = "CV" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + + num_utts = len(batch_dict["utts"]) + total_num_utts += num_utts + + batch_dict['turn'] = 'generator' + info_dict = batch_forward(model, batch_dict, info_dict) + + for k, v in info_dict['loss_dict'].items(): + if k not in total_loss_dict: + total_loss_dict[k] = [] + total_loss_dict[k].append(v.item() * num_utts) + log_per_step(None, info_dict) + for k, v in total_loss_dict.items(): + total_loss_dict[k] = sum(v) / total_num_utts + info_dict['loss_dict'] = total_loss_dict + log_per_save(writer, info_dict) + model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) + save_model(model, model_name, info_dict) diff --git a/cosyvoice/utils/losses.py b/cosyvoice/utils/losses.py new file mode 100644 index 0000000..46d9883 --- /dev/null +++ b/cosyvoice/utils/losses.py @@ -0,0 +1,18 @@ +import torch +import torch.nn.functional as F + +def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): + loss = 0 + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + m_DG = torch.median((dr - dg)) + L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) + loss += tau - F.relu(tau - L_rel) + return loss + +def mel_loss(real_speech, generated_speech, mel_transforms): + loss = 0 + for transform in mel_transforms: + mel_r = transform(real_speech) + mel_g = transform(generated_speech) + loss += F.l1_loss(mel_g, mel_r) + return loss \ No newline at end of file diff --git a/cosyvoice/utils/train_utils.py b/cosyvoice/utils/train_utils.py index b7a6a48..2fbba78 100644 --- a/cosyvoice/utils/train_utils.py +++ b/cosyvoice/utils/train_utils.py @@ -142,6 +142,49 @@ def init_optimizer_and_scheduler(args, configs, model): return model, optimizer, scheduler +def init_optimizer_and_scheduler_gan(args, configs, model): + if configs['train_conf']['optim'] == 'adam': + optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim'] == 'adamw': + optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler = ConstantLR(optimizer) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + + if configs['train_conf']['optim_d'] == 'adam': + optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim_d'] == 'adamw': + optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler_d'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler_d = ConstantLR(optimizer_d) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + + # currently we wrap generator and discriminator in one model, so we cannot use deepspeed + return model, optimizer, scheduler, optimizer_d, scheduler_d + + def init_summarywriter(args): writer = None if int(os.environ.get('RANK', 0)) == 0: diff --git a/examples/libritts/cosyvoice/conf/cosyvoice.hifigan.yaml b/examples/libritts/cosyvoice/conf/cosyvoice.hifigan.yaml new file mode 100644 index 0000000..80b5745 --- /dev/null +++ b/examples/libritts/cosyvoice/conf/cosyvoice.hifigan.yaml @@ -0,0 +1,141 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1986] +__set_seed2: !apply:numpy.random.seed [1986] +__set_seed3: !apply:torch.manual_seed [1986] +__set_seed4: !apply:torch.cuda.manual_seed_all [1986] + +# fixed params +sample_rate: 22050 +text_encoder_input_size: 512 +llm_input_size: 1024 +llm_output_size: 1024 +spk_embed_dim: 192 + +# model params +# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. +# for system/third_party class/function, we do not require this. +hift: !new:cosyvoice.hifigan.generator.HiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 8] + upsample_kernel_sizes: [16, 16] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 + +mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 80 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: 8000 + center: False +hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan + generator: !ref + discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator + mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator + mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator + mel_spec_transform: [ + !ref + ] + +# processor functions +parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener +get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe + multilingual: True + num_languages: 100 + language: 'en' + task: 'transcribe' +tokenize: !name:cosyvoice.dataset.processor.tokenize + get_tokenizer: !ref + allowed_special: 'all' +filter: !name:cosyvoice.dataset.processor.filter + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 +resample: !name:cosyvoice.dataset.processor.resample + resample_rate: !ref +truncate: !name:cosyvoice.dataset.processor.truncate + truncate_length: 24576 # must be a multiplier of hop_size +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 80 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: 8000 + center: False +compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank + feat_extractor: !ref +pitch_extractor: !name:torchaudio.functional.compute_kaldi_pitch + sample_rate: !ref + frame_length: 46.4 # match feat_extractor win_size/sampling_rate + frame_shift: 11.6 # match feat_extractor hop_size/sampling_rate +compute_f0: !name:cosyvoice.dataset.processor.compute_f0 + pitch_extractor: !ref +parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding + normalize: True +shuffle: !name:cosyvoice.dataset.processor.shuffle + shuffle_size: 1000 +sort: !name:cosyvoice.dataset.processor.sort + sort_size: 500 # sort_size should be less than shuffle_size +batch: !name:cosyvoice.dataset.processor.batch + batch_type: 'dynamic' + max_frames_in_batch: 1200 +padding: !name:cosyvoice.dataset.processor.padding + use_spk_embedding: False # change to True during sft + +# dataset processor pipeline +data_pipeline: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] + +# train conf +train_conf: + optim: adam + optim_conf: + lr: 0.002 # change to 0.001 if you want to train flow from scratch + scheduler: warmuplr + scheduler_conf: + warmup_steps: 25000 + optim_d: adam + optim_conf_d: + lr: 0.002 # change to 0.001 if you want to train flow from scratch + scheduler_d: warmuplr + scheduler_conf_d: + warmup_steps: 25000 + max_epoch: 200 + grad_clip: 5 + accum_grad: 2 + log_interval: 100 + save_per_step: -1 \ No newline at end of file