Merge pull request #275 from TMElyralab/v15

V15
This commit is contained in:
Octpus
2025-03-28 17:39:51 +08:00
committed by GitHub
46 changed files with 733 additions and 205 deletions

345
README.md
View File

@@ -1,15 +1,16 @@
# MuseTalk
MuseTalk: Real-Time High Quality Lip Synchronization with Latent Space Inpainting
</br>
Yue Zhang <sup>\*</sup>,
<strong>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</strong>
Yue Zhang<sup>\*</sup>,
Zhizhou Zhong<sup>\*</sup>,
Minhao Liu<sup>\*</sup>,
Zhaokang Chen,
Bin Wu<sup></sup>,
Yubin Zeng,
Chao Zhan,
Yingjie He,
Junxin Huang,
Yingjie He,
Wenjiang Zhou
(<sup>*</sup>Equal Contribution, <sup></sup>Corresponding Author, benbinwu@tencent.com)
@@ -19,7 +20,10 @@ Lyra Lab, Tencent Music Entertainment
We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+ on an NVIDIA Tesla V100). MuseTalk can be applied with input videos, e.g., generated by [MuseV](https://github.com/TMElyralab/MuseV), as a complete virtual human solution.
:new: Update: We are thrilled to announce that [MusePose](https://github.com/TMElyralab/MusePose/) has been released. MusePose is an image-to-video generation framework for virtual human under control signal like pose. Together with MuseV and MuseTalk, we hope the community can join us and march towards the vision where a virtual human can be generated end2end with native ability of full body movement and interaction.
## 🔥 Updates
We're excited to unveil MuseTalk 1.5.
This version **(1)** integrates training with perceptual loss, GAN loss, and sync loss, significantly boosting its overall performance. **(2)** We've implemented a two-stage training strategy and a spatio-temporal data sampling approach to strike a balance between visual quality and lip-sync accuracy.
Learn more details [here](https://arxiv.org/abs/2410.10122)
# Overview
`MuseTalk` is a real-time high quality audio-driven lip-syncing model trained in the latent space of `ft-mse-vae`, which
@@ -28,23 +32,200 @@ We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+
1. supports audio in various languages, such as Chinese, English, and Japanese.
1. supports real-time inference with 30fps+ on an NVIDIA Tesla V100.
1. supports modification of the center point of the face region proposes, which **SIGNIFICANTLY** affects generation results.
1. checkpoint available trained on the HDTF dataset.
1. training codes (comming soon).
1. checkpoint available trained on the HDTF and private dataset.
# News
- [04/02/2024] Release MuseTalk project and pretrained models.
- [03/28/2025] :mega: We are thrilled to announce the release of our 1.5 version. This version is a significant improvement over the 1.0 version, with enhanced clarity, identity consistency, and precise lip-speech synchronization. We update the [technical report](https://arxiv.org/abs/2410.10122) with more details.
- [10/18/2024] We release the [technical report](https://arxiv.org/abs/2410.10122v2). Our report details a superior model to the open-source L1 loss version. It includes GAN and perceptual losses for improved clarity, and sync loss for enhanced performance.
- [04/17/2024] We release a pipeline that utilizes MuseTalk for real-time inference.
- [04/16/2024] Release Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk) on HuggingFace Spaces (thanks to HF team for their community grant)
- [04/17/2024] : We release a pipeline that utilizes MuseTalk for real-time inference.
- [10/18/2024] :mega: We release the [technical report](https://arxiv.org/abs/2410.10122). Our report details a superior model to the open-source L1 loss version. It includes GAN and perceptual losses for improved clarity, and sync loss for enhanced performance.
- [04/02/2024] Release MuseTalk project and pretrained models.
## Model
![Model Structure](assets/figs/musetalk_arc.jpg)
![Model Structure](https://github.com/user-attachments/assets/02f4a214-1bdd-4326-983c-e70b478accba)
MuseTalk was trained in latent spaces, where the images were encoded by a freezed VAE. The audio was encoded by a freezed `whisper-tiny` model. The architecture of the generation network was borrowed from the UNet of the `stable-diffusion-v1-4`, where the audio embeddings were fused to the image embeddings by cross-attention.
Note that although we use a very similar architecture as Stable Diffusion, MuseTalk is distinct in that it is **NOT** a diffusion model. Instead, MuseTalk operates by inpainting in the latent space with a single step.
## Cases
### MuseV + MuseTalk make human photos alive
<table>
<tr>
<td width="33%">
### Input Video
---
https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107
---
https://github.com/user-attachments/assets/1ce3e850-90ac-4a31-a45f-8dfa4f2960ac
---
https://github.com/user-attachments/assets/fa3b13a1-ae26-4d1d-899e-87435f8d22b3
---
https://github.com/user-attachments/assets/15800692-39d1-4f4c-99f2-aef044dc3251
---
https://github.com/user-attachments/assets/a843f9c9-136d-4ed4-9303-4a7269787a60
---
https://github.com/user-attachments/assets/6eb4e70e-9e19-48e9-85a9-bbfa589c5fcb
</td>
<td width="33%">
### MuseTalk 1.0
---
https://github.com/user-attachments/assets/c04f3cd5-9f77-40e9-aafd-61978380d0ef
---
https://github.com/user-attachments/assets/2051a388-1cef-4c1d-b2a2-3c1ceee5dc99
---
https://github.com/user-attachments/assets/b5f56f71-5cdc-4e2e-a519-454242000d32
---
https://github.com/user-attachments/assets/a5843835-04ab-4c31-989f-0995cfc22f34
---
https://github.com/user-attachments/assets/3dc7f1d7-8747-4733-bbdd-97874af0c028
---
https://github.com/user-attachments/assets/3c78064e-faad-4637-83ae-28452a22b09a
</td>
<td width="33%">
### MuseTalk 1.5
---
https://github.com/user-attachments/assets/999a6f5b-61dd-48e1-b902-bb3f9cbc7247
---
https://github.com/user-attachments/assets/d26a5c9a-003c-489d-a043-c9a331456e75
---
https://github.com/user-attachments/assets/471290d7-b157-4cf6-8a6d-7e899afa302c
---
https://github.com/user-attachments/assets/1ee77c4c-8c70-4add-b6db-583a12faa7dc
---
https://github.com/user-attachments/assets/370510ea-624c-43b7-bbb0-ab5333e0fcc4
---
https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
</td>
</tr>
</table>
# TODO:
- [x] trained models and inference codes.
- [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk).
- [x] codes for real-time inference.
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
- [ ] training and dataloader code (Expected completion on 04/04/2025).
- [ ] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
# Getting Started
We provide a detailed tutorial about the installation and the basic usage of MuseTalk for new users:
## Third party integration
Thanks for the third-party integration, which makes installation and use more convenient for everyone.
We also hope you note that we have not verified, maintained, or updated third-party. Please refer to this project for specific results.
### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseTalk)
## Installation
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
### Build environment
We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows:
```shell
pip install -r requirements.txt
```
### mmlab packages
```bash
pip install --no-cache-dir -U openmim
mim install mmengine
mim install "mmcv>=2.0.1"
mim install "mmdet>=3.1.0"
mim install "mmpose>=1.1.0"
```
### Download ffmpeg-static
Download the ffmpeg-static and
```
export FFMPEG_PATH=/path/to/ffmpeg
```
for example:
```
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
```
### Download weights
You can download weights manually as follows:
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk).
2. Download the weights of other components:
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
- [whisper](https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt)
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
Finally, these weights should be organized in `models` as follows:
```
./models/
├── musetalk
│ └── musetalk.json
│ └── pytorch_model.bin
├── musetalkV15
│ └── musetalk.json
│ └── unet.pth
├── dwpose
│ └── dw-ll_ucoco_384.pth
├── face-parse-bisent
│ ├── 79999_iter.pth
│ └── resnet18-5c106cde.pth
├── sd-vae-ft-mse
│ ├── config.json
│ └── diffusion_pytorch_model.bin
└── whisper
└── tiny.pt
```
## Quickstart
### Inference
We provide inference scripts for both versions of MuseTalk:
#### MuseTalk 1.5 (Recommended)
```bash
python3 -m scripts.inference_alpha --inference_config configs/inference/test.yaml --unet_model_path ./models/musetalkV15/unet.pth
```
This inference script supports both MuseTalk 1.5 and 1.0 models:
- For MuseTalk 1.5: Use the command above with the V1.5 model path
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
configs/inference/test.yaml is the path to the inference configuration file, including video_path and audio_path.
The video_path should be either a video file, an image file or a directory of images.
#### MuseTalk 1.0
```bash
python3 -m scripts.inference --inference_config configs/inference/test.yaml
```
You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
<details close>
## TestCases For 1.0
<table class="center">
<tr style="font-weight: bolder;text-align:center;">
<td width="33%">Image</td>
@@ -130,132 +311,7 @@ Note that although we use a very similar architecture as Stable Diffusion, MuseT
</tr>
</table >
* The character of the last two rows, `Xinying Sun`, is a supermodel KOL. You can follow her on [douyin](https://www.douyin.com/user/MS4wLjABAAAAWDThbMPN_6Xmm_JgXexbOii1K-httbu2APdG8DvDyM8).
## Video dubbing
<table class="center">
<tr style="font-weight: bolder;text-align:center;">
<td width="70%">MuseTalk</td>
<td width="30%">Original videos</td>
</tr>
<tr>
<td>
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/4d7c5fa1-3550-4d52-8ed2-52f158150f24 controls preload></video>
</td>
<td>
<a href="//www.bilibili.com/video/BV1wT411b7HU">Link</a>
<href src=""></href>
</td>
</tr>
</table>
* For video dubbing, we applied a self-developed tool which can identify the talking person.
## Some interesting videos!
<table class="center">
<tr style="font-weight: bolder;text-align:center;">
<td width="50%">Image</td>
<td width="50%">MuseV + MuseTalk</td>
</tr>
<tr>
<td>
<img src=assets/demo/video1/video1.png width="95%">
</td>
<td>
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/1f02f9c6-8b98-475e-86b8-82ebee82fe0d controls preload></video>
</td>
</tr>
</table>
# TODO:
- [x] trained models and inference codes.
- [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk).
- [x] codes for real-time inference.
- [ ] technical report.
- [ ] training codes.
- [ ] a better model (may take longer).
# Getting Started
We provide a detailed tutorial about the installation and the basic usage of MuseTalk for new users:
## Third party integration
Thanks for the third-party integration, which makes installation and use more convenient for everyone.
We also hope you note that we have not verified, maintained, or updated third-party. Please refer to this project for specific results.
### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseTalk)
## Installation
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
### Build environment
We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows:
```shell
pip install -r requirements.txt
```
### mmlab packages
```bash
pip install --no-cache-dir -U openmim
mim install mmengine
mim install "mmcv>=2.0.1"
mim install "mmdet>=3.1.0"
mim install "mmpose>=1.1.0"
```
### Download ffmpeg-static
Download the ffmpeg-static and
```
export FFMPEG_PATH=/path/to/ffmpeg
```
for example:
```
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
```
### Download weights
You can download weights manually as follows:
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk).
2. Download the weights of other components:
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
- [whisper](https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt)
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
Finally, these weights should be organized in `models` as follows:
```
./models/
├── musetalk
│ └── musetalk.json
│ └── pytorch_model.bin
├── dwpose
│ └── dw-ll_ucoco_384.pth
├── face-parse-bisent
│ ├── 79999_iter.pth
│ └── resnet18-5c106cde.pth
├── sd-vae-ft-mse
│ ├── config.json
│ └── diffusion_pytorch_model.bin
└── whisper
└── tiny.pt
```
## Quickstart
### Inference
Here, we provide the inference script.
```
python -m scripts.inference --inference_config configs/inference/test.yaml
```
configs/inference/test.yaml is the path to the inference configuration file, including video_path and audio_path.
The video_path should be either a video file, an image file or a directory of images.
You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
#### Use of bbox_shift to have adjustable results
#### Use of bbox_shift to have adjustable results(For 1.0)
:mag_right: We have found that upper-bound of the mask has an important impact on mouth openness. Thus, to control the mask region, we suggest using the `bbox_shift` parameter. Positive values (moving towards the lower half) increase mouth openness, while negative values (moving towards the upper half) decrease mouth openness.
You can start by running with the default configuration to obtain the adjustable value range, and then re-run the script within this range.
@@ -266,12 +322,16 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
```
:pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
</details>
#### Combining MuseV and MuseTalk
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
#### :new: Real-time inference
#### Real-time inference
<details close>
Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
```
@@ -293,6 +353,7 @@ configs/inference/realtime.yaml is the path to the real-time inference configura
```
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
```
</details>
# Acknowledgement
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
@@ -312,10 +373,10 @@ If you need higher resolution, you could apply super resolution models such as [
# Citation
```bib
@article{musetalk,
title={MuseTalk: Real-Time High Quality Lip Synchorization with Latent Space Inpainting},
author={Zhang, Yue and Liu, Minhao and Chen, Zhaokang and Wu, Bin and Zeng, Yubin and Zhan, Chao and He, Yingjie and Huang, Junxin and Zhou, Wenjiang},
title={MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling},
author={Zhang, Yue and Zhong, Zhizhou and Liu, Minhao and Chen, Zhaokang and Wu, Bin and Zeng, Yubin and Zhan, Chao and He, Yingjie and Huang, Junxin and Zhou, Wenjiang},
journal={arxiv},
year={2024}
year={2025}
}
```
# Disclaimer/License

View File

@@ -31,12 +31,16 @@ class UNet():
unet_config,
model_path,
use_float16=False,
device=None
):
with open(unet_config, 'r') as f:
unet_config = json.load(f)
self.model = UNet2DConditionModel(**unet_config)
self.pe = PositionalEncoding(d_model=384)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device != None:
self.device = device
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(weights)
if use_float16:

0
musetalk/utils/__init__.py Normal file → Executable file
View File

View File

@@ -0,0 +1,99 @@
import os
import math
import librosa
import numpy as np
import torch
from einops import rearrange
from transformers import AutoFeatureExtractor
class AudioProcessor:
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
def get_audio_feature(self, wav_path, start_index=0):
if not os.path.exists(wav_path):
return None
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
assert sampling_rate == 16000
# Split audio into 30s segments
segment_length = 30 * sampling_rate
segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
features = []
for segment in segments:
audio_feature = self.feature_extractor(
segment,
return_tensors="pt",
sampling_rate=sampling_rate
).input_features
features.append(audio_feature)
return features, len(librosa_output)
def get_whisper_chunk(
self,
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=25,
audio_padding_length_left=2,
audio_padding_length_right=2,
):
audio_feature_length_per_frame = 2 * (audio_padding_length_left + audio_padding_length_right + 1)
whisper_feature = []
# Process multiple 30s mel input features
for input_feature in whisper_input_features:
audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)
whisper_feature.append(audio_feats)
whisper_feature = torch.cat(whisper_feature, dim=1)
# Trim the last segment to remove padding
sr = 16000
audio_fps = 50
fps = int(fps)
whisper_idx_multiplier = audio_fps / fps
num_frames = math.floor((librosa_length / sr)) * fps
actual_length = math.floor((librosa_length / sr)) * audio_fps
whisper_feature = whisper_feature[:,:actual_length,...]
# Calculate padding amount
padding_nums = math.floor(whisper_idx_multiplier)
# Add padding at start and end
whisper_feature = torch.cat([
torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
whisper_feature,
# Add extra padding to prevent out of bounds
torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
], 1)
audio_prompts = []
for frame_index in range(num_frames):
try:
audio_index = math.floor(frame_index * whisper_idx_multiplier)
audio_clip = whisper_feature[:, audio_index: audio_index + audio_feature_length_per_frame]
assert audio_clip.shape[1] == audio_feature_length_per_frame
audio_prompts.append(audio_clip)
except Exception as e:
print(f"Error occurred: {e}")
print(f"whisper_feature.shape: {whisper_feature.shape}")
print(f"audio_clip.shape: {audio_clip.shape}")
print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
exit()
audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
return audio_prompts
if __name__ == "__main__":
audio_processor = AudioProcessor()
wav_path = "/cfs-workspace/users/gozhong/codes/musetalk_opensource2/data/audio/2.wav"
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
print("Audio Feature shape:", audio_feature.shape)
print("librosa_feature_length:", librosa_feature_length)

123
musetalk/utils/blending.py Normal file → Executable file
View File

@@ -2,9 +2,6 @@ from PIL import Image
import numpy as np
import cv2
import copy
from face_parsing import FaceParsing
fp = FaceParsing()
def get_crop_box(box, expand):
x, y, x1, y1 = box
@@ -14,46 +11,98 @@ def get_crop_box(box, expand):
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
return crop_box, s
def face_seg(image):
seg_image = fp(image)
def face_seg(image, mode="jaw", fp=None):
"""
对图像进行面部解析,生成面部区域的掩码。
Args:
image (PIL.Image): 输入图像。
Returns:
PIL.Image: 面部区域的掩码图像。
"""
seg_image = fp(image, mode=mode) # 使用 FaceParsing 模型解析面部
if seg_image is None:
print("error, no person_segment")
print("error, no person_segment") # 如果没有检测到面部,返回错误
return None
seg_image = seg_image.resize(image.size)
seg_image = seg_image.resize(image.size) # 将掩码图像调整为输入图像的大小
return seg_image
def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2):
#print(image.shape)
#print(face.shape)
def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode="raw", fp=None):
"""
将裁剪的面部图像粘贴回原始图像,并进行一些处理。
Args:
image (numpy.ndarray): 原始图像(身体部分)。
face (numpy.ndarray): 裁剪的面部图像。
face_box (tuple): 面部边界框的坐标 (x, y, x1, y1)。
upper_boundary_ratio (float): 用于控制面部区域的保留比例。
expand (float): 扩展因子,用于放大裁剪框。
mode: 融合mask构建方式
Returns:
numpy.ndarray: 处理后的图像。
"""
# 将 numpy 数组转换为 PIL 图像
body = Image.fromarray(image[:, :, ::-1]) # 身体部分图像(整张图)
face = Image.fromarray(face[:, :, ::-1]) # 面部图像
x, y, x1, y1 = face_box # 获取面部边界框的坐标
crop_box, s = get_crop_box(face_box, expand) # 计算扩展后的裁剪框
x_s, y_s, x_e, y_e = crop_box # 裁剪框的坐标
face_position = (x, y) # 面部在原始图像中的位置
# 从身体图像中裁剪出扩展后的面部区域(下巴到边界有距离)
face_large = body.crop(crop_box)
ori_shape = face_large.size # 裁剪后图像的原始尺寸
# 对裁剪后的面部区域进行面部解析,生成掩码
mask_image = face_seg(face_large, mode=mode, fp=fp)
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 裁剪出面部区域的掩码
mask_image = Image.new('L', ori_shape, 0) # 创建一个全黑的掩码图像
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 将面部掩码粘贴到全黑图像上
# 保留面部区域的上半部分(用于控制说话区域)
width, height = mask_image.size
top_boundary = int(height * upper_boundary_ratio) # 计算上半部分的边界
modified_mask_image = Image.new('L', ori_shape, 0) # 创建一个新的全黑掩码图像
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) # 粘贴上半部分掩码
# 对掩码进行高斯模糊,使边缘更平滑
blur_kernel_size = int(0.05 * ori_shape[0] // 2 * 2) + 1 # 计算模糊核大小
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) # 高斯模糊
#mask_array = np.array(modified_mask_image)
mask_image = Image.fromarray(mask_array) # 将模糊后的掩码转换回 PIL 图像
# 将裁剪的面部图像粘贴回扩展后的面部区域
face_large.paste(face, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
body.paste(face_large, crop_box[:2], mask_image)
# 不用掩码完全用infer
#face_large.save("debug/checkpoint_6_face_large.png")
body = np.array(body) # 将 PIL 图像转换回 numpy 数组
return body[:, :, ::-1] # 返回处理后的图像BGR 转 RGB
def get_image_blending(image,face,face_box,mask_array,crop_box):
body = Image.fromarray(image[:,:,::-1])
face = Image.fromarray(face[:,:,::-1])
x, y, x1, y1 = face_box
#print(x1-x,y1-y)
crop_box, s = get_crop_box(face_box, expand)
x, y, x1, y1 = face_box
x_s, y_s, x_e, y_e = crop_box
face_position = (x, y)
face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
# keep upper_boundary_ratio of talking area
width, height = mask_image.size
top_boundary = int(height * upper_boundary_ratio)
modified_mask_image = Image.new('L', ori_shape, 0)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
mask_image = Image.fromarray(mask_array)
mask_image = mask_image.convert("L")
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
body.paste(face_large, crop_box[:2], mask_image)
body = np.array(body)
@@ -84,17 +133,3 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
return mask_array,crop_box
def get_image_blending(image,face,face_box,mask_array,crop_box):
body = image
x, y, x1, y1 = face_box
x_s, y_s, x_e, y_e = crop_box
face_large = copy.deepcopy(body[y_s:y_e, x_s:x_e])
face_large[y-y_s:y1-y_s, x-x_s:x1-x_s]=face
mask_image = cv2.cvtColor(mask_array,cv2.COLOR_BGR2GRAY)
mask_image = (mask_image/255).astype(np.float32)
body[y_s:y_e, x_s:x_e] = cv2.blendLinear(face_large,body[y_s:y_e, x_s:x_e],mask_image,1-mask_image)
return body

0
musetalk/utils/dwpose/default_runtime.py Normal file → Executable file
View File

0
musetalk/utils/face_detection/README.md Normal file → Executable file
View File

0
musetalk/utils/face_detection/__init__.py Normal file → Executable file
View File

0
musetalk/utils/face_detection/api.py Normal file → Executable file
View File

0
musetalk/utils/face_detection/detection/__init__.py Normal file → Executable file
View File

0
musetalk/utils/face_detection/detection/core.py Normal file → Executable file
View File

View File

0
musetalk/utils/face_detection/detection/sfd/bbox.py Normal file → Executable file
View File

0
musetalk/utils/face_detection/detection/sfd/detect.py Normal file → Executable file
View File

View File

View File

0
musetalk/utils/face_detection/models.py Normal file → Executable file
View File

0
musetalk/utils/face_detection/utils.py Normal file → Executable file
View File

View File

@@ -8,9 +8,53 @@ from .model import BiSeNet
import torchvision.transforms as transforms
class FaceParsing():
def __init__(self):
def __init__(self, left_cheek_width=80, right_cheek_width=80):
self.net = self.model_init()
self.preprocess = self.image_preprocess()
# Ensure all size parameters are integers
cone_height = 21
tail_height = 12
total_size = cone_height + tail_height
# Create kernel with explicit integer dimensions
kernel = np.zeros((total_size, total_size), dtype=np.uint8)
center_x = total_size // 2 # Ensure center coordinates are integers
# Cone part
for row in range(cone_height):
if row < cone_height//2:
continue
width = int(2 * (row - cone_height//2) + 1)
start = int(center_x - (width // 2))
end = int(center_x + (width // 2) + 1)
kernel[row, start:end] = 1
# Vertical extension part
if cone_height > 0:
base_width = int(kernel[cone_height-1].sum())
else:
base_width = 1
for row in range(cone_height, total_size):
start = max(0, int(center_x - (base_width//2)))
end = min(total_size, int(center_x + (base_width//2) + 1))
kernel[row, start:end] = 1
self.kernel = kernel
# Modify cheek erosion kernel to be flatter ellipse
self.cheek_kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (35, 3))
# Add cheek area mask (protect chin area)
self.cheek_mask = self._create_cheek_mask(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
def _create_cheek_mask(self, left_cheek_width=80, right_cheek_width=80):
"""Create cheek area mask (1/4 area on both sides)"""
mask = np.zeros((512, 512), dtype=np.uint8)
center = 512 // 2
cv2.rectangle(mask, (0, 0), (center - left_cheek_width, 512), 255, -1) # Left cheek
cv2.rectangle(mask, (center + right_cheek_width, 0), (512, 512), 255, -1) # Right cheek
return mask
def model_init(self,
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
@@ -30,7 +74,7 @@ class FaceParsing():
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def __call__(self, image, size=(512, 512)):
def __call__(self, image, size=(512, 512), mode="jaw"):
if isinstance(image, str):
image = Image.open(image)
@@ -44,8 +88,25 @@ class FaceParsing():
img = torch.unsqueeze(img, 0)
out = self.net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
parsing[np.where(parsing>13)] = 0
parsing[np.where(parsing>=1)] = 255
# Add 14:neck, remove 10:nose and 7:8:9
if mode == "neck":
parsing[np.isin(parsing, [1, 11, 12, 13, 14])] = 255
parsing[np.where(parsing!=255)] = 0
elif mode == "jaw":
face_region = np.isin(parsing, [1])*255
face_region = face_region.astype(np.uint8)
original_dilated = cv2.dilate(face_region, self.kernel, iterations=1)
eroded = cv2.erode(original_dilated, self.cheek_kernel, iterations=2)
face_region = cv2.bitwise_and(eroded, self.cheek_mask)
face_region = cv2.bitwise_or(face_region, cv2.bitwise_and(original_dilated, ~self.cheek_mask))
parsing[(face_region==255) & (~np.isin(parsing, [10]))] = 255
parsing[np.isin(parsing, [11, 12, 13])] = 255
parsing[np.where(parsing!=255)] = 0
else:
parsing[np.isin(parsing, [1, 11, 12, 13])] = 255
parsing[np.where(parsing!=255)] = 0
parsing = Image.fromarray(parsing.astype(np.uint8))
return parsing

0
musetalk/utils/preprocessing.py Normal file → Executable file
View File

42
musetalk/utils/utils.py Normal file → Executable file
View File

@@ -15,13 +15,24 @@ from musetalk.whisper.audio2feature import Audio2Feature
from musetalk.models.vae import VAE
from musetalk.models.unet import UNet,PositionalEncoding
def load_all_model():
audio_processor = Audio2Feature(model_path="./models/whisper/tiny.pt")
vae = VAE(model_path = "./models/sd-vae-ft-mse/")
unet = UNet(unet_config="./models/musetalk/musetalk.json",
model_path ="./models/musetalk/pytorch_model.bin")
def load_all_model(
unet_model_path="./models/musetalk/pytorch_model.bin",
vae_type="sd-vae-ft-mse",
unet_config="./models/musetalk/musetalk.json",
device=None,
):
vae = VAE(
model_path = f"./models/{vae_type}/",
)
print(f"load unet model from {unet_model_path}")
unet = UNet(
unet_config=unet_config,
model_path=unet_model_path,
device=device
)
pe = PositionalEncoding(d_model=384)
return audio_processor,vae,unet,pe
return vae, unet, pe
def get_file_type(video_path):
_, ext = os.path.splitext(video_path)
@@ -39,10 +50,13 @@ def get_video_fps(video_path):
video.release()
return fps
def datagen(whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0):
def datagen(
whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0,
device="cuda:0",
):
whisper_batch, latent_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(vae_encode_latents)
@@ -51,14 +65,14 @@ def datagen(whisper_chunks,
latent_batch.append(latent)
if len(latent_batch) >= batch_size:
whisper_batch = np.stack(whisper_batch)
whisper_batch = torch.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
whisper_batch, latent_batch = [], []
whisper_batch, latent_batch = [], []
# the last batch may smaller than batch size
if len(latent_batch) > 0:
whisper_batch = np.stack(whisper_batch)
whisper_batch = torch.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
yield whisper_batch.to(device), latent_batch.to(device)

0
musetalk/whisper/audio2feature.py Normal file → Executable file
View File

0
musetalk/whisper/whisper/__init__.py Normal file → Executable file
View File

0
musetalk/whisper/whisper/__main__.py Normal file → Executable file
View File

0
musetalk/whisper/whisper/assets/gpt2/merges.txt Normal file → Executable file
View File

View File

View File

0
musetalk/whisper/whisper/assets/gpt2/vocab.json Normal file → Executable file
View File

View File

View File

View File

View File

View File

0
musetalk/whisper/whisper/audio.py Normal file → Executable file
View File

0
musetalk/whisper/whisper/decoding.py Normal file → Executable file
View File

0
musetalk/whisper/whisper/model.py Normal file → Executable file
View File

0
musetalk/whisper/whisper/normalizers/__init__.py Normal file → Executable file
View File

0
musetalk/whisper/whisper/normalizers/basic.py Normal file → Executable file
View File

0
musetalk/whisper/whisper/normalizers/english.json Normal file → Executable file
View File

0
musetalk/whisper/whisper/normalizers/english.py Normal file → Executable file
View File

0
musetalk/whisper/whisper/tokenizer.py Normal file → Executable file
View File

0
musetalk/whisper/whisper/transcribe.py Normal file → Executable file
View File

0
musetalk/whisper/whisper/utils.py Normal file → Executable file
View File

View File

@@ -9,6 +9,7 @@ tensorboard==2.12.0
opencv-python==4.9.0.80
soundfile==0.12.1
transformers==4.39.2
huggingface_hub==0.25.0
gdown
requests

253
scripts/inference_alpha.py Normal file
View File

@@ -0,0 +1,253 @@
import os
import cv2
import math
import copy
import torch
import glob
import shutil
import pickle
import argparse
import subprocess
import numpy as np
from tqdm import tqdm
from omegaconf import OmegaConf
from transformers import WhisperModel
from musetalk.utils.blending import get_image
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
@torch.no_grad()
def main(args):
# Configure ffmpeg path
if args.ffmpeg_path not in os.getenv('PATH'):
print("Adding ffmpeg to PATH")
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
# Set computing device
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
# Load model weights
vae, unet, pe = load_all_model(
unet_model_path=args.unet_model_path,
vae_type=args.vae_type,
unet_config=args.unet_config,
device=device
)
timesteps = torch.tensor([0], device=device)
# Convert models to half precision if float16 is enabled
if args.use_float16:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
# Move models to specified device
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
# Initialize audio processor and Whisper model
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
weight_dtype = unet.model.dtype
whisper = WhisperModel.from_pretrained(args.whisper_dir)
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
# Initialize face parser
fp = FaceParsing(left_cheek_width=args.left_cheek_width, right_cheek_width=args.right_cheek_width)
# Load inference configuration
inference_config = OmegaConf.load(args.inference_config)
print("Loaded inference config:", inference_config)
# Process each task
for task_id in inference_config:
try:
# Get task configuration
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
if "result_name" in inference_config[task_id]:
args.output_vid_name = inference_config[task_id]["result_name"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
# Set output paths
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
# Create temporary directories
temp_dir = os.path.join(args.result_dir, "frames_result")
os.makedirs(temp_dir, exist_ok=True)
# Set result save paths
result_img_save_path = os.path.join(temp_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path, exist_ok=True)
# Set output video paths
if args.output_vid_name is None:
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
else:
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
# Skip if output file already exists
if os.path.exists(output_vid_name):
print(f"{output_vid_name} already exists, skipping!")
continue
# Extract frames from source video
if get_file_type(video_path) == "video":
save_dir_full = os.path.join(temp_dir, input_basename)
os.makedirs(save_dir_full, exist_ok=True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path) == "image":
input_img_list = [video_path]
fps = args.fps
elif os.path.isdir(video_path):
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
# Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
)
# Preprocess input images
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("Using saved coordinates")
with open(crop_coord_save_path, 'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("Extracting landmarks... time-consuming operation")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
print(f"Number of frames: {len(frame_list)}")
# Process each frame
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
# Smooth first and last frames
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
# Batch inference
print("Starting inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=device,
)
res_frame_list = []
total = int(np.ceil(float(video_num) / batch_size))
# Execute inference
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
audio_feature_batch = pe(whisper_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
# Pad generated images to original video size
print("Padding generated images to original video size")
for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
except:
continue
# Merge results
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
# Save prediction results
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
print("Video generation command:", cmd_img2video)
os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
print("Audio combination command:", cmd_combine_audio)
os.system(cmd_combine_audio)
# Clean up temporary files
shutil.rmtree(result_img_save_path)
os.remove(temp_vid_path)
shutil.rmtree(save_dir_full)
if not args.saved_coord:
os.remove(crop_coord_save_path)
print(f"Results saved to {output_vid_name}")
except Exception as e:
print("Error occurred during processing:", e)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_path", type=str, default="/cfs-workspace/users/gozhong/ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
parser.add_argument("--unet_model_path", type=str, default="/cfs-datasets/users/gozhong/codes/musetalk_exp/exp_out/stage1_bs40/unet-20000.pth", help="Path to UNet model weights")
parser.add_argument("--whisper_dir", type=str, default="/cfs-datasets/public_models/whisper-tiny", help="Directory containing Whisper model")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
args = parser.parse_args()
main(args)