feat: windows infer & gradio (#312)

* fix: windows infer

* docs: update readme

* docs: update readme

* feat: v1.5 gradio for windows&linux

* fix: dependencies

* feat: windows infer & gradio

---------

Co-authored-by: NeRF-Factory <zzhizhou66@gmail.com>
This commit is contained in:
Zhizhou Zhong
2025-04-12 23:22:22 +08:00
committed by GitHub
parent 36163fccbd
commit 67e7ee3c73
14 changed files with 613 additions and 245 deletions

7
.gitignore vendored
View File

@@ -5,11 +5,14 @@
*.pyc
.ipynb_checkpoints
results/
./models
models/
**/__pycache__/
*.py[cod]
*$py.class
dataset/
ffmpeg*
ffmprobe*
ffplay*
debug
exp_out
exp_out
.gradio

169
README.md
View File

@@ -146,50 +146,87 @@ We also hope you note that we have not verified, maintained, or updated third-pa
## Installation
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
### Build environment
We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows:
### Build environment
We recommend Python 3.10 and CUDA 11.7. Set up your environment as follows:
```shell
conda create -n MuseTalk python==3.10
conda activate MuseTalk
```
### Install PyTorch 2.0.1
Choose one of the following installation methods:
```shell
# Option 1: Using pip
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
# Option 2: Using conda
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
```
### Install Dependencies
Install the remaining required packages:
```shell
pip install -r requirements.txt
```
### mmlab packages
### Install MMLab Packages
Install the MMLab ecosystem packages:
```bash
pip install --no-cache-dir -U openmim
mim install mmengine
mim install "mmcv>=2.0.1"
mim install "mmdet>=3.1.0"
mim install "mmpose>=1.1.0"
pip install --no-cache-dir -U openmim
mim install mmengine
mim install "mmcv==2.0.1"
mim install "mmdet==3.1.0"
mim install "mmpose==1.1.0"
```
### Download ffmpeg-static
Download the ffmpeg-static and
```
### Setup FFmpeg
1. [Download](https://github.com/BtbN/FFmpeg-Builds/releases) the ffmpeg-static package
2. Configure FFmpeg based on your operating system:
For Linux:
```bash
export FFMPEG_PATH=/path/to/ffmpeg
```
for example:
```
# Example:
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
```
### Download weights
You can download weights manually as follows:
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk).
For Windows:
Add the `ffmpeg-xxx\bin` directory to your system's PATH environment variable. Verify the installation by running `ffmpeg -version` in the command prompt - it should display the ffmpeg version information.
### Download weights
You can download weights in two ways:
#### Option 1: Using Download Scripts
We provide two scripts for automatic downloading:
For Linux:
```bash
# !pip install -U "huggingface_hub[cli]"
export HF_ENDPOINT=https://hf-mirror.com
huggingface-cli download TMElyralab/MuseTalk --local-dir models/
sh ./download_weights.sh
```
For Windows:
```batch
# Run the script
download_weights.bat
```
#### Option 2: Manual Download
You can also download the weights manually from the following links:
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk/tree/main)
2. Download the weights of other components:
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
- [face-parse-bisent](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view?pli=1)
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
Finally, these weights should be organized in `models` as follows:
```
@@ -207,7 +244,7 @@ Finally, these weights should be organized in `models` as follows:
├── face-parse-bisent
│ ├── 79999_iter.pth
│ └── resnet18-5c106cde.pth
├── sd-vae-ft-mse
├── sd-vae
│ ├── config.json
│ └── diffusion_pytorch_model.bin
└── whisper
@@ -221,21 +258,60 @@ Finally, these weights should be organized in `models` as follows:
### Inference
We provide inference scripts for both versions of MuseTalk:
#### MuseTalk 1.5 (Recommended)
#### Prerequisites
Before running inference, please ensure ffmpeg is installed and accessible:
```bash
# Run MuseTalk 1.5 inference
sh inference.sh v1.5 normal
# Check ffmpeg installation
ffmpeg -version
```
If ffmpeg is not found, please install it first:
- Windows: Download from [ffmpeg-static](https://github.com/BtbN/FFmpeg-Builds/releases) and add to PATH
- Linux: `sudo apt-get install ffmpeg`
#### MuseTalk 1.0
#### Normal Inference
##### Linux Environment
```bash
# Run MuseTalk 1.0 inference
# MuseTalk 1.5 (Recommended)
sh inference.sh v1.5 normal
# MuseTalk 1.0
sh inference.sh v1.0 normal
```
The inference script supports both MuseTalk 1.5 and 1.0 models:
- For MuseTalk 1.5: Use the command above with the V1.5 model path
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
##### Windows Environment
Please ensure that you set the `ffmpeg_path` to match the actual location of your FFmpeg installation.
```bash
# MuseTalk 1.5 (Recommended)
python -m scripts.inference --inference_config configs\inference\test.yaml --result_dir results\test --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
# For MuseTalk 1.0, change:
# - models\musetalkV15 -> models\musetalk
# - unet.pth -> pytorch_model.bin
# - --version v15 -> --version v1
```
#### Real-time Inference
##### Linux Environment
```bash
# MuseTalk 1.5 (Recommended)
sh inference.sh v1.5 realtime
# MuseTalk 1.0
sh inference.sh v1.0 realtime
```
##### Windows Environment
```bash
# MuseTalk 1.5 (Recommended)
python -m scripts.realtime_inference --inference_config configs\inference\realtime.yaml --result_dir results\realtime --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --fps 25 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
# For MuseTalk 1.0, change:
# - models\musetalkV15 -> models\musetalk
# - unet.pth -> pytorch_model.bin
# - --version v15 -> --version v1
```
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
- `video_path`: Path to the input video, image file, or directory of images
@@ -243,21 +319,6 @@ The configuration file `configs/inference/test.yaml` contains the inference sett
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
#### Real-time Inference
For real-time inference, use the following command:
```bash
# Run real-time inference
sh inference.sh v1.5 realtime # For MuseTalk 1.5
# or
sh inference.sh v1.0 realtime # For MuseTalk 1.0
```
The real-time inference configuration is in `configs/inference/realtime.yaml`, which includes:
- `preparation`: Set to `True` for new avatar preparation
- `video_path`: Path to the input video
- `bbox_shift`: Adjustable parameter for mouth region control
- `audio_clips`: List of audio clips for generation
Important notes for real-time inference:
1. Set `preparation` to `True` when processing a new avatar
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
@@ -269,6 +330,18 @@ For faster generation without saving images, you can use:
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
```
## Gradio Demo
We provide an intuitive web interface through Gradio for users to easily adjust input parameters. To optimize inference time, users can generate only the **first frame** to fine-tune the best lip-sync parameters, which helps reduce facial artifacts in the final output.
![para](assets/figs/gradio_2.png)
For minimum hardware requirements, we tested the system on a Windows environment using an NVIDIA GeForce RTX 3050 Ti Laptop GPU with 4GB VRAM. In fp16 mode, generating an 8-second video takes approximately 5 minutes. ![speed](assets/figs/gradio.png)
Both Linux and Windows users can launch the demo using the following command. Please ensure that the `ffmpeg_path` parameter matches your actual FFmpeg installation path:
```bash
# You can remove --use_float16 for better quality, but it will increase VRAM usage and inference time
python app.py --use_float16 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
```
## Training
### Data Preparation

476
app.py
View File

@@ -4,7 +4,6 @@ import pdb
import re
import gradio as gr
import spaces
import numpy as np
import sys
import subprocess
@@ -28,11 +27,101 @@ import gdown
import imageio
import ffmpeg
from moviepy.editor import *
from transformers import WhisperModel
ProjectDir = os.path.abspath(os.path.dirname(__file__))
CheckpointsDir = os.path.join(ProjectDir, "models")
@torch.no_grad()
def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
left_cheek_width=90, right_cheek_width=90):
"""Debug inpainting parameters, only process the first frame"""
# Set default parameters
args_dict = {
"result_dir": './results/debug',
"fps": 25,
"batch_size": 1,
"output_vid_name": '',
"use_saved_coord": False,
"audio_padding_length_left": 2,
"audio_padding_length_right": 2,
"version": "v15",
"extra_margin": extra_margin,
"parsing_mode": parsing_mode,
"left_cheek_width": left_cheek_width,
"right_cheek_width": right_cheek_width
}
args = Namespace(**args_dict)
# Create debug directory
os.makedirs(args.result_dir, exist_ok=True)
# Read first frame
if get_file_type(video_path) == "video":
reader = imageio.get_reader(video_path)
first_frame = reader.get_data(0)
reader.close()
else:
first_frame = cv2.imread(video_path)
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
# Save first frame
debug_frame_path = os.path.join(args.result_dir, "debug_frame.png")
cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
# Get face coordinates
coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift)
bbox = coord_list[0]
frame = frame_list[0]
if bbox == coord_placeholder:
return None, "No face detected, please adjust bbox_shift parameter"
# Initialize face parser
fp = FaceParsing(
left_cheek_width=args.left_cheek_width,
right_cheek_width=args.right_cheek_width
)
# Process first frame
x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
# Generate random audio features
random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype)
audio_feature = pe(random_audio)
# Get latents
latents = vae.get_latents_for_unet(crop_frame)
latents = latents.to(dtype=weight_dtype)
# Generate prediction results
pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample
recon = vae.decode_latents(pred_latents)
# Inpaint back to original image
res_frame = recon[0]
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
# Save results (no need to convert color space again since get_image already returns RGB format)
debug_result_path = os.path.join(args.result_dir, "debug_result.png")
cv2.imwrite(debug_result_path, combine_frame)
# Create information text
info_text = f"Parameter information:\n" + \
f"bbox_shift: {bbox_shift}\n" + \
f"extra_margin: {extra_margin}\n" + \
f"parsing_mode: {parsing_mode}\n" + \
f"left_cheek_width: {left_cheek_width}\n" + \
f"right_cheek_width: {right_cheek_width}\n" + \
f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]"
return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text
def print_directory_contents(path):
for child in os.listdir(path):
child_path = os.path.join(path, child)
@@ -40,119 +129,107 @@ def print_directory_contents(path):
print(child_path)
def download_model():
if not os.path.exists(CheckpointsDir):
os.makedirs(CheckpointsDir)
print("Checkpoint Not Downloaded, start downloading...")
tic = time.time()
snapshot_download(
repo_id="TMElyralab/MuseTalk",
local_dir=CheckpointsDir,
max_workers=8,
local_dir_use_symlinks=True,
force_download=True, resume_download=False
)
# weight
os.makedirs(f"{CheckpointsDir}/sd-vae-ft-mse/")
snapshot_download(
repo_id="stabilityai/sd-vae-ft-mse",
local_dir=CheckpointsDir+'/sd-vae-ft-mse',
max_workers=8,
local_dir_use_symlinks=True,
force_download=True, resume_download=False
)
#dwpose
os.makedirs(f"{CheckpointsDir}/dwpose/")
snapshot_download(
repo_id="yzd-v/DWPose",
local_dir=CheckpointsDir+'/dwpose',
max_workers=8,
local_dir_use_symlinks=True,
force_download=True, resume_download=False
)
#vae
url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
response = requests.get(url)
# 确保请求成功
if response.status_code == 200:
# 指定文件保存的位置
file_path = f"{CheckpointsDir}/whisper/tiny.pt"
os.makedirs(f"{CheckpointsDir}/whisper/")
# 将文件内容写入指定位置
with open(file_path, "wb") as f:
f.write(response.content)
# 检查必需的模型文件是否存在
required_models = {
"MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth",
"MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json",
"SD VAE": f"{CheckpointsDir}/sd-vae/config.json",
"Whisper": f"{CheckpointsDir}/whisper/config.json",
"DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth",
"SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt",
"Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth",
"ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
}
missing_models = []
for model_name, model_path in required_models.items():
if not os.path.exists(model_path):
missing_models.append(model_name)
if missing_models:
# 全用英文
print("The following required model files are missing:")
for model in missing_models:
print(f"- {model}")
print("\nPlease run the download script to download the missing models:")
if sys.platform == "win32":
print("Windows: Run download_weights.bat")
else:
print(f"请求失败,状态码:{response.status_code}")
#gdown face parse
url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812"
os.makedirs(f"{CheckpointsDir}/face-parse-bisent/")
file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth"
gdown.download(url, file_path, quiet=False)
#resnet
url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
response = requests.get(url)
# 确保请求成功
if response.status_code == 200:
# 指定文件保存的位置
file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
# 将文件内容写入指定位置
with open(file_path, "wb") as f:
f.write(response.content)
else:
print(f"请求失败,状态码:{response.status_code}")
toc = time.time()
print(f"download cost {toc-tic} seconds")
print_directory_contents(CheckpointsDir)
print("Linux/Mac: Run ./download_weights.sh")
sys.exit(1)
else:
print("Already download the model.")
print("All required model files exist.")
download_model() # for huggingface deployment.
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder,get_bbox_range
from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
@spaces.GPU(duration=600)
@torch.no_grad()
def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=True)):
args_dict={"result_dir":'./results/output', "fps":25, "batch_size":8, "output_vid_name":'', "use_saved_coord":False}#same with inferenece script
def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)):
# Set default parameters, aligned with inference.py
args_dict = {
"result_dir": './results/output',
"fps": 25,
"batch_size": 8,
"output_vid_name": '',
"use_saved_coord": False,
"audio_padding_length_left": 2,
"audio_padding_length_right": 2,
"version": "v15", # Fixed use v15 version
"extra_margin": extra_margin,
"parsing_mode": parsing_mode,
"left_cheek_width": left_cheek_width,
"right_cheek_width": right_cheek_width
}
args = Namespace(**args_dict)
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path,exist_ok =True)
# Check ffmpeg
if not fast_check_ffmpeg():
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
if args.output_vid_name=="":
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
# Create temporary directory
temp_dir = os.path.join(args.result_dir, f"{args.version}")
os.makedirs(temp_dir, exist_ok=True)
# Set result save path
result_img_save_path = os.path.join(temp_dir, output_basename)
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
os.makedirs(result_img_save_path, exist_ok=True)
if args.output_vid_name == "":
output_vid_name = os.path.join(temp_dir, output_basename+".mp4")
else:
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
############################################## extract frames from source video ##############################################
if get_file_type(video_path)=="video":
save_dir_full = os.path.join(args.result_dir, input_basename)
os.makedirs(save_dir_full,exist_ok = True)
# cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
# os.system(cmd)
# 读取视频
if get_file_type(video_path) == "video":
save_dir_full = os.path.join(temp_dir, input_basename)
os.makedirs(save_dir_full, exist_ok=True)
# Read video
reader = imageio.get_reader(video_path)
# 保存图片
# Save images
for i, im in enumerate(reader):
imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
@@ -161,10 +238,21 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
#print(input_img_list)
############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
# Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
)
############################################## preprocess input image ##############################################
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates")
@@ -176,13 +264,22 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
bbox_shift_text=get_bbox_range(input_img_list, bbox_shift)
bbox_shift_text = get_bbox_range(input_img_list, bbox_shift)
# Initialize face parser
fp = FaceParsing(
left_cheek_width=args.left_cheek_width,
right_cheek_width=args.right_cheek_width
)
i = 0
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
@@ -192,17 +289,23 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
############################################## inference batch by batch ##############################################
print("start inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=device,
)
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
audio_feature_batch = pe(audio_feature_batch)
audio_feature_batch = pe(whisper_batch)
# Ensure latent_batch is consistent with model weight type
latent_batch = latent_batch.to(dtype=weight_dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
@@ -215,25 +318,24 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
# print(bbox)
continue
combine_frame = get_image(ori_frame,res_frame,bbox)
# Use v15 version blending
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
# cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p temp.mp4"
# print(cmd_img2video)
# os.system(cmd_img2video)
# 帧率
# Frame rate
fps = 25
# 图片路径
# 输出视频路径
# Output video path
output_video = 'temp.mp4'
# 读取图片
# Read images
def is_valid_image(file):
pattern = re.compile(r'\d{8}\.png')
return pattern.match(file)
@@ -247,13 +349,9 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
images.append(imageio.imread(filename))
# 保存视频
# Save video
imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
# cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
# print(cmd_combine_audio)
# os.system(cmd_combine_audio)
input_video = './temp.mp4'
# Check if the input_video and audio_path exist
if not os.path.exists(input_video):
@@ -261,40 +359,15 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
# 读取视频
# Read video
reader = imageio.get_reader(input_video)
fps = reader.get_meta_data()['fps'] # 获取原视频的帧率
reader.close() # 否则在win11上会报错PermissionError: [WinError 32] 另一个程序正在使用此文件,进程无法访问。: 'temp.mp4'
# 将帧存储在列表中
fps = reader.get_meta_data()['fps'] # Get original video frame rate
reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4'
# Store frames in list
frames = images
# 保存视频并添加音频
# imageio.mimwrite(output_vid_name, frames, 'FFMPEG', fps=fps, codec='libx264', audio_codec='aac', input_params=['-i', audio_path])
# input_video = ffmpeg.input(input_video)
# input_audio = ffmpeg.input(audio_path)
print(len(frames))
# imageio.mimwrite(
# output_video,
# frames,
# 'FFMPEG',
# fps=25,
# codec='libx264',
# audio_codec='aac',
# input_params=['-i', audio_path],
# output_params=['-y'], # Add the '-y' flag to overwrite the output file if it exists
# )
# writer = imageio.get_writer(output_vid_name, fps = 25, codec='libx264', quality=10, pixelformat='yuvj444p')
# for im in frames:
# writer.append_data(im)
# writer.close()
# Load the video
video_clip = VideoFileClip(input_video)
@@ -315,11 +388,45 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
# load model weights
audio_processor,vae,unet,pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae, unet, pe = load_all_model(
unet_model_path="./models/musetalkV15/unet.pth",
vae_type="sd-vae",
unet_config="./models/musetalkV15/musetalk.json",
device=device
)
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable")
parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
parser.add_argument("--share", action="store_true", help="Create a public link")
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
args = parser.parse_args()
# Set data type
if args.use_float16:
# Convert models to half precision for better performance
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
weight_dtype = torch.float16
else:
weight_dtype = torch.float32
# Move models to specified device
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
timesteps = torch.tensor([0], device=device)
# Initialize audio processor and Whisper model
audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
whisper = WhisperModel.from_pretrained("./models/whisper")
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
def check_video(video):
@@ -340,9 +447,6 @@ def check_video(video):
output_video = os.path.join('./results/input', output_file_name)
# # Run the ffmpeg command to change the frame rate to 25fps
# command = f"ffmpeg -i {video} -r 25 -vcodec libx264 -vtag hvc1 -pix_fmt yuv420p crf 18 {output_video} -y"
# read video
reader = imageio.get_reader(video)
fps = reader.get_meta_data()['fps'] # get fps from original video
@@ -374,34 +478,45 @@ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024p
with gr.Blocks(css=css) as demo:
gr.Markdown(
"<div align='center'> <h1>MuseTalk: Real-Time High Quality Lip Synchronization with Latent Space Inpainting </span> </h1> \
"""<div align='center'> <h1>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</h1> \
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
</br>\
Yue Zhang <sup>\*</sup>,\
Minhao Liu<sup>\*</sup>,\
Yue Zhang <sup>*</sup>,\
Zhizhou Zhong <sup>*</sup>,\
Minhao Liu<sup>*</sup>,\
Zhaokang Chen,\
Bin Wu<sup>†</sup>,\
Yubin Zeng,\
Chao Zhang,\
Yingjie He,\
Chao Zhan,\
Wenjiang Zhou\
Junxin Huang,\
Wenjiang Zhou <br>\
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
Lyra Lab, Tencent Music Entertainment\
</h2> \
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
<a style='font-size:18px;color: #000000' href=''> [Technical report(Coming Soon)] </a>\
<a style='font-size:18px;color: #000000' href=''> [Project Page(Coming Soon)] </a> </div>"
<a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2410.10122'> [Technical report] </a>"""
)
with gr.Row():
with gr.Column():
audio = gr.Audio(label="Driven Audio",type="filepath")
audio = gr.Audio(label="Drving Audio",type="filepath")
video = gr.Video(label="Reference Video",sources=['upload'])
bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
bbox_shift_scale = gr.Textbox(label="BBox_shift recommend value lower bound,The corresponding bbox range is generated after the initial result is generated. \n If the result is not good, it can be adjusted according to this reference value", value="",interactive=False)
extra_margin = gr.Slider(label="Extra Margin", minimum=0, maximum=40, value=10, step=1)
parsing_mode = gr.Radio(label="Parsing Mode", choices=["jaw", "raw"], value="jaw")
left_cheek_width = gr.Slider(label="Left Cheek Width", minimum=20, maximum=160, value=90, step=5)
right_cheek_width = gr.Slider(label="Right Cheek Width", minimum=20, maximum=160, value=90, step=5)
bbox_shift_scale = gr.Textbox(label="'left_cheek_width' and 'right_cheek_width' parameters determine the range of left and right cheeks editing when parsing model is 'jaw'. The 'extra_margin' parameter determines the movement range of the jaw. Users can freely adjust these three parameters to obtain better inpainting results.")
btn = gr.Button("Generate")
out1 = gr.Video()
with gr.Row():
debug_btn = gr.Button("1. Test Inpainting ")
btn = gr.Button("2. Generate")
with gr.Column():
debug_image = gr.Image(label="Test Inpainting Result (First Frame)")
debug_info = gr.Textbox(label="Parameter Information", lines=5)
out1 = gr.Video()
video.change(
fn=check_video, inputs=[video], outputs=[video]
@@ -412,15 +527,44 @@ with gr.Blocks(css=css) as demo:
audio,
video,
bbox_shift,
extra_margin,
parsing_mode,
left_cheek_width,
right_cheek_width
],
outputs=[out1,bbox_shift_scale]
)
debug_btn.click(
fn=debug_inpainting,
inputs=[
video,
bbox_shift,
extra_margin,
parsing_mode,
left_cheek_width,
right_cheek_width
],
outputs=[debug_image, debug_info]
)
# Set the IP and port
ip_address = "0.0.0.0" # Replace with your desired IP address
port_number = 7860 # Replace with your desired port number
# Check ffmpeg and add to PATH
if not fast_check_ffmpeg():
print(f"Adding ffmpeg to PATH: {args.ffmpeg_path}")
# According to operating system, choose path separator
path_separator = ';' if sys.platform == 'win32' else ':'
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
if not fast_check_ffmpeg():
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
# Solve asynchronous IO issues on Windows
if sys.platform == 'win32':
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# Start Gradio application
demo.queue().launch(
share=False , debug=True, server_name=ip_address, server_port=port_number
share=args.share,
debug=True,
server_name=args.ip,
server_port=args.port
)

BIN
assets/figs/gradio.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

BIN
assets/figs/gradio_2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 73 KiB

45
download_weights.bat Normal file
View File

@@ -0,0 +1,45 @@
@echo off
setlocal
:: Set the checkpoints directory
set CheckpointsDir=models
:: Create necessary directories
mkdir %CheckpointsDir%\musetalk
mkdir %CheckpointsDir%\musetalkV15
mkdir %CheckpointsDir%\syncnet
mkdir %CheckpointsDir%\dwpose
mkdir %CheckpointsDir%\face-parse-bisent
mkdir %CheckpointsDir%\sd-vae-ft-mse
mkdir %CheckpointsDir%\whisper
:: Install required packages
pip install -U "huggingface_hub[cli]"
pip install gdown
:: Set HuggingFace endpoint
set HF_ENDPOINT=https://hf-mirror.com
:: Download MuseTalk weights
huggingface-cli download TMElyralab/MuseTalk --local-dir %CheckpointsDir%
:: Download SD VAE weights
huggingface-cli download stabilityai/sd-vae-ft-mse --local-dir %CheckpointsDir%\sd-vae --include "config.json" "diffusion_pytorch_model.bin"
:: Download Whisper weights
huggingface-cli download openai/whisper-tiny --local-dir %CheckpointsDir%\whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
:: Download DWPose weights
huggingface-cli download yzd-v/DWPose --local-dir %CheckpointsDir%\dwpose --include "dw-ll_ucoco_384.pth"
:: Download SyncNet weights
huggingface-cli download ByteDance/LatentSync --local-dir %CheckpointsDir%\syncnet --include "latentsync_syncnet.pt"
:: Download Face Parse Bisent weights (using gdown)
gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O %CheckpointsDir%\face-parse-bisent\79999_iter.pth
:: Download ResNet weights
curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o %CheckpointsDir%\face-parse-bisent\resnet18-5c106cde.pth
echo All weights have been downloaded successfully!
endlocal

37
download_weights.sh Normal file
View File

@@ -0,0 +1,37 @@
#!/bin/bash
# Set the checkpoints directory
CheckpointsDir="models"
# Create necessary directories
mkdir -p $CheckpointsDir/{musetalk,musetalkV15,syncnet,dwpose,face-parse-bisent,sd-vae-ft-mse,whisper}
# Install required packages
pip install -U "huggingface_hub[cli]"
pip install gdown
# Set HuggingFace endpoint
export HF_ENDPOINT=https://hf-mirror.com
# Download MuseTalk weights
huggingface-cli download TMElyralab/MuseTalk --local-dir $CheckpointsDir
# Download SD VAE weights
huggingface-cli download stabilityai/sd-vae-ft-mse --local-dir $CheckpointsDir/sd-vae --include "config.json" "diffusion_pytorch_model.bin"
# Download Whisper weights
huggingface-cli download openai/whisper-tiny --local-dir $CheckpointsDir/whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
# Download DWPose weights
huggingface-cli download yzd-v/DWPose --local-dir $CheckpointsDir/dwpose --include "dw-ll_ucoco_384.pth"
# Download SyncNet weights
huggingface-cli download ByteDance/LatentSync --local-dir $CheckpointsDir/syncnet --include "latentsync_syncnet.pt"
# Download Face Parse Bisent weights (using gdown)
gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O $CheckpointsDir/face-parse-bisent/79999_iter.pth
# Download ResNet weights
curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o $CheckpointsDir/face-parse-bisent/resnet18-5c106cde.pth
echo "All weights have been downloaded successfully!"

View File

@@ -49,8 +49,9 @@ class AudioProcessor:
whisper_feature = []
# Process multiple 30s mel input features
for input_feature in whisper_input_features:
audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)
input_feature = input_feature.to(device).to(weight_dtype)
audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2)
whisper_feature.append(audio_feats)
whisper_feature = torch.cat(whisper_feature, dim=1)

View File

@@ -8,26 +8,18 @@ from einops import rearrange
import shutil
import os.path as osp
ffmpeg_path = os.getenv('FFMPEG_PATH')
if ffmpeg_path is None:
print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
elif ffmpeg_path not in os.getenv('PATH'):
print("add ffmpeg to path")
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
from musetalk.models.vae import VAE
from musetalk.models.unet import UNet,PositionalEncoding
def load_all_model(
unet_model_path="./models/musetalk/pytorch_model.bin",
vae_type="sd-vae-ft-mse",
unet_config="./models/musetalk/musetalk.json",
unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
vae_type="sd-vae",
unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
device=None,
):
vae = VAE(
model_path = f"./models/{vae_type}/",
model_path = os.path.join("models", vae_type),
)
print(f"load unet model from {unet_model_path}")
unet = UNet(

View File

@@ -1,15 +1,15 @@
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.0.1
torchvision==0.15.2
torchaudio==2.0.2
diffusers==0.27.2
diffusers==0.30.2
accelerate==0.28.0
numpy==1.23.5
tensorflow==2.12.0
tensorboard==2.12.0
opencv-python==4.9.0.80
soundfile==0.12.1
transformers==4.39.2
huggingface_hub==0.25.0
huggingface_hub==0.30.2
librosa==0.11.0
einops==0.8.1
gradio==5.24.0
gdown
requests
@@ -17,6 +17,4 @@ imageio[ffmpeg]
omegaconf
ffmpeg-python
gradio
spaces
moviepy

View File

@@ -8,9 +8,11 @@ import shutil
import pickle
import argparse
import numpy as np
import subprocess
from tqdm import tqdm
from omegaconf import OmegaConf
from transformers import WhisperModel
import sys
from musetalk.utils.blending import get_image
from musetalk.utils.face_parsing import FaceParsing
@@ -18,16 +20,26 @@ from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
@torch.no_grad()
def main(args):
# Configure ffmpeg path
if args.ffmpeg_path not in os.getenv('PATH'):
if not fast_check_ffmpeg():
print("Adding ffmpeg to PATH")
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
# Choose path separator based on operating system
path_separator = ';' if sys.platform == 'win32' else ':'
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
if not fast_check_ffmpeg():
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
# Set computing device
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
# Load model weights
vae, unet, pe = load_all_model(
unet_model_path=args.unet_model_path,

View File

@@ -12,11 +12,23 @@ from mmpose.structures import merge_data_samples
import torch
import numpy as np
from tqdm import tqdm
import sys
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
if ffmpeg_path not in os.getenv('PATH'):
print("add ffmpeg to path")
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
if not fast_check_ffmpeg():
print("Adding ffmpeg to PATH")
# Choose path separator based on operating system
path_separator = ';' if sys.platform == 'win32' else ':'
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
if not fast_check_ffmpeg():
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
class AnalyzeFace:
def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):

View File

@@ -23,6 +23,15 @@ import shutil
import threading
import queue
import time
import subprocess
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
@@ -318,7 +327,7 @@ if __name__ == "__main__":
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=25, help="Batch size for inference")
parser.add_argument("--batch_size", type=int, default=20, help="Batch size for inference")
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
@@ -332,6 +341,15 @@ if __name__ == "__main__":
args = parser.parse_args()
# Configure ffmpeg path
if not fast_check_ffmpeg():
print("Adding ffmpeg to PATH")
# Choose path separator based on operating system
path_separator = ';' if sys.platform == 'win32' else ':'
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
if not fast_check_ffmpeg():
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
# Set computing device
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")

33
test_ffmpeg.py Normal file
View File

@@ -0,0 +1,33 @@
import os
import subprocess
import sys
def test_ffmpeg(ffmpeg_path):
print(f"Testing ffmpeg path: {ffmpeg_path}")
# Choose path separator based on operating system
path_separator = ';' if sys.platform == 'win32' else ':'
# Add ffmpeg path to environment variable
os.environ["PATH"] = f"{ffmpeg_path}{path_separator}{os.environ['PATH']}"
try:
# Try to run ffmpeg
result = subprocess.run(["ffmpeg", "-version"], capture_output=True, text=True)
print("FFmpeg test successful!")
print("FFmpeg version information:")
print(result.stdout)
return True
except Exception as e:
print("FFmpeg test failed!")
print(f"Error message: {str(e)}")
return False
if __name__ == "__main__":
# Default ffmpeg path, can be modified as needed
default_path = r"ffmpeg-master-latest-win64-gpl-shared\bin"
# Use command line argument if provided, otherwise use default path
ffmpeg_path = sys.argv[1] if len(sys.argv) > 1 else default_path
test_ffmpeg(ffmpeg_path)