From 1ab53a626b806c1b2778272dae6a600d8d941b2d Mon Sep 17 00:00:00 2001 From: Zhizhou Zhong Date: Fri, 4 Apr 2025 22:10:03 +0800 Subject: [PATCH] feat: data preprocessing and training (#294) * docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training --- .gitignore | 6 +- README.md | 77 +++- configs/training/gpu.yaml | 21 ++ configs/training/preprocess.yaml | 31 ++ configs/training/stage1.yaml | 89 +++++ configs/training/stage2.yaml | 89 +++++ configs/training/syncnet.yaml | 19 + inference.sh | 2 +- musetalk/data/audio.py | 168 +++++++++ musetalk/data/dataset.py | 607 +++++++++++++++++++++++++++++++ musetalk/data/sample_method.py | 233 ++++++++++++ musetalk/loss/basic_loss.py | 81 +++++ musetalk/loss/conv.py | 44 +++ musetalk/loss/discriminator.py | 145 ++++++++ musetalk/loss/resnet.py | 152 ++++++++ musetalk/loss/syncnet.py | 95 +++++ musetalk/loss/vgg_face.py | 237 ++++++++++++ musetalk/models/syncnet.py | 240 ++++++++++++ musetalk/utils/training_utils.py | 337 +++++++++++++++++ musetalk/utils/utils.py | 251 ++++++++++++- scripts/preprocess.py | 322 ++++++++++++++++ train.py | 580 +++++++++++++++++++++++++++++ train.sh | 34 ++ 23 files changed, 3854 insertions(+), 6 deletions(-) create mode 100755 configs/training/gpu.yaml create mode 100755 configs/training/preprocess.yaml create mode 100755 configs/training/stage1.yaml create mode 100755 configs/training/stage2.yaml create mode 100644 configs/training/syncnet.yaml create mode 100755 musetalk/data/audio.py create mode 100755 musetalk/data/dataset.py create mode 100755 musetalk/data/sample_method.py create mode 100755 musetalk/loss/basic_loss.py create mode 100755 musetalk/loss/conv.py create mode 100755 musetalk/loss/discriminator.py create mode 100755 musetalk/loss/resnet.py create mode 100755 musetalk/loss/syncnet.py create mode 100755 musetalk/loss/vgg_face.py create mode 100755 musetalk/models/syncnet.py create mode 100644 musetalk/utils/training_utils.py create mode 100755 scripts/preprocess.py create mode 100755 train.py create mode 100644 train.sh diff --git a/.gitignore b/.gitignore index b0f4f3c..aa31084 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,8 @@ results/ ./models **/__pycache__/ *.py[cod] -*$py.class \ No newline at end of file +*$py.class +dataset/ +ffmpeg* +debug +exp_out \ No newline at end of file diff --git a/README.md b/README.md index 2f4664d..5d29054 100644 --- a/README.md +++ b/README.md @@ -130,8 +130,9 @@ https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde - [x] codes for real-time inference. - [x] [technical report](https://arxiv.org/abs/2410.10122v2). - [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122). -- [x] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon). -- [ ] training and dataloader code (Expected completion on 04/04/2025). +- [x] realtime inference code for 1.5 version. +- [x] training and data preprocessing codes. +- [ ] **always** welcome to submit issues and PRs to improve this repository! 😊 # Getting Started @@ -187,6 +188,7 @@ huggingface-cli download TMElyralab/MuseTalk --local-dir models/ - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main) - [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch) - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth) + - [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main) Finally, these weights should be organized in `models` as follows: @@ -198,6 +200,8 @@ Finally, these weights should be organized in `models` as follows: β”œβ”€β”€ musetalkV15 β”‚ └── musetalk.json β”‚ └── unet.pth +β”œβ”€β”€ syncnet +β”‚ └── latentsync_syncnet.pt β”œβ”€β”€ dwpose β”‚ └── dw-ll_ucoco_384.pth β”œβ”€β”€ face-parse-bisent @@ -265,6 +269,73 @@ For faster generation without saving images, you can use: python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images ``` +## Training + +### Data Preparation +To train MuseTalk, you need to prepare your dataset following these steps: + +1. **Place your source videos** + + For example, if you're using the HDTF dataset, place all your video files in `./dataset/HDTF/source`. + +2. **Run the preprocessing script** + ```bash + python -m scripts.preprocess --config ./configs/training/preprocess.yaml + ``` + This script will: + - Extract frames from videos + - Detect and align faces + - Generate audio features + - Create the necessary data structure for training + +### Training Process +After data preprocessing, you can start the training process: + +1. **First Stage** + ```bash + sh train.sh stage1 + ``` + +2. **Second Stage** + ```bash + sh train.sh stage2 + ``` + +### Configuration Adjustment +Before starting the training, you should adjust the configuration files according to your hardware and requirements: + +1. **GPU Configuration** (`configs/training/gpu.yaml`): + - `gpu_ids`: Specify the GPU IDs you want to use (e.g., "0,1,2,3") + - `num_processes`: Set this to match the number of GPUs you're using + +2. **Stage 1 Configuration** (`configs/training/stage1.yaml`): + - `data.train_bs`: Adjust batch size based on your GPU memory (default: 32) + - `data.n_sample_frames`: Number of sampled frames per video (default: 1) + +3. **Stage 2 Configuration** (`configs/training/stage2.yaml`): + - `random_init_unet`: Must be set to `False` to use the model from stage 1 + - `data.train_bs`: Smaller batch size due to high GPU memory cost (default: 2) + - `data.n_sample_frames`: Higher value for temporal consistency (default: 16) + - `solver.gradient_accumulation_steps`: Increase to simulate larger batch sizes (default: 8) + + +### GPU Memory Requirements +Based on our testing on a machine with 8 NVIDIA H20 GPUs: + +#### Stage 1 Memory Usage +| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation | +|:----------:|:----------------------:|:--------------:|:--------------:| +| 8 | 1 | ~32GB | | +| 16 | 1 | ~45GB | | +| 32 | 1 | ~74GB | βœ“ | + +#### Stage 2 Memory Usage +| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation | +|:----------:|:----------------------:|:--------------:|:--------------:| +| 1 | 8 | ~54GB | | +| 2 | 2 | ~80GB | | +| 2 | 8 | ~85GB | βœ“ | + ## TestCases For 1.0 @@ -368,7 +439,7 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference). # Acknowledgement -1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch). +1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch) and [LatentSync](https://huggingface.co/ByteDance/LatentSync/tree/main). 1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings). 1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets. diff --git a/configs/training/gpu.yaml b/configs/training/gpu.yaml new file mode 100755 index 0000000..e2ccee1 --- /dev/null +++ b/configs/training/gpu.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: True +deepspeed_config: + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: False + zero_stage: 2 + +distributed_type: DEEPSPEED +downcast_bf16: 'no' +gpu_ids: "5, 7" # modify this according to your GPU number +machine_rank: 0 +main_training_function: main +num_machines: 1 +num_processes: 2 # it should be the same as the number of GPUs +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/configs/training/preprocess.yaml b/configs/training/preprocess.yaml new file mode 100755 index 0000000..c7c12e7 --- /dev/null +++ b/configs/training/preprocess.yaml @@ -0,0 +1,31 @@ +clip_len_second: 30 # the length of the video clip +video_root_raw: "./dataset/HDTF/source/" # the path of the original video +val_list_hdtf: + - RD_Radio7_000 + - RD_Radio8_000 + - RD_Radio9_000 + - WDA_TinaSmith_000 + - WDA_TomCarper_000 + - WDA_TomPerez_000 + - WDA_TomUdall_000 + - WDA_VeronicaEscobar0_000 + - WDA_VeronicaEscobar1_000 + - WDA_WhipJimClyburn_000 + - WDA_XavierBecerra_000 + - WDA_XavierBecerra_001 + - WDA_XavierBecerra_002 + - WDA_ZoeLofgren_000 + - WRA_SteveScalise1_000 + - WRA_TimScott_000 + - WRA_ToddYoung_000 + - WRA_TomCotton_000 + - WRA_TomPrice_000 + - WRA_VickyHartzler_000 + +# following dir will be automatically generated +video_root_25fps: "./dataset/HDTF/video_root_25fps/" +video_file_list: "./dataset/HDTF/video_file_list.txt" +video_audio_clip_root: "./dataset/HDTF/video_audio_clip_root/" +meta_root: "./dataset/HDTF/meta/" +video_clip_file_list_train: "./dataset/HDTF/train.txt" +video_clip_file_list_val: "./dataset/HDTF/val.txt" diff --git a/configs/training/stage1.yaml b/configs/training/stage1.yaml new file mode 100755 index 0000000..952e7b6 --- /dev/null +++ b/configs/training/stage1.yaml @@ -0,0 +1,89 @@ +exp_name: 'test' # Name of the experiment +output_dir: './exp_out/stage1/' # Directory to save experiment outputs +unet_sub_folder: musetalk # Subfolder name for UNet model +random_init_unet: True # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2) +whisper_path: "./models/whisper" # Path to the Whisper model +pretrained_model_name_or_path: "./models" # Path to pretrained models +resume_from_checkpoint: True # Whether to resume training from a checkpoint +padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region +vae_type: "sd-vae" # Type of VAE model to use +# Validation parameters +num_images_to_keep: 8 # Number of validation images to keep +ref_dropout_rate: 0 # Dropout rate for reference images +syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration +use_adapted_weight: False # Whether to use adapted weights for loss calculation +cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping +cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping +crop_type: "crop_resize" # Type of cropping method +random_margin_method: "normal" # Method for random margin generation +num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet + +data: + dataset_key: "HDTF" # Dataset to use for training + train_bs: 32 # Training batch size (actual batch size is train_bs*n_sample_frames) + image_size: 256 # Size of input images + n_sample_frames: 1 # Number of frames to sample per batch + num_workers: 8 # Number of data loading workers + audio_padding_length_left: 2 # Left padding length for audio features + audio_padding_length_right: 2 # Right padding length for audio features + sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames + top_k_ratio: 0.51 # Ratio for top-k sampling + contorl_face_min_size: True # Whether to control minimum face size + min_face_size: 150 # Minimum face size in pixels + +loss_params: + l1_loss: 1.0 # Weight for L1 loss + vgg_loss: 0.01 # Weight for VGG perceptual loss + vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers + pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid + gan_loss: 0 # Weight for GAN loss + fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss + sync_loss: 0 # Weight for sync loss + mouth_gan_loss: 0 # Weight for mouth-specific GAN loss + +model_params: + discriminator_params: + scales: [1] # Scales for discriminator + block_expansion: 32 # Expansion factor for discriminator blocks + max_features: 512 # Maximum number of features in discriminator + num_blocks: 4 # Number of blocks in discriminator + sn: True # Whether to use spectral normalization + image_channel: 3 # Number of image channels + estimate_jacobian: False # Whether to estimate Jacobian + +discriminator_train_params: + lr: 0.000005 # Learning rate for discriminator + eps: 0.00000001 # Epsilon for optimizer + weight_decay: 0.01 # Weight decay for optimizer + patch_size: 1 # Size of patches for discriminator + betas: [0.5, 0.999] # Beta parameters for Adam optimizer + epochs: 10000 # Number of training epochs + start_gan: 1000 # Step to start GAN training + +solver: + gradient_accumulation_steps: 1 # Number of steps for gradient accumulation + uncond_steps: 10 # Number of unconditional steps + mixed_precision: 'fp32' # Precision mode for training + enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention + gradient_checkpointing: True # Whether to use gradient checkpointing + max_train_steps: 250000 # Maximum number of training steps + max_grad_norm: 1.0 # Maximum gradient norm for clipping + # Learning rate parameters + learning_rate: 2.0e-5 # Base learning rate + scale_lr: False # Whether to scale learning rate + lr_warmup_steps: 1000 # Number of warmup steps for learning rate + lr_scheduler: "linear" # Type of learning rate scheduler + # Optimizer parameters + use_8bit_adam: False # Whether to use 8-bit Adam optimizer + adam_beta1: 0.5 # Beta1 parameter for Adam optimizer + adam_beta2: 0.999 # Beta2 parameter for Adam optimizer + adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer + adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer + +total_limit: 10 # Maximum number of checkpoints to keep +save_model_epoch_interval: 250000 # Interval between model saves +checkpointing_steps: 10000 # Number of steps between checkpoints +val_freq: 2000 # Frequency of validation + +seed: 41 # Random seed for reproducibility + diff --git a/configs/training/stage2.yaml b/configs/training/stage2.yaml new file mode 100755 index 0000000..9431fbb --- /dev/null +++ b/configs/training/stage2.yaml @@ -0,0 +1,89 @@ +exp_name: 'test' # Name of the experiment +output_dir: './exp_out/stage2/' # Directory to save experiment outputs +unet_sub_folder: musetalk # Subfolder name for UNet model +random_init_unet: False # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2) +whisper_path: "./models/whisper" # Path to the Whisper model +pretrained_model_name_or_path: "./models" # Path to pretrained models +resume_from_checkpoint: True # Whether to resume training from a checkpoint +padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region +vae_type: "sd-vae" # Type of VAE model to use +# Validation parameters +num_images_to_keep: 8 # Number of validation images to keep +ref_dropout_rate: 0 # Dropout rate for reference images +syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration +use_adapted_weight: False # Whether to use adapted weights for loss calculation +cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping +cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping +crop_type: "dynamic_margin_crop_resize" # Type of cropping method +random_margin_method: "normal" # Method for random margin generation +num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet + +data: + dataset_key: "HDTF" # Dataset to use for training + train_bs: 2 # Training batch size (actual batch size is train_bs*n_sample_frames) + image_size: 256 # Size of input images + n_sample_frames: 16 # Number of frames to sample per batch + num_workers: 8 # Number of data loading workers + audio_padding_length_left: 2 # Left padding length for audio features + audio_padding_length_right: 2 # Right padding length for audio features + sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames + top_k_ratio: 0.51 # Ratio for top-k sampling + contorl_face_min_size: True # Whether to control minimum face size + min_face_size: 200 # Minimum face size in pixels + +loss_params: + l1_loss: 1.0 # Weight for L1 loss + vgg_loss: 0.01 # Weight for VGG perceptual loss + vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers + pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid + gan_loss: 0.01 # Weight for GAN loss + fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss + sync_loss: 0.05 # Weight for sync loss + mouth_gan_loss: 0.01 # Weight for mouth-specific GAN loss + +model_params: + discriminator_params: + scales: [1] # Scales for discriminator + block_expansion: 32 # Expansion factor for discriminator blocks + max_features: 512 # Maximum number of features in discriminator + num_blocks: 4 # Number of blocks in discriminator + sn: True # Whether to use spectral normalization + image_channel: 3 # Number of image channels + estimate_jacobian: False # Whether to estimate Jacobian + +discriminator_train_params: + lr: 0.000005 # Learning rate for discriminator + eps: 0.00000001 # Epsilon for optimizer + weight_decay: 0.01 # Weight decay for optimizer + patch_size: 1 # Size of patches for discriminator + betas: [0.5, 0.999] # Beta parameters for Adam optimizer + epochs: 10000 # Number of training epochs + start_gan: 1000 # Step to start GAN training + +solver: + gradient_accumulation_steps: 8 # Number of steps for gradient accumulation + uncond_steps: 10 # Number of unconditional steps + mixed_precision: 'fp32' # Precision mode for training + enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention + gradient_checkpointing: True # Whether to use gradient checkpointing + max_train_steps: 250000 # Maximum number of training steps + max_grad_norm: 1.0 # Maximum gradient norm for clipping + # Learning rate parameters + learning_rate: 5.0e-6 # Base learning rate + scale_lr: False # Whether to scale learning rate + lr_warmup_steps: 1000 # Number of warmup steps for learning rate + lr_scheduler: "linear" # Type of learning rate scheduler + # Optimizer parameters + use_8bit_adam: False # Whether to use 8-bit Adam optimizer + adam_beta1: 0.5 # Beta1 parameter for Adam optimizer + adam_beta2: 0.999 # Beta2 parameter for Adam optimizer + adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer + adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer + +total_limit: 10 # Maximum number of checkpoints to keep +save_model_epoch_interval: 250000 # Interval between model saves +checkpointing_steps: 2000 # Number of steps between checkpoints +val_freq: 2000 # Frequency of validation + +seed: 41 # Random seed for reproducibility + diff --git a/configs/training/syncnet.yaml b/configs/training/syncnet.yaml new file mode 100644 index 0000000..88494a7 --- /dev/null +++ b/configs/training/syncnet.yaml @@ -0,0 +1,19 @@ +# This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/configs/training/syncnet_16_pixel.yaml). +model: + audio_encoder: # input (1, 80, 52) + in_channels: 1 + block_out_channels: [32, 64, 128, 256, 512, 1024, 2048] + downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]] + attn_blocks: [0, 0, 0, 0, 0, 0, 0] + dropout: 0.0 + visual_encoder: # input (48, 128, 256) + in_channels: 48 + block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048] + downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2] + attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0] + dropout: 0.0 + +ckpt: + resume_ckpt_path: "" + inference_ckpt_path: ./models/syncnet/latentsync_syncnet.pt # this pretrained model is from LatentSync (https://huggingface.co/ByteDance/LatentSync/tree/main) + save_ckpt_steps: 2500 diff --git a/inference.sh b/inference.sh index 43cccd9..8bc0c39 100644 --- a/inference.sh +++ b/inference.sh @@ -59,7 +59,7 @@ cmd_args="--inference_config $config_path \ --result_dir $result_dir \ --unet_model_path $unet_model_path \ --unet_config $unet_config \ - --version $version_ar" + --version $version_arg" # Add realtime-specific arguments if in realtime mode if [ "$mode" = "realtime" ]; then diff --git a/musetalk/data/audio.py b/musetalk/data/audio.py new file mode 100755 index 0000000..d80ba4f --- /dev/null +++ b/musetalk/data/audio.py @@ -0,0 +1,168 @@ +import librosa +import librosa.filters +import numpy as np +from scipy import signal +from scipy.io import wavfile + +class HParams: + # copy from wav2lip + def __init__(self): + self.n_fft = 800 + self.hop_size = 200 + self.win_size = 800 + self.sample_rate = 16000 + self.frame_shift_ms = None + self.signal_normalization = True + + self.allow_clipping_in_normalization = True + self.symmetric_mels = True + self.max_abs_value = 4.0 + self.preemphasize = True + self.preemphasis = 0.97 + self.min_level_db = -100 + self.ref_level_db = 20 + self.fmin = 55 + self.fmax=7600 + + self.use_lws=False + self.num_mels=80 # Number of mel-spectrogram channels and local conditioning dimensionality + self.rescale=True # Whether to rescale audio prior to preprocessing + self.rescaling_max=0.9 # Rescaling value + self.use_lws=False + + +hp = HParams() + +def load_wav(path, sr): + return librosa.core.load(path, sr=sr)[0] +#def load_wav(path, sr): +# audio, sr_native = sf.read(path) +# if sr != sr_native: +# audio = librosa.resample(audio.T, sr_native, sr).T +# return audio + +def save_wav(wav, path, sr): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + #proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + +def save_wavenet_wav(wav, path, sr): + librosa.output.write_wav(path, wav, sr=sr) + +def preemphasis(wav, k, preemphasize=True): + if preemphasize: + return signal.lfilter([1, -k], [1], wav) + return wav + +def inv_preemphasis(wav, k, inv_preemphasize=True): + if inv_preemphasize: + return signal.lfilter([1], [1, -k], wav) + return wav + +def get_hop_size(): + hop_size = hp.hop_size + if hop_size is None: + assert hp.frame_shift_ms is not None + hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) + return hop_size + +def linearspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(np.abs(D)) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def melspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def _lws_processor(): + import lws + return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") + +def _stft(y): + if hp.use_lws: + return _lws_processor(hp).stft(y).T + else: + return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) + +########################################################## +#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) +def num_frames(length, fsize, fshift): + """Compute number of time frames of spectrogram + """ + pad = (fsize - fshift) + if length % fshift == 0: + M = (length + pad * 2 - fsize) // fshift + 1 + else: + M = (length + pad * 2 - fsize) // fshift + 2 + return M + + +def pad_lr(x, fsize, fshift): + """Compute left and right padding + """ + M = num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r +########################################################## +#Librosa correct padding +def librosa_pad_lr(x, fsize, fshift): + return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + +# Conversions +_mel_basis = None + +def _linear_to_mel(spectogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis() + return np.dot(_mel_basis, spectogram) + +def _build_mel_basis(): + assert hp.fmax <= hp.sample_rate // 2 + return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, + fmin=hp.fmin, fmax=hp.fmax) + +def _amp_to_db(x): + min_level = np.exp(hp.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + +def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + +def _normalize(S): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, + -hp.max_abs_value, hp.max_abs_value) + else: + return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) + + assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 + if hp.symmetric_mels: + return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value + else: + return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) + +def _denormalize(D): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return (((np.clip(D, -hp.max_abs_value, + hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + + hp.min_level_db) + else: + return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) + + if hp.symmetric_mels: + return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) + else: + return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) \ No newline at end of file diff --git a/musetalk/data/dataset.py b/musetalk/data/dataset.py new file mode 100755 index 0000000..4b72449 --- /dev/null +++ b/musetalk/data/dataset.py @@ -0,0 +1,607 @@ +import os +import numpy as np +import random +from PIL import Image +import torch +from torch.utils.data import Dataset, ConcatDataset +import torchvision.transforms as transforms +from transformers import AutoFeatureExtractor +import librosa +import time +import json +import math +from decord import AudioReader, VideoReader +from decord.ndarray import cpu + +from musetalk.data.sample_method import get_src_idx, shift_landmarks_to_face_coordinates, resize_landmark +from musetalk.data import audio + +syncnet_mel_step_size = math.ceil(16 / 5 * 16) # latentsync + + +class FaceDataset(Dataset): + """Dataset class for loading and processing video data + + Each video can be represented as: + - Concatenated frame images + - '.mp4' or '.gif' files + - Folder containing all frames + """ + def __init__(self, + cfg, + list_paths, + root_path='./dataset/', + repeats=None): + # Initialize dataset paths + meta_paths = [] + if repeats is None: + repeats = [1] * len(list_paths) + assert len(repeats) == len(list_paths) + + # Load data list + for list_path, repeat_time in zip(list_paths, repeats): + with open(list_path, 'r') as f: + num = 0 + f.readline() # Skip header line + for line in f.readlines(): + line_info = line.strip() + meta = line_info.split() + meta = meta[0] + meta_paths.extend([os.path.join(root_path, meta)] * repeat_time) + num += 1 + print(f'{list_path}: {num} x {repeat_time} = {num * repeat_time} samples') + + # Set basic attributes + self.meta_paths = meta_paths + self.root_path = root_path + self.image_size = cfg['image_size'] + self.min_face_size = cfg['min_face_size'] + self.T = cfg['T'] + self.sample_method = cfg['sample_method'] + self.top_k_ratio = cfg['top_k_ratio'] + self.max_attempts = 200 + self.padding_pixel_mouth = cfg['padding_pixel_mouth'] + + # Cropping related parameters + self.crop_type = cfg['crop_type'] + self.jaw2edge_margin_mean = cfg['cropping_jaw2edge_margin_mean'] + self.jaw2edge_margin_std = cfg['cropping_jaw2edge_margin_std'] + self.random_margin_method = cfg['random_margin_method'] + + # Image transformations + self.to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + self.pose_to_tensor = transforms.Compose([ + transforms.ToTensor(), + ]) + + # Feature extractor + self.feature_extractor = AutoFeatureExtractor.from_pretrained(cfg['whisper_path']) + self.contorl_face_min_size = cfg["contorl_face_min_size"] + + print("The sample method is: ", self.sample_method) + print(f"only use face size > {self.min_face_size}", self.contorl_face_min_size) + + def generate_random_value(self): + """Generate random value + + Returns: + float: Generated random value + """ + if self.random_margin_method == "uniform": + random_value = np.random.uniform( + self.jaw2edge_margin_mean - self.jaw2edge_margin_std, + self.jaw2edge_margin_mean + self.jaw2edge_margin_std + ) + elif self.random_margin_method == "normal": + random_value = np.random.normal( + loc=self.jaw2edge_margin_mean, + scale=self.jaw2edge_margin_std + ) + random_value = np.clip( + random_value, + self.jaw2edge_margin_mean - self.jaw2edge_margin_std, + self.jaw2edge_margin_mean + self.jaw2edge_margin_std, + ) + else: + raise ValueError(f"Invalid random margin method: {self.random_margin_method}") + return max(0, random_value) + + def dynamic_margin_crop(self, img, original_bbox, extra_margin=None): + """Dynamically crop image with dynamic margin + + Args: + img: Input image + original_bbox: Original bounding box + extra_margin: Extra margin + + Returns: + tuple: (x1, y1, x2, y2, extra_margin) + """ + if extra_margin is None: + extra_margin = self.generate_random_value() + w, h = img.size + x1, y1, x2, y2 = original_bbox + y2 = min(y2 + int(extra_margin), h) + return x1, y1, x2, y2, extra_margin + + def crop_resize_img(self, img, bbox, crop_type='crop_resize', extra_margin=None): + """Crop and resize image + + Args: + img: Input image + bbox: Bounding box + crop_type: Type of cropping + extra_margin: Extra margin + + Returns: + tuple: (Processed image, extra_margin, mask_scaled_factor) + """ + mask_scaled_factor = 1. + if crop_type == 'crop_resize': + x1, y1, x2, y2 = bbox + img = img.crop((x1, y1, x2, y2)) + img = img.resize((self.image_size, self.image_size), Image.LANCZOS) + elif crop_type == 'dynamic_margin_crop_resize': + x1, y1, x2, y2, extra_margin = self.dynamic_margin_crop(img, bbox, extra_margin) + w_original, _ = img.size + img = img.crop((x1, y1, x2, y2)) + w_cropped, _ = img.size + mask_scaled_factor = w_cropped / w_original + img = img.resize((self.image_size, self.image_size), Image.LANCZOS) + elif crop_type == 'resize': + w, h = img.size + scale = np.sqrt(self.image_size ** 2 / (h * w)) + new_w = int(w * scale) / 64 * 64 + new_h = int(h * scale) / 64 * 64 + img = img.resize((new_w, new_h), Image.LANCZOS) + return img, extra_margin, mask_scaled_factor + + def get_audio_file(self, wav_path, start_index): + """Get audio file features + + Args: + wav_path: Audio file path + start_index: Starting index + + Returns: + tuple: (Audio features, start index) + """ + if not os.path.exists(wav_path): + return None + audio_input_librosa, sampling_rate = librosa.load(wav_path, sr=16000) + assert sampling_rate == 16000 + + while start_index >= 25 * 30: + audio_input = audio_input_librosa[16000*30:] + start_index -= 25 * 30 + if start_index + 2 * 25 >= 25 * 30: + start_index -= 4 * 25 + audio_input = audio_input_librosa[16000*4:16000*34] + else: + audio_input = audio_input_librosa[:16000*30] + + assert 2 * (start_index) >= 0 + assert 2 * (start_index + 2 * 25) <= 1500 + + audio_input = self.feature_extractor( + audio_input, + return_tensors="pt", + sampling_rate=sampling_rate + ).input_features + return audio_input, start_index + + def get_audio_file_mel(self, wav_path, start_index): + """Get mel spectrogram of audio file + + Args: + wav_path: Audio file path + start_index: Starting index + + Returns: + tuple: (Mel spectrogram, start index) + """ + if not os.path.exists(wav_path): + return None + + audio_input, sampling_rate = librosa.load(wav_path, sr=16000) + assert sampling_rate == 16000 + + audio_input = self.mel_feature_extractor(audio_input) + return audio_input, start_index + + def mel_feature_extractor(self, audio_input): + """Extract mel spectrogram features + + Args: + audio_input: Input audio + + Returns: + ndarray: Mel spectrogram features + """ + orig_mel = audio.melspectrogram(audio_input) + return orig_mel.T + + def crop_audio_window(self, spec, start_frame_num, fps=25): + """Crop audio window + + Args: + spec: Spectrogram + start_frame_num: Starting frame number + fps: Frames per second + + Returns: + ndarray: Cropped spectrogram + """ + start_idx = int(80. * (start_frame_num / float(fps))) + end_idx = start_idx + syncnet_mel_step_size + return spec[start_idx: end_idx, :] + + def get_syncnet_input(self, video_path): + """Get SyncNet input features + + Args: + video_path: Video file path + + Returns: + ndarray: SyncNet input features + """ + ar = AudioReader(video_path, sample_rate=16000) + original_mel = audio.melspectrogram(ar[:].asnumpy().squeeze(0)) + return original_mel.T + + def get_resized_mouth_mask( + self, + img_resized, + landmark_array, + face_shape, + padding_pixel_mouth=0, + image_size=256, + crop_margin=0 + ): + landmark_array = np.array(landmark_array) + resized_landmark = resize_landmark( + landmark_array, w=face_shape[0], h=face_shape[1], new_w=image_size, new_h=image_size) + + landmark_array = np.array(resized_landmark[48 : 67]) # the lip landmarks in 68 landmarks format + min_x, min_y = np.min(landmark_array, axis=0) + max_x, max_y = np.max(landmark_array, axis=0) + min_x = min_x - padding_pixel_mouth + max_x = max_x + padding_pixel_mouth + + # Calculate x-axis length and use it for y-axis + width = max_x - min_x + + # Calculate old center point + center_y = (max_y + min_y) / 2 + + # Determine new min_y and max_y based on width + min_y = center_y - width / 4 + max_y = center_y + width / 4 + + # Adjust mask position for dynamic crop, shift y-axis + min_y = min_y - crop_margin + max_y = max_y - crop_margin + + # Prevent out of bounds + min_x = max(min_x, 0) + min_y = max(min_y, 0) + max_x = min(max_x, face_shape[0]) + max_y = min(max_y, face_shape[1]) + + mask = np.zeros_like(np.array(img_resized)) + mask[round(min_y):round(max_y), round(min_x):round(max_x)] = 255 + return Image.fromarray(mask) + + def __len__(self): + return 100000 + + def __getitem__(self, idx): + attempts = 0 + while attempts < self.max_attempts: + try: + meta_path = random.sample(self.meta_paths, k=1)[0] + with open(meta_path, 'r') as f: + meta_data = json.load(f) + except Exception as e: + print(f"meta file error:{meta_path}") + print(e) + attempts += 1 + time.sleep(0.1) + continue + + video_path = meta_data["mp4_path"] + wav_path = meta_data["wav_path"] + bbox_list = meta_data["face_list"] + landmark_list = meta_data["landmark_list"] + T = self.T + + s = 0 + e = meta_data["frames"] + len_valid_clip = e - s + + if len_valid_clip < T * 10: + attempts += 1 + print(f"video {video_path} has less than {T * 10} frames") + continue + + try: + cap = VideoReader(video_path, fault_tol=1, ctx=cpu(0)) + total_frames = len(cap) + assert total_frames == len(landmark_list) + assert total_frames == len(bbox_list) + landmark_shape = np.array(landmark_list).shape + if landmark_shape != (total_frames, 68, 2): + attempts += 1 + print(f"video {video_path} has invalid landmark shape: {landmark_shape}, expected: {(total_frames, 68, 2)}") # we use 68 landmarks + continue + except Exception as e: + print(f"video file error:{video_path}") + print(e) + attempts += 1 + time.sleep(0.1) + continue + + shift_landmarks, bbox_list_union, face_shapes = shift_landmarks_to_face_coordinates( + landmark_list, + bbox_list + ) + if self.contorl_face_min_size and face_shapes[0][0] < self.min_face_size: + print(f"video {video_path} has face size {face_shapes[0][0]} less than minimum required {self.min_face_size}") + attempts += 1 + continue + + step = 1 + drive_idx_start = random.randint(s, e - T * step) + drive_idx_list = list( + range(drive_idx_start, drive_idx_start + T * step, step)) + assert len(drive_idx_list) == T + + src_idx_list = [] + list_index_out_of_range = False + for drive_idx in drive_idx_list: + src_idx = get_src_idx( + drive_idx, T, self.sample_method, shift_landmarks, face_shapes, self.top_k_ratio) + if src_idx is None: + list_index_out_of_range = True + break + src_idx = min(src_idx, e - 1) + src_idx = max(src_idx, s) + src_idx_list.append(src_idx) + + if list_index_out_of_range: + attempts += 1 + print(f"video {video_path} has invalid source index for drive frames") + continue + + ref_face_valid_flag = True + extra_margin = self.generate_random_value() + + # Get reference images + ref_imgs = [] + for src_idx in src_idx_list: + imSrc = Image.fromarray(cap[src_idx].asnumpy()) + bbox_s = bbox_list_union[src_idx] + imSrc, _, _ = self.crop_resize_img( + imSrc, + bbox_s, + self.crop_type, + extra_margin=None + ) + if self.contorl_face_min_size and min(imSrc.size[0], imSrc.size[1]) < self.min_face_size: + ref_face_valid_flag = False + break + ref_imgs.append(imSrc) + + if not ref_face_valid_flag: + attempts += 1 + print(f"video {video_path} has reference face size smaller than minimum required {self.min_face_size}") + continue + + # Get target images and masks + imSameIDs = [] + bboxes = [] + face_masks = [] + face_mask_valid = True + target_face_valid_flag = True + + for drive_idx in drive_idx_list: + imSameID = Image.fromarray(cap[drive_idx].asnumpy()) + bbox_s = bbox_list_union[drive_idx] + imSameID, _ , mask_scaled_factor = self.crop_resize_img( + imSameID, + bbox_s, + self.crop_type, + extra_margin=extra_margin + ) + if self.contorl_face_min_size and min(imSameID.size[0], imSameID.size[1]) < self.min_face_size: + target_face_valid_flag = False + break + crop_margin = extra_margin * mask_scaled_factor + face_mask = self.get_resized_mouth_mask( + imSameID, + shift_landmarks[drive_idx], + face_shapes[drive_idx], + self.padding_pixel_mouth, + self.image_size, + crop_margin=crop_margin + ) + if np.count_nonzero(face_mask) == 0: + face_mask_valid = False + break + + if face_mask.size[1] == 0 or face_mask.size[0] == 0: + print(f"video {video_path} has invalid face mask size at frame {drive_idx}") + face_mask_valid = False + break + + imSameIDs.append(imSameID) + bboxes.append(bbox_s) + face_masks.append(face_mask) + + if not face_mask_valid: + attempts += 1 + print(f"video {video_path} has invalid face mask") + continue + + if not target_face_valid_flag: + attempts += 1 + print(f"video {video_path} has target face size smaller than minimum required {self.min_face_size}") + continue + + # Process audio features + audio_offset = drive_idx_list[0] + audio_step = step + fps = 25.0 / step + + try: + audio_feature, audio_offset = self.get_audio_file(wav_path, audio_offset) + _, audio_offset = self.get_audio_file_mel(wav_path, audio_offset) + audio_feature_mel = self.get_syncnet_input(video_path) + except Exception as e: + print(f"audio file error:{wav_path}") + print(e) + attempts += 1 + time.sleep(0.1) + continue + + mel = self.crop_audio_window(audio_feature_mel, audio_offset) + if mel.shape[0] != syncnet_mel_step_size: + attempts += 1 + print(f"video {video_path} has invalid mel spectrogram shape: {mel.shape}, expected: {syncnet_mel_step_size}") + continue + + mel = torch.FloatTensor(mel.T).unsqueeze(0) + + # Build sample dictionary + sample = dict( + pixel_values_vid=torch.stack( + [self.to_tensor(imSameID) for imSameID in imSameIDs], dim=0), + pixel_values_ref_img=torch.stack( + [self.to_tensor(ref_img) for ref_img in ref_imgs], dim=0), + pixel_values_face_mask=torch.stack( + [self.pose_to_tensor(face_mask) for face_mask in face_masks], dim=0), + audio_feature=audio_feature[0], + audio_offset=audio_offset, + audio_step=audio_step, + mel=mel, + wav_path=wav_path, + fps=fps, + ) + + return sample + + raise ValueError("Unable to find a valid sample after maximum attempts.") + +class HDTFDataset(FaceDataset): + """HDTF dataset class""" + def __init__(self, cfg): + root_path = './dataset/HDTF/meta' + list_paths = [ + './dataset/HDTF/train.txt', + ] + + + repeats = [10] + super().__init__(cfg, list_paths, root_path, repeats) + print('HDTFDataset: ', len(self)) + +class VFHQDataset(FaceDataset): + """VFHQ dataset class""" + def __init__(self, cfg): + root_path = './dataset/VFHQ/meta' + list_paths = [ + './dataset/VFHQ/train.txt', + ] + repeats = [1] + super().__init__(cfg, list_paths, root_path, repeats) + print('VFHQDataset: ', len(self)) + +def PortraitDataset(cfg=None): + """Return dataset based on configuration + + Args: + cfg: Configuration dictionary + + Returns: + Dataset: Combined dataset + """ + if cfg["dataset_key"] == "HDTF": + return ConcatDataset([HDTFDataset(cfg)]) + elif cfg["dataset_key"] == "VFHQ": + return ConcatDataset([VFHQDataset(cfg)]) + else: + print("############ use all dataset ############ ") + return ConcatDataset([HDTFDataset(cfg), VFHQDataset(cfg)]) + + +if __name__ == '__main__': + # Set random seeds for reproducibility + seed = 42 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # Create dataset with configuration parameters + dataset = PortraitDataset(cfg={ + 'T': 1, # Number of frames to process at once + 'random_margin_method': "normal", # Method for generating random margins: "normal" or "uniform" + 'dataset_key': "HDTF", # Dataset to use: "HDTF", "VFHQ", or None for both + 'image_size': 256, # Size of processed images (height and width) + 'sample_method': 'pose_similarity_and_mouth_dissimilarity', # Method for selecting reference frames + 'top_k_ratio': 0.51, # Ratio for top-k selection in reference frame sampling + 'contorl_face_min_size': True, # Whether to enforce minimum face size + 'padding_pixel_mouth': 10, # Padding pixels around mouth region in mask + 'min_face_size': 200, # Minimum face size requirement for dataset + 'whisper_path': "./models/whisper", # Path to Whisper model + 'cropping_jaw2edge_margin_mean': 10, # Mean margin for jaw-to-edge cropping + 'cropping_jaw2edge_margin_std': 10, # Standard deviation for jaw-to-edge cropping + 'crop_type': "dynamic_margin_crop_resize", # Type of cropping: "crop_resize", "dynamic_margin_crop_resize", or "resize" + }) + print(len(dataset)) + + import torchvision + os.makedirs('debug', exist_ok=True) + for i in range(10): # Check 10 samples + sample = dataset[0] + print(f"processing {i}") + + # Get images and mask + ref_img = (sample['pixel_values_ref_img'] + 1.0) / 2 # (b, c, h, w) + target_img = (sample['pixel_values_vid'] + 1.0) / 2 + face_mask = sample['pixel_values_face_mask'] + + # Print dimension information + print(f"ref_img shape: {ref_img.shape}") + print(f"target_img shape: {target_img.shape}") + print(f"face_mask shape: {face_mask.shape}") + + # Create visualization images + b, c, h, w = ref_img.shape + + # Apply mask only to target image + target_mask = face_mask + + # Keep reference image unchanged + ref_with_mask = ref_img.clone() + + # Create mask overlay for target image + target_with_mask = target_img.clone() + target_with_mask = target_with_mask * (1 - target_mask) + target_mask # Apply mask only to target + + # Save original images, mask, and overlay results + # First row: original images + # Second row: mask + # Third row: overlay effect + concatenated_img = torch.cat(( + ref_img, target_img, # Original images + torch.zeros_like(ref_img), target_mask, # Mask (black for ref) + ref_with_mask, target_with_mask # Overlay effect + ), dim=3) + + torchvision.utils.save_image( + concatenated_img, f'debug/mask_check_{i}.jpg', nrow=2) diff --git a/musetalk/data/sample_method.py b/musetalk/data/sample_method.py new file mode 100755 index 0000000..c7d1265 --- /dev/null +++ b/musetalk/data/sample_method.py @@ -0,0 +1,233 @@ +import numpy as np +import random + +def summarize_tensor(x): + return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)" + +def calculate_mouth_open_similarity(landmarks_list, select_idx,top_k=50,ascending=True): + num_landmarks = len(landmarks_list) + mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array + print(np.shape(landmarks_list)) + ## Calculate mouth opening ratios + for i, landmarks in enumerate(landmarks_list): + # Assuming landmarks are in the format [x, y] and accessible by index + mouth_top = landmarks[165] # Adjust index according to your landmarks format + mouth_bottom = landmarks[147] # Adjust index according to your landmarks format + mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom) + mouth_open_ratios[i] = mouth_open_ratio + + # Calculate differences matrix + differences_matrix = np.abs(mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx]) + differences_matrix_with_signs = mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx] + print(differences_matrix.shape) + # Find top_k similar indices for each landmark set + if ascending: + top_indices = np.argsort(differences_matrix[i])[:top_k] + else: + top_indices = np.argsort(-differences_matrix[i])[:top_k] + similar_landmarks_indices = top_indices.tolist() + similar_landmarks_distances = differences_matrix_with_signs[i].tolist() #ζ³¨ζ„θΏ™ι‡ŒδΈθ¦ζŽ’εΊ + + return similar_landmarks_indices, similar_landmarks_distances +############################################################################################# +def get_closed_mouth(landmarks_list,ascending=True,top_k=50): + num_landmarks = len(landmarks_list) + + mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array + ## Calculate mouth opening ratios + #print("landmarks shape",np.shape(landmarks_list)) + for i, landmarks in enumerate(landmarks_list): + # Assuming landmarks are in the format [x, y] and accessible by index + #print(landmarks[165]) + mouth_top = np.array(landmarks[165])# Adjust index according to your landmarks format + mouth_bottom = np.array(landmarks[147]) # Adjust index according to your landmarks format + mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom) + mouth_open_ratios[i] = mouth_open_ratio + + # Find top_k similar indices for each landmark set + if ascending: + top_indices = np.argsort(mouth_open_ratios)[:top_k] + else: + top_indices = np.argsort(-mouth_open_ratios)[:top_k] + return top_indices + +def calculate_landmarks_similarity(selected_idx, landmarks_list,image_shapes, start_index, end_index, top_k=50,ascending=True): + """ + Calculate the similarity between sets of facial landmarks and return the indices of the most similar faces. + + Parameters: + landmarks_list (list): A list containing sets of facial landmarks, each element is a set of landmarks. + image_shapes (list): A list containing the shape of each image, each element is a (width, height) tuple. + start_index (int): The starting index of the facial landmarks. + end_index (int): The ending index of the facial landmarks. + top_k (int): The number of most similar landmark sets to return. Default is 50. + ascending (bool): Controls the sorting order. If True, sort in ascending order; If False, sort in descending order. Default is True. + + Returns: + similar_landmarks_indices (list): A list containing the indices of the most similar facial landmarks for each face. + resized_landmarks (list): A list containing the resized facial landmarks. + """ + num_landmarks = len(landmarks_list) + resized_landmarks = [] + + # Preprocess landmarks + for i in range(num_landmarks): + landmark_array = np.array(landmarks_list[i]) + selected_landmarks = landmark_array[start_index:end_index] + resized_landmark = resize_landmark(selected_landmarks, w=image_shapes[i][0], h=image_shapes[i][1],new_w=256,new_h=256) + resized_landmarks.append(resized_landmark) + + resized_landmarks_array = np.array(resized_landmarks) # Convert list to array for easier manipulation + + # Calculate similarity + distances = np.linalg.norm(resized_landmarks_array - resized_landmarks_array[selected_idx][np.newaxis, :], axis=2) + overall_distances = np.mean(distances, axis=1) # Calculate mean distance for each set of landmarks + + if ascending: + sorted_indices = np.argsort(overall_distances) + similar_landmarks_indices = sorted_indices[1:top_k+1].tolist() # Exclude self and take top_k + else: + sorted_indices = np.argsort(-overall_distances) + similar_landmarks_indices = sorted_indices[0:top_k].tolist() + + return similar_landmarks_indices + +def process_bbox_musetalk(face_array, landmark_array): + x_min_face, y_min_face, x_max_face, y_max_face = map(int, face_array) + x_min_lm = min([int(x) for x, y in landmark_array]) + y_min_lm = min([int(y) for x, y in landmark_array]) + x_max_lm = max([int(x) for x, y in landmark_array]) + y_max_lm = max([int(y) for x, y in landmark_array]) + x_min = min(x_min_face, x_min_lm) + y_min = min(y_min_face, y_min_lm) + x_max = max(x_max_face, x_max_lm) + y_max = max(y_max_face, y_max_lm) + + x_min = max(x_min, 0) + y_min = max(y_min, 0) + + return [x_min, y_min, x_max, y_max] + +def shift_landmarks_to_face_coordinates(landmark_list, face_list): + """ + Translates the data in landmark_list to the coordinates of the cropped larger face. + + Parameters: + landmark_list (list): A list containing multiple sets of facial landmarks. + face_list (list): A list containing multiple facial images. + + Returns: + landmark_list_shift (list): The list of translated landmarks. + bbox_union (list): The list of union bounding boxes. + face_shapes (list): The list of facial shapes. + """ + landmark_list_shift = [] + bbox_union = [] + face_shapes = [] + + for i in range(len(face_list)): + landmark_array = np.array(landmark_list[i]) # 转捒为numpyζ•°η»„εΉΆεˆ›ε»Ίε‰―ζœ¬ + face_array = face_list[i] + f_landmark_bbox = process_bbox_musetalk(face_array, landmark_array) + x_min, y_min, x_max, y_max = f_landmark_bbox + landmark_array[:, 0] = landmark_array[:, 0] - f_landmark_bbox[0] + landmark_array[:, 1] = landmark_array[:, 1] - f_landmark_bbox[1] + landmark_list_shift.append(landmark_array) + bbox_union.append(f_landmark_bbox) + face_shapes.append((x_max - x_min, y_max - y_min)) + + return landmark_list_shift, bbox_union, face_shapes + +def resize_landmark(landmark, w, h, new_w, new_h): + landmark_norm = landmark / [w, h] + landmark_resized = landmark_norm * [new_w, new_h] + + return landmark_resized + +def get_src_idx(drive_idx, T, sample_method,landmarks_list,image_shapes,top_k_ratio): + """ + Calculate the source index (src_idx) based on the given drive index, T, s, e, and sampling method. + + Parameters: + - drive_idx (int): The current drive index. + - T (int): Total number of frames or a specific range limit. + - sample_method (str): Sampling method, which can be "random" or other methods. + - landmarks_list (list): List of facial landmarks. + - image_shapes (list): List of image shapes. + - top_k_ratio (float): Ratio for selecting top k similar frames. + + Returns: + - src_idx (int): The calculated source index. + """ + if sample_method == "random": + src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T) + elif sample_method == "pose_similarity": + top_k = int(top_k_ratio*len(landmarks_list)) + try: + top_k = int(top_k_ratio*len(landmarks_list)) + # facial contour + landmark_start_idx = 0 + landmark_end_idx = 16 + pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True) + src_idx = random.choice(pose_similarity_list) + while abs(src_idx-drive_idx)<5: + src_idx = random.choice(pose_similarity_list) + except Exception as e: + print(e) + return None + elif sample_method=="pose_similarity_and_closed_mouth": + # facial contour + landmark_start_idx = 0 + landmark_end_idx = 16 + try: + top_k = int(top_k_ratio*len(landmarks_list)) + closed_mouth_list = get_closed_mouth(landmarks_list, ascending=True,top_k=top_k) + #print("closed_mouth_list",closed_mouth_list) + pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True) + #print("pose_similarity_list",pose_similarity_list) + common_list = list(set(closed_mouth_list).intersection(set(pose_similarity_list))) + if len(common_list) == 0: + src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T) + else: + src_idx = random.choice(common_list) + + while abs(src_idx-drive_idx) <5: + src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T) + + except Exception as e: + print(e) + return None + + elif sample_method=="pose_similarity_and_mouth_dissimilarity": + top_k = int(top_k_ratio*len(landmarks_list)) + try: + top_k = int(top_k_ratio*len(landmarks_list)) + + # facial contour for 68 landmarks format + landmark_start_idx = 0 + landmark_end_idx = 16 + + pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True) + + # Mouth inner coutour for 68 landmarks format + landmark_start_idx = 60 + landmark_end_idx = 67 + + mouth_dissimilarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=False) + + common_list = list(set(pose_similarity_list).intersection(set(mouth_dissimilarity_list))) + if len(common_list) == 0: + src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T) + else: + src_idx = random.choice(common_list) + + while abs(src_idx-drive_idx) <5: + src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T) + + except Exception as e: + print(e) + return None + + else: + raise ValueError(f"Unknown sample_method: {sample_method}") + return src_idx diff --git a/musetalk/loss/basic_loss.py b/musetalk/loss/basic_loss.py new file mode 100755 index 0000000..4159525 --- /dev/null +++ b/musetalk/loss/basic_loss.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +from omegaconf import OmegaConf +import torch +import torch.nn.functional as F +from torch import nn, optim +from torch.optim.lr_scheduler import CosineAnnealingLR +from musetalk.loss.discriminator import MultiScaleDiscriminator,DiscriminatorFullModel +import musetalk.loss.vgg_face as vgg_face + +class Interpolate(nn.Module): + def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None): + super(Interpolate, self).__init__() + self.size = size + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, input): + return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners) + +def set_requires_grad(net, requires_grad=False): + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + +if __name__ == "__main__": + cfg = OmegaConf.load("config/audio_adapter/E7.yaml") + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + pyramid_scale = [1, 0.5, 0.25, 0.125] + vgg_IN = vgg_face.Vgg19().to(device) + pyramid = vgg_face.ImagePyramide(cfg.loss_params.pyramid_scale, 3).to(device) + vgg_IN.eval() + downsampler = Interpolate(size=(224, 224), mode='bilinear', align_corners=False) + + image = torch.rand(8, 3, 256, 256).to(device) + image_pred = torch.rand(8, 3, 256, 256).to(device) + pyramide_real = pyramid(downsampler(image)) + pyramide_generated = pyramid(downsampler(image_pred)) + + + loss_IN = 0 + for scale in cfg.loss_params.pyramid_scale: + x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)]) + y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)]) + for i, weight in enumerate(cfg.loss_params.vgg_layer_weight): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + loss_IN += weight * value + loss_IN /= sum(cfg.loss_params.vgg_layer_weight) # ε―ΉvggδΈεŒε±‚ε–ε‡ε€ΌοΌŒι‡‘ε­—ε‘”lossζ˜―ζ―ε±‚ε  + print(loss_IN) + + #print(cfg.model_params.discriminator_params) + + discriminator = MultiScaleDiscriminator(**cfg.model_params.discriminator_params).to(device) + discriminator_full = DiscriminatorFullModel(discriminator) + disc_scales = cfg.model_params.discriminator_params.scales + # Prepare optimizer and loss function + optimizer_D = optim.AdamW(discriminator.parameters(), + lr=cfg.discriminator_train_params.lr, + weight_decay=cfg.discriminator_train_params.weight_decay, + betas=cfg.discriminator_train_params.betas, + eps=cfg.discriminator_train_params.eps) + scheduler_D = CosineAnnealingLR(optimizer_D, + T_max=cfg.discriminator_train_params.epochs, + eta_min=1e-6) + + discriminator.train() + + set_requires_grad(discriminator, False) + + loss_G = 0. + discriminator_maps_generated = discriminator(pyramide_generated) + discriminator_maps_real = discriminator(pyramide_real) + + for scale in disc_scales: + key = 'prediction_map_%s' % scale + value = ((1 - discriminator_maps_generated[key]) ** 2).mean() + loss_G += value + + print(loss_G) diff --git a/musetalk/loss/conv.py b/musetalk/loss/conv.py new file mode 100755 index 0000000..92751db --- /dev/null +++ b/musetalk/loss/conv.py @@ -0,0 +1,44 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + +class nonorm_Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + ) + self.act = nn.LeakyReLU(0.01, inplace=True) + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) + +class Conv2dTranspose(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) \ No newline at end of file diff --git a/musetalk/loss/discriminator.py b/musetalk/loss/discriminator.py new file mode 100755 index 0000000..6a506ed --- /dev/null +++ b/musetalk/loss/discriminator.py @@ -0,0 +1,145 @@ +from torch import nn +import torch.nn.functional as F +import torch +from musetalk.loss.vgg_face import ImagePyramide + +class DownBlock2d(nn.Module): + """ + Simple block for processing video (encoder). + """ + + def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) + + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + if norm: + self.norm = nn.InstanceNorm2d(out_features, affine=True) + else: + self.norm = None + self.pool = pool + + def forward(self, x): + out = x + out = self.conv(out) + if self.norm: + out = self.norm(out) + out = F.leaky_relu(out, 0.2) + if self.pool: + out = F.avg_pool2d(out, (2, 2)) + return out + + +class Discriminator(nn.Module): + """ + Discriminator similar to Pix2Pix + """ + + def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, + sn=False, **kwargs): + super(Discriminator, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) + + self.down_blocks = nn.ModuleList(down_blocks) + self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + def forward(self, x): + feature_maps = [] + out = x + + for down_block in self.down_blocks: + feature_maps.append(down_block(out)) + out = feature_maps[-1] + prediction_map = self.conv(out) + + return feature_maps, prediction_map + + +class MultiScaleDiscriminator(nn.Module): + """ + Multi-scale (scale) discriminator + """ + + def __init__(self, scales=(), **kwargs): + super(MultiScaleDiscriminator, self).__init__() + self.scales = scales + discs = {} + for scale in scales: + discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) + self.discs = nn.ModuleDict(discs) + + def forward(self, x): + out_dict = {} + for scale, disc in self.discs.items(): + scale = str(scale).replace('-', '.') + key = 'prediction_' + scale + #print(key) + #print(x) + feature_maps, prediction_map = disc(x[key]) + out_dict['feature_maps_' + scale] = feature_maps + out_dict['prediction_map_' + scale] = prediction_map + return out_dict + + + +class DiscriminatorFullModel(torch.nn.Module): + """ + Merge all discriminator related updates into single model for better multi-gpu usage + """ + + def __init__(self, discriminator): + super(DiscriminatorFullModel, self).__init__() + self.discriminator = discriminator + self.scales = self.discriminator.scales + print("scales",self.scales) + self.pyramid = ImagePyramide(self.scales, 3) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.zero_tensor = None + + def get_zero_tensor(self, input): + if self.zero_tensor is None: + self.zero_tensor = torch.FloatTensor(1).fill_(0).cuda() + self.zero_tensor.requires_grad_(False) + return self.zero_tensor.expand_as(input) + + def forward(self, x, generated, gan_mode='ls'): + pyramide_real = self.pyramid(x) + pyramide_generated = self.pyramid(generated.detach()) + + discriminator_maps_generated = self.discriminator(pyramide_generated) + discriminator_maps_real = self.discriminator(pyramide_real) + + value_total = 0 + for scale in self.scales: + key = 'prediction_map_%s' % scale + if gan_mode == 'hinge': + value = -torch.mean(torch.min(discriminator_maps_real[key]-1, self.get_zero_tensor(discriminator_maps_real[key]))) - torch.mean(torch.min(-discriminator_maps_generated[key]-1, self.get_zero_tensor(discriminator_maps_generated[key]))) + elif gan_mode == 'ls': + value = ((1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2).mean() + else: + raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode'])) + + value_total += value + + return value_total + +def main(): + discriminator = MultiScaleDiscriminator(scales=[1], + block_expansion=32, + max_features=512, + num_blocks=4, + sn=True, + image_channel=3, + estimate_jacobian=False) \ No newline at end of file diff --git a/musetalk/loss/resnet.py b/musetalk/loss/resnet.py new file mode 100755 index 0000000..b9b1082 --- /dev/null +++ b/musetalk/loss/resnet.py @@ -0,0 +1,152 @@ +import torch.nn as nn +import math + +__all__ = ['ResNet', 'resnet50'] + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, include_top=True): + self.inplanes = 64 + super(ResNet, self).__init__() + self.include_top = include_top + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = x * 255. + x = x.flip(1) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + + if not self.include_top: + return x + + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + +def resnet50(**kwargs): + """Constructs a ResNet-50 model. + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + return model \ No newline at end of file diff --git a/musetalk/loss/syncnet.py b/musetalk/loss/syncnet.py new file mode 100755 index 0000000..8593efd --- /dev/null +++ b/musetalk/loss/syncnet.py @@ -0,0 +1,95 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .conv import Conv2d + +logloss = nn.BCELoss(reduction="none") +def cosine_loss(a, v, y): + d = nn.functional.cosine_similarity(a, v) + d = d.clamp(0,1) # cosine_similarityηš„ε–ε€ΌθŒƒε›΄ζ˜―γ€-1,1γ€‘οΌŒBCEε¦‚ζžœθΎ“ε…₯θ΄Ÿζ•°δΌšζŠ₯ι”™RuntimeError: CUDA error: device-side assert triggered + loss = logloss(d.unsqueeze(1), y).squeeze() + loss = loss.mean() + return loss, d + +def get_sync_loss( + audio_embed, + gt_frames, + pred_frames, + syncnet, + adapted_weight, + frames_left_index=0, + frames_right_index=16, +): + # 跟gt_framesεšιšζœΊηš„ζ’ε…₯δΊ€ζ’οΌŒθŠ‚ηœζ˜Ύε­˜εΌ€ι”€ + assert pred_frames.shape[1] == (frames_right_index - frames_left_index) * 3 + # 3ι€šι“ε›Ύεƒ + frames_sync_loss = torch.cat( + [gt_frames[:, :3 * frames_left_index, ...], pred_frames, gt_frames[:, 3 * frames_right_index:, ...]], + axis=1 + ) + vision_embed = syncnet.get_image_embed(frames_sync_loss) + y = torch.ones(frames_sync_loss.size(0), 1).float().to(audio_embed.device) + loss, score = cosine_loss(audio_embed, vision_embed, y) + return loss, score + +class SyncNet_color(nn.Module): + def __init__(self): + super(SyncNet_color, self).__init__() + + self.face_encoder = nn.Sequential( + Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3), + + Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=2, padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=2, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T) + face_embedding = self.face_encoder(face_sequences) + audio_embedding = self.audio_encoder(audio_sequences) + + audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) + face_embedding = face_embedding.view(face_embedding.size(0), -1) + + audio_embedding = F.normalize(audio_embedding, p=2, dim=1) + face_embedding = F.normalize(face_embedding, p=2, dim=1) + + + return audio_embedding, face_embedding \ No newline at end of file diff --git a/musetalk/loss/vgg_face.py b/musetalk/loss/vgg_face.py new file mode 100755 index 0000000..b41faad --- /dev/null +++ b/musetalk/loss/vgg_face.py @@ -0,0 +1,237 @@ +''' + This part of code contains a pretrained vgg_face model. + ref link: https://github.com/prlz77/vgg-face.pytorch +''' +import torch +import torch.nn.functional as F +import torch.utils.model_zoo +import pickle +from musetalk.loss import resnet as ResNet + + +MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth" +VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl' + +# It was 93.5940, 104.7624, 129.1863 before dividing by 255 +MEAN_RGB = [ + 0.367035294117647, + 0.41083294117647057, + 0.5066129411764705 +] +def load_state_dict(model, fname): + """ + Set parameters converted from Caffe models authors of VGGFace2 provide. + See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/. + + Arguments: + model: model + fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle. + """ + with open(fname, 'rb') as f: + weights = pickle.load(f, encoding='latin1') + + own_state = model.state_dict() + for name, param in weights.items(): + if name in own_state: + try: + own_state[name].copy_(torch.from_numpy(param)) + except Exception: + raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\ + 'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size())) + else: + raise KeyError('unexpected key "{}" in state_dict'.format(name)) + + +def vggface2(pretrained=True): + vggface = ResNet.resnet50(num_classes=8631, include_top=True) + load_state_dict(vggface, VGG_FACE_PATH) + return vggface + +def vggface(pretrained=False, **kwargs): + """VGGFace model. + + Args: + pretrained (bool): If True, returns pre-trained model + """ + model = VggFace(**kwargs) + if pretrained: + state = torch.utils.model_zoo.load_url(MODEL_URL) + model.load_state_dict(state) + return model + + +class VggFace(torch.nn.Module): + def __init__(self, classes=2622): + """VGGFace model. + + Face recognition network. It takes as input a Bx3x224x224 + batch of face images and gives as output a BxC score vector + (C is the number of identities). + Input images need to be scaled in the 0-1 range and then + normalized with respect to the mean RGB used during training. + + Args: + classes (int): number of identities recognized by the + network + + """ + super().__init__() + self.conv1 = _ConvBlock(3, 64, 64) + self.conv2 = _ConvBlock(64, 128, 128) + self.conv3 = _ConvBlock(128, 256, 256, 256) + self.conv4 = _ConvBlock(256, 512, 512, 512) + self.conv5 = _ConvBlock(512, 512, 512, 512) + self.dropout = torch.nn.Dropout(0.5) + self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096) + self.fc2 = torch.nn.Linear(4096, 4096) + self.fc3 = torch.nn.Linear(4096, classes) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.conv5(x) + x = x.view(x.size(0), -1) + x = self.dropout(F.relu(self.fc1(x))) + x = self.dropout(F.relu(self.fc2(x))) + x = self.fc3(x) + return x + + +class _ConvBlock(torch.nn.Module): + """A Convolutional block.""" + + def __init__(self, *units): + """Create a block with len(units) - 1 convolutions. + + convolution number i transforms the number of channels from + units[i - 1] to units[i] channels. + + """ + super().__init__() + self.convs = torch.nn.ModuleList([ + torch.nn.Conv2d(in_, out, 3, 1, 1) + for in_, out in zip(units[:-1], units[1:]) + ]) + + def forward(self, x): + # Each convolution is followed by a ReLU, then the block is + # concluded by a max pooling. + for c in self.convs: + x = F.relu(c(x)) + return F.max_pool2d(x, 2, 2, 0, ceil_mode=True) + + + +import numpy as np +from torchvision import models +class Vgg19(torch.nn.Module): + """ + Vgg19 network for perceptual loss. + """ + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), + requires_grad=False) + self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), + requires_grad=False) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + X = (X - self.mean) / self.std + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +from torch import nn +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, for better preservation of the input signal. + """ + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + sigma = (1 / scale - 1) / 2 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + self.scale = scale + inv_scale = 1 / scale + self.int_inv_scale = int(inv_scale) + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] + + return out + + +class ImagePyramide(torch.nn.Module): + """ + Create image pyramide for computing pyramide perceptual loss. + """ + def __init__(self, scales, num_channels): + super(ImagePyramide, self).__init__() + downs = {} + for scale in scales: + downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) + self.downs = nn.ModuleDict(downs) + + def forward(self, x): + out_dict = {} + for scale, down_module in self.downs.items(): + out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) + return out_dict \ No newline at end of file diff --git a/musetalk/models/syncnet.py b/musetalk/models/syncnet.py new file mode 100755 index 0000000..185b236 --- /dev/null +++ b/musetalk/models/syncnet.py @@ -0,0 +1,240 @@ +""" +This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/models/stable_syncnet.py). +""" + +import torch +from torch import nn +from einops import rearrange +from torch.nn import functional as F + +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.models.attention import Attention as CrossAttention, FeedForward +from diffusers.utils.import_utils import is_xformers_available +from einops import rearrange + + +class SyncNet(nn.Module): + def __init__(self, config): + super().__init__() + self.audio_encoder = DownEncoder2D( + in_channels=config["audio_encoder"]["in_channels"], + block_out_channels=config["audio_encoder"]["block_out_channels"], + downsample_factors=config["audio_encoder"]["downsample_factors"], + dropout=config["audio_encoder"]["dropout"], + attn_blocks=config["audio_encoder"]["attn_blocks"], + ) + + self.visual_encoder = DownEncoder2D( + in_channels=config["visual_encoder"]["in_channels"], + block_out_channels=config["visual_encoder"]["block_out_channels"], + downsample_factors=config["visual_encoder"]["downsample_factors"], + dropout=config["visual_encoder"]["dropout"], + attn_blocks=config["visual_encoder"]["attn_blocks"], + ) + + self.eval() + + def forward(self, image_sequences, audio_sequences): + vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1) + audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1) + + vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c) + audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c) + + # Make them unit vectors + vision_embeds = F.normalize(vision_embeds, p=2, dim=1) + audio_embeds = F.normalize(audio_embeds, p=2, dim=1) + + return vision_embeds, audio_embeds + + def get_image_embed(self, image_sequences): + vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1) + + vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c) + + # Make them unit vectors + vision_embeds = F.normalize(vision_embeds, p=2, dim=1) + + return vision_embeds + + def get_audio_embed(self, audio_sequences): + audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1) + + audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c) + + audio_embeds = F.normalize(audio_embeds, p=2, dim=1) + + return audio_embeds + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + norm_num_groups: int = 32, + eps: float = 1e-6, + act_fn: str = "silu", + downsample_factor=2, + ): + super().__init__() + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if act_fn == "relu": + self.act_fn = nn.ReLU() + elif act_fn == "silu": + self.act_fn = nn.SiLU() + + if in_channels != out_channels: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv_shortcut = None + + if isinstance(downsample_factor, list): + downsample_factor = tuple(downsample_factor) + + if downsample_factor == 1: + self.downsample_conv = None + else: + self.downsample_conv = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0 + ) + self.pad = (0, 1, 0, 1) + if isinstance(downsample_factor, tuple): + if downsample_factor[0] == 1: + self.pad = (0, 1, 1, 1) # The padding order is from back to front + elif downsample_factor[1] == 1: + self.pad = (1, 1, 0, 1) + + def forward(self, input_tensor): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.act_fn(hidden_states) + + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm2(hidden_states) + hidden_states = self.act_fn(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + hidden_states += input_tensor + + if self.downsample_conv is not None: + hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0) + hidden_states = self.downsample_conv(hidden_states) + + return hidden_states + + +class AttentionBlock2D(nn.Module): + def __init__(self, query_dim, norm_num_groups=32, dropout=0.0): + super().__init__() + if not is_xformers_available(): + raise ModuleNotFoundError( + "You have to install xformers to enable memory efficient attetion", name="xformers" + ) + # inner_dim = dim_head * heads + self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True) + self.norm2 = nn.LayerNorm(query_dim) + self.norm3 = nn.LayerNorm(query_dim) + + self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu") + + self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0) + self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0) + + self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True) + self.attn._use_memory_efficient_attention_xformers = True + + def forward(self, hidden_states): + assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}." + + batch, channel, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.conv_in(hidden_states) + hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c") + + norm_hidden_states = self.norm2(hidden_states) + hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width) + hidden_states = self.conv_out(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states + + +class DownEncoder2D(nn.Module): + def __init__( + self, + in_channels=4 * 16, + block_out_channels=[64, 128, 256, 256], + downsample_factors=[2, 2, 2, 2], + layers_per_block=2, + norm_num_groups=32, + attn_blocks=[1, 1, 1, 1], + dropout: float = 0.0, + act_fn="silu", + ): + super().__init__() + self.layers_per_block = layers_per_block + + # in + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + + # down + self.down_blocks = nn.ModuleList([]) + + output_channels = block_out_channels[0] + for i, block_out_channel in enumerate(block_out_channels): + input_channels = output_channels + output_channels = block_out_channel + # is_final_block = i == len(block_out_channels) - 1 + + down_block = ResnetBlock2D( + in_channels=input_channels, + out_channels=output_channels, + downsample_factor=downsample_factors[i], + norm_num_groups=norm_num_groups, + dropout=dropout, + act_fn=act_fn, + ) + + self.down_blocks.append(down_block) + + if attn_blocks[i] == 1: + attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout) + self.down_blocks.append(attention_block) + + # out + self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.act_fn_out = nn.ReLU() + + def forward(self, hidden_states): + hidden_states = self.conv_in(hidden_states) + + # down + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + # post-process + hidden_states = self.norm_out(hidden_states) + hidden_states = self.act_fn_out(hidden_states) + + return hidden_states diff --git a/musetalk/utils/training_utils.py b/musetalk/utils/training_utils.py new file mode 100644 index 0000000..010f01f --- /dev/null +++ b/musetalk/utils/training_utils.py @@ -0,0 +1,337 @@ +import os +import json +import logging +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim.lr_scheduler import CosineAnnealingLR +from diffusers import AutoencoderKL, UNet2DConditionModel +from transformers import WhisperModel +from diffusers.optimization import get_scheduler +from omegaconf import OmegaConf +from einops import rearrange + +from musetalk.models.syncnet import SyncNet +from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel +from musetalk.loss.basic_loss import Interpolate +import musetalk.loss.vgg_face as vgg_face +from musetalk.data.dataset import PortraitDataset +from musetalk.utils.utils import ( + get_image_pred, + process_audio_features, + process_and_save_images +) + +class Net(nn.Module): + def __init__( + self, + unet: UNet2DConditionModel, + ): + super().__init__() + self.unet = unet + + def forward( + self, + input_latents, + timesteps, + audio_prompts, + ): + model_pred = self.unet( + input_latents, + timesteps, + encoder_hidden_states=audio_prompts + ).sample + return model_pred + +logger = logging.getLogger(__name__) + +def initialize_models_and_optimizers(cfg, accelerator, weight_dtype): + """Initialize models and optimizers""" + model_dict = { + 'vae': None, + 'unet': None, + 'net': None, + 'wav2vec': None, + 'optimizer': None, + 'lr_scheduler': None, + 'scheduler_max_steps': None, + 'trainable_params': None + } + + model_dict['vae'] = AutoencoderKL.from_pretrained( + cfg.pretrained_model_name_or_path, + subfolder=cfg.vae_type, + ) + + unet_config_file = os.path.join( + cfg.pretrained_model_name_or_path, + cfg.unet_sub_folder + "/musetalk.json" + ) + + with open(unet_config_file, 'r') as f: + unet_config = json.load(f) + model_dict['unet'] = UNet2DConditionModel(**unet_config) + + if not cfg.random_init_unet: + pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin") + print(f"### Loading existing unet weights from {pretrained_unet_path}. ###") + checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device) + model_dict['unet'].load_state_dict(checkpoint) + + unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()] + logger.info(f"unet {sum(unet_params) / 1e6}M-parameter") + + model_dict['vae'].requires_grad_(False) + model_dict['unet'].requires_grad_(True) + + model_dict['vae'].to(accelerator.device, dtype=weight_dtype) + + model_dict['net'] = Net(model_dict['unet']) + + model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to( + device="cuda", dtype=weight_dtype).eval() + model_dict['wav2vec'].requires_grad_(False) + + if cfg.solver.gradient_checkpointing: + model_dict['unet'].enable_gradient_checkpointing() + + if cfg.solver.scale_lr: + learning_rate = ( + cfg.solver.learning_rate + * cfg.solver.gradient_accumulation_steps + * cfg.data.train_bs + * accelerator.num_processes + ) + else: + learning_rate = cfg.solver.learning_rate + + if cfg.solver.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters())) + if accelerator.is_main_process: + print('trainable params') + for n, p in model_dict['net'].named_parameters(): + if p.requires_grad: + print(n) + + model_dict['optimizer'] = optimizer_cls( + model_dict['trainable_params'], + lr=learning_rate, + betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), + weight_decay=cfg.solver.adam_weight_decay, + eps=cfg.solver.adam_epsilon, + ) + + model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps + model_dict['lr_scheduler'] = get_scheduler( + cfg.solver.lr_scheduler, + optimizer=model_dict['optimizer'], + num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps, + num_training_steps=model_dict['scheduler_max_steps'], + ) + + return model_dict + +def initialize_dataloaders(cfg): + """Initialize training and validation dataloaders""" + dataloader_dict = { + 'train_dataset': None, + 'val_dataset': None, + 'train_dataloader': None, + 'val_dataloader': None + } + + dataloader_dict['train_dataset'] = PortraitDataset(cfg={ + 'image_size': cfg.data.image_size, + 'T': cfg.data.n_sample_frames, + "sample_method": cfg.data.sample_method, + 'top_k_ratio': cfg.data.top_k_ratio, + "contorl_face_min_size": cfg.data.contorl_face_min_size, + "dataset_key": cfg.data.dataset_key, + "padding_pixel_mouth": cfg.padding_pixel_mouth, + "whisper_path": cfg.whisper_path, + "min_face_size": cfg.data.min_face_size, + "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean, + "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std, + "crop_type": cfg.crop_type, + "random_margin_method": cfg.random_margin_method, + }) + + dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader( + dataloader_dict['train_dataset'], + batch_size=cfg.data.train_bs, + shuffle=True, + num_workers=cfg.data.num_workers, + ) + + dataloader_dict['val_dataset'] = PortraitDataset(cfg={ + 'image_size': cfg.data.image_size, + 'T': cfg.data.n_sample_frames, + "sample_method": cfg.data.sample_method, + 'top_k_ratio': cfg.data.top_k_ratio, + "contorl_face_min_size": cfg.data.contorl_face_min_size, + "dataset_key": cfg.data.dataset_key, + "padding_pixel_mouth": cfg.padding_pixel_mouth, + "whisper_path": cfg.whisper_path, + "min_face_size": cfg.data.min_face_size, + "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean, + "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std, + "crop_type": cfg.crop_type, + "random_margin_method": cfg.random_margin_method, + }) + + dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader( + dataloader_dict['val_dataset'], + batch_size=cfg.data.train_bs, + shuffle=True, + num_workers=1, + ) + + return dataloader_dict + +def initialize_loss_functions(cfg, accelerator, scheduler_max_steps): + """Initialize loss functions and discriminators""" + loss_dict = { + 'L1_loss': nn.L1Loss(reduction='mean'), + 'discriminator': None, + 'mouth_discriminator': None, + 'optimizer_D': None, + 'mouth_optimizer_D': None, + 'scheduler_D': None, + 'mouth_scheduler_D': None, + 'disc_scales': None, + 'discriminator_full': None, + 'mouth_discriminator_full': None + } + + if cfg.loss_params.gan_loss > 0: + loss_dict['discriminator'] = MultiScaleDiscriminator( + **cfg.model_params.discriminator_params).to(accelerator.device) + loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator']) + loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales + loss_dict['optimizer_D'] = optim.AdamW( + loss_dict['discriminator'].parameters(), + lr=cfg.discriminator_train_params.lr, + weight_decay=cfg.discriminator_train_params.weight_decay, + betas=cfg.discriminator_train_params.betas, + eps=cfg.discriminator_train_params.eps) + loss_dict['scheduler_D'] = CosineAnnealingLR( + loss_dict['optimizer_D'], + T_max=scheduler_max_steps, + eta_min=1e-6 + ) + + if cfg.loss_params.mouth_gan_loss > 0: + loss_dict['mouth_discriminator'] = MultiScaleDiscriminator( + **cfg.model_params.discriminator_params).to(accelerator.device) + loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator']) + loss_dict['mouth_optimizer_D'] = optim.AdamW( + loss_dict['mouth_discriminator'].parameters(), + lr=cfg.discriminator_train_params.lr, + weight_decay=cfg.discriminator_train_params.weight_decay, + betas=cfg.discriminator_train_params.betas, + eps=cfg.discriminator_train_params.eps) + loss_dict['mouth_scheduler_D'] = CosineAnnealingLR( + loss_dict['mouth_optimizer_D'], + T_max=scheduler_max_steps, + eta_min=1e-6 + ) + + return loss_dict + +def initialize_syncnet(cfg, accelerator, weight_dtype): + """Initialize SyncNet model""" + if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight: + if cfg.data.n_sample_frames != 16: + raise ValueError( + f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16." + ) + syncnet_config = OmegaConf.load(cfg.syncnet_config_path) + syncnet = SyncNet(OmegaConf.to_container( + syncnet_config.model)).to(accelerator.device) + print( + f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}") + checkpoint = torch.load( + syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device) + syncnet.load_state_dict(checkpoint["state_dict"]) + syncnet.to(dtype=weight_dtype) + syncnet.requires_grad_(False) + syncnet.eval() + return syncnet + return None + +def initialize_vgg(cfg, accelerator): + """Initialize VGG model""" + if cfg.loss_params.vgg_loss > 0: + vgg_IN = vgg_face.Vgg19().to(accelerator.device,) + pyramid = vgg_face.ImagePyramide( + cfg.loss_params.pyramid_scale, 3).to(accelerator.device) + vgg_IN.eval() + downsampler = Interpolate( + size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device) + return vgg_IN, pyramid, downsampler + return None, None, None + +def validation( + cfg, + val_dataloader, + net, + vae, + wav2vec, + accelerator, + save_dir, + global_step, + weight_dtype, + syncnet_score=1, +): + """Validation function for model evaluation""" + net.eval() # Set the model to evaluation mode + for batch in val_dataloader: + # The same ref_latents + ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to( + accelerator.device, non_blocking=True + ) + pixel_values = batch["pixel_values_vid"].to(weight_dtype).to( + accelerator.device, non_blocking=True + ) + bsz, num_frames, c, h, w = ref_pixel_values.shape + + audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype) + # audio feature for unet + audio_prompts = rearrange( + audio_prompts, + 'b f c h w-> (b f) c h w' + ) + audio_prompts = rearrange( + audio_prompts, + '(b f) c h w -> (b f) (c h) w', + b=bsz + ) + # different masked_latents + image_pred_train = get_image_pred( + pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype) + image_pred_infer = get_image_pred( + ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype) + + process_and_save_images( + batch, + image_pred_train, + image_pred_infer, + save_dir, + global_step, + accelerator, + cfg.num_images_to_keep, + syncnet_score + ) + # only infer 1 image in validation + break + net.train() # Set the model back to training mode diff --git a/musetalk/utils/utils.py b/musetalk/utils/utils.py index 911fc40..6b14eff 100755 --- a/musetalk/utils/utils.py +++ b/musetalk/utils/utils.py @@ -2,6 +2,11 @@ import os import cv2 import numpy as np import torch +from typing import Union, List +import torch.nn.functional as F +from einops import rearrange +import shutil +import os.path as osp ffmpeg_path = os.getenv('FFMPEG_PATH') if ffmpeg_path is None: @@ -11,7 +16,6 @@ elif ffmpeg_path not in os.getenv('PATH'): os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}" -from musetalk.whisper.audio2feature import Audio2Feature from musetalk.models.vae import VAE from musetalk.models.unet import UNet,PositionalEncoding @@ -76,3 +80,248 @@ def datagen( latent_batch = torch.cat(latent_batch, dim=0) yield whisper_batch.to(device), latent_batch.to(device) + +def cast_training_params( + model: Union[torch.nn.Module, List[torch.nn.Module]], + dtype=torch.float32, +): + if not isinstance(model, list): + model = [model] + for m in model: + for param in m.parameters(): + # only upcast trainable parameters into fp32 + if param.requires_grad: + param.data = param.to(dtype) + +def rand_log_normal( + shape, + loc=0., + scale=1., + device='cpu', + dtype=torch.float32, + generator=None +): + """Draws samples from an lognormal distribution.""" + rnd_normal = torch.randn( + shape, device=device, dtype=dtype, generator=generator) # N(0, I) + sigma = (rnd_normal * scale + loc).exp() + return sigma + +def get_mouth_region(frames, image_pred, pixel_values_face_mask): + # Initialize lists to store the results for each image in the batch + mouth_real_list = [] + mouth_generated_list = [] + + # Process each image in the batch + for b in range(frames.shape[0]): + # Find the non-zero area in the face mask + non_zero_indices = torch.nonzero(pixel_values_face_mask[b]) + # If there are no non-zero indices, skip this image + if non_zero_indices.numel() == 0: + continue + + min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max( + non_zero_indices[:, 1]) + min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max( + non_zero_indices[:, 2]) + + # Crop the frames and image_pred according to the non-zero area + frames_cropped = frames[b, :, min_y:max_y, min_x:max_x] + image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x] + # Resize the cropped images to 256*256 + frames_resized = F.interpolate(frames_cropped.unsqueeze( + 0), size=(256, 256), mode='bilinear', align_corners=False) + image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze( + 0), size=(256, 256), mode='bilinear', align_corners=False) + + # Append the resized images to the result lists + mouth_real_list.append(frames_resized) + mouth_generated_list.append(image_pred_resized) + + # Convert the lists to tensors if they are not empty + mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None + mouth_generated = torch.cat( + mouth_generated_list, dim=0) if mouth_generated_list else None + + return mouth_real, mouth_generated + +def get_image_pred(pixel_values, + ref_pixel_values, + audio_prompts, + vae, + net, + weight_dtype): + with torch.no_grad(): + bsz, num_frames, c, h, w = pixel_values.shape + + masked_pixel_values = pixel_values.clone() + masked_pixel_values[:, :, :, h//2:, :] = -1 + + masked_frames = rearrange( + masked_pixel_values, 'b f c h w -> (b f) c h w') + masked_latents = vae.encode(masked_frames).latent_dist.mode() + masked_latents = masked_latents * vae.config.scaling_factor + masked_latents = masked_latents.float() + + ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w') + ref_latents = vae.encode(ref_frames).latent_dist.mode() + ref_latents = ref_latents * vae.config.scaling_factor + ref_latents = ref_latents.float() + + input_latents = torch.cat([masked_latents, ref_latents], dim=1) + input_latents = input_latents.to(weight_dtype) + timesteps = torch.tensor([0], device=input_latents.device) + latents_pred = net( + input_latents, + timesteps, + audio_prompts, + ) + latents_pred = (1 / vae.config.scaling_factor) * latents_pred + image_pred = vae.decode(latents_pred).sample + image_pred = image_pred.float() + + return image_pred + +def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype): + with torch.no_grad(): + audio_feature_length_per_frame = 2 * \ + (cfg.data.audio_padding_length_left + + cfg.data.audio_padding_length_right + 1) + audio_feats = batch['audio_feature'].to(weight_dtype) + audio_feats = wav2vec.encoder( + audio_feats, output_hidden_states=True).hidden_states + audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384] + + start_ts = batch['audio_offset'] + step_ts = batch['audio_step'] + audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]), + audio_feats, + torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1) + audio_prompts = [] + for bb in range(bsz): + audio_feats_list = [] + for f in range(num_frames): + cur_t = (start_ts[bb] + f * step_ts[bb]) * 2 + audio_clip = audio_feats[bb:bb+1, + cur_t: cur_t+audio_feature_length_per_frame] + + audio_feats_list.append(audio_clip) + audio_feats_list = torch.stack(audio_feats_list, 1) + audio_prompts.append(audio_feats_list) + audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384 + return audio_prompts + +def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None): + save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth") + + if total_limit is not None: + checkpoints = os.listdir(save_dir) + checkpoints = [d for d in checkpoints if d.endswith(".pth")] + checkpoints = [d for d in checkpoints if name in d] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]) + ) + + if len(checkpoints) >= total_limit: + num_to_remove = len(checkpoints) - total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join( + save_dir, removing_checkpoint) + os.remove(removing_checkpoint) + + state_dict = model.state_dict() + torch.save(state_dict, save_path) + +def save_models(accelerator, net, save_dir, global_step, cfg, logger=None): + unwarp_net = accelerator.unwrap_model(net) + save_checkpoint( + unwarp_net.unet, + save_dir, + global_step, + name="unet", + total_limit=cfg.total_limit, + logger=logger + ) + +def delete_additional_ckpt(base_path, num_keep): + dirs = [] + for d in os.listdir(base_path): + if d.startswith("checkpoint-"): + dirs.append(d) + num_tot = len(dirs) + if num_tot <= num_keep: + return + # ensure ckpt is sorted and delete the ealier! + del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep] + for d in del_dirs: + path_to_dir = osp.join(base_path, d) + if osp.exists(path_to_dir): + shutil.rmtree(path_to_dir) + +def seed_everything(seed): + import random + + import numpy as np + + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed % (2**32)) + random.seed(seed) + +def process_and_save_images( + batch, + image_pred, + image_pred_infer, + save_dir, + global_step, + accelerator, + num_images_to_keep=10, + syncnet_score=1 +): + # Rearrange the tensors + print("image_pred.shape: ", image_pred.shape) + pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w") + pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w') + + # Create masked pixel values + masked_pixel_values = batch["pixel_values_vid"].clone() + _, _, _, h, _ = batch["pixel_values_vid"].shape + masked_pixel_values[:, :, :, h//2:, :] = -1 + masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w') + + # Keep only the specified number of images + pixel_values = pixel_values[:num_images_to_keep, :, :, :] + masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :] + pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :] + image_pred = image_pred.detach()[:num_images_to_keep, :, :, :] + image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :] + + # Concatenate images + concat = torch.cat([ + masked_pixel_values * 0.5 + 0.5, + pixel_values_ref_img * 0.5 + 0.5, + image_pred * 0.5 + 0.5, + pixel_values * 0.5 + 0.5, + image_pred_infer * 0.5 + 0.5, + ], dim=2) + print("concat.shape: ", concat.shape) + + # Create the save directory if it doesn't exist + os.makedirs(f'{save_dir}/samples/', exist_ok=True) + + # Try to save the concatenated image + try: + # Concatenate images horizontally and convert to numpy array + final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255 + # Save the image + cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image) + print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg") + except Exception as e: + print(f"Failed to save image: {e}") \ No newline at end of file diff --git a/scripts/preprocess.py b/scripts/preprocess.py new file mode 100755 index 0000000..c493e3c --- /dev/null +++ b/scripts/preprocess.py @@ -0,0 +1,322 @@ +import os +import argparse +import subprocess +from omegaconf import OmegaConf +from typing import Tuple, List, Union +import decord +import json +import cv2 +from musetalk.utils.face_detection import FaceAlignment,LandmarksType +from mmpose.apis import inference_topdown, init_model +from mmpose.structures import merge_data_samples +import torch +import numpy as np +from tqdm import tqdm + +ffmpeg_path = "./ffmpeg-4.4-amd64-static/" +if ffmpeg_path not in os.getenv('PATH'): + print("add ffmpeg to path") + os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}" + +class AnalyzeFace: + def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str): + """ + Initialize the AnalyzeFace class with the given device, config file, and checkpoint file. + + Parameters: + device (Union[str, torch.device]): The device to run the models on ('cuda' or 'cpu'). + config_file (str): Path to the mmpose model configuration file. + checkpoint_file (str): Path to the mmpose model checkpoint file. + """ + self.device = device + self.dwpose = init_model(config_file, checkpoint_file, device=self.device) + self.facedet = FaceAlignment(LandmarksType._2D, flip_input=False, device=self.device) + + def __call__(self, im: np.ndarray) -> Tuple[List[np.ndarray], np.ndarray]: + """ + Detect faces and keypoints in the given image. + + Parameters: + im (np.ndarray): The input image. + maxface (bool): Whether to detect the maximum face. Default is True. + + Returns: + Tuple[List[np.ndarray], np.ndarray]: A tuple containing the bounding boxes and keypoints. + """ + try: + # Ensure the input image has the correct shape + if im.ndim == 3: + im = np.expand_dims(im, axis=0) + elif im.ndim != 4 or im.shape[0] != 1: + raise ValueError("Input image must have shape (1, H, W, C)") + + bbox = self.facedet.get_detections_for_batch(np.asarray(im)) + results = inference_topdown(self.dwpose, np.asarray(im)[0]) + results = merge_data_samples(results) + keypoints = results.pred_instances.keypoints + face_land_mark= keypoints[0][23:91] + face_land_mark = face_land_mark.astype(np.int32) + + return face_land_mark, bbox + + except Exception as e: + print(f"Error during face analysis: {e}") + return np.array([]),[] + +def convert_video(org_path: str, dst_path: str, vid_list: List[str]) -> None: + + """ + Convert video files to a specified format and save them to the destination path. + + Parameters: + org_path (str): The directory containing the original video files. + dst_path (str): The directory where the converted video files will be saved. + vid_list (List[str]): A list of video file names to process. + + Returns: + None + """ + for idx, vid in enumerate(vid_list): + if vid.endswith('.mp4'): + org_vid_path = os.path.join(org_path, vid) + dst_vid_path = os.path.join(dst_path, vid) + + if org_vid_path != dst_vid_path: + cmd = [ + "ffmpeg", "-hide_banner", "-y", "-i", org_vid_path, + "-r", "25", "-crf", "15", "-c:v", "libx264", + "-pix_fmt", "yuv420p", dst_vid_path + ] + subprocess.run(cmd, check=True) + + if idx % 1000 == 0: + print(f"### {idx} videos converted ###") + +def segment_video(org_path: str, dst_path: str, vid_list: List[str], segment_duration: int = 30) -> None: + """ + Segment video files into smaller clips of specified duration. + + Parameters: + org_path (str): The directory containing the original video files. + dst_path (str): The directory where the segmented video files will be saved. + vid_list (List[str]): A list of video file names to process. + segment_duration (int): The duration of each segment in seconds. Default is 30 seconds. + + Returns: + None + """ + for idx, vid in enumerate(vid_list): + if vid.endswith('.mp4'): + input_file = os.path.join(org_path, vid) + original_filename = os.path.basename(input_file) + + command = [ + 'ffmpeg', '-i', input_file, '-c', 'copy', '-map', '0', + '-segment_time', str(segment_duration), '-f', 'segment', + '-reset_timestamps', '1', + os.path.join(dst_path, f'clip%03d_{original_filename}') + ] + + subprocess.run(command, check=True) + +def extract_audio(org_path: str, dst_path: str, vid_list: List[str]) -> None: + """ + Extract audio from video files and save as WAV format. + + Parameters: + org_path (str): The directory containing the original video files. + dst_path (str): The directory where the extracted audio files will be saved. + vid_list (List[str]): A list of video file names to process. + + Returns: + None + """ + for idx, vid in enumerate(vid_list): + if vid.endswith('.mp4'): + video_path = os.path.join(org_path, vid) + audio_output_path = os.path.join(dst_path, os.path.splitext(vid)[0] + ".wav") + try: + command = [ + 'ffmpeg', '-hide_banner', '-y', '-i', video_path, + '-vn', '-acodec', 'pcm_s16le', '-f', 'wav', + '-ar', '16000', '-ac', '1', audio_output_path, + ] + + subprocess.run(command, check=True) + print(f"Audio saved to: {audio_output_path}") + except subprocess.CalledProcessError as e: + print(f"Error extracting audio from {vid}: {e}") + +def split_data(video_files: List[str], val_list_hdtf: List[str]) -> (List[str], List[str]): + """ + Split video files into training and validation sets based on val_list_hdtf. + + Parameters: + video_files (List[str]): A list of video file names. + val_list_hdtf (List[str]): A list of validation file identifiers. + + Returns: + (List[str], List[str]): A tuple containing the training and validation file lists. + """ + val_files = [f for f in video_files if any(val_id in f for val_id in val_list_hdtf)] + train_files = [f for f in video_files if f not in val_files] + return train_files, val_files + +def save_list_to_file(file_path: str, data_list: List[str]) -> None: + """ + Save a list of strings to a file, each string on a new line. + + Parameters: + file_path (str): The path to the file where the list will be saved. + data_list (List[str]): The list of strings to save. + + Returns: + None + """ + with open(file_path, 'w') as file: + for item in data_list: + file.write(f"{item}\n") + +def generate_train_list(cfg): + train_file_path = cfg.video_clip_file_list_train + val_file_path = cfg.video_clip_file_list_val + val_list_hdtf = cfg.val_list_hdtf + + meta_list = os.listdir(cfg.meta_root) + + sorted_meta_list = sorted(meta_list) + train_files, val_files = split_data(meta_list, val_list_hdtf) + + save_list_to_file(train_file_path, train_files) + save_list_to_file(val_file_path, val_files) + + print(val_list_hdtf) + +def analyze_video(org_path: str, dst_path: str, vid_list: List[str]) -> None: + """ + Convert video files to a specified format and save them to the destination path. + + Parameters: + org_path (str): The directory containing the original video files. + dst_path (str): The directory where the meta json will be saved. + vid_list (List[str]): A list of video file names to process. + + Returns: + None + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' + checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth' + + analyze_face = AnalyzeFace(device, config_file, checkpoint_file) + + for vid in tqdm(vid_list, desc="Processing videos"): + #vid = "clip005_WDA_BernieSanders_000.mp4" + #print(vid) + if vid.endswith('.mp4'): + vid_path = os.path.join(org_path, vid) + wav_path = vid_path.replace(".mp4",".wav") + vid_meta = os.path.join(dst_path, os.path.splitext(vid)[0] + ".json") + if os.path.exists(vid_meta): + continue + print('process video {}'.format(vid)) + + total_bbox_list = [] + total_pts_list = [] + isvalid = True + + # process + try: + cap = decord.VideoReader(vid_path, fault_tol=1) + except Exception as e: + print(e) + continue + + total_frames = len(cap) + for frame_idx in range(total_frames): + frame = cap[frame_idx] + if frame_idx==0: + video_height,video_width,_ = frame.shape + frame_bgr = cv2.cvtColor(frame.asnumpy(), cv2.COLOR_BGR2RGB) + pts_list, bbox_list = analyze_face(frame_bgr) + + if len(bbox_list)>0 and None not in bbox_list: + bbox = bbox_list[0] + else: + isvalid = False + bbox = [] + print(f"set isvalid to False as broken img in {frame_idx} of {vid}") + break + + #print(pts_list) + if len(pts_list)>0 and pts_list is not None: + pts = pts_list.tolist() + else: + isvalid = False + pts = [] + break + + if frame_idx==0: + x1,y1,x2,y2 = bbox + face_height, face_width = y2-y1,x2-x1 + + total_pts_list.append(pts) + total_bbox_list.append(bbox) + + meta_data = { + "mp4_path": vid_path, + "wav_path": wav_path, + "video_size": [video_height, video_width], + "face_size": [face_height, face_width], + "frames": total_frames, + "face_list": total_bbox_list, + "landmark_list": total_pts_list, + "isvalid":isvalid, + } + with open(vid_meta, 'w') as f: + json.dump(meta_data, f, indent=4) + + + +def main(cfg): + # Ensure all necessary directories exist + os.makedirs(cfg.video_root_25fps, exist_ok=True) + os.makedirs(cfg.video_audio_clip_root, exist_ok=True) + os.makedirs(cfg.meta_root, exist_ok=True) + os.makedirs(os.path.dirname(cfg.video_file_list), exist_ok=True) + os.makedirs(os.path.dirname(cfg.video_clip_file_list_train), exist_ok=True) + os.makedirs(os.path.dirname(cfg.video_clip_file_list_val), exist_ok=True) + + vid_list = os.listdir(cfg.video_root_raw) + sorted_vid_list = sorted(vid_list) + + # Save video file list + with open(cfg.video_file_list, 'w') as file: + for vid in sorted_vid_list: + file.write(vid + '\n') + + # 1. Convert videos to 25 FPS + convert_video(cfg.video_root_raw, cfg.video_root_25fps, sorted_vid_list) + + # 2. Segment videos into 30-second clips + segment_video(cfg.video_root_25fps, cfg.video_audio_clip_root, vid_list, segment_duration=cfg.clip_len_second) + + # 3. Extract audio + clip_vid_list = os.listdir(cfg.video_audio_clip_root) + extract_audio(cfg.video_audio_clip_root, cfg.video_audio_clip_root, clip_vid_list) + + # 4. Generate video metadata + analyze_video(cfg.video_audio_clip_root, cfg.meta_root, clip_vid_list) + + # 5. Generate training and validation set lists + generate_train_list(cfg) + print("done") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="./configs/training/preprocess.yaml") + args = parser.parse_args() + config = OmegaConf.load(args.config) + + main(config) + \ No newline at end of file diff --git a/train.py b/train.py new file mode 100755 index 0000000..2c9f0ec --- /dev/null +++ b/train.py @@ -0,0 +1,580 @@ +import argparse +import diffusers +import logging +import math +import os +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +import warnings +import random + +from accelerate import Accelerator +from accelerate.utils import LoggerType +from accelerate import InitProcessGroupKwargs +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from datetime import datetime +from datetime import timedelta + +from diffusers.utils import check_min_version +from einops import rearrange +from omegaconf import OmegaConf +from tqdm.auto import tqdm + +from musetalk.utils.utils import ( + delete_additional_ckpt, + seed_everything, + get_mouth_region, + process_audio_features, + save_models +) +from musetalk.loss.basic_loss import set_requires_grad +from musetalk.loss.syncnet import get_sync_loss +from musetalk.utils.training_utils import ( + initialize_models_and_optimizers, + initialize_dataloaders, + initialize_loss_functions, + initialize_syncnet, + initialize_vgg, + validation +) + +logger = get_logger(__name__, log_level="INFO") +warnings.filterwarnings("ignore") +check_min_version("0.10.0.dev0") + +def main(cfg): + exp_name = cfg.exp_name + save_dir = f"{cfg.output_dir}/{exp_name}" + os.makedirs(save_dir, exist_ok=True) + + kwargs = DistributedDataParallelKwargs() + process_group_kwargs = InitProcessGroupKwargs( + timeout=timedelta(seconds=5400)) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, + log_with=["tensorboard", LoggerType.TENSORBOARD], + project_dir=os.path.join(save_dir, "./tensorboard"), + kwargs_handlers=[kwargs, process_group_kwargs], + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if cfg.seed is not None: + print('cfg.seed', cfg.seed, accelerator.process_index) + seed_everything(cfg.seed + accelerator.process_index) + + weight_dtype = torch.float32 + + model_dict = initialize_models_and_optimizers(cfg, accelerator, weight_dtype) + dataloader_dict = initialize_dataloaders(cfg) + loss_dict = initialize_loss_functions(cfg, accelerator, model_dict['scheduler_max_steps']) + syncnet = initialize_syncnet(cfg, accelerator, weight_dtype) + vgg_IN, pyramid, downsampler = initialize_vgg(cfg, accelerator) + + # Prepare everything with our `accelerator`. + model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] = accelerator.prepare( + model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] + ) + print("length train/val", len(dataloader_dict['train_dataloader']), len(dataloader_dict['val_dataloader'])) + + # Calculate training steps and epochs + num_update_steps_per_epoch = math.ceil( + len(dataloader_dict['train_dataloader']) / cfg.solver.gradient_accumulation_steps + ) + num_train_epochs = math.ceil( + cfg.solver.max_train_steps / num_update_steps_per_epoch + ) + + # Initialize trackers on the main process + if accelerator.is_main_process: + run_time = datetime.now().strftime("%Y%m%d-%H%M") + accelerator.init_trackers( + cfg.exp_name, + init_kwargs={"mlflow": {"run_name": run_time}}, + ) + + # Calculate total batch size + total_batch_size = ( + cfg.data.train_bs + * accelerator.num_processes + * cfg.solver.gradient_accumulation_steps + ) + + # Log training information + logger.info("***** Running training *****") + logger.info(f"Num Epochs = {num_train_epochs}") + logger.info(f"Instantaneous batch size per device = {cfg.data.train_bs}") + logger.info( + f"Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info( + f"Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}") + logger.info(f"Total optimization steps = {cfg.solver.max_train_steps}") + + global_step = 0 + first_epoch = 0 + + # Load checkpoint if resuming training + if cfg.resume_from_checkpoint: + resume_dir = save_dir + dirs = os.listdir(resume_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + if len(dirs) > 0: + path = dirs[-1] + accelerator.load_state(os.path.join(resume_dir, path)) + accelerator.print(f"Resuming from checkpoint {path}") + global_step = int(path.split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + resume_step = global_step % num_update_steps_per_epoch + + # Initialize progress bar + progress_bar = tqdm( + range(global_step, cfg.solver.max_train_steps), + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description("Steps") + + # Log model types + print("log type of models") + print("unet", model_dict['unet'].dtype) + print("vae", model_dict['vae'].dtype) + print("wav2vec", model_dict['wav2vec'].dtype) + + def get_ganloss_weight(step): + """Calculate GAN loss weight based on training step""" + if step < cfg.discriminator_train_params.start_gan: + return 0.0 + else: + return 1.0 + + # Training loop + for epoch in range(first_epoch, num_train_epochs): + # Set models to training mode + model_dict['unet'].train() + if cfg.loss_params.gan_loss > 0: + loss_dict['discriminator'].train() + if cfg.loss_params.mouth_gan_loss > 0: + loss_dict['mouth_discriminator'].train() + + # Initialize loss accumulators + train_loss = 0.0 + train_loss_D = 0.0 + train_loss_D_mouth = 0.0 + l1_loss_accum = 0.0 + vgg_loss_accum = 0.0 + gan_loss_accum = 0.0 + gan_loss_accum_mouth = 0.0 + fm_loss_accum = 0.0 + sync_loss_accum = 0.0 + adapted_weight_accum = 0.0 + + t_data_start = time.time() + for step, batch in enumerate(dataloader_dict['train_dataloader']): + t_data = time.time() - t_data_start + t_model_start = time.time() + + with torch.no_grad(): + # Process input data + pixel_values = batch["pixel_values_vid"].to(weight_dtype).to( + accelerator.device, + non_blocking=True + ) + bsz, num_frames, c, h, w = pixel_values.shape + + # Process reference images + ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to( + accelerator.device, + non_blocking=True + ) + + # Get face mask for GAN + pixel_values_face_mask = batch['pixel_values_face_mask'] + + # Process audio features + audio_prompts = process_audio_features(cfg, batch, model_dict['wav2vec'], bsz, num_frames, weight_dtype) + + # Initialize adapted weight + adapted_weight = 1 + + # Process sync loss if enabled + if cfg.loss_params.sync_loss > 0: + mels = batch['mel'] + # Prepare frames for latentsync (combine channels and frames) + gt_frames = rearrange(pixel_values, 'b f c h w-> b (f c) h w') + # Use lower half of face for latentsync + height = gt_frames.shape[2] + gt_frames = gt_frames[:, :, height // 2:, :] + + # Get audio embeddings + audio_embed = syncnet.get_audio_embed(mels) + + # Calculate adapted weight based on audio-visual similarity + if cfg.use_adapted_weight: + vision_embed_gt = syncnet.get_vision_embed(gt_frames) + image_audio_sim_gt = F.cosine_similarity( + audio_embed, + vision_embed_gt, + dim=1 + )[0] + + if image_audio_sim_gt < 0.05 or image_audio_sim_gt > 0.65: + if cfg.adapted_weight_type == "cut_off": + adapted_weight = 0.0 # Skip this batch + print( + f"\nThe i-a similarity in step {global_step} is {image_audio_sim_gt}, set adapted_weight to {adapted_weight}.") + elif cfg.adapted_weight_type == "linear": + adapted_weight = image_audio_sim_gt + else: + print(f"unknown adapted_weight_type: {cfg.adapted_weight_type}") + adapted_weight = 1 + + # Random frame selection for memory efficiency + max_start = 16 - cfg.num_backward_frames + frames_left_index = random.randint(0, max_start) if max_start > 0 else 0 + frames_right_index = frames_left_index + cfg.num_backward_frames + else: + frames_left_index = 0 + frames_right_index = cfg.data.n_sample_frames + + # Extract frames for backward pass + pixel_values_backward = pixel_values[:, frames_left_index:frames_right_index, ...] + ref_pixel_values_backward = ref_pixel_values[:, frames_left_index:frames_right_index, ...] + pixel_values_face_mask_backward = pixel_values_face_mask[:, frames_left_index:frames_right_index, ...] + audio_prompts_backward = audio_prompts[:, frames_left_index:frames_right_index, ...] + + # Encode target images + frames = rearrange(pixel_values_backward, 'b f c h w-> (b f) c h w') + latents = model_dict['vae'].encode(frames).latent_dist.mode() + latents = latents * model_dict['vae'].config.scaling_factor + latents = latents.float() + + # Create masked images + masked_pixel_values = pixel_values_backward.clone() + masked_pixel_values[:, :, :, h//2:, :] = -1 + masked_frames = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w') + masked_latents = model_dict['vae'].encode(masked_frames).latent_dist.mode() + masked_latents = masked_latents * model_dict['vae'].config.scaling_factor + masked_latents = masked_latents.float() + + # Encode reference images + ref_frames = rearrange(ref_pixel_values_backward, 'b f c h w-> (b f) c h w') + ref_latents = model_dict['vae'].encode(ref_frames).latent_dist.mode() + ref_latents = ref_latents * model_dict['vae'].config.scaling_factor + ref_latents = ref_latents.float() + + # Prepare face mask and audio features + pixel_values_face_mask_backward = rearrange( + pixel_values_face_mask_backward, + "b f c h w -> (b f) c h w" + ) + audio_prompts_backward = rearrange( + audio_prompts_backward, + 'b f c h w-> (b f) c h w' + ) + audio_prompts_backward = rearrange( + audio_prompts_backward, + '(b f) c h w -> (b f) (c h) w', + b=bsz + ) + + # Apply reference dropout (currently inactive) + dropout = nn.Dropout(p=cfg.ref_dropout_rate) + ref_latents = dropout(ref_latents) + + # Prepare model inputs + input_latents = torch.cat([masked_latents, ref_latents], dim=1) + input_latents = input_latents.to(weight_dtype) + timesteps = torch.tensor([0], device=input_latents.device) + + # Forward pass + latents_pred = model_dict['net']( + input_latents, + timesteps, + audio_prompts_backward, + ) + latents_pred = (1 / model_dict['vae'].config.scaling_factor) * latents_pred + image_pred = model_dict['vae'].decode(latents_pred).sample + + # Convert to float + image_pred = image_pred.float() + frames = frames.float() + + # Calculate L1 loss + l1_loss = loss_dict['L1_loss'](frames, image_pred) + l1_loss_accum += l1_loss.item() + loss = cfg.loss_params.l1_loss * l1_loss * adapted_weight + + # Process mouth GAN loss if enabled + if cfg.loss_params.mouth_gan_loss > 0: + frames_mouth, image_pred_mouth = get_mouth_region( + frames, + image_pred, + pixel_values_face_mask_backward + ) + pyramide_real_mouth = pyramid(downsampler(frames_mouth)) + pyramide_generated_mouth = pyramid(downsampler(image_pred_mouth)) + + # Process VGG loss if enabled + if cfg.loss_params.vgg_loss > 0: + pyramide_real = pyramid(downsampler(frames)) + pyramide_generated = pyramid(downsampler(image_pred)) + + loss_IN = 0 + for scale in cfg.loss_params.pyramid_scale: + x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)]) + y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)]) + for i, weight in enumerate(cfg.loss_params.vgg_layer_weight): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + loss_IN += weight * value + loss_IN /= sum(cfg.loss_params.vgg_layer_weight) + loss += loss_IN * cfg.loss_params.vgg_loss * adapted_weight + vgg_loss_accum += loss_IN.item() + + # Process GAN loss if enabled + if cfg.loss_params.gan_loss > 0: + set_requires_grad(loss_dict['discriminator'], False) + loss_G = 0. + discriminator_maps_generated = loss_dict['discriminator'](pyramide_generated) + discriminator_maps_real = loss_dict['discriminator'](pyramide_real) + + for scale in loss_dict['disc_scales']: + key = 'prediction_map_%s' % scale + value = ((1 - discriminator_maps_generated[key]) ** 2).mean() + loss_G += value + gan_loss_accum += loss_G.item() + + loss += loss_G * cfg.loss_params.gan_loss * get_ganloss_weight(global_step) * adapted_weight + + # Process feature matching loss if enabled + if cfg.loss_params.fm_loss[0] > 0: + L_feature_matching = 0. + for scale in loss_dict['disc_scales']: + key = 'feature_maps_%s' % scale + for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): + value = torch.abs(a - b).mean() + L_feature_matching += value * cfg.loss_params.fm_loss[i] + loss += L_feature_matching * adapted_weight + fm_loss_accum += L_feature_matching.item() + + # Process mouth GAN loss if enabled + if cfg.loss_params.mouth_gan_loss > 0: + set_requires_grad(loss_dict['mouth_discriminator'], False) + loss_G = 0. + mouth_discriminator_maps_generated = loss_dict['mouth_discriminator'](pyramide_generated_mouth) + mouth_discriminator_maps_real = loss_dict['mouth_discriminator'](pyramide_real_mouth) + + for scale in loss_dict['disc_scales']: + key = 'prediction_map_%s' % scale + value = ((1 - mouth_discriminator_maps_generated[key]) ** 2).mean() + loss_G += value + gan_loss_accum_mouth += loss_G.item() + + loss += loss_G * cfg.loss_params.mouth_gan_loss * get_ganloss_weight(global_step) * adapted_weight + + # Process feature matching loss for mouth if enabled + if cfg.loss_params.fm_loss[0] > 0: + L_feature_matching = 0. + for scale in loss_dict['disc_scales']: + key = 'feature_maps_%s' % scale + for i, (a, b) in enumerate(zip(mouth_discriminator_maps_real[key], mouth_discriminator_maps_generated[key])): + value = torch.abs(a - b).mean() + L_feature_matching += value * cfg.loss_params.fm_loss[i] + loss += L_feature_matching * adapted_weight + fm_loss_accum += L_feature_matching.item() + + # Process sync loss if enabled + if cfg.loss_params.sync_loss > 0: + pred_frames = rearrange( + image_pred, '(b f) c h w-> b (f c) h w', f=pixel_values_backward.shape[1]) + pred_frames = pred_frames[:, :, height // 2 :, :] + sync_loss, image_audio_sim_pred = get_sync_loss( + audio_embed, + gt_frames, + pred_frames, + syncnet, + adapted_weight, + frames_left_index=frames_left_index, + frames_right_index=frames_right_index, + ) + sync_loss_accum += sync_loss.item() + loss += sync_loss * cfg.loss_params.sync_loss * adapted_weight + + # Backward pass + avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean() + train_loss += avg_loss.item() + accelerator.backward(loss) + + # Train discriminator if GAN loss is enabled + if cfg.loss_params.gan_loss > 0: + set_requires_grad(loss_dict['discriminator'], True) + loss_D = loss_dict['discriminator_full'](frames, image_pred.detach()) + avg_loss_D = accelerator.gather(loss_D.repeat(cfg.data.train_bs)).mean() + train_loss_D += avg_loss_D.item() / 1 + loss_D = loss_D * get_ganloss_weight(global_step) * adapted_weight + accelerator.backward(loss_D) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + loss_dict['discriminator'].parameters(), cfg.solver.max_grad_norm) + if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0: + loss_dict['optimizer_D'].step() + loss_dict['scheduler_D'].step() + loss_dict['optimizer_D'].zero_grad() + + # Train mouth discriminator if mouth GAN loss is enabled + if cfg.loss_params.mouth_gan_loss > 0: + set_requires_grad(loss_dict['mouth_discriminator'], True) + mouth_loss_D = loss_dict['mouth_discriminator_full']( + frames_mouth, image_pred_mouth.detach()) + avg_mouth_loss_D = accelerator.gather( + mouth_loss_D.repeat(cfg.data.train_bs)).mean() + train_loss_D_mouth += avg_mouth_loss_D.item() / 1 + mouth_loss_D = mouth_loss_D * get_ganloss_weight(global_step) * adapted_weight + accelerator.backward(mouth_loss_D) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + loss_dict['mouth_discriminator'].parameters(), cfg.solver.max_grad_norm) + if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0: + loss_dict['mouth_optimizer_D'].step() + loss_dict['mouth_scheduler_D'].step() + loss_dict['mouth_optimizer_D'].zero_grad() + + # Update main model + if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0: + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + model_dict['trainable_params'], + cfg.solver.max_grad_norm, + ) + model_dict['optimizer'].step() + model_dict['lr_scheduler'].step() + model_dict['optimizer'].zero_grad() + + # Update progress and log metrics + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({ + "train_loss": train_loss, + "train_loss_D": train_loss_D, + "train_loss_D_mouth": train_loss_D_mouth, + "l1_loss": l1_loss_accum, + "vgg_loss": vgg_loss_accum, + "gan_loss": gan_loss_accum, + "fm_loss": fm_loss_accum, + "sync_loss": sync_loss_accum, + "adapted_weight": adapted_weight_accum, + "lr": model_dict['lr_scheduler'].get_last_lr()[0], + }, step=global_step) + + # Reset loss accumulators + train_loss = 0.0 + l1_loss_accum = 0.0 + vgg_loss_accum = 0.0 + gan_loss_accum = 0.0 + fm_loss_accum = 0.0 + sync_loss_accum = 0.0 + adapted_weight_accum = 0.0 + train_loss_D = 0.0 + train_loss_D_mouth = 0.0 + + # Run validation if needed + if global_step % cfg.val_freq == 0 or global_step == 10: + try: + validation( + cfg, + dataloader_dict['val_dataloader'], + model_dict['net'], + model_dict['vae'], + model_dict['wav2vec'], + accelerator, + save_dir, + global_step, + weight_dtype, + syncnet_score=adapted_weight, + ) + except Exception as e: + print(f"An error occurred during validation: {e}") + + # Save checkpoint if needed + if global_step % cfg.checkpointing_steps == 0: + save_path = os.path.join(save_dir, f"checkpoint-{global_step}") + try: + start_time = time.time() + if accelerator.is_main_process: + save_models( + accelerator, + model_dict['net'], + save_dir, + global_step, + cfg, + logger=logger + ) + delete_additional_ckpt(save_dir, cfg.total_limit) + elapsed_time = time.time() - start_time + if elapsed_time > 300: + print(f"Skipping storage as it took too long in step {global_step}.") + else: + print(f"Resume states saved at {save_dir} successfully in {elapsed_time}s.") + except Exception as e: + print(f"Error when saving model in step {global_step}:", e) + + # Update progress bar + t_model = time.time() - t_model_start + logs = { + "step_loss": loss.detach().item(), + "lr": model_dict['lr_scheduler'].get_last_lr()[0], + "td": f"{t_data:.2f}s", + "tm": f"{t_model:.2f}s", + } + t_data_start = time.time() + progress_bar.set_postfix(**logs) + + if global_step >= cfg.solver.max_train_steps: + break + + # Save model after each epoch + if (epoch + 1) % cfg.save_model_epoch_interval == 0: + try: + start_time = time.time() + if accelerator.is_main_process: + save_models(accelerator, model_dict['net'], save_dir, global_step, cfg) + accelerator.save_state(save_path) + elapsed_time = time.time() - start_time + if elapsed_time > 120: + print(f"Skipping storage as it took too long in step {global_step}.") + else: + print(f"Model saved successfully in {elapsed_time}s.") + except Exception as e: + print(f"Error when saving model in step {global_step}:", e) + accelerator.wait_for_everyone() + + # End training + accelerator.end_training() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml") + args = parser.parse_args() + config = OmegaConf.load(args.config) + main(config) diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..d4ebf3b --- /dev/null +++ b/train.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# MuseTalk Training Script +# This script combines both training stages for the MuseTalk model +# Usage: sh train.sh [stage1|stage2] +# Example: sh train.sh stage1 # To run stage 1 training +# Example: sh train.sh stage2 # To run stage 2 training + +# Check if stage argument is provided +if [ $# -ne 1 ]; then + echo "Error: Please specify the training stage" + echo "Usage: ./train.sh [stage1|stage2]" + exit 1 +fi + +STAGE=$1 + +# Validate stage argument +if [ "$STAGE" != "stage1" ] && [ "$STAGE" != "stage2" ]; then + echo "Error: Invalid stage. Must be either 'stage1' or 'stage2'" + exit 1 +fi + +# Launch distributed training using accelerate +# --config_file: Path to the GPU configuration file +# --main_process_port: Port number for the main process, used for distributed training communication +# train.py: Training script +# --config: Path to the training configuration file +echo "Starting $STAGE training..." +accelerate launch --config_file ./configs/training/gpu.yaml \ + --main_process_port 29502 \ + train.py --config ./configs/training/$STAGE.yaml + +echo "Training completed for $STAGE" \ No newline at end of file