mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
Merge pull request #653 from FunAudioLLM/dev/lyuxiang.lx
resume training
This commit is contained in:
@@ -118,9 +118,15 @@ def main():
|
|||||||
|
|
||||||
# load checkpoint
|
# load checkpoint
|
||||||
model = configs[args.model]
|
model = configs[args.model]
|
||||||
|
start_step, start_epoch = 0, -1
|
||||||
if args.checkpoint is not None:
|
if args.checkpoint is not None:
|
||||||
if os.path.exists(args.checkpoint):
|
if os.path.exists(args.checkpoint):
|
||||||
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
|
state_dict = torch.load(args.checkpoint, map_location='cpu')
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
if 'step' in state_dict:
|
||||||
|
start_step = state_dict['step']
|
||||||
|
if 'epoch' in state_dict:
|
||||||
|
start_epoch = state_dict['epoch']
|
||||||
else:
|
else:
|
||||||
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
||||||
|
|
||||||
@@ -129,19 +135,25 @@ def main():
|
|||||||
|
|
||||||
# Get optimizer & scheduler
|
# Get optimizer & scheduler
|
||||||
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
||||||
|
scheduler.set_step(start_step)
|
||||||
|
if scheduler_d is not None:
|
||||||
|
scheduler_d.set_step(start_step)
|
||||||
|
|
||||||
# Save init checkpoints
|
# Save init checkpoints
|
||||||
info_dict = deepcopy(configs['train_conf'])
|
info_dict = deepcopy(configs['train_conf'])
|
||||||
|
info_dict['step'] = start_step
|
||||||
|
info_dict['epoch'] = start_epoch
|
||||||
save_model(model, 'init', info_dict)
|
save_model(model, 'init', info_dict)
|
||||||
|
|
||||||
# Get executor
|
# Get executor
|
||||||
executor = Executor(gan=gan)
|
executor = Executor(gan=gan)
|
||||||
|
executor.step = start_step
|
||||||
|
|
||||||
# Init scaler, used for pytorch amp mixed precision training
|
# Init scaler, used for pytorch amp mixed precision training
|
||||||
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
||||||
|
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
||||||
# Start training loop
|
# Start training loop
|
||||||
for epoch in range(info_dict['max_epoch']):
|
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
||||||
executor.epoch = epoch
|
executor.epoch = epoch
|
||||||
train_dataset.set_epoch(epoch)
|
train_dataset.set_epoch(epoch)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|||||||
@@ -199,7 +199,7 @@ def save_model(model, model_name, info_dict):
|
|||||||
|
|
||||||
if info_dict["train_engine"] == "torch_ddp":
|
if info_dict["train_engine"] == "torch_ddp":
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
torch.save(model.module.state_dict(), save_model_path)
|
torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
|
||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.save_checkpoint(save_dir=model_dir,
|
model.save_checkpoint(save_dir=model_dir,
|
||||||
@@ -284,6 +284,7 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
|
|||||||
# We don't check grad here since that if the gradient
|
# We don't check grad here since that if the gradient
|
||||||
# has inf/nan values, scaler.step will skip
|
# has inf/nan values, scaler.step will skip
|
||||||
# optimizer.step().
|
# optimizer.step().
|
||||||
|
if torch.isfinite(grad_norm):
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user