feat: data preprocessing and training (#294)

* docs: update readme

* docs: update readme

* feat: training codes

* feat: data preprocess

* docs: release training
This commit is contained in:
Zhizhou Zhong
2025-04-04 22:10:03 +08:00
committed by GitHub
parent e636166b85
commit 1ab53a626b
23 changed files with 3854 additions and 6 deletions

6
.gitignore vendored
View File

@@ -8,4 +8,8 @@ results/
./models ./models
**/__pycache__/ **/__pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
dataset/
ffmpeg*
debug
exp_out

View File

@@ -130,8 +130,9 @@ https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
- [x] codes for real-time inference. - [x] codes for real-time inference.
- [x] [technical report](https://arxiv.org/abs/2410.10122v2). - [x] [technical report](https://arxiv.org/abs/2410.10122v2).
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122). - [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
- [x] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon). - [x] realtime inference code for 1.5 version.
- [ ] training and dataloader code (Expected completion on 04/04/2025). - [x] training and data preprocessing codes.
- [ ] **always** welcome to submit issues and PRs to improve this repository! 😊
# Getting Started # Getting Started
@@ -187,6 +188,7 @@ huggingface-cli download TMElyralab/MuseTalk --local-dir models/
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main) - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch) - [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth) - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
Finally, these weights should be organized in `models` as follows: Finally, these weights should be organized in `models` as follows:
@@ -198,6 +200,8 @@ Finally, these weights should be organized in `models` as follows:
├── musetalkV15 ├── musetalkV15
│ └── musetalk.json │ └── musetalk.json
│ └── unet.pth │ └── unet.pth
├── syncnet
│ └── latentsync_syncnet.pt
├── dwpose ├── dwpose
│ └── dw-ll_ucoco_384.pth │ └── dw-ll_ucoco_384.pth
├── face-parse-bisent ├── face-parse-bisent
@@ -265,6 +269,73 @@ For faster generation without saving images, you can use:
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
``` ```
## Training
### Data Preparation
To train MuseTalk, you need to prepare your dataset following these steps:
1. **Place your source videos**
For example, if you're using the HDTF dataset, place all your video files in `./dataset/HDTF/source`.
2. **Run the preprocessing script**
```bash
python -m scripts.preprocess --config ./configs/training/preprocess.yaml
```
This script will:
- Extract frames from videos
- Detect and align faces
- Generate audio features
- Create the necessary data structure for training
### Training Process
After data preprocessing, you can start the training process:
1. **First Stage**
```bash
sh train.sh stage1
```
2. **Second Stage**
```bash
sh train.sh stage2
```
### Configuration Adjustment
Before starting the training, you should adjust the configuration files according to your hardware and requirements:
1. **GPU Configuration** (`configs/training/gpu.yaml`):
- `gpu_ids`: Specify the GPU IDs you want to use (e.g., "0,1,2,3")
- `num_processes`: Set this to match the number of GPUs you're using
2. **Stage 1 Configuration** (`configs/training/stage1.yaml`):
- `data.train_bs`: Adjust batch size based on your GPU memory (default: 32)
- `data.n_sample_frames`: Number of sampled frames per video (default: 1)
3. **Stage 2 Configuration** (`configs/training/stage2.yaml`):
- `random_init_unet`: Must be set to `False` to use the model from stage 1
- `data.train_bs`: Smaller batch size due to high GPU memory cost (default: 2)
- `data.n_sample_frames`: Higher value for temporal consistency (default: 16)
- `solver.gradient_accumulation_steps`: Increase to simulate larger batch sizes (default: 8)
### GPU Memory Requirements
Based on our testing on a machine with 8 NVIDIA H20 GPUs:
#### Stage 1 Memory Usage
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|:----------:|:----------------------:|:--------------:|:--------------:|
| 8 | 1 | ~32GB | |
| 16 | 1 | ~45GB | |
| 32 | 1 | ~74GB | ✓ |
#### Stage 2 Memory Usage
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|:----------:|:----------------------:|:--------------:|:--------------:|
| 1 | 8 | ~54GB | |
| 2 | 2 | ~80GB | |
| 2 | 8 | ~85GB | ✓ |
## TestCases For 1.0 ## TestCases For 1.0
<table class="center"> <table class="center">
<tr style="font-weight: bolder;text-align:center;"> <tr style="font-weight: bolder;text-align:center;">
@@ -368,7 +439,7 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference). As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
# Acknowledgement # Acknowledgement
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch). 1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch) and [LatentSync](https://huggingface.co/ByteDance/LatentSync/tree/main).
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings). 1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets. 1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.

21
configs/training/gpu.yaml Executable file
View File

@@ -0,0 +1,21 @@
compute_environment: LOCAL_MACHINE
debug: True
deepspeed_config:
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: False
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
gpu_ids: "5, 7" # modify this according to your GPU number
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 2 # it should be the same as the number of GPUs
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,31 @@
clip_len_second: 30 # the length of the video clip
video_root_raw: "./dataset/HDTF/source/" # the path of the original video
val_list_hdtf:
- RD_Radio7_000
- RD_Radio8_000
- RD_Radio9_000
- WDA_TinaSmith_000
- WDA_TomCarper_000
- WDA_TomPerez_000
- WDA_TomUdall_000
- WDA_VeronicaEscobar0_000
- WDA_VeronicaEscobar1_000
- WDA_WhipJimClyburn_000
- WDA_XavierBecerra_000
- WDA_XavierBecerra_001
- WDA_XavierBecerra_002
- WDA_ZoeLofgren_000
- WRA_SteveScalise1_000
- WRA_TimScott_000
- WRA_ToddYoung_000
- WRA_TomCotton_000
- WRA_TomPrice_000
- WRA_VickyHartzler_000
# following dir will be automatically generated
video_root_25fps: "./dataset/HDTF/video_root_25fps/"
video_file_list: "./dataset/HDTF/video_file_list.txt"
video_audio_clip_root: "./dataset/HDTF/video_audio_clip_root/"
meta_root: "./dataset/HDTF/meta/"
video_clip_file_list_train: "./dataset/HDTF/train.txt"
video_clip_file_list_val: "./dataset/HDTF/val.txt"

89
configs/training/stage1.yaml Executable file
View File

@@ -0,0 +1,89 @@
exp_name: 'test' # Name of the experiment
output_dir: './exp_out/stage1/' # Directory to save experiment outputs
unet_sub_folder: musetalk # Subfolder name for UNet model
random_init_unet: True # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
whisper_path: "./models/whisper" # Path to the Whisper model
pretrained_model_name_or_path: "./models" # Path to pretrained models
resume_from_checkpoint: True # Whether to resume training from a checkpoint
padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
vae_type: "sd-vae" # Type of VAE model to use
# Validation parameters
num_images_to_keep: 8 # Number of validation images to keep
ref_dropout_rate: 0 # Dropout rate for reference images
syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
use_adapted_weight: False # Whether to use adapted weights for loss calculation
cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
crop_type: "crop_resize" # Type of cropping method
random_margin_method: "normal" # Method for random margin generation
num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
data:
dataset_key: "HDTF" # Dataset to use for training
train_bs: 32 # Training batch size (actual batch size is train_bs*n_sample_frames)
image_size: 256 # Size of input images
n_sample_frames: 1 # Number of frames to sample per batch
num_workers: 8 # Number of data loading workers
audio_padding_length_left: 2 # Left padding length for audio features
audio_padding_length_right: 2 # Right padding length for audio features
sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
top_k_ratio: 0.51 # Ratio for top-k sampling
contorl_face_min_size: True # Whether to control minimum face size
min_face_size: 150 # Minimum face size in pixels
loss_params:
l1_loss: 1.0 # Weight for L1 loss
vgg_loss: 0.01 # Weight for VGG perceptual loss
vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
gan_loss: 0 # Weight for GAN loss
fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
sync_loss: 0 # Weight for sync loss
mouth_gan_loss: 0 # Weight for mouth-specific GAN loss
model_params:
discriminator_params:
scales: [1] # Scales for discriminator
block_expansion: 32 # Expansion factor for discriminator blocks
max_features: 512 # Maximum number of features in discriminator
num_blocks: 4 # Number of blocks in discriminator
sn: True # Whether to use spectral normalization
image_channel: 3 # Number of image channels
estimate_jacobian: False # Whether to estimate Jacobian
discriminator_train_params:
lr: 0.000005 # Learning rate for discriminator
eps: 0.00000001 # Epsilon for optimizer
weight_decay: 0.01 # Weight decay for optimizer
patch_size: 1 # Size of patches for discriminator
betas: [0.5, 0.999] # Beta parameters for Adam optimizer
epochs: 10000 # Number of training epochs
start_gan: 1000 # Step to start GAN training
solver:
gradient_accumulation_steps: 1 # Number of steps for gradient accumulation
uncond_steps: 10 # Number of unconditional steps
mixed_precision: 'fp32' # Precision mode for training
enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
gradient_checkpointing: True # Whether to use gradient checkpointing
max_train_steps: 250000 # Maximum number of training steps
max_grad_norm: 1.0 # Maximum gradient norm for clipping
# Learning rate parameters
learning_rate: 2.0e-5 # Base learning rate
scale_lr: False # Whether to scale learning rate
lr_warmup_steps: 1000 # Number of warmup steps for learning rate
lr_scheduler: "linear" # Type of learning rate scheduler
# Optimizer parameters
use_8bit_adam: False # Whether to use 8-bit Adam optimizer
adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
total_limit: 10 # Maximum number of checkpoints to keep
save_model_epoch_interval: 250000 # Interval between model saves
checkpointing_steps: 10000 # Number of steps between checkpoints
val_freq: 2000 # Frequency of validation
seed: 41 # Random seed for reproducibility

89
configs/training/stage2.yaml Executable file
View File

@@ -0,0 +1,89 @@
exp_name: 'test' # Name of the experiment
output_dir: './exp_out/stage2/' # Directory to save experiment outputs
unet_sub_folder: musetalk # Subfolder name for UNet model
random_init_unet: False # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
whisper_path: "./models/whisper" # Path to the Whisper model
pretrained_model_name_or_path: "./models" # Path to pretrained models
resume_from_checkpoint: True # Whether to resume training from a checkpoint
padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
vae_type: "sd-vae" # Type of VAE model to use
# Validation parameters
num_images_to_keep: 8 # Number of validation images to keep
ref_dropout_rate: 0 # Dropout rate for reference images
syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
use_adapted_weight: False # Whether to use adapted weights for loss calculation
cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
crop_type: "dynamic_margin_crop_resize" # Type of cropping method
random_margin_method: "normal" # Method for random margin generation
num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
data:
dataset_key: "HDTF" # Dataset to use for training
train_bs: 2 # Training batch size (actual batch size is train_bs*n_sample_frames)
image_size: 256 # Size of input images
n_sample_frames: 16 # Number of frames to sample per batch
num_workers: 8 # Number of data loading workers
audio_padding_length_left: 2 # Left padding length for audio features
audio_padding_length_right: 2 # Right padding length for audio features
sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
top_k_ratio: 0.51 # Ratio for top-k sampling
contorl_face_min_size: True # Whether to control minimum face size
min_face_size: 200 # Minimum face size in pixels
loss_params:
l1_loss: 1.0 # Weight for L1 loss
vgg_loss: 0.01 # Weight for VGG perceptual loss
vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
gan_loss: 0.01 # Weight for GAN loss
fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
sync_loss: 0.05 # Weight for sync loss
mouth_gan_loss: 0.01 # Weight for mouth-specific GAN loss
model_params:
discriminator_params:
scales: [1] # Scales for discriminator
block_expansion: 32 # Expansion factor for discriminator blocks
max_features: 512 # Maximum number of features in discriminator
num_blocks: 4 # Number of blocks in discriminator
sn: True # Whether to use spectral normalization
image_channel: 3 # Number of image channels
estimate_jacobian: False # Whether to estimate Jacobian
discriminator_train_params:
lr: 0.000005 # Learning rate for discriminator
eps: 0.00000001 # Epsilon for optimizer
weight_decay: 0.01 # Weight decay for optimizer
patch_size: 1 # Size of patches for discriminator
betas: [0.5, 0.999] # Beta parameters for Adam optimizer
epochs: 10000 # Number of training epochs
start_gan: 1000 # Step to start GAN training
solver:
gradient_accumulation_steps: 8 # Number of steps for gradient accumulation
uncond_steps: 10 # Number of unconditional steps
mixed_precision: 'fp32' # Precision mode for training
enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
gradient_checkpointing: True # Whether to use gradient checkpointing
max_train_steps: 250000 # Maximum number of training steps
max_grad_norm: 1.0 # Maximum gradient norm for clipping
# Learning rate parameters
learning_rate: 5.0e-6 # Base learning rate
scale_lr: False # Whether to scale learning rate
lr_warmup_steps: 1000 # Number of warmup steps for learning rate
lr_scheduler: "linear" # Type of learning rate scheduler
# Optimizer parameters
use_8bit_adam: False # Whether to use 8-bit Adam optimizer
adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
total_limit: 10 # Maximum number of checkpoints to keep
save_model_epoch_interval: 250000 # Interval between model saves
checkpointing_steps: 2000 # Number of steps between checkpoints
val_freq: 2000 # Frequency of validation
seed: 41 # Random seed for reproducibility

View File

@@ -0,0 +1,19 @@
# This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/configs/training/syncnet_16_pixel.yaml).
model:
audio_encoder: # input (1, 80, 52)
in_channels: 1
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
attn_blocks: [0, 0, 0, 0, 0, 0, 0]
dropout: 0.0
visual_encoder: # input (48, 128, 256)
in_channels: 48
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
dropout: 0.0
ckpt:
resume_ckpt_path: ""
inference_ckpt_path: ./models/syncnet/latentsync_syncnet.pt # this pretrained model is from LatentSync (https://huggingface.co/ByteDance/LatentSync/tree/main)
save_ckpt_steps: 2500

View File

@@ -59,7 +59,7 @@ cmd_args="--inference_config $config_path \
--result_dir $result_dir \ --result_dir $result_dir \
--unet_model_path $unet_model_path \ --unet_model_path $unet_model_path \
--unet_config $unet_config \ --unet_config $unet_config \
--version $version_ar" --version $version_arg"
# Add realtime-specific arguments if in realtime mode # Add realtime-specific arguments if in realtime mode
if [ "$mode" = "realtime" ]; then if [ "$mode" = "realtime" ]; then

168
musetalk/data/audio.py Executable file
View File

@@ -0,0 +1,168 @@
import librosa
import librosa.filters
import numpy as np
from scipy import signal
from scipy.io import wavfile
class HParams:
# copy from wav2lip
def __init__(self):
self.n_fft = 800
self.hop_size = 200
self.win_size = 800
self.sample_rate = 16000
self.frame_shift_ms = None
self.signal_normalization = True
self.allow_clipping_in_normalization = True
self.symmetric_mels = True
self.max_abs_value = 4.0
self.preemphasize = True
self.preemphasis = 0.97
self.min_level_db = -100
self.ref_level_db = 20
self.fmin = 55
self.fmax=7600
self.use_lws=False
self.num_mels=80 # Number of mel-spectrogram channels and local conditioning dimensionality
self.rescale=True # Whether to rescale audio prior to preprocessing
self.rescaling_max=0.9 # Rescaling value
self.use_lws=False
hp = HParams()
def load_wav(path, sr):
return librosa.core.load(path, sr=sr)[0]
#def load_wav(path, sr):
# audio, sr_native = sf.read(path)
# if sr != sr_native:
# audio = librosa.resample(audio.T, sr_native, sr).T
# return audio
def save_wav(wav, path, sr):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
#proposed by @dsmiller
wavfile.write(path, sr, wav.astype(np.int16))
def save_wavenet_wav(wav, path, sr):
librosa.output.write_wav(path, wav, sr=sr)
def preemphasis(wav, k, preemphasize=True):
if preemphasize:
return signal.lfilter([1, -k], [1], wav)
return wav
def inv_preemphasis(wav, k, inv_preemphasize=True):
if inv_preemphasize:
return signal.lfilter([1], [1, -k], wav)
return wav
def get_hop_size():
hop_size = hp.hop_size
if hop_size is None:
assert hp.frame_shift_ms is not None
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
return hop_size
def linearspectrogram(wav):
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
if hp.signal_normalization:
return _normalize(S)
return S
def melspectrogram(wav):
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
if hp.signal_normalization:
return _normalize(S)
return S
def _lws_processor():
import lws
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
def _stft(y):
if hp.use_lws:
return _lws_processor(hp).stft(y).T
else:
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
##########################################################
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
def num_frames(length, fsize, fshift):
"""Compute number of time frames of spectrogram
"""
pad = (fsize - fshift)
if length % fshift == 0:
M = (length + pad * 2 - fsize) // fshift + 1
else:
M = (length + pad * 2 - fsize) // fshift + 2
return M
def pad_lr(x, fsize, fshift):
"""Compute left and right padding
"""
M = num_frames(len(x), fsize, fshift)
pad = (fsize - fshift)
T = len(x) + 2 * pad
r = (M - 1) * fshift + fsize - T
return pad, pad + r
##########################################################
#Librosa correct padding
def librosa_pad_lr(x, fsize, fshift):
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
# Conversions
_mel_basis = None
def _linear_to_mel(spectogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectogram)
def _build_mel_basis():
assert hp.fmax <= hp.sample_rate // 2
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
fmin=hp.fmin, fmax=hp.fmax)
def _amp_to_db(x):
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, (x) * 0.05)
def _normalize(S):
if hp.allow_clipping_in_normalization:
if hp.symmetric_mels:
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
-hp.max_abs_value, hp.max_abs_value)
else:
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
if hp.symmetric_mels:
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
else:
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
def _denormalize(D):
if hp.allow_clipping_in_normalization:
if hp.symmetric_mels:
return (((np.clip(D, -hp.max_abs_value,
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
+ hp.min_level_db)
else:
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
if hp.symmetric_mels:
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
else:
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)

607
musetalk/data/dataset.py Executable file
View File

@@ -0,0 +1,607 @@
import os
import numpy as np
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, ConcatDataset
import torchvision.transforms as transforms
from transformers import AutoFeatureExtractor
import librosa
import time
import json
import math
from decord import AudioReader, VideoReader
from decord.ndarray import cpu
from musetalk.data.sample_method import get_src_idx, shift_landmarks_to_face_coordinates, resize_landmark
from musetalk.data import audio
syncnet_mel_step_size = math.ceil(16 / 5 * 16) # latentsync
class FaceDataset(Dataset):
"""Dataset class for loading and processing video data
Each video can be represented as:
- Concatenated frame images
- '.mp4' or '.gif' files
- Folder containing all frames
"""
def __init__(self,
cfg,
list_paths,
root_path='./dataset/',
repeats=None):
# Initialize dataset paths
meta_paths = []
if repeats is None:
repeats = [1] * len(list_paths)
assert len(repeats) == len(list_paths)
# Load data list
for list_path, repeat_time in zip(list_paths, repeats):
with open(list_path, 'r') as f:
num = 0
f.readline() # Skip header line
for line in f.readlines():
line_info = line.strip()
meta = line_info.split()
meta = meta[0]
meta_paths.extend([os.path.join(root_path, meta)] * repeat_time)
num += 1
print(f'{list_path}: {num} x {repeat_time} = {num * repeat_time} samples')
# Set basic attributes
self.meta_paths = meta_paths
self.root_path = root_path
self.image_size = cfg['image_size']
self.min_face_size = cfg['min_face_size']
self.T = cfg['T']
self.sample_method = cfg['sample_method']
self.top_k_ratio = cfg['top_k_ratio']
self.max_attempts = 200
self.padding_pixel_mouth = cfg['padding_pixel_mouth']
# Cropping related parameters
self.crop_type = cfg['crop_type']
self.jaw2edge_margin_mean = cfg['cropping_jaw2edge_margin_mean']
self.jaw2edge_margin_std = cfg['cropping_jaw2edge_margin_std']
self.random_margin_method = cfg['random_margin_method']
# Image transformations
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.pose_to_tensor = transforms.Compose([
transforms.ToTensor(),
])
# Feature extractor
self.feature_extractor = AutoFeatureExtractor.from_pretrained(cfg['whisper_path'])
self.contorl_face_min_size = cfg["contorl_face_min_size"]
print("The sample method is: ", self.sample_method)
print(f"only use face size > {self.min_face_size}", self.contorl_face_min_size)
def generate_random_value(self):
"""Generate random value
Returns:
float: Generated random value
"""
if self.random_margin_method == "uniform":
random_value = np.random.uniform(
self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
self.jaw2edge_margin_mean + self.jaw2edge_margin_std
)
elif self.random_margin_method == "normal":
random_value = np.random.normal(
loc=self.jaw2edge_margin_mean,
scale=self.jaw2edge_margin_std
)
random_value = np.clip(
random_value,
self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
self.jaw2edge_margin_mean + self.jaw2edge_margin_std,
)
else:
raise ValueError(f"Invalid random margin method: {self.random_margin_method}")
return max(0, random_value)
def dynamic_margin_crop(self, img, original_bbox, extra_margin=None):
"""Dynamically crop image with dynamic margin
Args:
img: Input image
original_bbox: Original bounding box
extra_margin: Extra margin
Returns:
tuple: (x1, y1, x2, y2, extra_margin)
"""
if extra_margin is None:
extra_margin = self.generate_random_value()
w, h = img.size
x1, y1, x2, y2 = original_bbox
y2 = min(y2 + int(extra_margin), h)
return x1, y1, x2, y2, extra_margin
def crop_resize_img(self, img, bbox, crop_type='crop_resize', extra_margin=None):
"""Crop and resize image
Args:
img: Input image
bbox: Bounding box
crop_type: Type of cropping
extra_margin: Extra margin
Returns:
tuple: (Processed image, extra_margin, mask_scaled_factor)
"""
mask_scaled_factor = 1.
if crop_type == 'crop_resize':
x1, y1, x2, y2 = bbox
img = img.crop((x1, y1, x2, y2))
img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
elif crop_type == 'dynamic_margin_crop_resize':
x1, y1, x2, y2, extra_margin = self.dynamic_margin_crop(img, bbox, extra_margin)
w_original, _ = img.size
img = img.crop((x1, y1, x2, y2))
w_cropped, _ = img.size
mask_scaled_factor = w_cropped / w_original
img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
elif crop_type == 'resize':
w, h = img.size
scale = np.sqrt(self.image_size ** 2 / (h * w))
new_w = int(w * scale) / 64 * 64
new_h = int(h * scale) / 64 * 64
img = img.resize((new_w, new_h), Image.LANCZOS)
return img, extra_margin, mask_scaled_factor
def get_audio_file(self, wav_path, start_index):
"""Get audio file features
Args:
wav_path: Audio file path
start_index: Starting index
Returns:
tuple: (Audio features, start index)
"""
if not os.path.exists(wav_path):
return None
audio_input_librosa, sampling_rate = librosa.load(wav_path, sr=16000)
assert sampling_rate == 16000
while start_index >= 25 * 30:
audio_input = audio_input_librosa[16000*30:]
start_index -= 25 * 30
if start_index + 2 * 25 >= 25 * 30:
start_index -= 4 * 25
audio_input = audio_input_librosa[16000*4:16000*34]
else:
audio_input = audio_input_librosa[:16000*30]
assert 2 * (start_index) >= 0
assert 2 * (start_index + 2 * 25) <= 1500
audio_input = self.feature_extractor(
audio_input,
return_tensors="pt",
sampling_rate=sampling_rate
).input_features
return audio_input, start_index
def get_audio_file_mel(self, wav_path, start_index):
"""Get mel spectrogram of audio file
Args:
wav_path: Audio file path
start_index: Starting index
Returns:
tuple: (Mel spectrogram, start index)
"""
if not os.path.exists(wav_path):
return None
audio_input, sampling_rate = librosa.load(wav_path, sr=16000)
assert sampling_rate == 16000
audio_input = self.mel_feature_extractor(audio_input)
return audio_input, start_index
def mel_feature_extractor(self, audio_input):
"""Extract mel spectrogram features
Args:
audio_input: Input audio
Returns:
ndarray: Mel spectrogram features
"""
orig_mel = audio.melspectrogram(audio_input)
return orig_mel.T
def crop_audio_window(self, spec, start_frame_num, fps=25):
"""Crop audio window
Args:
spec: Spectrogram
start_frame_num: Starting frame number
fps: Frames per second
Returns:
ndarray: Cropped spectrogram
"""
start_idx = int(80. * (start_frame_num / float(fps)))
end_idx = start_idx + syncnet_mel_step_size
return spec[start_idx: end_idx, :]
def get_syncnet_input(self, video_path):
"""Get SyncNet input features
Args:
video_path: Video file path
Returns:
ndarray: SyncNet input features
"""
ar = AudioReader(video_path, sample_rate=16000)
original_mel = audio.melspectrogram(ar[:].asnumpy().squeeze(0))
return original_mel.T
def get_resized_mouth_mask(
self,
img_resized,
landmark_array,
face_shape,
padding_pixel_mouth=0,
image_size=256,
crop_margin=0
):
landmark_array = np.array(landmark_array)
resized_landmark = resize_landmark(
landmark_array, w=face_shape[0], h=face_shape[1], new_w=image_size, new_h=image_size)
landmark_array = np.array(resized_landmark[48 : 67]) # the lip landmarks in 68 landmarks format
min_x, min_y = np.min(landmark_array, axis=0)
max_x, max_y = np.max(landmark_array, axis=0)
min_x = min_x - padding_pixel_mouth
max_x = max_x + padding_pixel_mouth
# Calculate x-axis length and use it for y-axis
width = max_x - min_x
# Calculate old center point
center_y = (max_y + min_y) / 2
# Determine new min_y and max_y based on width
min_y = center_y - width / 4
max_y = center_y + width / 4
# Adjust mask position for dynamic crop, shift y-axis
min_y = min_y - crop_margin
max_y = max_y - crop_margin
# Prevent out of bounds
min_x = max(min_x, 0)
min_y = max(min_y, 0)
max_x = min(max_x, face_shape[0])
max_y = min(max_y, face_shape[1])
mask = np.zeros_like(np.array(img_resized))
mask[round(min_y):round(max_y), round(min_x):round(max_x)] = 255
return Image.fromarray(mask)
def __len__(self):
return 100000
def __getitem__(self, idx):
attempts = 0
while attempts < self.max_attempts:
try:
meta_path = random.sample(self.meta_paths, k=1)[0]
with open(meta_path, 'r') as f:
meta_data = json.load(f)
except Exception as e:
print(f"meta file error:{meta_path}")
print(e)
attempts += 1
time.sleep(0.1)
continue
video_path = meta_data["mp4_path"]
wav_path = meta_data["wav_path"]
bbox_list = meta_data["face_list"]
landmark_list = meta_data["landmark_list"]
T = self.T
s = 0
e = meta_data["frames"]
len_valid_clip = e - s
if len_valid_clip < T * 10:
attempts += 1
print(f"video {video_path} has less than {T * 10} frames")
continue
try:
cap = VideoReader(video_path, fault_tol=1, ctx=cpu(0))
total_frames = len(cap)
assert total_frames == len(landmark_list)
assert total_frames == len(bbox_list)
landmark_shape = np.array(landmark_list).shape
if landmark_shape != (total_frames, 68, 2):
attempts += 1
print(f"video {video_path} has invalid landmark shape: {landmark_shape}, expected: {(total_frames, 68, 2)}") # we use 68 landmarks
continue
except Exception as e:
print(f"video file error:{video_path}")
print(e)
attempts += 1
time.sleep(0.1)
continue
shift_landmarks, bbox_list_union, face_shapes = shift_landmarks_to_face_coordinates(
landmark_list,
bbox_list
)
if self.contorl_face_min_size and face_shapes[0][0] < self.min_face_size:
print(f"video {video_path} has face size {face_shapes[0][0]} less than minimum required {self.min_face_size}")
attempts += 1
continue
step = 1
drive_idx_start = random.randint(s, e - T * step)
drive_idx_list = list(
range(drive_idx_start, drive_idx_start + T * step, step))
assert len(drive_idx_list) == T
src_idx_list = []
list_index_out_of_range = False
for drive_idx in drive_idx_list:
src_idx = get_src_idx(
drive_idx, T, self.sample_method, shift_landmarks, face_shapes, self.top_k_ratio)
if src_idx is None:
list_index_out_of_range = True
break
src_idx = min(src_idx, e - 1)
src_idx = max(src_idx, s)
src_idx_list.append(src_idx)
if list_index_out_of_range:
attempts += 1
print(f"video {video_path} has invalid source index for drive frames")
continue
ref_face_valid_flag = True
extra_margin = self.generate_random_value()
# Get reference images
ref_imgs = []
for src_idx in src_idx_list:
imSrc = Image.fromarray(cap[src_idx].asnumpy())
bbox_s = bbox_list_union[src_idx]
imSrc, _, _ = self.crop_resize_img(
imSrc,
bbox_s,
self.crop_type,
extra_margin=None
)
if self.contorl_face_min_size and min(imSrc.size[0], imSrc.size[1]) < self.min_face_size:
ref_face_valid_flag = False
break
ref_imgs.append(imSrc)
if not ref_face_valid_flag:
attempts += 1
print(f"video {video_path} has reference face size smaller than minimum required {self.min_face_size}")
continue
# Get target images and masks
imSameIDs = []
bboxes = []
face_masks = []
face_mask_valid = True
target_face_valid_flag = True
for drive_idx in drive_idx_list:
imSameID = Image.fromarray(cap[drive_idx].asnumpy())
bbox_s = bbox_list_union[drive_idx]
imSameID, _ , mask_scaled_factor = self.crop_resize_img(
imSameID,
bbox_s,
self.crop_type,
extra_margin=extra_margin
)
if self.contorl_face_min_size and min(imSameID.size[0], imSameID.size[1]) < self.min_face_size:
target_face_valid_flag = False
break
crop_margin = extra_margin * mask_scaled_factor
face_mask = self.get_resized_mouth_mask(
imSameID,
shift_landmarks[drive_idx],
face_shapes[drive_idx],
self.padding_pixel_mouth,
self.image_size,
crop_margin=crop_margin
)
if np.count_nonzero(face_mask) == 0:
face_mask_valid = False
break
if face_mask.size[1] == 0 or face_mask.size[0] == 0:
print(f"video {video_path} has invalid face mask size at frame {drive_idx}")
face_mask_valid = False
break
imSameIDs.append(imSameID)
bboxes.append(bbox_s)
face_masks.append(face_mask)
if not face_mask_valid:
attempts += 1
print(f"video {video_path} has invalid face mask")
continue
if not target_face_valid_flag:
attempts += 1
print(f"video {video_path} has target face size smaller than minimum required {self.min_face_size}")
continue
# Process audio features
audio_offset = drive_idx_list[0]
audio_step = step
fps = 25.0 / step
try:
audio_feature, audio_offset = self.get_audio_file(wav_path, audio_offset)
_, audio_offset = self.get_audio_file_mel(wav_path, audio_offset)
audio_feature_mel = self.get_syncnet_input(video_path)
except Exception as e:
print(f"audio file error:{wav_path}")
print(e)
attempts += 1
time.sleep(0.1)
continue
mel = self.crop_audio_window(audio_feature_mel, audio_offset)
if mel.shape[0] != syncnet_mel_step_size:
attempts += 1
print(f"video {video_path} has invalid mel spectrogram shape: {mel.shape}, expected: {syncnet_mel_step_size}")
continue
mel = torch.FloatTensor(mel.T).unsqueeze(0)
# Build sample dictionary
sample = dict(
pixel_values_vid=torch.stack(
[self.to_tensor(imSameID) for imSameID in imSameIDs], dim=0),
pixel_values_ref_img=torch.stack(
[self.to_tensor(ref_img) for ref_img in ref_imgs], dim=0),
pixel_values_face_mask=torch.stack(
[self.pose_to_tensor(face_mask) for face_mask in face_masks], dim=0),
audio_feature=audio_feature[0],
audio_offset=audio_offset,
audio_step=audio_step,
mel=mel,
wav_path=wav_path,
fps=fps,
)
return sample
raise ValueError("Unable to find a valid sample after maximum attempts.")
class HDTFDataset(FaceDataset):
"""HDTF dataset class"""
def __init__(self, cfg):
root_path = './dataset/HDTF/meta'
list_paths = [
'./dataset/HDTF/train.txt',
]
repeats = [10]
super().__init__(cfg, list_paths, root_path, repeats)
print('HDTFDataset: ', len(self))
class VFHQDataset(FaceDataset):
"""VFHQ dataset class"""
def __init__(self, cfg):
root_path = './dataset/VFHQ/meta'
list_paths = [
'./dataset/VFHQ/train.txt',
]
repeats = [1]
super().__init__(cfg, list_paths, root_path, repeats)
print('VFHQDataset: ', len(self))
def PortraitDataset(cfg=None):
"""Return dataset based on configuration
Args:
cfg: Configuration dictionary
Returns:
Dataset: Combined dataset
"""
if cfg["dataset_key"] == "HDTF":
return ConcatDataset([HDTFDataset(cfg)])
elif cfg["dataset_key"] == "VFHQ":
return ConcatDataset([VFHQDataset(cfg)])
else:
print("############ use all dataset ############ ")
return ConcatDataset([HDTFDataset(cfg), VFHQDataset(cfg)])
if __name__ == '__main__':
# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Create dataset with configuration parameters
dataset = PortraitDataset(cfg={
'T': 1, # Number of frames to process at once
'random_margin_method': "normal", # Method for generating random margins: "normal" or "uniform"
'dataset_key': "HDTF", # Dataset to use: "HDTF", "VFHQ", or None for both
'image_size': 256, # Size of processed images (height and width)
'sample_method': 'pose_similarity_and_mouth_dissimilarity', # Method for selecting reference frames
'top_k_ratio': 0.51, # Ratio for top-k selection in reference frame sampling
'contorl_face_min_size': True, # Whether to enforce minimum face size
'padding_pixel_mouth': 10, # Padding pixels around mouth region in mask
'min_face_size': 200, # Minimum face size requirement for dataset
'whisper_path': "./models/whisper", # Path to Whisper model
'cropping_jaw2edge_margin_mean': 10, # Mean margin for jaw-to-edge cropping
'cropping_jaw2edge_margin_std': 10, # Standard deviation for jaw-to-edge cropping
'crop_type': "dynamic_margin_crop_resize", # Type of cropping: "crop_resize", "dynamic_margin_crop_resize", or "resize"
})
print(len(dataset))
import torchvision
os.makedirs('debug', exist_ok=True)
for i in range(10): # Check 10 samples
sample = dataset[0]
print(f"processing {i}")
# Get images and mask
ref_img = (sample['pixel_values_ref_img'] + 1.0) / 2 # (b, c, h, w)
target_img = (sample['pixel_values_vid'] + 1.0) / 2
face_mask = sample['pixel_values_face_mask']
# Print dimension information
print(f"ref_img shape: {ref_img.shape}")
print(f"target_img shape: {target_img.shape}")
print(f"face_mask shape: {face_mask.shape}")
# Create visualization images
b, c, h, w = ref_img.shape
# Apply mask only to target image
target_mask = face_mask
# Keep reference image unchanged
ref_with_mask = ref_img.clone()
# Create mask overlay for target image
target_with_mask = target_img.clone()
target_with_mask = target_with_mask * (1 - target_mask) + target_mask # Apply mask only to target
# Save original images, mask, and overlay results
# First row: original images
# Second row: mask
# Third row: overlay effect
concatenated_img = torch.cat((
ref_img, target_img, # Original images
torch.zeros_like(ref_img), target_mask, # Mask (black for ref)
ref_with_mask, target_with_mask # Overlay effect
), dim=3)
torchvision.utils.save_image(
concatenated_img, f'debug/mask_check_{i}.jpg', nrow=2)

233
musetalk/data/sample_method.py Executable file
View File

@@ -0,0 +1,233 @@
import numpy as np
import random
def summarize_tensor(x):
return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"
def calculate_mouth_open_similarity(landmarks_list, select_idx,top_k=50,ascending=True):
num_landmarks = len(landmarks_list)
mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
print(np.shape(landmarks_list))
## Calculate mouth opening ratios
for i, landmarks in enumerate(landmarks_list):
# Assuming landmarks are in the format [x, y] and accessible by index
mouth_top = landmarks[165] # Adjust index according to your landmarks format
mouth_bottom = landmarks[147] # Adjust index according to your landmarks format
mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
mouth_open_ratios[i] = mouth_open_ratio
# Calculate differences matrix
differences_matrix = np.abs(mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx])
differences_matrix_with_signs = mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx]
print(differences_matrix.shape)
# Find top_k similar indices for each landmark set
if ascending:
top_indices = np.argsort(differences_matrix[i])[:top_k]
else:
top_indices = np.argsort(-differences_matrix[i])[:top_k]
similar_landmarks_indices = top_indices.tolist()
similar_landmarks_distances = differences_matrix_with_signs[i].tolist() #注意这里不要排序
return similar_landmarks_indices, similar_landmarks_distances
#############################################################################################
def get_closed_mouth(landmarks_list,ascending=True,top_k=50):
num_landmarks = len(landmarks_list)
mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
## Calculate mouth opening ratios
#print("landmarks shape",np.shape(landmarks_list))
for i, landmarks in enumerate(landmarks_list):
# Assuming landmarks are in the format [x, y] and accessible by index
#print(landmarks[165])
mouth_top = np.array(landmarks[165])# Adjust index according to your landmarks format
mouth_bottom = np.array(landmarks[147]) # Adjust index according to your landmarks format
mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
mouth_open_ratios[i] = mouth_open_ratio
# Find top_k similar indices for each landmark set
if ascending:
top_indices = np.argsort(mouth_open_ratios)[:top_k]
else:
top_indices = np.argsort(-mouth_open_ratios)[:top_k]
return top_indices
def calculate_landmarks_similarity(selected_idx, landmarks_list,image_shapes, start_index, end_index, top_k=50,ascending=True):
"""
Calculate the similarity between sets of facial landmarks and return the indices of the most similar faces.
Parameters:
landmarks_list (list): A list containing sets of facial landmarks, each element is a set of landmarks.
image_shapes (list): A list containing the shape of each image, each element is a (width, height) tuple.
start_index (int): The starting index of the facial landmarks.
end_index (int): The ending index of the facial landmarks.
top_k (int): The number of most similar landmark sets to return. Default is 50.
ascending (bool): Controls the sorting order. If True, sort in ascending order; If False, sort in descending order. Default is True.
Returns:
similar_landmarks_indices (list): A list containing the indices of the most similar facial landmarks for each face.
resized_landmarks (list): A list containing the resized facial landmarks.
"""
num_landmarks = len(landmarks_list)
resized_landmarks = []
# Preprocess landmarks
for i in range(num_landmarks):
landmark_array = np.array(landmarks_list[i])
selected_landmarks = landmark_array[start_index:end_index]
resized_landmark = resize_landmark(selected_landmarks, w=image_shapes[i][0], h=image_shapes[i][1],new_w=256,new_h=256)
resized_landmarks.append(resized_landmark)
resized_landmarks_array = np.array(resized_landmarks) # Convert list to array for easier manipulation
# Calculate similarity
distances = np.linalg.norm(resized_landmarks_array - resized_landmarks_array[selected_idx][np.newaxis, :], axis=2)
overall_distances = np.mean(distances, axis=1) # Calculate mean distance for each set of landmarks
if ascending:
sorted_indices = np.argsort(overall_distances)
similar_landmarks_indices = sorted_indices[1:top_k+1].tolist() # Exclude self and take top_k
else:
sorted_indices = np.argsort(-overall_distances)
similar_landmarks_indices = sorted_indices[0:top_k].tolist()
return similar_landmarks_indices
def process_bbox_musetalk(face_array, landmark_array):
x_min_face, y_min_face, x_max_face, y_max_face = map(int, face_array)
x_min_lm = min([int(x) for x, y in landmark_array])
y_min_lm = min([int(y) for x, y in landmark_array])
x_max_lm = max([int(x) for x, y in landmark_array])
y_max_lm = max([int(y) for x, y in landmark_array])
x_min = min(x_min_face, x_min_lm)
y_min = min(y_min_face, y_min_lm)
x_max = max(x_max_face, x_max_lm)
y_max = max(y_max_face, y_max_lm)
x_min = max(x_min, 0)
y_min = max(y_min, 0)
return [x_min, y_min, x_max, y_max]
def shift_landmarks_to_face_coordinates(landmark_list, face_list):
"""
Translates the data in landmark_list to the coordinates of the cropped larger face.
Parameters:
landmark_list (list): A list containing multiple sets of facial landmarks.
face_list (list): A list containing multiple facial images.
Returns:
landmark_list_shift (list): The list of translated landmarks.
bbox_union (list): The list of union bounding boxes.
face_shapes (list): The list of facial shapes.
"""
landmark_list_shift = []
bbox_union = []
face_shapes = []
for i in range(len(face_list)):
landmark_array = np.array(landmark_list[i]) # 转换为numpy数组并创建副本
face_array = face_list[i]
f_landmark_bbox = process_bbox_musetalk(face_array, landmark_array)
x_min, y_min, x_max, y_max = f_landmark_bbox
landmark_array[:, 0] = landmark_array[:, 0] - f_landmark_bbox[0]
landmark_array[:, 1] = landmark_array[:, 1] - f_landmark_bbox[1]
landmark_list_shift.append(landmark_array)
bbox_union.append(f_landmark_bbox)
face_shapes.append((x_max - x_min, y_max - y_min))
return landmark_list_shift, bbox_union, face_shapes
def resize_landmark(landmark, w, h, new_w, new_h):
landmark_norm = landmark / [w, h]
landmark_resized = landmark_norm * [new_w, new_h]
return landmark_resized
def get_src_idx(drive_idx, T, sample_method,landmarks_list,image_shapes,top_k_ratio):
"""
Calculate the source index (src_idx) based on the given drive index, T, s, e, and sampling method.
Parameters:
- drive_idx (int): The current drive index.
- T (int): Total number of frames or a specific range limit.
- sample_method (str): Sampling method, which can be "random" or other methods.
- landmarks_list (list): List of facial landmarks.
- image_shapes (list): List of image shapes.
- top_k_ratio (float): Ratio for selecting top k similar frames.
Returns:
- src_idx (int): The calculated source index.
"""
if sample_method == "random":
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
elif sample_method == "pose_similarity":
top_k = int(top_k_ratio*len(landmarks_list))
try:
top_k = int(top_k_ratio*len(landmarks_list))
# facial contour
landmark_start_idx = 0
landmark_end_idx = 16
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
src_idx = random.choice(pose_similarity_list)
while abs(src_idx-drive_idx)<5:
src_idx = random.choice(pose_similarity_list)
except Exception as e:
print(e)
return None
elif sample_method=="pose_similarity_and_closed_mouth":
# facial contour
landmark_start_idx = 0
landmark_end_idx = 16
try:
top_k = int(top_k_ratio*len(landmarks_list))
closed_mouth_list = get_closed_mouth(landmarks_list, ascending=True,top_k=top_k)
#print("closed_mouth_list",closed_mouth_list)
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
#print("pose_similarity_list",pose_similarity_list)
common_list = list(set(closed_mouth_list).intersection(set(pose_similarity_list)))
if len(common_list) == 0:
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
else:
src_idx = random.choice(common_list)
while abs(src_idx-drive_idx) <5:
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
except Exception as e:
print(e)
return None
elif sample_method=="pose_similarity_and_mouth_dissimilarity":
top_k = int(top_k_ratio*len(landmarks_list))
try:
top_k = int(top_k_ratio*len(landmarks_list))
# facial contour for 68 landmarks format
landmark_start_idx = 0
landmark_end_idx = 16
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
# Mouth inner coutour for 68 landmarks format
landmark_start_idx = 60
landmark_end_idx = 67
mouth_dissimilarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=False)
common_list = list(set(pose_similarity_list).intersection(set(mouth_dissimilarity_list)))
if len(common_list) == 0:
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
else:
src_idx = random.choice(common_list)
while abs(src_idx-drive_idx) <5:
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
except Exception as e:
print(e)
return None
else:
raise ValueError(f"Unknown sample_method: {sample_method}")
return src_idx

81
musetalk/loss/basic_loss.py Executable file
View File

@@ -0,0 +1,81 @@
import torch
import torch.nn as nn
from omegaconf import OmegaConf
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from musetalk.loss.discriminator import MultiScaleDiscriminator,DiscriminatorFullModel
import musetalk.loss.vgg_face as vgg_face
class Interpolate(nn.Module):
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolate, self).__init__()
self.size = size
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, input):
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
def set_requires_grad(net, requires_grad=False):
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
if __name__ == "__main__":
cfg = OmegaConf.load("config/audio_adapter/E7.yaml")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pyramid_scale = [1, 0.5, 0.25, 0.125]
vgg_IN = vgg_face.Vgg19().to(device)
pyramid = vgg_face.ImagePyramide(cfg.loss_params.pyramid_scale, 3).to(device)
vgg_IN.eval()
downsampler = Interpolate(size=(224, 224), mode='bilinear', align_corners=False)
image = torch.rand(8, 3, 256, 256).to(device)
image_pred = torch.rand(8, 3, 256, 256).to(device)
pyramide_real = pyramid(downsampler(image))
pyramide_generated = pyramid(downsampler(image_pred))
loss_IN = 0
for scale in cfg.loss_params.pyramid_scale:
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
loss_IN += weight * value
loss_IN /= sum(cfg.loss_params.vgg_layer_weight) # 对vgg不同层取均值金字塔loss是每层叠
print(loss_IN)
#print(cfg.model_params.discriminator_params)
discriminator = MultiScaleDiscriminator(**cfg.model_params.discriminator_params).to(device)
discriminator_full = DiscriminatorFullModel(discriminator)
disc_scales = cfg.model_params.discriminator_params.scales
# Prepare optimizer and loss function
optimizer_D = optim.AdamW(discriminator.parameters(),
lr=cfg.discriminator_train_params.lr,
weight_decay=cfg.discriminator_train_params.weight_decay,
betas=cfg.discriminator_train_params.betas,
eps=cfg.discriminator_train_params.eps)
scheduler_D = CosineAnnealingLR(optimizer_D,
T_max=cfg.discriminator_train_params.epochs,
eta_min=1e-6)
discriminator.train()
set_requires_grad(discriminator, False)
loss_G = 0.
discriminator_maps_generated = discriminator(pyramide_generated)
discriminator_maps_real = discriminator(pyramide_real)
for scale in disc_scales:
key = 'prediction_map_%s' % scale
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
loss_G += value
print(loss_G)

44
musetalk/loss/conv.py Executable file
View File

@@ -0,0 +1,44 @@
import torch
from torch import nn
from torch.nn import functional as F
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act(out)
class nonorm_Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
)
self.act = nn.LeakyReLU(0.01, inplace=True)
def forward(self, x):
out = self.conv_block(x)
return self.act(out)
class Conv2dTranspose(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
def forward(self, x):
out = self.conv_block(x)
return self.act(out)

145
musetalk/loss/discriminator.py Executable file
View File

@@ -0,0 +1,145 @@
from torch import nn
import torch.nn.functional as F
import torch
from musetalk.loss.vgg_face import ImagePyramide
class DownBlock2d(nn.Module):
"""
Simple block for processing video (encoder).
"""
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
if norm:
self.norm = nn.InstanceNorm2d(out_features, affine=True)
else:
self.norm = None
self.pool = pool
def forward(self, x):
out = x
out = self.conv(out)
if self.norm:
out = self.norm(out)
out = F.leaky_relu(out, 0.2)
if self.pool:
out = F.avg_pool2d(out, (2, 2))
return out
class Discriminator(nn.Module):
"""
Discriminator similar to Pix2Pix
"""
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
sn=False, **kwargs):
super(Discriminator, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(
DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
min(max_features, block_expansion * (2 ** (i + 1))),
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
self.down_blocks = nn.ModuleList(down_blocks)
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
def forward(self, x):
feature_maps = []
out = x
for down_block in self.down_blocks:
feature_maps.append(down_block(out))
out = feature_maps[-1]
prediction_map = self.conv(out)
return feature_maps, prediction_map
class MultiScaleDiscriminator(nn.Module):
"""
Multi-scale (scale) discriminator
"""
def __init__(self, scales=(), **kwargs):
super(MultiScaleDiscriminator, self).__init__()
self.scales = scales
discs = {}
for scale in scales:
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
self.discs = nn.ModuleDict(discs)
def forward(self, x):
out_dict = {}
for scale, disc in self.discs.items():
scale = str(scale).replace('-', '.')
key = 'prediction_' + scale
#print(key)
#print(x)
feature_maps, prediction_map = disc(x[key])
out_dict['feature_maps_' + scale] = feature_maps
out_dict['prediction_map_' + scale] = prediction_map
return out_dict
class DiscriminatorFullModel(torch.nn.Module):
"""
Merge all discriminator related updates into single model for better multi-gpu usage
"""
def __init__(self, discriminator):
super(DiscriminatorFullModel, self).__init__()
self.discriminator = discriminator
self.scales = self.discriminator.scales
print("scales",self.scales)
self.pyramid = ImagePyramide(self.scales, 3)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.zero_tensor = None
def get_zero_tensor(self, input):
if self.zero_tensor is None:
self.zero_tensor = torch.FloatTensor(1).fill_(0).cuda()
self.zero_tensor.requires_grad_(False)
return self.zero_tensor.expand_as(input)
def forward(self, x, generated, gan_mode='ls'):
pyramide_real = self.pyramid(x)
pyramide_generated = self.pyramid(generated.detach())
discriminator_maps_generated = self.discriminator(pyramide_generated)
discriminator_maps_real = self.discriminator(pyramide_real)
value_total = 0
for scale in self.scales:
key = 'prediction_map_%s' % scale
if gan_mode == 'hinge':
value = -torch.mean(torch.min(discriminator_maps_real[key]-1, self.get_zero_tensor(discriminator_maps_real[key]))) - torch.mean(torch.min(-discriminator_maps_generated[key]-1, self.get_zero_tensor(discriminator_maps_generated[key])))
elif gan_mode == 'ls':
value = ((1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2).mean()
else:
raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode']))
value_total += value
return value_total
def main():
discriminator = MultiScaleDiscriminator(scales=[1],
block_expansion=32,
max_features=512,
num_blocks=4,
sn=True,
image_channel=3,
estimate_jacobian=False)

152
musetalk/loss/resnet.py Executable file
View File

@@ -0,0 +1,152 @@
import torch.nn as nn
import math
__all__ = ['ResNet', 'resnet50']
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, include_top=True):
self.inplanes = 64
super(ResNet, self).__init__()
self.include_top = include_top
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = x * 255.
x = x.flip(1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
if not self.include_top:
return x
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def resnet50(**kwargs):
"""Constructs a ResNet-50 model.
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
return model

95
musetalk/loss/syncnet.py Executable file
View File

@@ -0,0 +1,95 @@
import torch
from torch import nn
from torch.nn import functional as F
from .conv import Conv2d
logloss = nn.BCELoss(reduction="none")
def cosine_loss(a, v, y):
d = nn.functional.cosine_similarity(a, v)
d = d.clamp(0,1) # cosine_similarity的取值范围是【-11】BCE如果输入负数会报错RuntimeError: CUDA error: device-side assert triggered
loss = logloss(d.unsqueeze(1), y).squeeze()
loss = loss.mean()
return loss, d
def get_sync_loss(
audio_embed,
gt_frames,
pred_frames,
syncnet,
adapted_weight,
frames_left_index=0,
frames_right_index=16,
):
# 跟gt_frames做随机的插入交换节省显存开销
assert pred_frames.shape[1] == (frames_right_index - frames_left_index) * 3
# 3通道图像
frames_sync_loss = torch.cat(
[gt_frames[:, :3 * frames_left_index, ...], pred_frames, gt_frames[:, 3 * frames_right_index:, ...]],
axis=1
)
vision_embed = syncnet.get_image_embed(frames_sync_loss)
y = torch.ones(frames_sync_loss.size(0), 1).float().to(audio_embed.device)
loss, score = cosine_loss(audio_embed, vision_embed, y)
return loss, score
class SyncNet_color(nn.Module):
def __init__(self):
super(SyncNet_color, self).__init__()
self.face_encoder = nn.Sequential(
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
face_embedding = self.face_encoder(face_sequences)
audio_embedding = self.audio_encoder(audio_sequences)
audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
face_embedding = face_embedding.view(face_embedding.size(0), -1)
audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
face_embedding = F.normalize(face_embedding, p=2, dim=1)
return audio_embedding, face_embedding

237
musetalk/loss/vgg_face.py Executable file
View File

@@ -0,0 +1,237 @@
'''
This part of code contains a pretrained vgg_face model.
ref link: https://github.com/prlz77/vgg-face.pytorch
'''
import torch
import torch.nn.functional as F
import torch.utils.model_zoo
import pickle
from musetalk.loss import resnet as ResNet
MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth"
VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl'
# It was 93.5940, 104.7624, 129.1863 before dividing by 255
MEAN_RGB = [
0.367035294117647,
0.41083294117647057,
0.5066129411764705
]
def load_state_dict(model, fname):
"""
Set parameters converted from Caffe models authors of VGGFace2 provide.
See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
Arguments:
model: model
fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
"""
with open(fname, 'rb') as f:
weights = pickle.load(f, encoding='latin1')
own_state = model.state_dict()
for name, param in weights.items():
if name in own_state:
try:
own_state[name].copy_(torch.from_numpy(param))
except Exception:
raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
else:
raise KeyError('unexpected key "{}" in state_dict'.format(name))
def vggface2(pretrained=True):
vggface = ResNet.resnet50(num_classes=8631, include_top=True)
load_state_dict(vggface, VGG_FACE_PATH)
return vggface
def vggface(pretrained=False, **kwargs):
"""VGGFace model.
Args:
pretrained (bool): If True, returns pre-trained model
"""
model = VggFace(**kwargs)
if pretrained:
state = torch.utils.model_zoo.load_url(MODEL_URL)
model.load_state_dict(state)
return model
class VggFace(torch.nn.Module):
def __init__(self, classes=2622):
"""VGGFace model.
Face recognition network. It takes as input a Bx3x224x224
batch of face images and gives as output a BxC score vector
(C is the number of identities).
Input images need to be scaled in the 0-1 range and then
normalized with respect to the mean RGB used during training.
Args:
classes (int): number of identities recognized by the
network
"""
super().__init__()
self.conv1 = _ConvBlock(3, 64, 64)
self.conv2 = _ConvBlock(64, 128, 128)
self.conv3 = _ConvBlock(128, 256, 256, 256)
self.conv4 = _ConvBlock(256, 512, 512, 512)
self.conv5 = _ConvBlock(512, 512, 512, 512)
self.dropout = torch.nn.Dropout(0.5)
self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096)
self.fc2 = torch.nn.Linear(4096, 4096)
self.fc3 = torch.nn.Linear(4096, classes)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = x.view(x.size(0), -1)
x = self.dropout(F.relu(self.fc1(x)))
x = self.dropout(F.relu(self.fc2(x)))
x = self.fc3(x)
return x
class _ConvBlock(torch.nn.Module):
"""A Convolutional block."""
def __init__(self, *units):
"""Create a block with len(units) - 1 convolutions.
convolution number i transforms the number of channels from
units[i - 1] to units[i] channels.
"""
super().__init__()
self.convs = torch.nn.ModuleList([
torch.nn.Conv2d(in_, out, 3, 1, 1)
for in_, out in zip(units[:-1], units[1:])
])
def forward(self, x):
# Each convolution is followed by a ReLU, then the block is
# concluded by a max pooling.
for c in self.convs:
x = F.relu(c(x))
return F.max_pool2d(x, 2, 2, 0, ceil_mode=True)
import numpy as np
from torchvision import models
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
X = (X - self.mean) / self.std
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
from torch import nn
class AntiAliasInterpolation2d(nn.Module):
"""
Band-limited downsampling, for better preservation of the input signal.
"""
def __init__(self, channels, scale):
super(AntiAliasInterpolation2d, self).__init__()
sigma = (1 / scale - 1) / 2
kernel_size = 2 * round(sigma * 4) + 1
self.ka = kernel_size // 2
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
kernel_size = [kernel_size, kernel_size]
sigma = [sigma, sigma]
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
self.scale = scale
inv_scale = 1 / scale
self.int_inv_scale = int(inv_scale)
def forward(self, input):
if self.scale == 1.0:
return input
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
out = F.conv2d(out, weight=self.weight, groups=self.groups)
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
return out
class ImagePyramide(torch.nn.Module):
"""
Create image pyramide for computing pyramide perceptual loss.
"""
def __init__(self, scales, num_channels):
super(ImagePyramide, self).__init__()
downs = {}
for scale in scales:
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
self.downs = nn.ModuleDict(downs)
def forward(self, x):
out_dict = {}
for scale, down_module in self.downs.items():
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
return out_dict

240
musetalk/models/syncnet.py Executable file
View File

@@ -0,0 +1,240 @@
"""
This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/models/stable_syncnet.py).
"""
import torch
from torch import nn
from einops import rearrange
from torch.nn import functional as F
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention import Attention as CrossAttention, FeedForward
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange
class SyncNet(nn.Module):
def __init__(self, config):
super().__init__()
self.audio_encoder = DownEncoder2D(
in_channels=config["audio_encoder"]["in_channels"],
block_out_channels=config["audio_encoder"]["block_out_channels"],
downsample_factors=config["audio_encoder"]["downsample_factors"],
dropout=config["audio_encoder"]["dropout"],
attn_blocks=config["audio_encoder"]["attn_blocks"],
)
self.visual_encoder = DownEncoder2D(
in_channels=config["visual_encoder"]["in_channels"],
block_out_channels=config["visual_encoder"]["block_out_channels"],
downsample_factors=config["visual_encoder"]["downsample_factors"],
dropout=config["visual_encoder"]["dropout"],
attn_blocks=config["visual_encoder"]["attn_blocks"],
)
self.eval()
def forward(self, image_sequences, audio_sequences):
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
# Make them unit vectors
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
return vision_embeds, audio_embeds
def get_image_embed(self, image_sequences):
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
# Make them unit vectors
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
return vision_embeds
def get_audio_embed(self, audio_sequences):
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
return audio_embeds
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
norm_num_groups: int = 32,
eps: float = 1e-6,
act_fn: str = "silu",
downsample_factor=2,
):
super().__init__()
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if act_fn == "relu":
self.act_fn = nn.ReLU()
elif act_fn == "silu":
self.act_fn = nn.SiLU()
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
if isinstance(downsample_factor, list):
downsample_factor = tuple(downsample_factor)
if downsample_factor == 1:
self.downsample_conv = None
else:
self.downsample_conv = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
)
self.pad = (0, 1, 0, 1)
if isinstance(downsample_factor, tuple):
if downsample_factor[0] == 1:
self.pad = (0, 1, 1, 1) # The padding order is from back to front
elif downsample_factor[1] == 1:
self.pad = (1, 1, 0, 1)
def forward(self, input_tensor):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
hidden_states += input_tensor
if self.downsample_conv is not None:
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
hidden_states = self.downsample_conv(hidden_states)
return hidden_states
class AttentionBlock2D(nn.Module):
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
super().__init__()
if not is_xformers_available():
raise ModuleNotFoundError(
"You have to install xformers to enable memory efficient attetion", name="xformers"
)
# inner_dim = dim_head * heads
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
self.norm2 = nn.LayerNorm(query_dim)
self.norm3 = nn.LayerNorm(query_dim)
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
self.attn._use_memory_efficient_attention_xformers = True
def forward(self, hidden_states):
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.conv_in(hidden_states)
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
norm_hidden_states = self.norm2(hidden_states)
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DownEncoder2D(nn.Module):
def __init__(
self,
in_channels=4 * 16,
block_out_channels=[64, 128, 256, 256],
downsample_factors=[2, 2, 2, 2],
layers_per_block=2,
norm_num_groups=32,
attn_blocks=[1, 1, 1, 1],
dropout: float = 0.0,
act_fn="silu",
):
super().__init__()
self.layers_per_block = layers_per_block
# in
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# down
self.down_blocks = nn.ModuleList([])
output_channels = block_out_channels[0]
for i, block_out_channel in enumerate(block_out_channels):
input_channels = output_channels
output_channels = block_out_channel
# is_final_block = i == len(block_out_channels) - 1
down_block = ResnetBlock2D(
in_channels=input_channels,
out_channels=output_channels,
downsample_factor=downsample_factors[i],
norm_num_groups=norm_num_groups,
dropout=dropout,
act_fn=act_fn,
)
self.down_blocks.append(down_block)
if attn_blocks[i] == 1:
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
self.down_blocks.append(attention_block)
# out
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.act_fn_out = nn.ReLU()
def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states)
# down
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
# post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.act_fn_out(hidden_states)
return hidden_states

View File

@@ -0,0 +1,337 @@
import os
import json
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import WhisperModel
from diffusers.optimization import get_scheduler
from omegaconf import OmegaConf
from einops import rearrange
from musetalk.models.syncnet import SyncNet
from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel
from musetalk.loss.basic_loss import Interpolate
import musetalk.loss.vgg_face as vgg_face
from musetalk.data.dataset import PortraitDataset
from musetalk.utils.utils import (
get_image_pred,
process_audio_features,
process_and_save_images
)
class Net(nn.Module):
def __init__(
self,
unet: UNet2DConditionModel,
):
super().__init__()
self.unet = unet
def forward(
self,
input_latents,
timesteps,
audio_prompts,
):
model_pred = self.unet(
input_latents,
timesteps,
encoder_hidden_states=audio_prompts
).sample
return model_pred
logger = logging.getLogger(__name__)
def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
"""Initialize models and optimizers"""
model_dict = {
'vae': None,
'unet': None,
'net': None,
'wav2vec': None,
'optimizer': None,
'lr_scheduler': None,
'scheduler_max_steps': None,
'trainable_params': None
}
model_dict['vae'] = AutoencoderKL.from_pretrained(
cfg.pretrained_model_name_or_path,
subfolder=cfg.vae_type,
)
unet_config_file = os.path.join(
cfg.pretrained_model_name_or_path,
cfg.unet_sub_folder + "/musetalk.json"
)
with open(unet_config_file, 'r') as f:
unet_config = json.load(f)
model_dict['unet'] = UNet2DConditionModel(**unet_config)
if not cfg.random_init_unet:
pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
model_dict['unet'].load_state_dict(checkpoint)
unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
logger.info(f"unet {sum(unet_params) / 1e6}M-parameter")
model_dict['vae'].requires_grad_(False)
model_dict['unet'].requires_grad_(True)
model_dict['vae'].to(accelerator.device, dtype=weight_dtype)
model_dict['net'] = Net(model_dict['unet'])
model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to(
device="cuda", dtype=weight_dtype).eval()
model_dict['wav2vec'].requires_grad_(False)
if cfg.solver.gradient_checkpointing:
model_dict['unet'].enable_gradient_checkpointing()
if cfg.solver.scale_lr:
learning_rate = (
cfg.solver.learning_rate
* cfg.solver.gradient_accumulation_steps
* cfg.data.train_bs
* accelerator.num_processes
)
else:
learning_rate = cfg.solver.learning_rate
if cfg.solver.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters()))
if accelerator.is_main_process:
print('trainable params')
for n, p in model_dict['net'].named_parameters():
if p.requires_grad:
print(n)
model_dict['optimizer'] = optimizer_cls(
model_dict['trainable_params'],
lr=learning_rate,
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
weight_decay=cfg.solver.adam_weight_decay,
eps=cfg.solver.adam_epsilon,
)
model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps
model_dict['lr_scheduler'] = get_scheduler(
cfg.solver.lr_scheduler,
optimizer=model_dict['optimizer'],
num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
num_training_steps=model_dict['scheduler_max_steps'],
)
return model_dict
def initialize_dataloaders(cfg):
"""Initialize training and validation dataloaders"""
dataloader_dict = {
'train_dataset': None,
'val_dataset': None,
'train_dataloader': None,
'val_dataloader': None
}
dataloader_dict['train_dataset'] = PortraitDataset(cfg={
'image_size': cfg.data.image_size,
'T': cfg.data.n_sample_frames,
"sample_method": cfg.data.sample_method,
'top_k_ratio': cfg.data.top_k_ratio,
"contorl_face_min_size": cfg.data.contorl_face_min_size,
"dataset_key": cfg.data.dataset_key,
"padding_pixel_mouth": cfg.padding_pixel_mouth,
"whisper_path": cfg.whisper_path,
"min_face_size": cfg.data.min_face_size,
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
"crop_type": cfg.crop_type,
"random_margin_method": cfg.random_margin_method,
})
dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader(
dataloader_dict['train_dataset'],
batch_size=cfg.data.train_bs,
shuffle=True,
num_workers=cfg.data.num_workers,
)
dataloader_dict['val_dataset'] = PortraitDataset(cfg={
'image_size': cfg.data.image_size,
'T': cfg.data.n_sample_frames,
"sample_method": cfg.data.sample_method,
'top_k_ratio': cfg.data.top_k_ratio,
"contorl_face_min_size": cfg.data.contorl_face_min_size,
"dataset_key": cfg.data.dataset_key,
"padding_pixel_mouth": cfg.padding_pixel_mouth,
"whisper_path": cfg.whisper_path,
"min_face_size": cfg.data.min_face_size,
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
"crop_type": cfg.crop_type,
"random_margin_method": cfg.random_margin_method,
})
dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader(
dataloader_dict['val_dataset'],
batch_size=cfg.data.train_bs,
shuffle=True,
num_workers=1,
)
return dataloader_dict
def initialize_loss_functions(cfg, accelerator, scheduler_max_steps):
"""Initialize loss functions and discriminators"""
loss_dict = {
'L1_loss': nn.L1Loss(reduction='mean'),
'discriminator': None,
'mouth_discriminator': None,
'optimizer_D': None,
'mouth_optimizer_D': None,
'scheduler_D': None,
'mouth_scheduler_D': None,
'disc_scales': None,
'discriminator_full': None,
'mouth_discriminator_full': None
}
if cfg.loss_params.gan_loss > 0:
loss_dict['discriminator'] = MultiScaleDiscriminator(
**cfg.model_params.discriminator_params).to(accelerator.device)
loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator'])
loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales
loss_dict['optimizer_D'] = optim.AdamW(
loss_dict['discriminator'].parameters(),
lr=cfg.discriminator_train_params.lr,
weight_decay=cfg.discriminator_train_params.weight_decay,
betas=cfg.discriminator_train_params.betas,
eps=cfg.discriminator_train_params.eps)
loss_dict['scheduler_D'] = CosineAnnealingLR(
loss_dict['optimizer_D'],
T_max=scheduler_max_steps,
eta_min=1e-6
)
if cfg.loss_params.mouth_gan_loss > 0:
loss_dict['mouth_discriminator'] = MultiScaleDiscriminator(
**cfg.model_params.discriminator_params).to(accelerator.device)
loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator'])
loss_dict['mouth_optimizer_D'] = optim.AdamW(
loss_dict['mouth_discriminator'].parameters(),
lr=cfg.discriminator_train_params.lr,
weight_decay=cfg.discriminator_train_params.weight_decay,
betas=cfg.discriminator_train_params.betas,
eps=cfg.discriminator_train_params.eps)
loss_dict['mouth_scheduler_D'] = CosineAnnealingLR(
loss_dict['mouth_optimizer_D'],
T_max=scheduler_max_steps,
eta_min=1e-6
)
return loss_dict
def initialize_syncnet(cfg, accelerator, weight_dtype):
"""Initialize SyncNet model"""
if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight:
if cfg.data.n_sample_frames != 16:
raise ValueError(
f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16."
)
syncnet_config = OmegaConf.load(cfg.syncnet_config_path)
syncnet = SyncNet(OmegaConf.to_container(
syncnet_config.model)).to(accelerator.device)
print(
f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
checkpoint = torch.load(
syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
syncnet.load_state_dict(checkpoint["state_dict"])
syncnet.to(dtype=weight_dtype)
syncnet.requires_grad_(False)
syncnet.eval()
return syncnet
return None
def initialize_vgg(cfg, accelerator):
"""Initialize VGG model"""
if cfg.loss_params.vgg_loss > 0:
vgg_IN = vgg_face.Vgg19().to(accelerator.device,)
pyramid = vgg_face.ImagePyramide(
cfg.loss_params.pyramid_scale, 3).to(accelerator.device)
vgg_IN.eval()
downsampler = Interpolate(
size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device)
return vgg_IN, pyramid, downsampler
return None, None, None
def validation(
cfg,
val_dataloader,
net,
vae,
wav2vec,
accelerator,
save_dir,
global_step,
weight_dtype,
syncnet_score=1,
):
"""Validation function for model evaluation"""
net.eval() # Set the model to evaluation mode
for batch in val_dataloader:
# The same ref_latents
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
accelerator.device, non_blocking=True
)
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
accelerator.device, non_blocking=True
)
bsz, num_frames, c, h, w = ref_pixel_values.shape
audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype)
# audio feature for unet
audio_prompts = rearrange(
audio_prompts,
'b f c h w-> (b f) c h w'
)
audio_prompts = rearrange(
audio_prompts,
'(b f) c h w -> (b f) (c h) w',
b=bsz
)
# different masked_latents
image_pred_train = get_image_pred(
pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
image_pred_infer = get_image_pred(
ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
process_and_save_images(
batch,
image_pred_train,
image_pred_infer,
save_dir,
global_step,
accelerator,
cfg.num_images_to_keep,
syncnet_score
)
# only infer 1 image in validation
break
net.train() # Set the model back to training mode

View File

@@ -2,6 +2,11 @@ import os
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from typing import Union, List
import torch.nn.functional as F
from einops import rearrange
import shutil
import os.path as osp
ffmpeg_path = os.getenv('FFMPEG_PATH') ffmpeg_path = os.getenv('FFMPEG_PATH')
if ffmpeg_path is None: if ffmpeg_path is None:
@@ -11,7 +16,6 @@ elif ffmpeg_path not in os.getenv('PATH'):
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}" os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
from musetalk.whisper.audio2feature import Audio2Feature
from musetalk.models.vae import VAE from musetalk.models.vae import VAE
from musetalk.models.unet import UNet,PositionalEncoding from musetalk.models.unet import UNet,PositionalEncoding
@@ -76,3 +80,248 @@ def datagen(
latent_batch = torch.cat(latent_batch, dim=0) latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch.to(device), latent_batch.to(device) yield whisper_batch.to(device), latent_batch.to(device)
def cast_training_params(
model: Union[torch.nn.Module, List[torch.nn.Module]],
dtype=torch.float32,
):
if not isinstance(model, list):
model = [model]
for m in model:
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)
def rand_log_normal(
shape,
loc=0.,
scale=1.,
device='cpu',
dtype=torch.float32,
generator=None
):
"""Draws samples from an lognormal distribution."""
rnd_normal = torch.randn(
shape, device=device, dtype=dtype, generator=generator) # N(0, I)
sigma = (rnd_normal * scale + loc).exp()
return sigma
def get_mouth_region(frames, image_pred, pixel_values_face_mask):
# Initialize lists to store the results for each image in the batch
mouth_real_list = []
mouth_generated_list = []
# Process each image in the batch
for b in range(frames.shape[0]):
# Find the non-zero area in the face mask
non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
# If there are no non-zero indices, skip this image
if non_zero_indices.numel() == 0:
continue
min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
non_zero_indices[:, 1])
min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
non_zero_indices[:, 2])
# Crop the frames and image_pred according to the non-zero area
frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
# Resize the cropped images to 256*256
frames_resized = F.interpolate(frames_cropped.unsqueeze(
0), size=(256, 256), mode='bilinear', align_corners=False)
image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
0), size=(256, 256), mode='bilinear', align_corners=False)
# Append the resized images to the result lists
mouth_real_list.append(frames_resized)
mouth_generated_list.append(image_pred_resized)
# Convert the lists to tensors if they are not empty
mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
mouth_generated = torch.cat(
mouth_generated_list, dim=0) if mouth_generated_list else None
return mouth_real, mouth_generated
def get_image_pred(pixel_values,
ref_pixel_values,
audio_prompts,
vae,
net,
weight_dtype):
with torch.no_grad():
bsz, num_frames, c, h, w = pixel_values.shape
masked_pixel_values = pixel_values.clone()
masked_pixel_values[:, :, :, h//2:, :] = -1
masked_frames = rearrange(
masked_pixel_values, 'b f c h w -> (b f) c h w')
masked_latents = vae.encode(masked_frames).latent_dist.mode()
masked_latents = masked_latents * vae.config.scaling_factor
masked_latents = masked_latents.float()
ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
ref_latents = vae.encode(ref_frames).latent_dist.mode()
ref_latents = ref_latents * vae.config.scaling_factor
ref_latents = ref_latents.float()
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
input_latents = input_latents.to(weight_dtype)
timesteps = torch.tensor([0], device=input_latents.device)
latents_pred = net(
input_latents,
timesteps,
audio_prompts,
)
latents_pred = (1 / vae.config.scaling_factor) * latents_pred
image_pred = vae.decode(latents_pred).sample
image_pred = image_pred.float()
return image_pred
def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
with torch.no_grad():
audio_feature_length_per_frame = 2 * \
(cfg.data.audio_padding_length_left +
cfg.data.audio_padding_length_right + 1)
audio_feats = batch['audio_feature'].to(weight_dtype)
audio_feats = wav2vec.encoder(
audio_feats, output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384]
start_ts = batch['audio_offset']
step_ts = batch['audio_step']
audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
audio_feats,
torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
audio_prompts = []
for bb in range(bsz):
audio_feats_list = []
for f in range(num_frames):
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
audio_clip = audio_feats[bb:bb+1,
cur_t: cur_t+audio_feature_length_per_frame]
audio_feats_list.append(audio_clip)
audio_feats_list = torch.stack(audio_feats_list, 1)
audio_prompts.append(audio_feats_list)
audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384
return audio_prompts
def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")
if total_limit is not None:
checkpoints = os.listdir(save_dir)
checkpoints = [d for d in checkpoints if d.endswith(".pth")]
checkpoints = [d for d in checkpoints if name in d]
checkpoints = sorted(
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
)
if len(checkpoints) >= total_limit:
num_to_remove = len(checkpoints) - total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(
f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(
save_dir, removing_checkpoint)
os.remove(removing_checkpoint)
state_dict = model.state_dict()
torch.save(state_dict, save_path)
def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
unwarp_net = accelerator.unwrap_model(net)
save_checkpoint(
unwarp_net.unet,
save_dir,
global_step,
name="unet",
total_limit=cfg.total_limit,
logger=logger
)
def delete_additional_ckpt(base_path, num_keep):
dirs = []
for d in os.listdir(base_path):
if d.startswith("checkpoint-"):
dirs.append(d)
num_tot = len(dirs)
if num_tot <= num_keep:
return
# ensure ckpt is sorted and delete the ealier!
del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
for d in del_dirs:
path_to_dir = osp.join(base_path, d)
if osp.exists(path_to_dir):
shutil.rmtree(path_to_dir)
def seed_everything(seed):
import random
import numpy as np
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed % (2**32))
random.seed(seed)
def process_and_save_images(
batch,
image_pred,
image_pred_infer,
save_dir,
global_step,
accelerator,
num_images_to_keep=10,
syncnet_score=1
):
# Rearrange the tensors
print("image_pred.shape: ", image_pred.shape)
pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
# Create masked pixel values
masked_pixel_values = batch["pixel_values_vid"].clone()
_, _, _, h, _ = batch["pixel_values_vid"].shape
masked_pixel_values[:, :, :, h//2:, :] = -1
masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
# Keep only the specified number of images
pixel_values = pixel_values[:num_images_to_keep, :, :, :]
masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
# Concatenate images
concat = torch.cat([
masked_pixel_values * 0.5 + 0.5,
pixel_values_ref_img * 0.5 + 0.5,
image_pred * 0.5 + 0.5,
pixel_values * 0.5 + 0.5,
image_pred_infer * 0.5 + 0.5,
], dim=2)
print("concat.shape: ", concat.shape)
# Create the save directory if it doesn't exist
os.makedirs(f'{save_dir}/samples/', exist_ok=True)
# Try to save the concatenated image
try:
# Concatenate images horizontally and convert to numpy array
final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
# Save the image
cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
except Exception as e:
print(f"Failed to save image: {e}")

322
scripts/preprocess.py Executable file
View File

@@ -0,0 +1,322 @@
import os
import argparse
import subprocess
from omegaconf import OmegaConf
from typing import Tuple, List, Union
import decord
import json
import cv2
from musetalk.utils.face_detection import FaceAlignment,LandmarksType
from mmpose.apis import inference_topdown, init_model
from mmpose.structures import merge_data_samples
import torch
import numpy as np
from tqdm import tqdm
ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
if ffmpeg_path not in os.getenv('PATH'):
print("add ffmpeg to path")
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
class AnalyzeFace:
def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):
"""
Initialize the AnalyzeFace class with the given device, config file, and checkpoint file.
Parameters:
device (Union[str, torch.device]): The device to run the models on ('cuda' or 'cpu').
config_file (str): Path to the mmpose model configuration file.
checkpoint_file (str): Path to the mmpose model checkpoint file.
"""
self.device = device
self.dwpose = init_model(config_file, checkpoint_file, device=self.device)
self.facedet = FaceAlignment(LandmarksType._2D, flip_input=False, device=self.device)
def __call__(self, im: np.ndarray) -> Tuple[List[np.ndarray], np.ndarray]:
"""
Detect faces and keypoints in the given image.
Parameters:
im (np.ndarray): The input image.
maxface (bool): Whether to detect the maximum face. Default is True.
Returns:
Tuple[List[np.ndarray], np.ndarray]: A tuple containing the bounding boxes and keypoints.
"""
try:
# Ensure the input image has the correct shape
if im.ndim == 3:
im = np.expand_dims(im, axis=0)
elif im.ndim != 4 or im.shape[0] != 1:
raise ValueError("Input image must have shape (1, H, W, C)")
bbox = self.facedet.get_detections_for_batch(np.asarray(im))
results = inference_topdown(self.dwpose, np.asarray(im)[0])
results = merge_data_samples(results)
keypoints = results.pred_instances.keypoints
face_land_mark= keypoints[0][23:91]
face_land_mark = face_land_mark.astype(np.int32)
return face_land_mark, bbox
except Exception as e:
print(f"Error during face analysis: {e}")
return np.array([]),[]
def convert_video(org_path: str, dst_path: str, vid_list: List[str]) -> None:
"""
Convert video files to a specified format and save them to the destination path.
Parameters:
org_path (str): The directory containing the original video files.
dst_path (str): The directory where the converted video files will be saved.
vid_list (List[str]): A list of video file names to process.
Returns:
None
"""
for idx, vid in enumerate(vid_list):
if vid.endswith('.mp4'):
org_vid_path = os.path.join(org_path, vid)
dst_vid_path = os.path.join(dst_path, vid)
if org_vid_path != dst_vid_path:
cmd = [
"ffmpeg", "-hide_banner", "-y", "-i", org_vid_path,
"-r", "25", "-crf", "15", "-c:v", "libx264",
"-pix_fmt", "yuv420p", dst_vid_path
]
subprocess.run(cmd, check=True)
if idx % 1000 == 0:
print(f"### {idx} videos converted ###")
def segment_video(org_path: str, dst_path: str, vid_list: List[str], segment_duration: int = 30) -> None:
"""
Segment video files into smaller clips of specified duration.
Parameters:
org_path (str): The directory containing the original video files.
dst_path (str): The directory where the segmented video files will be saved.
vid_list (List[str]): A list of video file names to process.
segment_duration (int): The duration of each segment in seconds. Default is 30 seconds.
Returns:
None
"""
for idx, vid in enumerate(vid_list):
if vid.endswith('.mp4'):
input_file = os.path.join(org_path, vid)
original_filename = os.path.basename(input_file)
command = [
'ffmpeg', '-i', input_file, '-c', 'copy', '-map', '0',
'-segment_time', str(segment_duration), '-f', 'segment',
'-reset_timestamps', '1',
os.path.join(dst_path, f'clip%03d_{original_filename}')
]
subprocess.run(command, check=True)
def extract_audio(org_path: str, dst_path: str, vid_list: List[str]) -> None:
"""
Extract audio from video files and save as WAV format.
Parameters:
org_path (str): The directory containing the original video files.
dst_path (str): The directory where the extracted audio files will be saved.
vid_list (List[str]): A list of video file names to process.
Returns:
None
"""
for idx, vid in enumerate(vid_list):
if vid.endswith('.mp4'):
video_path = os.path.join(org_path, vid)
audio_output_path = os.path.join(dst_path, os.path.splitext(vid)[0] + ".wav")
try:
command = [
'ffmpeg', '-hide_banner', '-y', '-i', video_path,
'-vn', '-acodec', 'pcm_s16le', '-f', 'wav',
'-ar', '16000', '-ac', '1', audio_output_path,
]
subprocess.run(command, check=True)
print(f"Audio saved to: {audio_output_path}")
except subprocess.CalledProcessError as e:
print(f"Error extracting audio from {vid}: {e}")
def split_data(video_files: List[str], val_list_hdtf: List[str]) -> (List[str], List[str]):
"""
Split video files into training and validation sets based on val_list_hdtf.
Parameters:
video_files (List[str]): A list of video file names.
val_list_hdtf (List[str]): A list of validation file identifiers.
Returns:
(List[str], List[str]): A tuple containing the training and validation file lists.
"""
val_files = [f for f in video_files if any(val_id in f for val_id in val_list_hdtf)]
train_files = [f for f in video_files if f not in val_files]
return train_files, val_files
def save_list_to_file(file_path: str, data_list: List[str]) -> None:
"""
Save a list of strings to a file, each string on a new line.
Parameters:
file_path (str): The path to the file where the list will be saved.
data_list (List[str]): The list of strings to save.
Returns:
None
"""
with open(file_path, 'w') as file:
for item in data_list:
file.write(f"{item}\n")
def generate_train_list(cfg):
train_file_path = cfg.video_clip_file_list_train
val_file_path = cfg.video_clip_file_list_val
val_list_hdtf = cfg.val_list_hdtf
meta_list = os.listdir(cfg.meta_root)
sorted_meta_list = sorted(meta_list)
train_files, val_files = split_data(meta_list, val_list_hdtf)
save_list_to_file(train_file_path, train_files)
save_list_to_file(val_file_path, val_files)
print(val_list_hdtf)
def analyze_video(org_path: str, dst_path: str, vid_list: List[str]) -> None:
"""
Convert video files to a specified format and save them to the destination path.
Parameters:
org_path (str): The directory containing the original video files.
dst_path (str): The directory where the meta json will be saved.
vid_list (List[str]): A list of video file names to process.
Returns:
None
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
analyze_face = AnalyzeFace(device, config_file, checkpoint_file)
for vid in tqdm(vid_list, desc="Processing videos"):
#vid = "clip005_WDA_BernieSanders_000.mp4"
#print(vid)
if vid.endswith('.mp4'):
vid_path = os.path.join(org_path, vid)
wav_path = vid_path.replace(".mp4",".wav")
vid_meta = os.path.join(dst_path, os.path.splitext(vid)[0] + ".json")
if os.path.exists(vid_meta):
continue
print('process video {}'.format(vid))
total_bbox_list = []
total_pts_list = []
isvalid = True
# process
try:
cap = decord.VideoReader(vid_path, fault_tol=1)
except Exception as e:
print(e)
continue
total_frames = len(cap)
for frame_idx in range(total_frames):
frame = cap[frame_idx]
if frame_idx==0:
video_height,video_width,_ = frame.shape
frame_bgr = cv2.cvtColor(frame.asnumpy(), cv2.COLOR_BGR2RGB)
pts_list, bbox_list = analyze_face(frame_bgr)
if len(bbox_list)>0 and None not in bbox_list:
bbox = bbox_list[0]
else:
isvalid = False
bbox = []
print(f"set isvalid to False as broken img in {frame_idx} of {vid}")
break
#print(pts_list)
if len(pts_list)>0 and pts_list is not None:
pts = pts_list.tolist()
else:
isvalid = False
pts = []
break
if frame_idx==0:
x1,y1,x2,y2 = bbox
face_height, face_width = y2-y1,x2-x1
total_pts_list.append(pts)
total_bbox_list.append(bbox)
meta_data = {
"mp4_path": vid_path,
"wav_path": wav_path,
"video_size": [video_height, video_width],
"face_size": [face_height, face_width],
"frames": total_frames,
"face_list": total_bbox_list,
"landmark_list": total_pts_list,
"isvalid":isvalid,
}
with open(vid_meta, 'w') as f:
json.dump(meta_data, f, indent=4)
def main(cfg):
# Ensure all necessary directories exist
os.makedirs(cfg.video_root_25fps, exist_ok=True)
os.makedirs(cfg.video_audio_clip_root, exist_ok=True)
os.makedirs(cfg.meta_root, exist_ok=True)
os.makedirs(os.path.dirname(cfg.video_file_list), exist_ok=True)
os.makedirs(os.path.dirname(cfg.video_clip_file_list_train), exist_ok=True)
os.makedirs(os.path.dirname(cfg.video_clip_file_list_val), exist_ok=True)
vid_list = os.listdir(cfg.video_root_raw)
sorted_vid_list = sorted(vid_list)
# Save video file list
with open(cfg.video_file_list, 'w') as file:
for vid in sorted_vid_list:
file.write(vid + '\n')
# 1. Convert videos to 25 FPS
convert_video(cfg.video_root_raw, cfg.video_root_25fps, sorted_vid_list)
# 2. Segment videos into 30-second clips
segment_video(cfg.video_root_25fps, cfg.video_audio_clip_root, vid_list, segment_duration=cfg.clip_len_second)
# 3. Extract audio
clip_vid_list = os.listdir(cfg.video_audio_clip_root)
extract_audio(cfg.video_audio_clip_root, cfg.video_audio_clip_root, clip_vid_list)
# 4. Generate video metadata
analyze_video(cfg.video_audio_clip_root, cfg.meta_root, clip_vid_list)
# 5. Generate training and validation set lists
generate_train_list(cfg)
print("done")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/training/preprocess.yaml")
args = parser.parse_args()
config = OmegaConf.load(args.config)
main(config)

580
train.py Executable file
View File

@@ -0,0 +1,580 @@
import argparse
import diffusers
import logging
import math
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
import warnings
import random
from accelerate import Accelerator
from accelerate.utils import LoggerType
from accelerate import InitProcessGroupKwargs
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs
from datetime import datetime
from datetime import timedelta
from diffusers.utils import check_min_version
from einops import rearrange
from omegaconf import OmegaConf
from tqdm.auto import tqdm
from musetalk.utils.utils import (
delete_additional_ckpt,
seed_everything,
get_mouth_region,
process_audio_features,
save_models
)
from musetalk.loss.basic_loss import set_requires_grad
from musetalk.loss.syncnet import get_sync_loss
from musetalk.utils.training_utils import (
initialize_models_and_optimizers,
initialize_dataloaders,
initialize_loss_functions,
initialize_syncnet,
initialize_vgg,
validation
)
logger = get_logger(__name__, log_level="INFO")
warnings.filterwarnings("ignore")
check_min_version("0.10.0.dev0")
def main(cfg):
exp_name = cfg.exp_name
save_dir = f"{cfg.output_dir}/{exp_name}"
os.makedirs(save_dir, exist_ok=True)
kwargs = DistributedDataParallelKwargs()
process_group_kwargs = InitProcessGroupKwargs(
timeout=timedelta(seconds=5400))
accelerator = Accelerator(
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
log_with=["tensorboard", LoggerType.TENSORBOARD],
project_dir=os.path.join(save_dir, "./tensorboard"),
kwargs_handlers=[kwargs, process_group_kwargs],
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if cfg.seed is not None:
print('cfg.seed', cfg.seed, accelerator.process_index)
seed_everything(cfg.seed + accelerator.process_index)
weight_dtype = torch.float32
model_dict = initialize_models_and_optimizers(cfg, accelerator, weight_dtype)
dataloader_dict = initialize_dataloaders(cfg)
loss_dict = initialize_loss_functions(cfg, accelerator, model_dict['scheduler_max_steps'])
syncnet = initialize_syncnet(cfg, accelerator, weight_dtype)
vgg_IN, pyramid, downsampler = initialize_vgg(cfg, accelerator)
# Prepare everything with our `accelerator`.
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] = accelerator.prepare(
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader']
)
print("length train/val", len(dataloader_dict['train_dataloader']), len(dataloader_dict['val_dataloader']))
# Calculate training steps and epochs
num_update_steps_per_epoch = math.ceil(
len(dataloader_dict['train_dataloader']) / cfg.solver.gradient_accumulation_steps
)
num_train_epochs = math.ceil(
cfg.solver.max_train_steps / num_update_steps_per_epoch
)
# Initialize trackers on the main process
if accelerator.is_main_process:
run_time = datetime.now().strftime("%Y%m%d-%H%M")
accelerator.init_trackers(
cfg.exp_name,
init_kwargs={"mlflow": {"run_name": run_time}},
)
# Calculate total batch size
total_batch_size = (
cfg.data.train_bs
* accelerator.num_processes
* cfg.solver.gradient_accumulation_steps
)
# Log training information
logger.info("***** Running training *****")
logger.info(f"Num Epochs = {num_train_epochs}")
logger.info(f"Instantaneous batch size per device = {cfg.data.train_bs}")
logger.info(
f"Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(
f"Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}")
logger.info(f"Total optimization steps = {cfg.solver.max_train_steps}")
global_step = 0
first_epoch = 0
# Load checkpoint if resuming training
if cfg.resume_from_checkpoint:
resume_dir = save_dir
dirs = os.listdir(resume_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
if len(dirs) > 0:
path = dirs[-1]
accelerator.load_state(os.path.join(resume_dir, path))
accelerator.print(f"Resuming from checkpoint {path}")
global_step = int(path.split("-")[1])
first_epoch = global_step // num_update_steps_per_epoch
resume_step = global_step % num_update_steps_per_epoch
# Initialize progress bar
progress_bar = tqdm(
range(global_step, cfg.solver.max_train_steps),
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
# Log model types
print("log type of models")
print("unet", model_dict['unet'].dtype)
print("vae", model_dict['vae'].dtype)
print("wav2vec", model_dict['wav2vec'].dtype)
def get_ganloss_weight(step):
"""Calculate GAN loss weight based on training step"""
if step < cfg.discriminator_train_params.start_gan:
return 0.0
else:
return 1.0
# Training loop
for epoch in range(first_epoch, num_train_epochs):
# Set models to training mode
model_dict['unet'].train()
if cfg.loss_params.gan_loss > 0:
loss_dict['discriminator'].train()
if cfg.loss_params.mouth_gan_loss > 0:
loss_dict['mouth_discriminator'].train()
# Initialize loss accumulators
train_loss = 0.0
train_loss_D = 0.0
train_loss_D_mouth = 0.0
l1_loss_accum = 0.0
vgg_loss_accum = 0.0
gan_loss_accum = 0.0
gan_loss_accum_mouth = 0.0
fm_loss_accum = 0.0
sync_loss_accum = 0.0
adapted_weight_accum = 0.0
t_data_start = time.time()
for step, batch in enumerate(dataloader_dict['train_dataloader']):
t_data = time.time() - t_data_start
t_model_start = time.time()
with torch.no_grad():
# Process input data
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
accelerator.device,
non_blocking=True
)
bsz, num_frames, c, h, w = pixel_values.shape
# Process reference images
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
accelerator.device,
non_blocking=True
)
# Get face mask for GAN
pixel_values_face_mask = batch['pixel_values_face_mask']
# Process audio features
audio_prompts = process_audio_features(cfg, batch, model_dict['wav2vec'], bsz, num_frames, weight_dtype)
# Initialize adapted weight
adapted_weight = 1
# Process sync loss if enabled
if cfg.loss_params.sync_loss > 0:
mels = batch['mel']
# Prepare frames for latentsync (combine channels and frames)
gt_frames = rearrange(pixel_values, 'b f c h w-> b (f c) h w')
# Use lower half of face for latentsync
height = gt_frames.shape[2]
gt_frames = gt_frames[:, :, height // 2:, :]
# Get audio embeddings
audio_embed = syncnet.get_audio_embed(mels)
# Calculate adapted weight based on audio-visual similarity
if cfg.use_adapted_weight:
vision_embed_gt = syncnet.get_vision_embed(gt_frames)
image_audio_sim_gt = F.cosine_similarity(
audio_embed,
vision_embed_gt,
dim=1
)[0]
if image_audio_sim_gt < 0.05 or image_audio_sim_gt > 0.65:
if cfg.adapted_weight_type == "cut_off":
adapted_weight = 0.0 # Skip this batch
print(
f"\nThe i-a similarity in step {global_step} is {image_audio_sim_gt}, set adapted_weight to {adapted_weight}.")
elif cfg.adapted_weight_type == "linear":
adapted_weight = image_audio_sim_gt
else:
print(f"unknown adapted_weight_type: {cfg.adapted_weight_type}")
adapted_weight = 1
# Random frame selection for memory efficiency
max_start = 16 - cfg.num_backward_frames
frames_left_index = random.randint(0, max_start) if max_start > 0 else 0
frames_right_index = frames_left_index + cfg.num_backward_frames
else:
frames_left_index = 0
frames_right_index = cfg.data.n_sample_frames
# Extract frames for backward pass
pixel_values_backward = pixel_values[:, frames_left_index:frames_right_index, ...]
ref_pixel_values_backward = ref_pixel_values[:, frames_left_index:frames_right_index, ...]
pixel_values_face_mask_backward = pixel_values_face_mask[:, frames_left_index:frames_right_index, ...]
audio_prompts_backward = audio_prompts[:, frames_left_index:frames_right_index, ...]
# Encode target images
frames = rearrange(pixel_values_backward, 'b f c h w-> (b f) c h w')
latents = model_dict['vae'].encode(frames).latent_dist.mode()
latents = latents * model_dict['vae'].config.scaling_factor
latents = latents.float()
# Create masked images
masked_pixel_values = pixel_values_backward.clone()
masked_pixel_values[:, :, :, h//2:, :] = -1
masked_frames = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
masked_latents = model_dict['vae'].encode(masked_frames).latent_dist.mode()
masked_latents = masked_latents * model_dict['vae'].config.scaling_factor
masked_latents = masked_latents.float()
# Encode reference images
ref_frames = rearrange(ref_pixel_values_backward, 'b f c h w-> (b f) c h w')
ref_latents = model_dict['vae'].encode(ref_frames).latent_dist.mode()
ref_latents = ref_latents * model_dict['vae'].config.scaling_factor
ref_latents = ref_latents.float()
# Prepare face mask and audio features
pixel_values_face_mask_backward = rearrange(
pixel_values_face_mask_backward,
"b f c h w -> (b f) c h w"
)
audio_prompts_backward = rearrange(
audio_prompts_backward,
'b f c h w-> (b f) c h w'
)
audio_prompts_backward = rearrange(
audio_prompts_backward,
'(b f) c h w -> (b f) (c h) w',
b=bsz
)
# Apply reference dropout (currently inactive)
dropout = nn.Dropout(p=cfg.ref_dropout_rate)
ref_latents = dropout(ref_latents)
# Prepare model inputs
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
input_latents = input_latents.to(weight_dtype)
timesteps = torch.tensor([0], device=input_latents.device)
# Forward pass
latents_pred = model_dict['net'](
input_latents,
timesteps,
audio_prompts_backward,
)
latents_pred = (1 / model_dict['vae'].config.scaling_factor) * latents_pred
image_pred = model_dict['vae'].decode(latents_pred).sample
# Convert to float
image_pred = image_pred.float()
frames = frames.float()
# Calculate L1 loss
l1_loss = loss_dict['L1_loss'](frames, image_pred)
l1_loss_accum += l1_loss.item()
loss = cfg.loss_params.l1_loss * l1_loss * adapted_weight
# Process mouth GAN loss if enabled
if cfg.loss_params.mouth_gan_loss > 0:
frames_mouth, image_pred_mouth = get_mouth_region(
frames,
image_pred,
pixel_values_face_mask_backward
)
pyramide_real_mouth = pyramid(downsampler(frames_mouth))
pyramide_generated_mouth = pyramid(downsampler(image_pred_mouth))
# Process VGG loss if enabled
if cfg.loss_params.vgg_loss > 0:
pyramide_real = pyramid(downsampler(frames))
pyramide_generated = pyramid(downsampler(image_pred))
loss_IN = 0
for scale in cfg.loss_params.pyramid_scale:
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
loss_IN += weight * value
loss_IN /= sum(cfg.loss_params.vgg_layer_weight)
loss += loss_IN * cfg.loss_params.vgg_loss * adapted_weight
vgg_loss_accum += loss_IN.item()
# Process GAN loss if enabled
if cfg.loss_params.gan_loss > 0:
set_requires_grad(loss_dict['discriminator'], False)
loss_G = 0.
discriminator_maps_generated = loss_dict['discriminator'](pyramide_generated)
discriminator_maps_real = loss_dict['discriminator'](pyramide_real)
for scale in loss_dict['disc_scales']:
key = 'prediction_map_%s' % scale
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
loss_G += value
gan_loss_accum += loss_G.item()
loss += loss_G * cfg.loss_params.gan_loss * get_ganloss_weight(global_step) * adapted_weight
# Process feature matching loss if enabled
if cfg.loss_params.fm_loss[0] > 0:
L_feature_matching = 0.
for scale in loss_dict['disc_scales']:
key = 'feature_maps_%s' % scale
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
value = torch.abs(a - b).mean()
L_feature_matching += value * cfg.loss_params.fm_loss[i]
loss += L_feature_matching * adapted_weight
fm_loss_accum += L_feature_matching.item()
# Process mouth GAN loss if enabled
if cfg.loss_params.mouth_gan_loss > 0:
set_requires_grad(loss_dict['mouth_discriminator'], False)
loss_G = 0.
mouth_discriminator_maps_generated = loss_dict['mouth_discriminator'](pyramide_generated_mouth)
mouth_discriminator_maps_real = loss_dict['mouth_discriminator'](pyramide_real_mouth)
for scale in loss_dict['disc_scales']:
key = 'prediction_map_%s' % scale
value = ((1 - mouth_discriminator_maps_generated[key]) ** 2).mean()
loss_G += value
gan_loss_accum_mouth += loss_G.item()
loss += loss_G * cfg.loss_params.mouth_gan_loss * get_ganloss_weight(global_step) * adapted_weight
# Process feature matching loss for mouth if enabled
if cfg.loss_params.fm_loss[0] > 0:
L_feature_matching = 0.
for scale in loss_dict['disc_scales']:
key = 'feature_maps_%s' % scale
for i, (a, b) in enumerate(zip(mouth_discriminator_maps_real[key], mouth_discriminator_maps_generated[key])):
value = torch.abs(a - b).mean()
L_feature_matching += value * cfg.loss_params.fm_loss[i]
loss += L_feature_matching * adapted_weight
fm_loss_accum += L_feature_matching.item()
# Process sync loss if enabled
if cfg.loss_params.sync_loss > 0:
pred_frames = rearrange(
image_pred, '(b f) c h w-> b (f c) h w', f=pixel_values_backward.shape[1])
pred_frames = pred_frames[:, :, height // 2 :, :]
sync_loss, image_audio_sim_pred = get_sync_loss(
audio_embed,
gt_frames,
pred_frames,
syncnet,
adapted_weight,
frames_left_index=frames_left_index,
frames_right_index=frames_right_index,
)
sync_loss_accum += sync_loss.item()
loss += sync_loss * cfg.loss_params.sync_loss * adapted_weight
# Backward pass
avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
train_loss += avg_loss.item()
accelerator.backward(loss)
# Train discriminator if GAN loss is enabled
if cfg.loss_params.gan_loss > 0:
set_requires_grad(loss_dict['discriminator'], True)
loss_D = loss_dict['discriminator_full'](frames, image_pred.detach())
avg_loss_D = accelerator.gather(loss_D.repeat(cfg.data.train_bs)).mean()
train_loss_D += avg_loss_D.item() / 1
loss_D = loss_D * get_ganloss_weight(global_step) * adapted_weight
accelerator.backward(loss_D)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(
loss_dict['discriminator'].parameters(), cfg.solver.max_grad_norm)
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
loss_dict['optimizer_D'].step()
loss_dict['scheduler_D'].step()
loss_dict['optimizer_D'].zero_grad()
# Train mouth discriminator if mouth GAN loss is enabled
if cfg.loss_params.mouth_gan_loss > 0:
set_requires_grad(loss_dict['mouth_discriminator'], True)
mouth_loss_D = loss_dict['mouth_discriminator_full'](
frames_mouth, image_pred_mouth.detach())
avg_mouth_loss_D = accelerator.gather(
mouth_loss_D.repeat(cfg.data.train_bs)).mean()
train_loss_D_mouth += avg_mouth_loss_D.item() / 1
mouth_loss_D = mouth_loss_D * get_ganloss_weight(global_step) * adapted_weight
accelerator.backward(mouth_loss_D)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(
loss_dict['mouth_discriminator'].parameters(), cfg.solver.max_grad_norm)
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
loss_dict['mouth_optimizer_D'].step()
loss_dict['mouth_scheduler_D'].step()
loss_dict['mouth_optimizer_D'].zero_grad()
# Update main model
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(
model_dict['trainable_params'],
cfg.solver.max_grad_norm,
)
model_dict['optimizer'].step()
model_dict['lr_scheduler'].step()
model_dict['optimizer'].zero_grad()
# Update progress and log metrics
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({
"train_loss": train_loss,
"train_loss_D": train_loss_D,
"train_loss_D_mouth": train_loss_D_mouth,
"l1_loss": l1_loss_accum,
"vgg_loss": vgg_loss_accum,
"gan_loss": gan_loss_accum,
"fm_loss": fm_loss_accum,
"sync_loss": sync_loss_accum,
"adapted_weight": adapted_weight_accum,
"lr": model_dict['lr_scheduler'].get_last_lr()[0],
}, step=global_step)
# Reset loss accumulators
train_loss = 0.0
l1_loss_accum = 0.0
vgg_loss_accum = 0.0
gan_loss_accum = 0.0
fm_loss_accum = 0.0
sync_loss_accum = 0.0
adapted_weight_accum = 0.0
train_loss_D = 0.0
train_loss_D_mouth = 0.0
# Run validation if needed
if global_step % cfg.val_freq == 0 or global_step == 10:
try:
validation(
cfg,
dataloader_dict['val_dataloader'],
model_dict['net'],
model_dict['vae'],
model_dict['wav2vec'],
accelerator,
save_dir,
global_step,
weight_dtype,
syncnet_score=adapted_weight,
)
except Exception as e:
print(f"An error occurred during validation: {e}")
# Save checkpoint if needed
if global_step % cfg.checkpointing_steps == 0:
save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
try:
start_time = time.time()
if accelerator.is_main_process:
save_models(
accelerator,
model_dict['net'],
save_dir,
global_step,
cfg,
logger=logger
)
delete_additional_ckpt(save_dir, cfg.total_limit)
elapsed_time = time.time() - start_time
if elapsed_time > 300:
print(f"Skipping storage as it took too long in step {global_step}.")
else:
print(f"Resume states saved at {save_dir} successfully in {elapsed_time}s.")
except Exception as e:
print(f"Error when saving model in step {global_step}:", e)
# Update progress bar
t_model = time.time() - t_model_start
logs = {
"step_loss": loss.detach().item(),
"lr": model_dict['lr_scheduler'].get_last_lr()[0],
"td": f"{t_data:.2f}s",
"tm": f"{t_model:.2f}s",
}
t_data_start = time.time()
progress_bar.set_postfix(**logs)
if global_step >= cfg.solver.max_train_steps:
break
# Save model after each epoch
if (epoch + 1) % cfg.save_model_epoch_interval == 0:
try:
start_time = time.time()
if accelerator.is_main_process:
save_models(accelerator, model_dict['net'], save_dir, global_step, cfg)
accelerator.save_state(save_path)
elapsed_time = time.time() - start_time
if elapsed_time > 120:
print(f"Skipping storage as it took too long in step {global_step}.")
else:
print(f"Model saved successfully in {elapsed_time}s.")
except Exception as e:
print(f"Error when saving model in step {global_step}:", e)
accelerator.wait_for_everyone()
# End training
accelerator.end_training()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml")
args = parser.parse_args()
config = OmegaConf.load(args.config)
main(config)

34
train.sh Normal file
View File

@@ -0,0 +1,34 @@
#!/bin/bash
# MuseTalk Training Script
# This script combines both training stages for the MuseTalk model
# Usage: sh train.sh [stage1|stage2]
# Example: sh train.sh stage1 # To run stage 1 training
# Example: sh train.sh stage2 # To run stage 2 training
# Check if stage argument is provided
if [ $# -ne 1 ]; then
echo "Error: Please specify the training stage"
echo "Usage: ./train.sh [stage1|stage2]"
exit 1
fi
STAGE=$1
# Validate stage argument
if [ "$STAGE" != "stage1" ] && [ "$STAGE" != "stage2" ]; then
echo "Error: Invalid stage. Must be either 'stage1' or 'stage2'"
exit 1
fi
# Launch distributed training using accelerate
# --config_file: Path to the GPU configuration file
# --main_process_port: Port number for the main process, used for distributed training communication
# train.py: Training script
# --config: Path to the training configuration file
echo "Starting $STAGE training..."
accelerate launch --config_file ./configs/training/gpu.yaml \
--main_process_port 29502 \
train.py --config ./configs/training/$STAGE.yaml
echo "Training completed for $STAGE"