mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add hifigan train code
This commit is contained in:
@@ -142,6 +142,49 @@ def init_optimizer_and_scheduler(args, configs, model):
|
||||
return model, optimizer, scheduler
|
||||
|
||||
|
||||
def init_optimizer_and_scheduler_gan(args, configs, model):
|
||||
if configs['train_conf']['optim'] == 'adam':
|
||||
optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
elif configs['train_conf']['optim'] == 'adamw':
|
||||
optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
else:
|
||||
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['scheduler'] == 'warmuplr':
|
||||
scheduler_type = WarmupLR
|
||||
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
||||
scheduler_type = NoamHoldAnnealing
|
||||
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||
scheduler_type = ConstantLR
|
||||
scheduler = ConstantLR(optimizer)
|
||||
else:
|
||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['optim_d'] == 'adam':
|
||||
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
elif configs['train_conf']['optim_d'] == 'adamw':
|
||||
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
else:
|
||||
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['scheduler_d'] == 'warmuplr':
|
||||
scheduler_type = WarmupLR
|
||||
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
|
||||
scheduler_type = NoamHoldAnnealing
|
||||
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||
scheduler_type = ConstantLR
|
||||
scheduler_d = ConstantLR(optimizer_d)
|
||||
else:
|
||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||
|
||||
# currently we wrap generator and discriminator in one model, so we cannot use deepspeed
|
||||
return model, optimizer, scheduler, optimizer_d, scheduler_d
|
||||
|
||||
|
||||
def init_summarywriter(args):
|
||||
writer = None
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
|
||||
Reference in New Issue
Block a user