add hifigan train

This commit is contained in:
lyuxiang.lx
2024-10-16 11:37:32 +08:00
parent cb200b21c5
commit 789ee9e5e7
13 changed files with 314 additions and 477 deletions

View File

@@ -86,8 +86,12 @@ def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
# gan train has some special initialization logic
gan = True if args.model == 'hifigan' else False
override_dict = {k: None for k in ['llm', 'flow', 'hifigan'] if k != args.model}
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
if gan is True:
override_dict.pop('hift')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
configs['train_conf'].update(vars(args))
@@ -97,7 +101,7 @@ def main():
# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs)
init_dataset_and_dataloader(args, configs, gan)
# Do some sanity checks and save config to arsg.model_dir
configs = check_modify_and_save_config(args, configs)
@@ -108,13 +112,13 @@ def main():
# load checkpoint
model = configs[args.model]
if args.checkpoint is not None:
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
# Dispatch model from cpu to gpu
model = wrap_cuda_model(args, model)
# Get optimizer & scheduler
model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
# Save init checkpoints
info_dict = deepcopy(configs['train_conf'])
@@ -129,7 +133,10 @@ def main():
train_dataset.set_epoch(epoch)
dist.barrier()
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
if gan is True:
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, group_join)
else:
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
dist.destroy_process_group(group_join)