mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 09:59:23 +08:00
resume training
This commit is contained in:
@@ -199,7 +199,7 @@ def save_model(model, model_name, info_dict):
|
||||
|
||||
if info_dict["train_engine"] == "torch_ddp":
|
||||
if rank == 0:
|
||||
torch.save(model.module.state_dict(), save_model_path)
|
||||
torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(save_dir=model_dir,
|
||||
@@ -284,7 +284,8 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
|
||||
# We don't check grad here since that if the gradient
|
||||
# has inf/nan values, scaler.step will skip
|
||||
# optimizer.step().
|
||||
scaler.step(optimizer)
|
||||
if torch.isfinite(grad_norm):
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
|
||||
|
||||
Reference in New Issue
Block a user