mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 09:29:20 +08:00
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
581 lines
25 KiB
Python
Executable File
581 lines
25 KiB
Python
Executable File
import argparse
|
|
import diffusers
|
|
import logging
|
|
import math
|
|
import os
|
|
import time
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
import transformers
|
|
import warnings
|
|
import random
|
|
|
|
from accelerate import Accelerator
|
|
from accelerate.utils import LoggerType
|
|
from accelerate import InitProcessGroupKwargs
|
|
from accelerate.logging import get_logger
|
|
from accelerate.utils import DistributedDataParallelKwargs
|
|
from datetime import datetime
|
|
from datetime import timedelta
|
|
|
|
from diffusers.utils import check_min_version
|
|
from einops import rearrange
|
|
from omegaconf import OmegaConf
|
|
from tqdm.auto import tqdm
|
|
|
|
from musetalk.utils.utils import (
|
|
delete_additional_ckpt,
|
|
seed_everything,
|
|
get_mouth_region,
|
|
process_audio_features,
|
|
save_models
|
|
)
|
|
from musetalk.loss.basic_loss import set_requires_grad
|
|
from musetalk.loss.syncnet import get_sync_loss
|
|
from musetalk.utils.training_utils import (
|
|
initialize_models_and_optimizers,
|
|
initialize_dataloaders,
|
|
initialize_loss_functions,
|
|
initialize_syncnet,
|
|
initialize_vgg,
|
|
validation
|
|
)
|
|
|
|
logger = get_logger(__name__, log_level="INFO")
|
|
warnings.filterwarnings("ignore")
|
|
check_min_version("0.10.0.dev0")
|
|
|
|
def main(cfg):
|
|
exp_name = cfg.exp_name
|
|
save_dir = f"{cfg.output_dir}/{exp_name}"
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
kwargs = DistributedDataParallelKwargs()
|
|
process_group_kwargs = InitProcessGroupKwargs(
|
|
timeout=timedelta(seconds=5400))
|
|
accelerator = Accelerator(
|
|
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
|
|
log_with=["tensorboard", LoggerType.TENSORBOARD],
|
|
project_dir=os.path.join(save_dir, "./tensorboard"),
|
|
kwargs_handlers=[kwargs, process_group_kwargs],
|
|
)
|
|
|
|
# Make one log on every process with the configuration for debugging.
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
datefmt="%m/%d/%Y %H:%M:%S",
|
|
level=logging.INFO,
|
|
)
|
|
logger.info(accelerator.state, main_process_only=False)
|
|
if accelerator.is_local_main_process:
|
|
transformers.utils.logging.set_verbosity_warning()
|
|
diffusers.utils.logging.set_verbosity_info()
|
|
else:
|
|
transformers.utils.logging.set_verbosity_error()
|
|
diffusers.utils.logging.set_verbosity_error()
|
|
|
|
# If passed along, set the training seed now.
|
|
if cfg.seed is not None:
|
|
print('cfg.seed', cfg.seed, accelerator.process_index)
|
|
seed_everything(cfg.seed + accelerator.process_index)
|
|
|
|
weight_dtype = torch.float32
|
|
|
|
model_dict = initialize_models_and_optimizers(cfg, accelerator, weight_dtype)
|
|
dataloader_dict = initialize_dataloaders(cfg)
|
|
loss_dict = initialize_loss_functions(cfg, accelerator, model_dict['scheduler_max_steps'])
|
|
syncnet = initialize_syncnet(cfg, accelerator, weight_dtype)
|
|
vgg_IN, pyramid, downsampler = initialize_vgg(cfg, accelerator)
|
|
|
|
# Prepare everything with our `accelerator`.
|
|
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] = accelerator.prepare(
|
|
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader']
|
|
)
|
|
print("length train/val", len(dataloader_dict['train_dataloader']), len(dataloader_dict['val_dataloader']))
|
|
|
|
# Calculate training steps and epochs
|
|
num_update_steps_per_epoch = math.ceil(
|
|
len(dataloader_dict['train_dataloader']) / cfg.solver.gradient_accumulation_steps
|
|
)
|
|
num_train_epochs = math.ceil(
|
|
cfg.solver.max_train_steps / num_update_steps_per_epoch
|
|
)
|
|
|
|
# Initialize trackers on the main process
|
|
if accelerator.is_main_process:
|
|
run_time = datetime.now().strftime("%Y%m%d-%H%M")
|
|
accelerator.init_trackers(
|
|
cfg.exp_name,
|
|
init_kwargs={"mlflow": {"run_name": run_time}},
|
|
)
|
|
|
|
# Calculate total batch size
|
|
total_batch_size = (
|
|
cfg.data.train_bs
|
|
* accelerator.num_processes
|
|
* cfg.solver.gradient_accumulation_steps
|
|
)
|
|
|
|
# Log training information
|
|
logger.info("***** Running training *****")
|
|
logger.info(f"Num Epochs = {num_train_epochs}")
|
|
logger.info(f"Instantaneous batch size per device = {cfg.data.train_bs}")
|
|
logger.info(
|
|
f"Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
|
)
|
|
logger.info(
|
|
f"Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}")
|
|
logger.info(f"Total optimization steps = {cfg.solver.max_train_steps}")
|
|
|
|
global_step = 0
|
|
first_epoch = 0
|
|
|
|
# Load checkpoint if resuming training
|
|
if cfg.resume_from_checkpoint:
|
|
resume_dir = save_dir
|
|
dirs = os.listdir(resume_dir)
|
|
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
|
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
|
if len(dirs) > 0:
|
|
path = dirs[-1]
|
|
accelerator.load_state(os.path.join(resume_dir, path))
|
|
accelerator.print(f"Resuming from checkpoint {path}")
|
|
global_step = int(path.split("-")[1])
|
|
first_epoch = global_step // num_update_steps_per_epoch
|
|
resume_step = global_step % num_update_steps_per_epoch
|
|
|
|
# Initialize progress bar
|
|
progress_bar = tqdm(
|
|
range(global_step, cfg.solver.max_train_steps),
|
|
disable=not accelerator.is_local_main_process,
|
|
)
|
|
progress_bar.set_description("Steps")
|
|
|
|
# Log model types
|
|
print("log type of models")
|
|
print("unet", model_dict['unet'].dtype)
|
|
print("vae", model_dict['vae'].dtype)
|
|
print("wav2vec", model_dict['wav2vec'].dtype)
|
|
|
|
def get_ganloss_weight(step):
|
|
"""Calculate GAN loss weight based on training step"""
|
|
if step < cfg.discriminator_train_params.start_gan:
|
|
return 0.0
|
|
else:
|
|
return 1.0
|
|
|
|
# Training loop
|
|
for epoch in range(first_epoch, num_train_epochs):
|
|
# Set models to training mode
|
|
model_dict['unet'].train()
|
|
if cfg.loss_params.gan_loss > 0:
|
|
loss_dict['discriminator'].train()
|
|
if cfg.loss_params.mouth_gan_loss > 0:
|
|
loss_dict['mouth_discriminator'].train()
|
|
|
|
# Initialize loss accumulators
|
|
train_loss = 0.0
|
|
train_loss_D = 0.0
|
|
train_loss_D_mouth = 0.0
|
|
l1_loss_accum = 0.0
|
|
vgg_loss_accum = 0.0
|
|
gan_loss_accum = 0.0
|
|
gan_loss_accum_mouth = 0.0
|
|
fm_loss_accum = 0.0
|
|
sync_loss_accum = 0.0
|
|
adapted_weight_accum = 0.0
|
|
|
|
t_data_start = time.time()
|
|
for step, batch in enumerate(dataloader_dict['train_dataloader']):
|
|
t_data = time.time() - t_data_start
|
|
t_model_start = time.time()
|
|
|
|
with torch.no_grad():
|
|
# Process input data
|
|
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
|
|
accelerator.device,
|
|
non_blocking=True
|
|
)
|
|
bsz, num_frames, c, h, w = pixel_values.shape
|
|
|
|
# Process reference images
|
|
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
|
|
accelerator.device,
|
|
non_blocking=True
|
|
)
|
|
|
|
# Get face mask for GAN
|
|
pixel_values_face_mask = batch['pixel_values_face_mask']
|
|
|
|
# Process audio features
|
|
audio_prompts = process_audio_features(cfg, batch, model_dict['wav2vec'], bsz, num_frames, weight_dtype)
|
|
|
|
# Initialize adapted weight
|
|
adapted_weight = 1
|
|
|
|
# Process sync loss if enabled
|
|
if cfg.loss_params.sync_loss > 0:
|
|
mels = batch['mel']
|
|
# Prepare frames for latentsync (combine channels and frames)
|
|
gt_frames = rearrange(pixel_values, 'b f c h w-> b (f c) h w')
|
|
# Use lower half of face for latentsync
|
|
height = gt_frames.shape[2]
|
|
gt_frames = gt_frames[:, :, height // 2:, :]
|
|
|
|
# Get audio embeddings
|
|
audio_embed = syncnet.get_audio_embed(mels)
|
|
|
|
# Calculate adapted weight based on audio-visual similarity
|
|
if cfg.use_adapted_weight:
|
|
vision_embed_gt = syncnet.get_vision_embed(gt_frames)
|
|
image_audio_sim_gt = F.cosine_similarity(
|
|
audio_embed,
|
|
vision_embed_gt,
|
|
dim=1
|
|
)[0]
|
|
|
|
if image_audio_sim_gt < 0.05 or image_audio_sim_gt > 0.65:
|
|
if cfg.adapted_weight_type == "cut_off":
|
|
adapted_weight = 0.0 # Skip this batch
|
|
print(
|
|
f"\nThe i-a similarity in step {global_step} is {image_audio_sim_gt}, set adapted_weight to {adapted_weight}.")
|
|
elif cfg.adapted_weight_type == "linear":
|
|
adapted_weight = image_audio_sim_gt
|
|
else:
|
|
print(f"unknown adapted_weight_type: {cfg.adapted_weight_type}")
|
|
adapted_weight = 1
|
|
|
|
# Random frame selection for memory efficiency
|
|
max_start = 16 - cfg.num_backward_frames
|
|
frames_left_index = random.randint(0, max_start) if max_start > 0 else 0
|
|
frames_right_index = frames_left_index + cfg.num_backward_frames
|
|
else:
|
|
frames_left_index = 0
|
|
frames_right_index = cfg.data.n_sample_frames
|
|
|
|
# Extract frames for backward pass
|
|
pixel_values_backward = pixel_values[:, frames_left_index:frames_right_index, ...]
|
|
ref_pixel_values_backward = ref_pixel_values[:, frames_left_index:frames_right_index, ...]
|
|
pixel_values_face_mask_backward = pixel_values_face_mask[:, frames_left_index:frames_right_index, ...]
|
|
audio_prompts_backward = audio_prompts[:, frames_left_index:frames_right_index, ...]
|
|
|
|
# Encode target images
|
|
frames = rearrange(pixel_values_backward, 'b f c h w-> (b f) c h w')
|
|
latents = model_dict['vae'].encode(frames).latent_dist.mode()
|
|
latents = latents * model_dict['vae'].config.scaling_factor
|
|
latents = latents.float()
|
|
|
|
# Create masked images
|
|
masked_pixel_values = pixel_values_backward.clone()
|
|
masked_pixel_values[:, :, :, h//2:, :] = -1
|
|
masked_frames = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
|
|
masked_latents = model_dict['vae'].encode(masked_frames).latent_dist.mode()
|
|
masked_latents = masked_latents * model_dict['vae'].config.scaling_factor
|
|
masked_latents = masked_latents.float()
|
|
|
|
# Encode reference images
|
|
ref_frames = rearrange(ref_pixel_values_backward, 'b f c h w-> (b f) c h w')
|
|
ref_latents = model_dict['vae'].encode(ref_frames).latent_dist.mode()
|
|
ref_latents = ref_latents * model_dict['vae'].config.scaling_factor
|
|
ref_latents = ref_latents.float()
|
|
|
|
# Prepare face mask and audio features
|
|
pixel_values_face_mask_backward = rearrange(
|
|
pixel_values_face_mask_backward,
|
|
"b f c h w -> (b f) c h w"
|
|
)
|
|
audio_prompts_backward = rearrange(
|
|
audio_prompts_backward,
|
|
'b f c h w-> (b f) c h w'
|
|
)
|
|
audio_prompts_backward = rearrange(
|
|
audio_prompts_backward,
|
|
'(b f) c h w -> (b f) (c h) w',
|
|
b=bsz
|
|
)
|
|
|
|
# Apply reference dropout (currently inactive)
|
|
dropout = nn.Dropout(p=cfg.ref_dropout_rate)
|
|
ref_latents = dropout(ref_latents)
|
|
|
|
# Prepare model inputs
|
|
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
|
|
input_latents = input_latents.to(weight_dtype)
|
|
timesteps = torch.tensor([0], device=input_latents.device)
|
|
|
|
# Forward pass
|
|
latents_pred = model_dict['net'](
|
|
input_latents,
|
|
timesteps,
|
|
audio_prompts_backward,
|
|
)
|
|
latents_pred = (1 / model_dict['vae'].config.scaling_factor) * latents_pred
|
|
image_pred = model_dict['vae'].decode(latents_pred).sample
|
|
|
|
# Convert to float
|
|
image_pred = image_pred.float()
|
|
frames = frames.float()
|
|
|
|
# Calculate L1 loss
|
|
l1_loss = loss_dict['L1_loss'](frames, image_pred)
|
|
l1_loss_accum += l1_loss.item()
|
|
loss = cfg.loss_params.l1_loss * l1_loss * adapted_weight
|
|
|
|
# Process mouth GAN loss if enabled
|
|
if cfg.loss_params.mouth_gan_loss > 0:
|
|
frames_mouth, image_pred_mouth = get_mouth_region(
|
|
frames,
|
|
image_pred,
|
|
pixel_values_face_mask_backward
|
|
)
|
|
pyramide_real_mouth = pyramid(downsampler(frames_mouth))
|
|
pyramide_generated_mouth = pyramid(downsampler(image_pred_mouth))
|
|
|
|
# Process VGG loss if enabled
|
|
if cfg.loss_params.vgg_loss > 0:
|
|
pyramide_real = pyramid(downsampler(frames))
|
|
pyramide_generated = pyramid(downsampler(image_pred))
|
|
|
|
loss_IN = 0
|
|
for scale in cfg.loss_params.pyramid_scale:
|
|
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
|
|
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
|
|
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
|
|
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
|
loss_IN += weight * value
|
|
loss_IN /= sum(cfg.loss_params.vgg_layer_weight)
|
|
loss += loss_IN * cfg.loss_params.vgg_loss * adapted_weight
|
|
vgg_loss_accum += loss_IN.item()
|
|
|
|
# Process GAN loss if enabled
|
|
if cfg.loss_params.gan_loss > 0:
|
|
set_requires_grad(loss_dict['discriminator'], False)
|
|
loss_G = 0.
|
|
discriminator_maps_generated = loss_dict['discriminator'](pyramide_generated)
|
|
discriminator_maps_real = loss_dict['discriminator'](pyramide_real)
|
|
|
|
for scale in loss_dict['disc_scales']:
|
|
key = 'prediction_map_%s' % scale
|
|
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
|
loss_G += value
|
|
gan_loss_accum += loss_G.item()
|
|
|
|
loss += loss_G * cfg.loss_params.gan_loss * get_ganloss_weight(global_step) * adapted_weight
|
|
|
|
# Process feature matching loss if enabled
|
|
if cfg.loss_params.fm_loss[0] > 0:
|
|
L_feature_matching = 0.
|
|
for scale in loss_dict['disc_scales']:
|
|
key = 'feature_maps_%s' % scale
|
|
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
|
|
value = torch.abs(a - b).mean()
|
|
L_feature_matching += value * cfg.loss_params.fm_loss[i]
|
|
loss += L_feature_matching * adapted_weight
|
|
fm_loss_accum += L_feature_matching.item()
|
|
|
|
# Process mouth GAN loss if enabled
|
|
if cfg.loss_params.mouth_gan_loss > 0:
|
|
set_requires_grad(loss_dict['mouth_discriminator'], False)
|
|
loss_G = 0.
|
|
mouth_discriminator_maps_generated = loss_dict['mouth_discriminator'](pyramide_generated_mouth)
|
|
mouth_discriminator_maps_real = loss_dict['mouth_discriminator'](pyramide_real_mouth)
|
|
|
|
for scale in loss_dict['disc_scales']:
|
|
key = 'prediction_map_%s' % scale
|
|
value = ((1 - mouth_discriminator_maps_generated[key]) ** 2).mean()
|
|
loss_G += value
|
|
gan_loss_accum_mouth += loss_G.item()
|
|
|
|
loss += loss_G * cfg.loss_params.mouth_gan_loss * get_ganloss_weight(global_step) * adapted_weight
|
|
|
|
# Process feature matching loss for mouth if enabled
|
|
if cfg.loss_params.fm_loss[0] > 0:
|
|
L_feature_matching = 0.
|
|
for scale in loss_dict['disc_scales']:
|
|
key = 'feature_maps_%s' % scale
|
|
for i, (a, b) in enumerate(zip(mouth_discriminator_maps_real[key], mouth_discriminator_maps_generated[key])):
|
|
value = torch.abs(a - b).mean()
|
|
L_feature_matching += value * cfg.loss_params.fm_loss[i]
|
|
loss += L_feature_matching * adapted_weight
|
|
fm_loss_accum += L_feature_matching.item()
|
|
|
|
# Process sync loss if enabled
|
|
if cfg.loss_params.sync_loss > 0:
|
|
pred_frames = rearrange(
|
|
image_pred, '(b f) c h w-> b (f c) h w', f=pixel_values_backward.shape[1])
|
|
pred_frames = pred_frames[:, :, height // 2 :, :]
|
|
sync_loss, image_audio_sim_pred = get_sync_loss(
|
|
audio_embed,
|
|
gt_frames,
|
|
pred_frames,
|
|
syncnet,
|
|
adapted_weight,
|
|
frames_left_index=frames_left_index,
|
|
frames_right_index=frames_right_index,
|
|
)
|
|
sync_loss_accum += sync_loss.item()
|
|
loss += sync_loss * cfg.loss_params.sync_loss * adapted_weight
|
|
|
|
# Backward pass
|
|
avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
|
|
train_loss += avg_loss.item()
|
|
accelerator.backward(loss)
|
|
|
|
# Train discriminator if GAN loss is enabled
|
|
if cfg.loss_params.gan_loss > 0:
|
|
set_requires_grad(loss_dict['discriminator'], True)
|
|
loss_D = loss_dict['discriminator_full'](frames, image_pred.detach())
|
|
avg_loss_D = accelerator.gather(loss_D.repeat(cfg.data.train_bs)).mean()
|
|
train_loss_D += avg_loss_D.item() / 1
|
|
loss_D = loss_D * get_ganloss_weight(global_step) * adapted_weight
|
|
accelerator.backward(loss_D)
|
|
|
|
if accelerator.sync_gradients:
|
|
accelerator.clip_grad_norm_(
|
|
loss_dict['discriminator'].parameters(), cfg.solver.max_grad_norm)
|
|
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
|
loss_dict['optimizer_D'].step()
|
|
loss_dict['scheduler_D'].step()
|
|
loss_dict['optimizer_D'].zero_grad()
|
|
|
|
# Train mouth discriminator if mouth GAN loss is enabled
|
|
if cfg.loss_params.mouth_gan_loss > 0:
|
|
set_requires_grad(loss_dict['mouth_discriminator'], True)
|
|
mouth_loss_D = loss_dict['mouth_discriminator_full'](
|
|
frames_mouth, image_pred_mouth.detach())
|
|
avg_mouth_loss_D = accelerator.gather(
|
|
mouth_loss_D.repeat(cfg.data.train_bs)).mean()
|
|
train_loss_D_mouth += avg_mouth_loss_D.item() / 1
|
|
mouth_loss_D = mouth_loss_D * get_ganloss_weight(global_step) * adapted_weight
|
|
accelerator.backward(mouth_loss_D)
|
|
|
|
if accelerator.sync_gradients:
|
|
accelerator.clip_grad_norm_(
|
|
loss_dict['mouth_discriminator'].parameters(), cfg.solver.max_grad_norm)
|
|
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
|
loss_dict['mouth_optimizer_D'].step()
|
|
loss_dict['mouth_scheduler_D'].step()
|
|
loss_dict['mouth_optimizer_D'].zero_grad()
|
|
|
|
# Update main model
|
|
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
|
if accelerator.sync_gradients:
|
|
accelerator.clip_grad_norm_(
|
|
model_dict['trainable_params'],
|
|
cfg.solver.max_grad_norm,
|
|
)
|
|
model_dict['optimizer'].step()
|
|
model_dict['lr_scheduler'].step()
|
|
model_dict['optimizer'].zero_grad()
|
|
|
|
# Update progress and log metrics
|
|
if accelerator.sync_gradients:
|
|
progress_bar.update(1)
|
|
global_step += 1
|
|
accelerator.log({
|
|
"train_loss": train_loss,
|
|
"train_loss_D": train_loss_D,
|
|
"train_loss_D_mouth": train_loss_D_mouth,
|
|
"l1_loss": l1_loss_accum,
|
|
"vgg_loss": vgg_loss_accum,
|
|
"gan_loss": gan_loss_accum,
|
|
"fm_loss": fm_loss_accum,
|
|
"sync_loss": sync_loss_accum,
|
|
"adapted_weight": adapted_weight_accum,
|
|
"lr": model_dict['lr_scheduler'].get_last_lr()[0],
|
|
}, step=global_step)
|
|
|
|
# Reset loss accumulators
|
|
train_loss = 0.0
|
|
l1_loss_accum = 0.0
|
|
vgg_loss_accum = 0.0
|
|
gan_loss_accum = 0.0
|
|
fm_loss_accum = 0.0
|
|
sync_loss_accum = 0.0
|
|
adapted_weight_accum = 0.0
|
|
train_loss_D = 0.0
|
|
train_loss_D_mouth = 0.0
|
|
|
|
# Run validation if needed
|
|
if global_step % cfg.val_freq == 0 or global_step == 10:
|
|
try:
|
|
validation(
|
|
cfg,
|
|
dataloader_dict['val_dataloader'],
|
|
model_dict['net'],
|
|
model_dict['vae'],
|
|
model_dict['wav2vec'],
|
|
accelerator,
|
|
save_dir,
|
|
global_step,
|
|
weight_dtype,
|
|
syncnet_score=adapted_weight,
|
|
)
|
|
except Exception as e:
|
|
print(f"An error occurred during validation: {e}")
|
|
|
|
# Save checkpoint if needed
|
|
if global_step % cfg.checkpointing_steps == 0:
|
|
save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
|
|
try:
|
|
start_time = time.time()
|
|
if accelerator.is_main_process:
|
|
save_models(
|
|
accelerator,
|
|
model_dict['net'],
|
|
save_dir,
|
|
global_step,
|
|
cfg,
|
|
logger=logger
|
|
)
|
|
delete_additional_ckpt(save_dir, cfg.total_limit)
|
|
elapsed_time = time.time() - start_time
|
|
if elapsed_time > 300:
|
|
print(f"Skipping storage as it took too long in step {global_step}.")
|
|
else:
|
|
print(f"Resume states saved at {save_dir} successfully in {elapsed_time}s.")
|
|
except Exception as e:
|
|
print(f"Error when saving model in step {global_step}:", e)
|
|
|
|
# Update progress bar
|
|
t_model = time.time() - t_model_start
|
|
logs = {
|
|
"step_loss": loss.detach().item(),
|
|
"lr": model_dict['lr_scheduler'].get_last_lr()[0],
|
|
"td": f"{t_data:.2f}s",
|
|
"tm": f"{t_model:.2f}s",
|
|
}
|
|
t_data_start = time.time()
|
|
progress_bar.set_postfix(**logs)
|
|
|
|
if global_step >= cfg.solver.max_train_steps:
|
|
break
|
|
|
|
# Save model after each epoch
|
|
if (epoch + 1) % cfg.save_model_epoch_interval == 0:
|
|
try:
|
|
start_time = time.time()
|
|
if accelerator.is_main_process:
|
|
save_models(accelerator, model_dict['net'], save_dir, global_step, cfg)
|
|
accelerator.save_state(save_path)
|
|
elapsed_time = time.time() - start_time
|
|
if elapsed_time > 120:
|
|
print(f"Skipping storage as it took too long in step {global_step}.")
|
|
else:
|
|
print(f"Model saved successfully in {elapsed_time}s.")
|
|
except Exception as e:
|
|
print(f"Error when saving model in step {global_step}:", e)
|
|
accelerator.wait_for_everyone()
|
|
|
|
# End training
|
|
accelerator.end_training()
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml")
|
|
args = parser.parse_args()
|
|
config = OmegaConf.load(args.config)
|
|
main(config)
|