mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update lint
This commit is contained in:
@@ -25,7 +25,7 @@ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, l
|
||||
|
||||
class Executor:
|
||||
|
||||
def __init__(self, gan: bool=False):
|
||||
def __init__(self, gan: bool = False):
|
||||
self.gan = gan
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
@@ -81,7 +81,8 @@ class Executor:
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
||||
|
||||
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, group_join):
|
||||
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||
writer, info_dict, group_join):
|
||||
''' Train one epoch
|
||||
'''
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
|
||||
loss = 0
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
@@ -9,10 +10,11 @@ def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
|
||||
loss += tau - F.relu(tau - L_rel)
|
||||
return loss
|
||||
|
||||
|
||||
def mel_loss(real_speech, generated_speech, mel_transforms):
|
||||
loss = 0
|
||||
for transform in mel_transforms:
|
||||
mel_r = transform(real_speech)
|
||||
mel_g = transform(generated_speech)
|
||||
loss += F.l1_loss(mel_g, mel_r)
|
||||
return loss
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user