mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
fix hifigan init bug
This commit is contained in:
@@ -32,7 +32,7 @@ class Executor:
|
||||
self.rank = int(os.environ.get('RANK', 0))
|
||||
self.device = torch.device('cuda:{}'.format(self.rank))
|
||||
|
||||
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join):
|
||||
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
|
||||
''' Train one epoch
|
||||
'''
|
||||
|
||||
@@ -65,10 +65,10 @@ class Executor:
|
||||
context = nullcontext
|
||||
|
||||
with context():
|
||||
info_dict = batch_forward(model, batch_dict, info_dict)
|
||||
info_dict = batch_backward(model, info_dict)
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
||||
info_dict = batch_backward(model, scaler, info_dict)
|
||||
|
||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
|
||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
||||
log_per_step(writer, info_dict)
|
||||
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
||||
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
||||
@@ -82,7 +82,7 @@ class Executor:
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
||||
|
||||
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||
writer, info_dict, group_join):
|
||||
writer, info_dict, scaler, group_join):
|
||||
''' Train one epoch
|
||||
'''
|
||||
|
||||
@@ -116,16 +116,16 @@ class Executor:
|
||||
|
||||
with context():
|
||||
batch_dict['turn'] = 'discriminator'
|
||||
info_dict = batch_forward(model, batch_dict, info_dict)
|
||||
info_dict = batch_backward(model, info_dict)
|
||||
info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, info_dict)
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
||||
info_dict = batch_backward(model, scaler, info_dict)
|
||||
info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
|
||||
optimizer.zero_grad()
|
||||
log_per_step(writer, info_dict)
|
||||
with context():
|
||||
batch_dict['turn'] = 'generator'
|
||||
info_dict = batch_forward(model, batch_dict, info_dict)
|
||||
info_dict = batch_backward(model, info_dict)
|
||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
||||
info_dict = batch_backward(model, scaler, info_dict)
|
||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
||||
optimizer_d.zero_grad()
|
||||
log_per_step(writer, info_dict)
|
||||
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
||||
@@ -157,7 +157,7 @@ class Executor:
|
||||
|
||||
if self.gan is True:
|
||||
batch_dict['turn'] = 'generator'
|
||||
info_dict = batch_forward(model, batch_dict, info_dict)
|
||||
info_dict = batch_forward(model, batch_dict, None, info_dict)
|
||||
|
||||
for k, v in info_dict['loss_dict'].items():
|
||||
if k not in total_loss_dict:
|
||||
|
||||
Reference in New Issue
Block a user