update lint

This commit is contained in:
lyuxiang.lx
2024-10-16 13:06:31 +08:00
parent 555efd0301
commit 29507bc77a
8 changed files with 21 additions and 13 deletions

View File

@@ -1,4 +1,3 @@
from typing import List
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
@@ -6,6 +5,7 @@ from typing import List, Optional, Tuple
from einops import rearrange
from torchaudio.transforms import Spectrogram
class MultipleDiscriminator(nn.Module):
def __init__(
self, mpd: nn.Module, mrd: nn.Module
@@ -28,6 +28,7 @@ class MultipleDiscriminator(nn.Module):
fmap_gs += this_fmap_gs
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
@@ -112,7 +113,7 @@ class DiscriminatorR(nn.Module):
x = torch.view_as_real(x)
x = rearrange(x, "b f t c -> b c t f")
# Split into bands
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
return x_bands
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
@@ -136,4 +137,4 @@ class DiscriminatorR(nn.Module):
fmap.append(x)
x += h
return x, fmap
return x, fmap

View File

@@ -5,6 +5,7 @@ import torch.nn.functional as F
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
from cosyvoice.utils.losses import tpr_loss, mel_loss
class HiFiGan(nn.Module):
def __init__(self, generator, discriminator, mel_spec_transform,
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
@@ -44,7 +45,9 @@ class HiFiGan(nn.Module):
else:
loss_tpr = torch.zeros(1).to(device)
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
loss = loss_gen + self.feat_match_loss_weight * loss_fm + self.multi_mel_spectral_recon_loss_weight * loss_mel + self.tpr_loss_weight * loss_tpr + loss_f0
loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
self.multi_mel_spectral_recon_loss_weight * loss_mel + \
self.tpr_loss_weight * loss_tpr + loss_f0
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
def forward_discriminator(self, batch, device):
@@ -63,4 +66,4 @@ class HiFiGan(nn.Module):
loss_tpr = torch.zeros(1).to(device)
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
loss = loss_disc + self.tpr_loss_weight * loss_tpr + loss_f0
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}