mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
feat: data preprocessing and training (#294)
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
This commit is contained in:
337
musetalk/utils/training_utils.py
Normal file
337
musetalk/utils/training_utils.py
Normal file
@@ -0,0 +1,337 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import WhisperModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from omegaconf import OmegaConf
|
||||
from einops import rearrange
|
||||
|
||||
from musetalk.models.syncnet import SyncNet
|
||||
from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel
|
||||
from musetalk.loss.basic_loss import Interpolate
|
||||
import musetalk.loss.vgg_face as vgg_face
|
||||
from musetalk.data.dataset import PortraitDataset
|
||||
from musetalk.utils.utils import (
|
||||
get_image_pred,
|
||||
process_audio_features,
|
||||
process_and_save_images
|
||||
)
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_latents,
|
||||
timesteps,
|
||||
audio_prompts,
|
||||
):
|
||||
model_pred = self.unet(
|
||||
input_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=audio_prompts
|
||||
).sample
|
||||
return model_pred
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
|
||||
"""Initialize models and optimizers"""
|
||||
model_dict = {
|
||||
'vae': None,
|
||||
'unet': None,
|
||||
'net': None,
|
||||
'wav2vec': None,
|
||||
'optimizer': None,
|
||||
'lr_scheduler': None,
|
||||
'scheduler_max_steps': None,
|
||||
'trainable_params': None
|
||||
}
|
||||
|
||||
model_dict['vae'] = AutoencoderKL.from_pretrained(
|
||||
cfg.pretrained_model_name_or_path,
|
||||
subfolder=cfg.vae_type,
|
||||
)
|
||||
|
||||
unet_config_file = os.path.join(
|
||||
cfg.pretrained_model_name_or_path,
|
||||
cfg.unet_sub_folder + "/musetalk.json"
|
||||
)
|
||||
|
||||
with open(unet_config_file, 'r') as f:
|
||||
unet_config = json.load(f)
|
||||
model_dict['unet'] = UNet2DConditionModel(**unet_config)
|
||||
|
||||
if not cfg.random_init_unet:
|
||||
pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
|
||||
print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
|
||||
checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
|
||||
model_dict['unet'].load_state_dict(checkpoint)
|
||||
|
||||
unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
|
||||
logger.info(f"unet {sum(unet_params) / 1e6}M-parameter")
|
||||
|
||||
model_dict['vae'].requires_grad_(False)
|
||||
model_dict['unet'].requires_grad_(True)
|
||||
|
||||
model_dict['vae'].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
model_dict['net'] = Net(model_dict['unet'])
|
||||
|
||||
model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to(
|
||||
device="cuda", dtype=weight_dtype).eval()
|
||||
model_dict['wav2vec'].requires_grad_(False)
|
||||
|
||||
if cfg.solver.gradient_checkpointing:
|
||||
model_dict['unet'].enable_gradient_checkpointing()
|
||||
|
||||
if cfg.solver.scale_lr:
|
||||
learning_rate = (
|
||||
cfg.solver.learning_rate
|
||||
* cfg.solver.gradient_accumulation_steps
|
||||
* cfg.data.train_bs
|
||||
* accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
learning_rate = cfg.solver.learning_rate
|
||||
|
||||
if cfg.solver.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
||||
)
|
||||
optimizer_cls = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters()))
|
||||
if accelerator.is_main_process:
|
||||
print('trainable params')
|
||||
for n, p in model_dict['net'].named_parameters():
|
||||
if p.requires_grad:
|
||||
print(n)
|
||||
|
||||
model_dict['optimizer'] = optimizer_cls(
|
||||
model_dict['trainable_params'],
|
||||
lr=learning_rate,
|
||||
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
|
||||
weight_decay=cfg.solver.adam_weight_decay,
|
||||
eps=cfg.solver.adam_epsilon,
|
||||
)
|
||||
|
||||
model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps
|
||||
model_dict['lr_scheduler'] = get_scheduler(
|
||||
cfg.solver.lr_scheduler,
|
||||
optimizer=model_dict['optimizer'],
|
||||
num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
|
||||
num_training_steps=model_dict['scheduler_max_steps'],
|
||||
)
|
||||
|
||||
return model_dict
|
||||
|
||||
def initialize_dataloaders(cfg):
|
||||
"""Initialize training and validation dataloaders"""
|
||||
dataloader_dict = {
|
||||
'train_dataset': None,
|
||||
'val_dataset': None,
|
||||
'train_dataloader': None,
|
||||
'val_dataloader': None
|
||||
}
|
||||
|
||||
dataloader_dict['train_dataset'] = PortraitDataset(cfg={
|
||||
'image_size': cfg.data.image_size,
|
||||
'T': cfg.data.n_sample_frames,
|
||||
"sample_method": cfg.data.sample_method,
|
||||
'top_k_ratio': cfg.data.top_k_ratio,
|
||||
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
||||
"dataset_key": cfg.data.dataset_key,
|
||||
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
||||
"whisper_path": cfg.whisper_path,
|
||||
"min_face_size": cfg.data.min_face_size,
|
||||
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
||||
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
||||
"crop_type": cfg.crop_type,
|
||||
"random_margin_method": cfg.random_margin_method,
|
||||
})
|
||||
|
||||
dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader(
|
||||
dataloader_dict['train_dataset'],
|
||||
batch_size=cfg.data.train_bs,
|
||||
shuffle=True,
|
||||
num_workers=cfg.data.num_workers,
|
||||
)
|
||||
|
||||
dataloader_dict['val_dataset'] = PortraitDataset(cfg={
|
||||
'image_size': cfg.data.image_size,
|
||||
'T': cfg.data.n_sample_frames,
|
||||
"sample_method": cfg.data.sample_method,
|
||||
'top_k_ratio': cfg.data.top_k_ratio,
|
||||
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
||||
"dataset_key": cfg.data.dataset_key,
|
||||
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
||||
"whisper_path": cfg.whisper_path,
|
||||
"min_face_size": cfg.data.min_face_size,
|
||||
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
||||
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
||||
"crop_type": cfg.crop_type,
|
||||
"random_margin_method": cfg.random_margin_method,
|
||||
})
|
||||
|
||||
dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader(
|
||||
dataloader_dict['val_dataset'],
|
||||
batch_size=cfg.data.train_bs,
|
||||
shuffle=True,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
return dataloader_dict
|
||||
|
||||
def initialize_loss_functions(cfg, accelerator, scheduler_max_steps):
|
||||
"""Initialize loss functions and discriminators"""
|
||||
loss_dict = {
|
||||
'L1_loss': nn.L1Loss(reduction='mean'),
|
||||
'discriminator': None,
|
||||
'mouth_discriminator': None,
|
||||
'optimizer_D': None,
|
||||
'mouth_optimizer_D': None,
|
||||
'scheduler_D': None,
|
||||
'mouth_scheduler_D': None,
|
||||
'disc_scales': None,
|
||||
'discriminator_full': None,
|
||||
'mouth_discriminator_full': None
|
||||
}
|
||||
|
||||
if cfg.loss_params.gan_loss > 0:
|
||||
loss_dict['discriminator'] = MultiScaleDiscriminator(
|
||||
**cfg.model_params.discriminator_params).to(accelerator.device)
|
||||
loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator'])
|
||||
loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales
|
||||
loss_dict['optimizer_D'] = optim.AdamW(
|
||||
loss_dict['discriminator'].parameters(),
|
||||
lr=cfg.discriminator_train_params.lr,
|
||||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||||
betas=cfg.discriminator_train_params.betas,
|
||||
eps=cfg.discriminator_train_params.eps)
|
||||
loss_dict['scheduler_D'] = CosineAnnealingLR(
|
||||
loss_dict['optimizer_D'],
|
||||
T_max=scheduler_max_steps,
|
||||
eta_min=1e-6
|
||||
)
|
||||
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
loss_dict['mouth_discriminator'] = MultiScaleDiscriminator(
|
||||
**cfg.model_params.discriminator_params).to(accelerator.device)
|
||||
loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator'])
|
||||
loss_dict['mouth_optimizer_D'] = optim.AdamW(
|
||||
loss_dict['mouth_discriminator'].parameters(),
|
||||
lr=cfg.discriminator_train_params.lr,
|
||||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||||
betas=cfg.discriminator_train_params.betas,
|
||||
eps=cfg.discriminator_train_params.eps)
|
||||
loss_dict['mouth_scheduler_D'] = CosineAnnealingLR(
|
||||
loss_dict['mouth_optimizer_D'],
|
||||
T_max=scheduler_max_steps,
|
||||
eta_min=1e-6
|
||||
)
|
||||
|
||||
return loss_dict
|
||||
|
||||
def initialize_syncnet(cfg, accelerator, weight_dtype):
|
||||
"""Initialize SyncNet model"""
|
||||
if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight:
|
||||
if cfg.data.n_sample_frames != 16:
|
||||
raise ValueError(
|
||||
f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16."
|
||||
)
|
||||
syncnet_config = OmegaConf.load(cfg.syncnet_config_path)
|
||||
syncnet = SyncNet(OmegaConf.to_container(
|
||||
syncnet_config.model)).to(accelerator.device)
|
||||
print(
|
||||
f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
|
||||
checkpoint = torch.load(
|
||||
syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
|
||||
syncnet.load_state_dict(checkpoint["state_dict"])
|
||||
syncnet.to(dtype=weight_dtype)
|
||||
syncnet.requires_grad_(False)
|
||||
syncnet.eval()
|
||||
return syncnet
|
||||
return None
|
||||
|
||||
def initialize_vgg(cfg, accelerator):
|
||||
"""Initialize VGG model"""
|
||||
if cfg.loss_params.vgg_loss > 0:
|
||||
vgg_IN = vgg_face.Vgg19().to(accelerator.device,)
|
||||
pyramid = vgg_face.ImagePyramide(
|
||||
cfg.loss_params.pyramid_scale, 3).to(accelerator.device)
|
||||
vgg_IN.eval()
|
||||
downsampler = Interpolate(
|
||||
size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device)
|
||||
return vgg_IN, pyramid, downsampler
|
||||
return None, None, None
|
||||
|
||||
def validation(
|
||||
cfg,
|
||||
val_dataloader,
|
||||
net,
|
||||
vae,
|
||||
wav2vec,
|
||||
accelerator,
|
||||
save_dir,
|
||||
global_step,
|
||||
weight_dtype,
|
||||
syncnet_score=1,
|
||||
):
|
||||
"""Validation function for model evaluation"""
|
||||
net.eval() # Set the model to evaluation mode
|
||||
for batch in val_dataloader:
|
||||
# The same ref_latents
|
||||
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
|
||||
accelerator.device, non_blocking=True
|
||||
)
|
||||
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
|
||||
accelerator.device, non_blocking=True
|
||||
)
|
||||
bsz, num_frames, c, h, w = ref_pixel_values.shape
|
||||
|
||||
audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype)
|
||||
# audio feature for unet
|
||||
audio_prompts = rearrange(
|
||||
audio_prompts,
|
||||
'b f c h w-> (b f) c h w'
|
||||
)
|
||||
audio_prompts = rearrange(
|
||||
audio_prompts,
|
||||
'(b f) c h w -> (b f) (c h) w',
|
||||
b=bsz
|
||||
)
|
||||
# different masked_latents
|
||||
image_pred_train = get_image_pred(
|
||||
pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
||||
image_pred_infer = get_image_pred(
|
||||
ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
||||
|
||||
process_and_save_images(
|
||||
batch,
|
||||
image_pred_train,
|
||||
image_pred_infer,
|
||||
save_dir,
|
||||
global_step,
|
||||
accelerator,
|
||||
cfg.num_images_to_keep,
|
||||
syncnet_score
|
||||
)
|
||||
# only infer 1 image in validation
|
||||
break
|
||||
net.train() # Set the model back to training mode
|
||||
@@ -2,6 +2,11 @@ import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Union, List
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import shutil
|
||||
import os.path as osp
|
||||
|
||||
ffmpeg_path = os.getenv('FFMPEG_PATH')
|
||||
if ffmpeg_path is None:
|
||||
@@ -11,7 +16,6 @@ elif ffmpeg_path not in os.getenv('PATH'):
|
||||
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
|
||||
|
||||
|
||||
from musetalk.whisper.audio2feature import Audio2Feature
|
||||
from musetalk.models.vae import VAE
|
||||
from musetalk.models.unet import UNet,PositionalEncoding
|
||||
|
||||
@@ -76,3 +80,248 @@ def datagen(
|
||||
latent_batch = torch.cat(latent_batch, dim=0)
|
||||
|
||||
yield whisper_batch.to(device), latent_batch.to(device)
|
||||
|
||||
def cast_training_params(
|
||||
model: Union[torch.nn.Module, List[torch.nn.Module]],
|
||||
dtype=torch.float32,
|
||||
):
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
for m in model:
|
||||
for param in m.parameters():
|
||||
# only upcast trainable parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(dtype)
|
||||
|
||||
def rand_log_normal(
|
||||
shape,
|
||||
loc=0.,
|
||||
scale=1.,
|
||||
device='cpu',
|
||||
dtype=torch.float32,
|
||||
generator=None
|
||||
):
|
||||
"""Draws samples from an lognormal distribution."""
|
||||
rnd_normal = torch.randn(
|
||||
shape, device=device, dtype=dtype, generator=generator) # N(0, I)
|
||||
sigma = (rnd_normal * scale + loc).exp()
|
||||
return sigma
|
||||
|
||||
def get_mouth_region(frames, image_pred, pixel_values_face_mask):
|
||||
# Initialize lists to store the results for each image in the batch
|
||||
mouth_real_list = []
|
||||
mouth_generated_list = []
|
||||
|
||||
# Process each image in the batch
|
||||
for b in range(frames.shape[0]):
|
||||
# Find the non-zero area in the face mask
|
||||
non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
|
||||
# If there are no non-zero indices, skip this image
|
||||
if non_zero_indices.numel() == 0:
|
||||
continue
|
||||
|
||||
min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
|
||||
non_zero_indices[:, 1])
|
||||
min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
|
||||
non_zero_indices[:, 2])
|
||||
|
||||
# Crop the frames and image_pred according to the non-zero area
|
||||
frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
|
||||
image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
|
||||
# Resize the cropped images to 256*256
|
||||
frames_resized = F.interpolate(frames_cropped.unsqueeze(
|
||||
0), size=(256, 256), mode='bilinear', align_corners=False)
|
||||
image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
|
||||
0), size=(256, 256), mode='bilinear', align_corners=False)
|
||||
|
||||
# Append the resized images to the result lists
|
||||
mouth_real_list.append(frames_resized)
|
||||
mouth_generated_list.append(image_pred_resized)
|
||||
|
||||
# Convert the lists to tensors if they are not empty
|
||||
mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
|
||||
mouth_generated = torch.cat(
|
||||
mouth_generated_list, dim=0) if mouth_generated_list else None
|
||||
|
||||
return mouth_real, mouth_generated
|
||||
|
||||
def get_image_pred(pixel_values,
|
||||
ref_pixel_values,
|
||||
audio_prompts,
|
||||
vae,
|
||||
net,
|
||||
weight_dtype):
|
||||
with torch.no_grad():
|
||||
bsz, num_frames, c, h, w = pixel_values.shape
|
||||
|
||||
masked_pixel_values = pixel_values.clone()
|
||||
masked_pixel_values[:, :, :, h//2:, :] = -1
|
||||
|
||||
masked_frames = rearrange(
|
||||
masked_pixel_values, 'b f c h w -> (b f) c h w')
|
||||
masked_latents = vae.encode(masked_frames).latent_dist.mode()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
masked_latents = masked_latents.float()
|
||||
|
||||
ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
|
||||
ref_latents = vae.encode(ref_frames).latent_dist.mode()
|
||||
ref_latents = ref_latents * vae.config.scaling_factor
|
||||
ref_latents = ref_latents.float()
|
||||
|
||||
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
input_latents = input_latents.to(weight_dtype)
|
||||
timesteps = torch.tensor([0], device=input_latents.device)
|
||||
latents_pred = net(
|
||||
input_latents,
|
||||
timesteps,
|
||||
audio_prompts,
|
||||
)
|
||||
latents_pred = (1 / vae.config.scaling_factor) * latents_pred
|
||||
image_pred = vae.decode(latents_pred).sample
|
||||
image_pred = image_pred.float()
|
||||
|
||||
return image_pred
|
||||
|
||||
def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
|
||||
with torch.no_grad():
|
||||
audio_feature_length_per_frame = 2 * \
|
||||
(cfg.data.audio_padding_length_left +
|
||||
cfg.data.audio_padding_length_right + 1)
|
||||
audio_feats = batch['audio_feature'].to(weight_dtype)
|
||||
audio_feats = wav2vec.encoder(
|
||||
audio_feats, output_hidden_states=True).hidden_states
|
||||
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384]
|
||||
|
||||
start_ts = batch['audio_offset']
|
||||
step_ts = batch['audio_step']
|
||||
audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
|
||||
audio_feats,
|
||||
torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
|
||||
audio_prompts = []
|
||||
for bb in range(bsz):
|
||||
audio_feats_list = []
|
||||
for f in range(num_frames):
|
||||
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
|
||||
audio_clip = audio_feats[bb:bb+1,
|
||||
cur_t: cur_t+audio_feature_length_per_frame]
|
||||
|
||||
audio_feats_list.append(audio_clip)
|
||||
audio_feats_list = torch.stack(audio_feats_list, 1)
|
||||
audio_prompts.append(audio_feats_list)
|
||||
audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384
|
||||
return audio_prompts
|
||||
|
||||
def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
|
||||
save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")
|
||||
|
||||
if total_limit is not None:
|
||||
checkpoints = os.listdir(save_dir)
|
||||
checkpoints = [d for d in checkpoints if d.endswith(".pth")]
|
||||
checkpoints = [d for d in checkpoints if name in d]
|
||||
checkpoints = sorted(
|
||||
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
|
||||
)
|
||||
|
||||
if len(checkpoints) >= total_limit:
|
||||
num_to_remove = len(checkpoints) - total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(
|
||||
f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(
|
||||
save_dir, removing_checkpoint)
|
||||
os.remove(removing_checkpoint)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, save_path)
|
||||
|
||||
def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
|
||||
unwarp_net = accelerator.unwrap_model(net)
|
||||
save_checkpoint(
|
||||
unwarp_net.unet,
|
||||
save_dir,
|
||||
global_step,
|
||||
name="unet",
|
||||
total_limit=cfg.total_limit,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
def delete_additional_ckpt(base_path, num_keep):
|
||||
dirs = []
|
||||
for d in os.listdir(base_path):
|
||||
if d.startswith("checkpoint-"):
|
||||
dirs.append(d)
|
||||
num_tot = len(dirs)
|
||||
if num_tot <= num_keep:
|
||||
return
|
||||
# ensure ckpt is sorted and delete the ealier!
|
||||
del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
|
||||
for d in del_dirs:
|
||||
path_to_dir = osp.join(base_path, d)
|
||||
if osp.exists(path_to_dir):
|
||||
shutil.rmtree(path_to_dir)
|
||||
|
||||
def seed_everything(seed):
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed % (2**32))
|
||||
random.seed(seed)
|
||||
|
||||
def process_and_save_images(
|
||||
batch,
|
||||
image_pred,
|
||||
image_pred_infer,
|
||||
save_dir,
|
||||
global_step,
|
||||
accelerator,
|
||||
num_images_to_keep=10,
|
||||
syncnet_score=1
|
||||
):
|
||||
# Rearrange the tensors
|
||||
print("image_pred.shape: ", image_pred.shape)
|
||||
pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
|
||||
pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
|
||||
|
||||
# Create masked pixel values
|
||||
masked_pixel_values = batch["pixel_values_vid"].clone()
|
||||
_, _, _, h, _ = batch["pixel_values_vid"].shape
|
||||
masked_pixel_values[:, :, :, h//2:, :] = -1
|
||||
masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
|
||||
|
||||
# Keep only the specified number of images
|
||||
pixel_values = pixel_values[:num_images_to_keep, :, :, :]
|
||||
masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
|
||||
pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
|
||||
image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
|
||||
image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
|
||||
|
||||
# Concatenate images
|
||||
concat = torch.cat([
|
||||
masked_pixel_values * 0.5 + 0.5,
|
||||
pixel_values_ref_img * 0.5 + 0.5,
|
||||
image_pred * 0.5 + 0.5,
|
||||
pixel_values * 0.5 + 0.5,
|
||||
image_pred_infer * 0.5 + 0.5,
|
||||
], dim=2)
|
||||
print("concat.shape: ", concat.shape)
|
||||
|
||||
# Create the save directory if it doesn't exist
|
||||
os.makedirs(f'{save_dir}/samples/', exist_ok=True)
|
||||
|
||||
# Try to save the concatenated image
|
||||
try:
|
||||
# Concatenate images horizontally and convert to numpy array
|
||||
final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
|
||||
# Save the image
|
||||
cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
|
||||
print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
|
||||
except Exception as e:
|
||||
print(f"Failed to save image: {e}")
|
||||
Reference in New Issue
Block a user