diff --git a/cosyvoice/bin/train.py b/cosyvoice/bin/train.py index 338375e..3b4710e 100644 --- a/cosyvoice/bin/train.py +++ b/cosyvoice/bin/train.py @@ -118,9 +118,15 @@ def main(): # load checkpoint model = configs[args.model] + start_step, start_epoch = 0, -1 if args.checkpoint is not None: if os.path.exists(args.checkpoint): - model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False) + state_dict = torch.load(args.checkpoint, map_location='cpu') + model.load_state_dict(state_dict, strict=False) + if 'step' in state_dict: + start_step = state_dict['step'] + if 'epoch' in state_dict: + start_epoch = state_dict['epoch'] else: logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint)) @@ -129,19 +135,25 @@ def main(): # Get optimizer & scheduler model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan) + scheduler.set_step(start_step) + if scheduler_d is not None: + scheduler_d.set_step(start_step) # Save init checkpoints info_dict = deepcopy(configs['train_conf']) + info_dict['step'] = start_step + info_dict['epoch'] = start_epoch save_model(model, 'init', info_dict) # Get executor executor = Executor(gan=gan) + executor.step = start_step # Init scaler, used for pytorch amp mixed precision training scaler = torch.cuda.amp.GradScaler() if args.use_amp else None - + print('start step {} start epoch {}'.format(start_step, start_epoch)) # Start training loop - for epoch in range(info_dict['max_epoch']): + for epoch in range(start_epoch + 1, info_dict['max_epoch']): executor.epoch = epoch train_dataset.set_epoch(epoch) dist.barrier() diff --git a/cosyvoice/utils/train_utils.py b/cosyvoice/utils/train_utils.py index eab92f8..72e291a 100644 --- a/cosyvoice/utils/train_utils.py +++ b/cosyvoice/utils/train_utils.py @@ -199,7 +199,7 @@ def save_model(model, model_name, info_dict): if info_dict["train_engine"] == "torch_ddp": if rank == 0: - torch.save(model.module.state_dict(), save_model_path) + torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path) else: with torch.no_grad(): model.save_checkpoint(save_dir=model_dir, @@ -284,7 +284,8 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): # We don't check grad here since that if the gradient # has inf/nan values, scaler.step will skip # optimizer.step(). - scaler.step(optimizer) + if torch.isfinite(grad_norm): + scaler.step(optimizer) scaler.update() else: grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])