mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 17:39:20 +08:00
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
338 lines
12 KiB
Python
338 lines
12 KiB
Python
import os
|
|
import json
|
|
import logging
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
from diffusers import AutoencoderKL, UNet2DConditionModel
|
|
from transformers import WhisperModel
|
|
from diffusers.optimization import get_scheduler
|
|
from omegaconf import OmegaConf
|
|
from einops import rearrange
|
|
|
|
from musetalk.models.syncnet import SyncNet
|
|
from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel
|
|
from musetalk.loss.basic_loss import Interpolate
|
|
import musetalk.loss.vgg_face as vgg_face
|
|
from musetalk.data.dataset import PortraitDataset
|
|
from musetalk.utils.utils import (
|
|
get_image_pred,
|
|
process_audio_features,
|
|
process_and_save_images
|
|
)
|
|
|
|
class Net(nn.Module):
|
|
def __init__(
|
|
self,
|
|
unet: UNet2DConditionModel,
|
|
):
|
|
super().__init__()
|
|
self.unet = unet
|
|
|
|
def forward(
|
|
self,
|
|
input_latents,
|
|
timesteps,
|
|
audio_prompts,
|
|
):
|
|
model_pred = self.unet(
|
|
input_latents,
|
|
timesteps,
|
|
encoder_hidden_states=audio_prompts
|
|
).sample
|
|
return model_pred
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
|
|
"""Initialize models and optimizers"""
|
|
model_dict = {
|
|
'vae': None,
|
|
'unet': None,
|
|
'net': None,
|
|
'wav2vec': None,
|
|
'optimizer': None,
|
|
'lr_scheduler': None,
|
|
'scheduler_max_steps': None,
|
|
'trainable_params': None
|
|
}
|
|
|
|
model_dict['vae'] = AutoencoderKL.from_pretrained(
|
|
cfg.pretrained_model_name_or_path,
|
|
subfolder=cfg.vae_type,
|
|
)
|
|
|
|
unet_config_file = os.path.join(
|
|
cfg.pretrained_model_name_or_path,
|
|
cfg.unet_sub_folder + "/musetalk.json"
|
|
)
|
|
|
|
with open(unet_config_file, 'r') as f:
|
|
unet_config = json.load(f)
|
|
model_dict['unet'] = UNet2DConditionModel(**unet_config)
|
|
|
|
if not cfg.random_init_unet:
|
|
pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
|
|
print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
|
|
checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
|
|
model_dict['unet'].load_state_dict(checkpoint)
|
|
|
|
unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
|
|
logger.info(f"unet {sum(unet_params) / 1e6}M-parameter")
|
|
|
|
model_dict['vae'].requires_grad_(False)
|
|
model_dict['unet'].requires_grad_(True)
|
|
|
|
model_dict['vae'].to(accelerator.device, dtype=weight_dtype)
|
|
|
|
model_dict['net'] = Net(model_dict['unet'])
|
|
|
|
model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to(
|
|
device="cuda", dtype=weight_dtype).eval()
|
|
model_dict['wav2vec'].requires_grad_(False)
|
|
|
|
if cfg.solver.gradient_checkpointing:
|
|
model_dict['unet'].enable_gradient_checkpointing()
|
|
|
|
if cfg.solver.scale_lr:
|
|
learning_rate = (
|
|
cfg.solver.learning_rate
|
|
* cfg.solver.gradient_accumulation_steps
|
|
* cfg.data.train_bs
|
|
* accelerator.num_processes
|
|
)
|
|
else:
|
|
learning_rate = cfg.solver.learning_rate
|
|
|
|
if cfg.solver.use_8bit_adam:
|
|
try:
|
|
import bitsandbytes as bnb
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
|
)
|
|
optimizer_cls = bnb.optim.AdamW8bit
|
|
else:
|
|
optimizer_cls = torch.optim.AdamW
|
|
|
|
model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters()))
|
|
if accelerator.is_main_process:
|
|
print('trainable params')
|
|
for n, p in model_dict['net'].named_parameters():
|
|
if p.requires_grad:
|
|
print(n)
|
|
|
|
model_dict['optimizer'] = optimizer_cls(
|
|
model_dict['trainable_params'],
|
|
lr=learning_rate,
|
|
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
|
|
weight_decay=cfg.solver.adam_weight_decay,
|
|
eps=cfg.solver.adam_epsilon,
|
|
)
|
|
|
|
model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps
|
|
model_dict['lr_scheduler'] = get_scheduler(
|
|
cfg.solver.lr_scheduler,
|
|
optimizer=model_dict['optimizer'],
|
|
num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
|
|
num_training_steps=model_dict['scheduler_max_steps'],
|
|
)
|
|
|
|
return model_dict
|
|
|
|
def initialize_dataloaders(cfg):
|
|
"""Initialize training and validation dataloaders"""
|
|
dataloader_dict = {
|
|
'train_dataset': None,
|
|
'val_dataset': None,
|
|
'train_dataloader': None,
|
|
'val_dataloader': None
|
|
}
|
|
|
|
dataloader_dict['train_dataset'] = PortraitDataset(cfg={
|
|
'image_size': cfg.data.image_size,
|
|
'T': cfg.data.n_sample_frames,
|
|
"sample_method": cfg.data.sample_method,
|
|
'top_k_ratio': cfg.data.top_k_ratio,
|
|
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
|
"dataset_key": cfg.data.dataset_key,
|
|
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
|
"whisper_path": cfg.whisper_path,
|
|
"min_face_size": cfg.data.min_face_size,
|
|
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
|
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
|
"crop_type": cfg.crop_type,
|
|
"random_margin_method": cfg.random_margin_method,
|
|
})
|
|
|
|
dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader(
|
|
dataloader_dict['train_dataset'],
|
|
batch_size=cfg.data.train_bs,
|
|
shuffle=True,
|
|
num_workers=cfg.data.num_workers,
|
|
)
|
|
|
|
dataloader_dict['val_dataset'] = PortraitDataset(cfg={
|
|
'image_size': cfg.data.image_size,
|
|
'T': cfg.data.n_sample_frames,
|
|
"sample_method": cfg.data.sample_method,
|
|
'top_k_ratio': cfg.data.top_k_ratio,
|
|
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
|
"dataset_key": cfg.data.dataset_key,
|
|
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
|
"whisper_path": cfg.whisper_path,
|
|
"min_face_size": cfg.data.min_face_size,
|
|
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
|
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
|
"crop_type": cfg.crop_type,
|
|
"random_margin_method": cfg.random_margin_method,
|
|
})
|
|
|
|
dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader(
|
|
dataloader_dict['val_dataset'],
|
|
batch_size=cfg.data.train_bs,
|
|
shuffle=True,
|
|
num_workers=1,
|
|
)
|
|
|
|
return dataloader_dict
|
|
|
|
def initialize_loss_functions(cfg, accelerator, scheduler_max_steps):
|
|
"""Initialize loss functions and discriminators"""
|
|
loss_dict = {
|
|
'L1_loss': nn.L1Loss(reduction='mean'),
|
|
'discriminator': None,
|
|
'mouth_discriminator': None,
|
|
'optimizer_D': None,
|
|
'mouth_optimizer_D': None,
|
|
'scheduler_D': None,
|
|
'mouth_scheduler_D': None,
|
|
'disc_scales': None,
|
|
'discriminator_full': None,
|
|
'mouth_discriminator_full': None
|
|
}
|
|
|
|
if cfg.loss_params.gan_loss > 0:
|
|
loss_dict['discriminator'] = MultiScaleDiscriminator(
|
|
**cfg.model_params.discriminator_params).to(accelerator.device)
|
|
loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator'])
|
|
loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales
|
|
loss_dict['optimizer_D'] = optim.AdamW(
|
|
loss_dict['discriminator'].parameters(),
|
|
lr=cfg.discriminator_train_params.lr,
|
|
weight_decay=cfg.discriminator_train_params.weight_decay,
|
|
betas=cfg.discriminator_train_params.betas,
|
|
eps=cfg.discriminator_train_params.eps)
|
|
loss_dict['scheduler_D'] = CosineAnnealingLR(
|
|
loss_dict['optimizer_D'],
|
|
T_max=scheduler_max_steps,
|
|
eta_min=1e-6
|
|
)
|
|
|
|
if cfg.loss_params.mouth_gan_loss > 0:
|
|
loss_dict['mouth_discriminator'] = MultiScaleDiscriminator(
|
|
**cfg.model_params.discriminator_params).to(accelerator.device)
|
|
loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator'])
|
|
loss_dict['mouth_optimizer_D'] = optim.AdamW(
|
|
loss_dict['mouth_discriminator'].parameters(),
|
|
lr=cfg.discriminator_train_params.lr,
|
|
weight_decay=cfg.discriminator_train_params.weight_decay,
|
|
betas=cfg.discriminator_train_params.betas,
|
|
eps=cfg.discriminator_train_params.eps)
|
|
loss_dict['mouth_scheduler_D'] = CosineAnnealingLR(
|
|
loss_dict['mouth_optimizer_D'],
|
|
T_max=scheduler_max_steps,
|
|
eta_min=1e-6
|
|
)
|
|
|
|
return loss_dict
|
|
|
|
def initialize_syncnet(cfg, accelerator, weight_dtype):
|
|
"""Initialize SyncNet model"""
|
|
if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight:
|
|
if cfg.data.n_sample_frames != 16:
|
|
raise ValueError(
|
|
f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16."
|
|
)
|
|
syncnet_config = OmegaConf.load(cfg.syncnet_config_path)
|
|
syncnet = SyncNet(OmegaConf.to_container(
|
|
syncnet_config.model)).to(accelerator.device)
|
|
print(
|
|
f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
|
|
checkpoint = torch.load(
|
|
syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
|
|
syncnet.load_state_dict(checkpoint["state_dict"])
|
|
syncnet.to(dtype=weight_dtype)
|
|
syncnet.requires_grad_(False)
|
|
syncnet.eval()
|
|
return syncnet
|
|
return None
|
|
|
|
def initialize_vgg(cfg, accelerator):
|
|
"""Initialize VGG model"""
|
|
if cfg.loss_params.vgg_loss > 0:
|
|
vgg_IN = vgg_face.Vgg19().to(accelerator.device,)
|
|
pyramid = vgg_face.ImagePyramide(
|
|
cfg.loss_params.pyramid_scale, 3).to(accelerator.device)
|
|
vgg_IN.eval()
|
|
downsampler = Interpolate(
|
|
size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device)
|
|
return vgg_IN, pyramid, downsampler
|
|
return None, None, None
|
|
|
|
def validation(
|
|
cfg,
|
|
val_dataloader,
|
|
net,
|
|
vae,
|
|
wav2vec,
|
|
accelerator,
|
|
save_dir,
|
|
global_step,
|
|
weight_dtype,
|
|
syncnet_score=1,
|
|
):
|
|
"""Validation function for model evaluation"""
|
|
net.eval() # Set the model to evaluation mode
|
|
for batch in val_dataloader:
|
|
# The same ref_latents
|
|
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
|
|
accelerator.device, non_blocking=True
|
|
)
|
|
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
|
|
accelerator.device, non_blocking=True
|
|
)
|
|
bsz, num_frames, c, h, w = ref_pixel_values.shape
|
|
|
|
audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype)
|
|
# audio feature for unet
|
|
audio_prompts = rearrange(
|
|
audio_prompts,
|
|
'b f c h w-> (b f) c h w'
|
|
)
|
|
audio_prompts = rearrange(
|
|
audio_prompts,
|
|
'(b f) c h w -> (b f) (c h) w',
|
|
b=bsz
|
|
)
|
|
# different masked_latents
|
|
image_pred_train = get_image_pred(
|
|
pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
|
image_pred_infer = get_image_pred(
|
|
ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
|
|
|
process_and_save_images(
|
|
batch,
|
|
image_pred_train,
|
|
image_pred_infer,
|
|
save_dir,
|
|
global_step,
|
|
accelerator,
|
|
cfg.num_images_to_keep,
|
|
syncnet_score
|
|
)
|
|
# only infer 1 image in validation
|
|
break
|
|
net.train() # Set the model back to training mode
|