import os import json import logging import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import CosineAnnealingLR from diffusers import AutoencoderKL, UNet2DConditionModel from transformers import WhisperModel from diffusers.optimization import get_scheduler from omegaconf import OmegaConf from einops import rearrange from musetalk.models.syncnet import SyncNet from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel from musetalk.loss.basic_loss import Interpolate import musetalk.loss.vgg_face as vgg_face from musetalk.data.dataset import PortraitDataset from musetalk.utils.utils import ( get_image_pred, process_audio_features, process_and_save_images ) class Net(nn.Module): def __init__( self, unet: UNet2DConditionModel, ): super().__init__() self.unet = unet def forward( self, input_latents, timesteps, audio_prompts, ): model_pred = self.unet( input_latents, timesteps, encoder_hidden_states=audio_prompts ).sample return model_pred logger = logging.getLogger(__name__) def initialize_models_and_optimizers(cfg, accelerator, weight_dtype): """Initialize models and optimizers""" model_dict = { 'vae': None, 'unet': None, 'net': None, 'wav2vec': None, 'optimizer': None, 'lr_scheduler': None, 'scheduler_max_steps': None, 'trainable_params': None } model_dict['vae'] = AutoencoderKL.from_pretrained( cfg.pretrained_model_name_or_path, subfolder=cfg.vae_type, ) unet_config_file = os.path.join( cfg.pretrained_model_name_or_path, cfg.unet_sub_folder + "/musetalk.json" ) with open(unet_config_file, 'r') as f: unet_config = json.load(f) model_dict['unet'] = UNet2DConditionModel(**unet_config) if not cfg.random_init_unet: pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin") print(f"### Loading existing unet weights from {pretrained_unet_path}. ###") checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device) model_dict['unet'].load_state_dict(checkpoint) unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()] logger.info(f"unet {sum(unet_params) / 1e6}M-parameter") model_dict['vae'].requires_grad_(False) model_dict['unet'].requires_grad_(True) model_dict['vae'].to(accelerator.device, dtype=weight_dtype) model_dict['net'] = Net(model_dict['unet']) model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to( device="cuda", dtype=weight_dtype).eval() model_dict['wav2vec'].requires_grad_(False) if cfg.solver.gradient_checkpointing: model_dict['unet'].enable_gradient_checkpointing() if cfg.solver.scale_lr: learning_rate = ( cfg.solver.learning_rate * cfg.solver.gradient_accumulation_steps * cfg.data.train_bs * accelerator.num_processes ) else: learning_rate = cfg.solver.learning_rate if cfg.solver.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" ) optimizer_cls = bnb.optim.AdamW8bit else: optimizer_cls = torch.optim.AdamW model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters())) if accelerator.is_main_process: print('trainable params') for n, p in model_dict['net'].named_parameters(): if p.requires_grad: print(n) model_dict['optimizer'] = optimizer_cls( model_dict['trainable_params'], lr=learning_rate, betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), weight_decay=cfg.solver.adam_weight_decay, eps=cfg.solver.adam_epsilon, ) model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps model_dict['lr_scheduler'] = get_scheduler( cfg.solver.lr_scheduler, optimizer=model_dict['optimizer'], num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps, num_training_steps=model_dict['scheduler_max_steps'], ) return model_dict def initialize_dataloaders(cfg): """Initialize training and validation dataloaders""" dataloader_dict = { 'train_dataset': None, 'val_dataset': None, 'train_dataloader': None, 'val_dataloader': None } dataloader_dict['train_dataset'] = PortraitDataset(cfg={ 'image_size': cfg.data.image_size, 'T': cfg.data.n_sample_frames, "sample_method": cfg.data.sample_method, 'top_k_ratio': cfg.data.top_k_ratio, "contorl_face_min_size": cfg.data.contorl_face_min_size, "dataset_key": cfg.data.dataset_key, "padding_pixel_mouth": cfg.padding_pixel_mouth, "whisper_path": cfg.whisper_path, "min_face_size": cfg.data.min_face_size, "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean, "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std, "crop_type": cfg.crop_type, "random_margin_method": cfg.random_margin_method, }) dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader( dataloader_dict['train_dataset'], batch_size=cfg.data.train_bs, shuffle=True, num_workers=cfg.data.num_workers, ) dataloader_dict['val_dataset'] = PortraitDataset(cfg={ 'image_size': cfg.data.image_size, 'T': cfg.data.n_sample_frames, "sample_method": cfg.data.sample_method, 'top_k_ratio': cfg.data.top_k_ratio, "contorl_face_min_size": cfg.data.contorl_face_min_size, "dataset_key": cfg.data.dataset_key, "padding_pixel_mouth": cfg.padding_pixel_mouth, "whisper_path": cfg.whisper_path, "min_face_size": cfg.data.min_face_size, "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean, "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std, "crop_type": cfg.crop_type, "random_margin_method": cfg.random_margin_method, }) dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader( dataloader_dict['val_dataset'], batch_size=cfg.data.train_bs, shuffle=True, num_workers=1, ) return dataloader_dict def initialize_loss_functions(cfg, accelerator, scheduler_max_steps): """Initialize loss functions and discriminators""" loss_dict = { 'L1_loss': nn.L1Loss(reduction='mean'), 'discriminator': None, 'mouth_discriminator': None, 'optimizer_D': None, 'mouth_optimizer_D': None, 'scheduler_D': None, 'mouth_scheduler_D': None, 'disc_scales': None, 'discriminator_full': None, 'mouth_discriminator_full': None } if cfg.loss_params.gan_loss > 0: loss_dict['discriminator'] = MultiScaleDiscriminator( **cfg.model_params.discriminator_params).to(accelerator.device) loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator']) loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales loss_dict['optimizer_D'] = optim.AdamW( loss_dict['discriminator'].parameters(), lr=cfg.discriminator_train_params.lr, weight_decay=cfg.discriminator_train_params.weight_decay, betas=cfg.discriminator_train_params.betas, eps=cfg.discriminator_train_params.eps) loss_dict['scheduler_D'] = CosineAnnealingLR( loss_dict['optimizer_D'], T_max=scheduler_max_steps, eta_min=1e-6 ) if cfg.loss_params.mouth_gan_loss > 0: loss_dict['mouth_discriminator'] = MultiScaleDiscriminator( **cfg.model_params.discriminator_params).to(accelerator.device) loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator']) loss_dict['mouth_optimizer_D'] = optim.AdamW( loss_dict['mouth_discriminator'].parameters(), lr=cfg.discriminator_train_params.lr, weight_decay=cfg.discriminator_train_params.weight_decay, betas=cfg.discriminator_train_params.betas, eps=cfg.discriminator_train_params.eps) loss_dict['mouth_scheduler_D'] = CosineAnnealingLR( loss_dict['mouth_optimizer_D'], T_max=scheduler_max_steps, eta_min=1e-6 ) return loss_dict def initialize_syncnet(cfg, accelerator, weight_dtype): """Initialize SyncNet model""" if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight: if cfg.data.n_sample_frames != 16: raise ValueError( f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16." ) syncnet_config = OmegaConf.load(cfg.syncnet_config_path) syncnet = SyncNet(OmegaConf.to_container( syncnet_config.model)).to(accelerator.device) print( f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}") checkpoint = torch.load( syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device) syncnet.load_state_dict(checkpoint["state_dict"]) syncnet.to(dtype=weight_dtype) syncnet.requires_grad_(False) syncnet.eval() return syncnet return None def initialize_vgg(cfg, accelerator): """Initialize VGG model""" if cfg.loss_params.vgg_loss > 0: vgg_IN = vgg_face.Vgg19().to(accelerator.device,) pyramid = vgg_face.ImagePyramide( cfg.loss_params.pyramid_scale, 3).to(accelerator.device) vgg_IN.eval() downsampler = Interpolate( size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device) return vgg_IN, pyramid, downsampler return None, None, None def validation( cfg, val_dataloader, net, vae, wav2vec, accelerator, save_dir, global_step, weight_dtype, syncnet_score=1, ): """Validation function for model evaluation""" net.eval() # Set the model to evaluation mode for batch in val_dataloader: # The same ref_latents ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to( accelerator.device, non_blocking=True ) pixel_values = batch["pixel_values_vid"].to(weight_dtype).to( accelerator.device, non_blocking=True ) bsz, num_frames, c, h, w = ref_pixel_values.shape audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype) # audio feature for unet audio_prompts = rearrange( audio_prompts, 'b f c h w-> (b f) c h w' ) audio_prompts = rearrange( audio_prompts, '(b f) c h w -> (b f) (c h) w', b=bsz ) # different masked_latents image_pred_train = get_image_pred( pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype) image_pred_infer = get_image_pred( ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype) process_and_save_images( batch, image_pred_train, image_pred_infer, save_dir, global_step, accelerator, cfg.num_images_to_keep, syncnet_score ) # only infer 1 image in validation break net.train() # Set the model back to training mode