mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
update lint
This commit is contained in:
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -52,5 +52,5 @@ jobs:
|
||||
set -eux
|
||||
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
|
||||
flake8 --version
|
||||
flake8 --max-line-length 150 --ignore B006,B008,B905,C408,E402,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
||||
flake8 --max-line-length 150 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
||||
if [ $? != 0 ]; then exit 1; fi
|
||||
@@ -16,7 +16,6 @@
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
import sys
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
|
||||
@@ -138,7 +138,8 @@ def main():
|
||||
dist.barrier()
|
||||
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
||||
if gan is True:
|
||||
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, group_join)
|
||||
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||
writer, info_dict, group_join)
|
||||
else:
|
||||
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
|
||||
dist.destroy_process_group(group_join)
|
||||
|
||||
@@ -177,6 +177,7 @@ def compute_fbank(data,
|
||||
sample['speech_feat'] = mat
|
||||
yield sample
|
||||
|
||||
|
||||
def compute_f0(data, pitch_extractor, mode='train'):
|
||||
""" Extract f0
|
||||
|
||||
@@ -404,8 +405,8 @@ def padding(data, use_spk_embedding, mode='train', gan=False):
|
||||
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
||||
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
||||
pitch_feat = pad_sequence(pitch_feat,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
batch["pitch_feat"] = pitch_feat
|
||||
batch["pitch_feat_len"] = pitch_feat_len
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm
|
||||
@@ -6,6 +5,7 @@ from typing import List, Optional, Tuple
|
||||
from einops import rearrange
|
||||
from torchaudio.transforms import Spectrogram
|
||||
|
||||
|
||||
class MultipleDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self, mpd: nn.Module, mrd: nn.Module
|
||||
@@ -28,6 +28,7 @@ class MultipleDiscriminator(nn.Module):
|
||||
fmap_gs += this_fmap_gs
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -112,7 +113,7 @@ class DiscriminatorR(nn.Module):
|
||||
x = torch.view_as_real(x)
|
||||
x = rearrange(x, "b f t c -> b c t f")
|
||||
# Split into bands
|
||||
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
||||
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
|
||||
return x_bands
|
||||
|
||||
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
||||
@@ -136,4 +137,4 @@ class DiscriminatorR(nn.Module):
|
||||
fmap.append(x)
|
||||
x += h
|
||||
|
||||
return x, fmap
|
||||
return x, fmap
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch.nn.functional as F
|
||||
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
|
||||
from cosyvoice.utils.losses import tpr_loss, mel_loss
|
||||
|
||||
|
||||
class HiFiGan(nn.Module):
|
||||
def __init__(self, generator, discriminator, mel_spec_transform,
|
||||
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
|
||||
@@ -44,7 +45,9 @@ class HiFiGan(nn.Module):
|
||||
else:
|
||||
loss_tpr = torch.zeros(1).to(device)
|
||||
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
||||
loss = loss_gen + self.feat_match_loss_weight * loss_fm + self.multi_mel_spectral_recon_loss_weight * loss_mel + self.tpr_loss_weight * loss_tpr + loss_f0
|
||||
loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
|
||||
self.multi_mel_spectral_recon_loss_weight * loss_mel + \
|
||||
self.tpr_loss_weight * loss_tpr + loss_f0
|
||||
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
||||
|
||||
def forward_discriminator(self, batch, device):
|
||||
@@ -63,4 +66,4 @@ class HiFiGan(nn.Module):
|
||||
loss_tpr = torch.zeros(1).to(device)
|
||||
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
||||
loss = loss_disc + self.tpr_loss_weight * loss_tpr + loss_f0
|
||||
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
||||
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
||||
|
||||
@@ -25,7 +25,7 @@ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, l
|
||||
|
||||
class Executor:
|
||||
|
||||
def __init__(self, gan: bool=False):
|
||||
def __init__(self, gan: bool = False):
|
||||
self.gan = gan
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
@@ -81,7 +81,8 @@ class Executor:
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
||||
|
||||
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, group_join):
|
||||
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||
writer, info_dict, group_join):
|
||||
''' Train one epoch
|
||||
'''
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
|
||||
loss = 0
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
@@ -9,10 +10,11 @@ def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
|
||||
loss += tau - F.relu(tau - L_rel)
|
||||
return loss
|
||||
|
||||
|
||||
def mel_loss(real_speech, generated_speech, mel_transforms):
|
||||
loss = 0
|
||||
for transform in mel_transforms:
|
||||
mel_r = transform(real_speech)
|
||||
mel_g = transform(generated_speech)
|
||||
loss += F.l1_loss(mel_g, mel_r)
|
||||
return loss
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user