mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add hifigan train
This commit is contained in:
@@ -25,7 +25,8 @@ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, l
|
||||
|
||||
class Executor:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, gan: bool=False):
|
||||
self.gan = gan
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.rank = int(os.environ.get('RANK', 0))
|
||||
@@ -80,6 +81,63 @@ class Executor:
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
||||
|
||||
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, group_join):
|
||||
''' Train one epoch
|
||||
'''
|
||||
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
|
||||
logging.info('using accumulate grad, new batch size is {} times'
|
||||
' larger than before'.format(info_dict['accum_grad']))
|
||||
# A context manager to be used in conjunction with an instance of
|
||||
# torch.nn.parallel.DistributedDataParallel to be able to train
|
||||
# with uneven inputs across participating processes.
|
||||
model.train()
|
||||
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
||||
with model_context():
|
||||
for batch_idx, batch_dict in enumerate(train_data_loader):
|
||||
info_dict["tag"] = "TRAIN"
|
||||
info_dict["step"] = self.step
|
||||
info_dict["epoch"] = self.epoch
|
||||
info_dict["batch_idx"] = batch_idx
|
||||
if cosyvoice_join(group_join, info_dict):
|
||||
break
|
||||
|
||||
# Disable gradient synchronizations across DDP processes.
|
||||
# Within this context, gradients will be accumulated on module
|
||||
# variables, which will later be synchronized.
|
||||
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
||||
context = model.no_sync
|
||||
# Used for single gpu training and DDP gradient synchronization
|
||||
# processes.
|
||||
else:
|
||||
context = nullcontext
|
||||
|
||||
with context():
|
||||
batch_dict['turn'] = 'discriminator'
|
||||
info_dict = batch_forward(model, batch_dict, info_dict)
|
||||
info_dict = batch_backward(model, info_dict)
|
||||
info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, info_dict)
|
||||
optimizer.zero_grad()
|
||||
log_per_step(writer, info_dict)
|
||||
with context():
|
||||
batch_dict['turn'] = 'generator'
|
||||
info_dict = batch_forward(model, batch_dict, info_dict)
|
||||
info_dict = batch_backward(model, info_dict)
|
||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
|
||||
optimizer_d.zero_grad()
|
||||
log_per_step(writer, info_dict)
|
||||
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
||||
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
||||
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
||||
model.train()
|
||||
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||
self.step += 1
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
||||
|
||||
@torch.inference_mode()
|
||||
def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
|
||||
''' Cross validation on
|
||||
@@ -96,6 +154,8 @@ class Executor:
|
||||
num_utts = len(batch_dict["utts"])
|
||||
total_num_utts += num_utts
|
||||
|
||||
if self.gan is True:
|
||||
batch_dict['turn'] = 'generator'
|
||||
info_dict = batch_forward(model, batch_dict, info_dict)
|
||||
|
||||
for k, v in info_dict['loss_dict'].items():
|
||||
|
||||
Reference in New Issue
Block a user