add hifigan train code

This commit is contained in:
lyuxiang.lx
2024-10-09 17:36:42 +08:00
parent 67f298d94a
commit cb200b21c5
10 changed files with 768 additions and 40 deletions

View File

@@ -142,6 +142,49 @@ def init_optimizer_and_scheduler(args, configs, model):
return model, optimizer, scheduler
def init_optimizer_and_scheduler_gan(args, configs, model):
if configs['train_conf']['optim'] == 'adam':
optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
elif configs['train_conf']['optim'] == 'adamw':
optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
else:
raise ValueError("unknown optimizer: " + configs['train_conf'])
if configs['train_conf']['scheduler'] == 'warmuplr':
scheduler_type = WarmupLR
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'constantlr':
scheduler_type = ConstantLR
scheduler = ConstantLR(optimizer)
else:
raise ValueError("unknown scheduler: " + configs['train_conf'])
if configs['train_conf']['optim_d'] == 'adam':
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
elif configs['train_conf']['optim_d'] == 'adamw':
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
else:
raise ValueError("unknown optimizer: " + configs['train_conf'])
if configs['train_conf']['scheduler_d'] == 'warmuplr':
scheduler_type = WarmupLR
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'constantlr':
scheduler_type = ConstantLR
scheduler_d = ConstantLR(optimizer_d)
else:
raise ValueError("unknown scheduler: " + configs['train_conf'])
# currently we wrap generator and discriminator in one model, so we cannot use deepspeed
return model, optimizer, scheduler, optimizer_d, scheduler_d
def init_summarywriter(args):
writer = None
if int(os.environ.get('RANK', 0)) == 0: