feat: real-time infer (#286)

* feat: realtime infer

* cchore: infer script
This commit is contained in:
Zhizhou Zhong
2025-04-02 19:13:18 +08:00
committed by GitHub
parent fbe6a97dff
commit 39ccf69f36
11 changed files with 490 additions and 592 deletions

View File

@@ -130,9 +130,8 @@ https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
- [x] codes for real-time inference. - [x] codes for real-time inference.
- [x] [technical report](https://arxiv.org/abs/2410.10122v2). - [x] [technical report](https://arxiv.org/abs/2410.10122v2).
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122). - [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
- [x] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
- [ ] training and dataloader code (Expected completion on 04/04/2025). - [ ] training and dataloader code (Expected completion on 04/04/2025).
- [ ] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
# Getting Started # Getting Started
@@ -220,21 +219,52 @@ We provide inference scripts for both versions of MuseTalk:
#### MuseTalk 1.5 (Recommended) #### MuseTalk 1.5 (Recommended)
```bash ```bash
sh inference.sh v1.5 # Run MuseTalk 1.5 inference
sh inference.sh v1.5 normal
``` ```
This inference script supports both MuseTalk 1.5 and 1.0 models:
- For MuseTalk 1.5: Use the command above with the V1.5 model path
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
configs/inference/test.yaml is the path to the inference configuration file, including video_path and audio_path.
The video_path should be either a video file, an image file or a directory of images.
#### MuseTalk 1.0 #### MuseTalk 1.0
```bash ```bash
sh inference.sh v1.0 # Run MuseTalk 1.0 inference
sh inference.sh v1.0 normal
``` ```
You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
<details close> The inference script supports both MuseTalk 1.5 and 1.0 models:
- For MuseTalk 1.5: Use the command above with the V1.5 model path
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
- `video_path`: Path to the input video, image file, or directory of images
- `audio_path`: Path to the input audio file
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
#### Real-time Inference
For real-time inference, use the following command:
```bash
# Run real-time inference
sh inference.sh v1.5 realtime # For MuseTalk 1.5
# or
sh inference.sh v1.0 realtime # For MuseTalk 1.0
```
The real-time inference configuration is in `configs/inference/realtime.yaml`, which includes:
- `preparation`: Set to `True` for new avatar preparation
- `video_path`: Path to the input video
- `bbox_shift`: Adjustable parameter for mouth region control
- `audio_clips`: List of audio clips for generation
Important notes for real-time inference:
1. Set `preparation` to `True` when processing a new avatar
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100
4. Set `preparation` to `False` for generating more videos with the same avatar
For faster generation without saving images, you can use:
```bash
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
```
## TestCases For 1.0 ## TestCases For 1.0
<table class="center"> <table class="center">
<tr style="font-weight: bolder;text-align:center;"> <tr style="font-weight: bolder;text-align:center;">
@@ -332,39 +362,11 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
``` ```
:pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md). :pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
</details>
#### Combining MuseV and MuseTalk #### Combining MuseV and MuseTalk
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference). As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
#### Real-time inference
<details close>
Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
```
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --batch_size 4
```
configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
1. Set `preparation` to `True` in `realtime.yaml` to prepare the materials for a new `avatar`. (If the `bbox_shift` has changed, you also need to re-prepare the materials.)
1. After that, the `avatar` will use an audio clip selected from `audio_clips` to generate video.
```
Inferring using: data/audio/yongen.wav
```
1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100.
1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
##### Note for Real-time inference
1. If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
1. In the previous script, the generation time is also limited by I/O (e.g. saving images). If you just want to test the generation speed without saving the images, you can run
```
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
```
</details>
# Acknowledgement # Acknowledgement
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch). 1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings). 1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).

View File

@@ -1,10 +1,10 @@
avator_1: avator_1:
preparation: False preparation: True # your can set it to False if you want to use the existing avator, it will save time
bbox_shift: 5 bbox_shift: 5
video_path: "data/video/sun.mp4" video_path: "data/video/yongen.mp4"
audio_clips: audio_clips:
audio_0: "data/audio/yongen.wav" audio_0: "data/audio/yongen.wav"
audio_1: "data/audio/sun.wav" audio_1: "data/audio/eng.wav"

View File

@@ -3,8 +3,8 @@ task_0:
audio_path: "data/audio/yongen.wav" audio_path: "data/audio/yongen.wav"
task_1: task_1:
video_path: "data/video/sun.mp4" video_path: "data/video/yongen.mp4"
audio_path: "data/audio/sun.wav" audio_path: "data/audio/eng.wav"
bbox_shift: -7 bbox_shift: -7

BIN
data/audio/eng.wav Executable file

Binary file not shown.

View File

@@ -1,46 +1,72 @@
#!/bin/bash #!/bin/bash
# This script runs inference based on the version specified by the user. # This script runs inference based on the version and mode specified by the user.
# Usage: # Usage:
# To run v1.0 inference: sh inference.sh v1.0 # To run v1.0 inference: sh inference.sh v1.0 [normal|realtime]
# To run v1.5 inference: sh inference.sh v1.5 # To run v1.5 inference: sh inference.sh v1.5 [normal|realtime]
# Check if the correct number of arguments is provided # Check if the correct number of arguments is provided
if [ "$#" -ne 1 ]; then if [ "$#" -ne 2 ]; then
echo "Usage: $0 <version>" echo "Usage: $0 <version> <mode>"
echo "Example: $0 v1.0 or $0 v1.5" echo "Example: $0 v1.0 normal or $0 v1.5 realtime"
exit 1 exit 1
fi fi
# Get the version from the user input # Get the version and mode from the user input
version=$1 version=$1
mode=$2
# Validate mode
if [ "$mode" != "normal" ] && [ "$mode" != "realtime" ]; then
echo "Invalid mode specified. Please use 'normal' or 'realtime'."
exit 1
fi
# Set config path based on mode
if [ "$mode" = "normal" ]; then
config_path="./configs/inference/test.yaml" config_path="./configs/inference/test.yaml"
result_dir="./results/test"
else
config_path="./configs/inference/realtime.yaml"
result_dir="./results/realtime"
fi
# Define the model paths based on the version # Define the model paths based on the version
if [ "$version" = "v1.0" ]; then if [ "$version" = "v1.0" ]; then
model_dir="./models/musetalk" model_dir="./models/musetalk"
unet_model_path="$model_dir/pytorch_model.bin" unet_model_path="$model_dir/pytorch_model.bin"
unet_config="$model_dir/musetalk.json" unet_config="$model_dir/musetalk.json"
version_arg="v1"
elif [ "$version" = "v1.5" ]; then elif [ "$version" = "v1.5" ]; then
model_dir="./models/musetalkV15" model_dir="./models/musetalkV15"
unet_model_path="$model_dir/unet.pth" unet_model_path="$model_dir/unet.pth"
unet_config="$model_dir/musetalk.json" unet_config="$model_dir/musetalk.json"
version_arg="v15"
else else
echo "Invalid version specified. Please use v1.0 or v1.5." echo "Invalid version specified. Please use v1.0 or v1.5."
exit 1 exit 1
fi fi
# Run inference based on the version # Set script name based on mode
if [ "$version" = "v1.0" ]; then if [ "$mode" = "normal" ]; then
python3 -m scripts.inference \ script_name="scripts.inference"
--inference_config "$config_path" \ else
--result_dir "./results/test" \ script_name="scripts.realtime_inference"
--unet_model_path "$unet_model_path" \
--unet_config "$unet_config"
elif [ "$version" = "v1.5" ]; then
python3 -m scripts.inference_alpha \
--inference_config "$config_path" \
--result_dir "./results/test" \
--unet_model_path "$unet_model_path" \
--unet_config "$unet_config"
fi fi
# Base command arguments
cmd_args="--inference_config $config_path \
--result_dir $result_dir \
--unet_model_path $unet_model_path \
--unet_config $unet_config \
--version $version_arg \
# Add realtime-specific arguments if in realtime mode
if [ "$mode" = "realtime" ]; then
cmd_args="$cmd_args \
--fps 25 \
--version $version_arg \
fi
# Run inference
python3 -m $script_name $cmd_args

View File

@@ -11,7 +11,7 @@ class AudioProcessor:
def __init__(self, feature_extractor_path="openai/whisper-tiny/"): def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path) self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
def get_audio_feature(self, wav_path, start_index=0): def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
if not os.path.exists(wav_path): if not os.path.exists(wav_path):
return None return None
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000) librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
@@ -27,6 +27,8 @@ class AudioProcessor:
return_tensors="pt", return_tensors="pt",
sampling_rate=sampling_rate sampling_rate=sampling_rate
).input_features ).input_features
if weight_dtype is not None:
audio_feature = audio_feature.to(dtype=weight_dtype)
features.append(audio_feature) features.append(audio_feature)
return features, len(librosa_output) return features, len(librosa_output)

View File

@@ -3,6 +3,7 @@ import numpy as np
import cv2 import cv2
import copy import copy
def get_crop_box(box, expand): def get_crop_box(box, expand):
x, y, x1, y1 = box x, y, x1, y1 = box
x_c, y_c = (x+x1)//2, (y+y1)//2 x_c, y_c = (x+x1)//2, (y+y1)//2
@@ -11,7 +12,8 @@ def get_crop_box(box, expand):
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s] crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
return crop_box, s return crop_box, s
def face_seg(image, mode="jaw", fp=None):
def face_seg(image, mode="raw", fp=None):
""" """
对图像进行面部解析,生成面部区域的掩码。 对图像进行面部解析,生成面部区域的掩码。
@@ -86,13 +88,11 @@ def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode=
body.paste(face_large, crop_box[:2], mask_image) body.paste(face_large, crop_box[:2], mask_image)
# 不用掩码完全用infer
#face_large.save("debug/checkpoint_6_face_large.png")
body = np.array(body) # 将 PIL 图像转换回 numpy 数组 body = np.array(body) # 将 PIL 图像转换回 numpy 数组
return body[:, :, ::-1] # 返回处理后的图像BGR 转 RGB return body[:, :, ::-1] # 返回处理后的图像BGR 转 RGB
def get_image_blending(image, face, face_box, mask_array, crop_box): def get_image_blending(image, face, face_box, mask_array, crop_box):
body = Image.fromarray(image[:,:,::-1]) body = Image.fromarray(image[:,:,::-1])
face = Image.fromarray(face[:,:,::-1]) face = Image.fromarray(face[:,:,::-1])
@@ -108,7 +108,8 @@ def get_image_blending(image,face,face_box,mask_array,crop_box):
body = np.array(body) body = np.array(body)
return body[:,:,::-1] return body[:,:,::-1]
def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
body = Image.fromarray(image[:,:,::-1]) body = Image.fromarray(image[:,:,::-1])
x, y, x1, y1 = face_box x, y, x1, y1 = face_box
@@ -119,7 +120,7 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=
face_large = body.crop(crop_box) face_large = body.crop(crop_box)
ori_shape = face_large.size ori_shape = face_large.size
mask_image = face_seg(face_large) mask_image = face_seg(face_large, mode=mode, fp=fp)
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s)) mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
mask_image = Image.new('L', ori_shape, 0) mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s)) mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))

View File

@@ -74,7 +74,7 @@ class FaceParsing():
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]) ])
def __call__(self, image, size=(512, 512), mode="jaw"): def __call__(self, image, size=(512, 512), mode="raw"):
if isinstance(image, str): if isinstance(image, str):
image = Image.open(image) image = Image.open(image)

View File

@@ -1,8 +1,9 @@
import os import os
import cv2 import cv2
import math
import copy import copy
import glob
import torch import torch
import glob
import shutil import shutil
import pickle import pickle
import argparse import argparse
@@ -17,8 +18,6 @@ from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
@torch.no_grad() @torch.no_grad()
def main(args): def main(args):
# Configure ffmpeg path # Configure ffmpeg path
@@ -38,12 +37,17 @@ def main(args):
) )
timesteps = torch.tensor([0], device=device) timesteps = torch.tensor([0], device=device)
# Convert models to half precision if float16 is enabled
if args.use_float16 is True: if args.use_float16:
pe = pe.half() pe = pe.half()
vae.vae = vae.vae.half() vae.vae = vae.vae.half()
unet.model = unet.model.half() unet.model = unet.model.half()
# Move models to specified device
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
# Initialize audio processor and Whisper model # Initialize audio processor and Whisper model
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir) audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
weight_dtype = unet.model.dtype weight_dtype = unet.model.dtype
@@ -51,46 +55,73 @@ def main(args):
whisper = whisper.to(device=device, dtype=weight_dtype).eval() whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False) whisper.requires_grad_(False)
# Initialize face parser # Initialize face parser with configurable parameters based on version
if args.version == "v15":
fp = FaceParsing(
left_cheek_width=args.left_cheek_width,
right_cheek_width=args.right_cheek_width
)
else: # v1
fp = FaceParsing() fp = FaceParsing()
# Load inference configuration
inference_config = OmegaConf.load(args.inference_config) inference_config = OmegaConf.load(args.inference_config)
print(inference_config) print("Loaded inference config:", inference_config)
# Process each task
for task_id in inference_config: for task_id in inference_config:
try:
# Get task configuration
video_path = inference_config[task_id]["video_path"] video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"] audio_path = inference_config[task_id]["audio_path"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) if "result_name" in inference_config[task_id]:
args.output_vid_name = inference_config[task_id]["result_name"]
# Set bbox_shift based on version
if args.version == "v15":
bbox_shift = 0 # v15 uses fixed bbox_shift
else:
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) # v1 uses config or default
# Set output paths
input_basename = os.path.basename(video_path).split('.')[0] input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0] audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}" output_basename = f"{input_basename}_{audio_basename}"
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input # Create temporary directories
temp_dir = os.path.join(args.result_dir, f"{args.version}")
os.makedirs(temp_dir, exist_ok=True)
# Set result save paths
result_img_save_path = os.path.join(temp_dir, output_basename)
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
os.makedirs(result_img_save_path, exist_ok=True) os.makedirs(result_img_save_path, exist_ok=True)
# Set output video paths
if args.output_vid_name is None: if args.output_vid_name is None:
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4") output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
else: else:
output_vid_name = os.path.join(args.result_dir, args.output_vid_name) output_vid_name = os.path.join(temp_dir, args.output_vid_name)
############################################## extract frames from source video ############################################## output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
# Extract frames from source video
if get_file_type(video_path) == "video": if get_file_type(video_path) == "video":
save_dir_full = os.path.join(args.result_dir, input_basename) save_dir_full = os.path.join(temp_dir, input_basename)
os.makedirs(save_dir_full, exist_ok=True) os.makedirs(save_dir_full, exist_ok=True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd) os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path) fps = get_video_fps(video_path)
elif get_file_type(video_path) == "image": elif get_file_type(video_path) == "image":
input_img_list = [video_path, ] input_img_list = [video_path]
fps = args.fps fps = args.fps
elif os.path.isdir(video_path): # input img folder elif os.path.isdir(video_path):
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps fps = args.fps
else: else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
############################################## extract audio feature ##############################################
# Extract audio features # Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk( whisper_chunks = audio_processor.get_whisper_chunk(
@@ -104,40 +135,56 @@ def main(args):
audio_padding_length_right=args.audio_padding_length_right, audio_padding_length_right=args.audio_padding_length_right,
) )
############################################## preprocess input image ############################################## # Preprocess input images
if os.path.exists(crop_coord_save_path) and args.use_saved_coord: if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates") print("Using saved coordinates")
with open(crop_coord_save_path, 'rb') as f: with open(crop_coord_save_path, 'rb') as f:
coord_list = pickle.load(f) coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list) frame_list = read_imgs(input_img_list)
else: else:
print("extracting landmarks...time consuming") print("Extracting landmarks... time-consuming operation")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f: with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f) pickle.dump(coord_list, f)
i = 0 print(f"Number of frames: {len(frame_list)}")
# Process each frame
input_latent_list = [] input_latent_list = []
for bbox, frame in zip(coord_list, frame_list): for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder: if bbox == coord_placeholder:
continue continue
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
if args.version == "v15":
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2] crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4) crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame) latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents) input_latent_list.append(latents)
# to smooth the first and the last frame # Smooth first and last frames
frame_list_cycle = frame_list + frame_list[::-1] frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1] coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
############################################## inference batch by batch ##############################################
print("start inference") # Batch inference
print("Starting inference")
video_num = len(whisper_chunks) video_num = len(whisper_chunks)
batch_size = args.batch_size batch_size = args.batch_size
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=device,
)
res_frame_list = [] res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): total = int(np.ceil(float(video_num) / batch_size))
# Execute inference
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
audio_feature_batch = pe(whisper_batch) audio_feature_batch = pe(whisper_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype) latent_batch = latent_batch.to(dtype=unet.model.dtype)
@@ -146,55 +193,72 @@ def main(args):
for res_frame in recon: for res_frame in recon:
res_frame_list.append(res_frame) res_frame_list.append(res_frame)
############################################## pad to full image ############################################## # Pad generated images to original video size
print("pad talking image to original video") print("Padding generated images to original video size")
for i, res_frame in enumerate(tqdm(res_frame_list)): for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i%(len(coord_list_cycle))] bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
if args.version == "v15":
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
try: try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
except: except:
continue continue
# Merge results # Merge results with version-specific parameters
if args.version == "v15":
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
else:
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp) combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame) cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" # Save prediction results
print(cmd_img2video) temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
print("Video generation command:", cmd_img2video)
os.system(cmd_img2video) os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
print(cmd_combine_audio) print("Audio combination command:", cmd_combine_audio)
os.system(cmd_combine_audio) os.system(cmd_combine_audio)
os.remove("temp.mp4") # Clean up temporary files
shutil.rmtree(result_img_save_path) shutil.rmtree(result_img_save_path)
print(f"result is save to {output_vid_name}") os.remove(temp_vid_path)
shutil.rmtree(save_dir_full)
if not args.saved_coord:
os.remove(crop_coord_save_path)
print(f"Results saved to {output_vid_name}")
except Exception as e:
print("Error occurred during processing:", e)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable") parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
parser.add_argument("--bbox_shift", type=int, default=0)
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use") parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--output_vid_name", type=str, default=None)
parser.add_argument("--use_saved_coord",
action="store_true",
help='use saved coordinate to save time')
parser.add_argument("--use_float16",
action="store_true",
help="Whether use float16 to speed up inference",
)
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model") parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file") parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model") parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio") parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio") parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -1,252 +0,0 @@
import os
import cv2
import math
import copy
import torch
import glob
import shutil
import pickle
import argparse
import subprocess
import numpy as np
from tqdm import tqdm
from omegaconf import OmegaConf
from transformers import WhisperModel
from musetalk.utils.blending import get_image
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
@torch.no_grad()
def main(args):
# Configure ffmpeg path
if args.ffmpeg_path not in os.getenv('PATH'):
print("Adding ffmpeg to PATH")
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
# Set computing device
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
# Load model weights
vae, unet, pe = load_all_model(
unet_model_path=args.unet_model_path,
vae_type=args.vae_type,
unet_config=args.unet_config,
device=device
)
timesteps = torch.tensor([0], device=device)
# Convert models to half precision if float16 is enabled
if args.use_float16:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
# Move models to specified device
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
# Initialize audio processor and Whisper model
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
weight_dtype = unet.model.dtype
whisper = WhisperModel.from_pretrained(args.whisper_dir)
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
# Initialize face parser
fp = FaceParsing(left_cheek_width=args.left_cheek_width, right_cheek_width=args.right_cheek_width)
# Load inference configuration
inference_config = OmegaConf.load(args.inference_config)
print("Loaded inference config:", inference_config)
# Process each task
for task_id in inference_config:
try:
# Get task configuration
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
if "result_name" in inference_config[task_id]:
args.output_vid_name = inference_config[task_id]["result_name"]
bbox_shift = args.bbox_shift
# Set output paths
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
# Create temporary directories
temp_dir = os.path.join(args.result_dir, "frames_result")
os.makedirs(temp_dir, exist_ok=True)
# Set result save paths
result_img_save_path = os.path.join(temp_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path, exist_ok=True)
# Set output video paths
if args.output_vid_name is None:
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
else:
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
# Skip if output file already exists
if os.path.exists(output_vid_name):
print(f"{output_vid_name} already exists, skipping!")
continue
# Extract frames from source video
if get_file_type(video_path) == "video":
save_dir_full = os.path.join(temp_dir, input_basename)
os.makedirs(save_dir_full, exist_ok=True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path) == "image":
input_img_list = [video_path]
fps = args.fps
elif os.path.isdir(video_path):
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
# Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
)
# Preprocess input images
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("Using saved coordinates")
with open(crop_coord_save_path, 'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("Extracting landmarks... time-consuming operation")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
print(f"Number of frames: {len(frame_list)}")
# Process each frame
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
# Smooth first and last frames
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
# Batch inference
print("Starting inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=device,
)
res_frame_list = []
total = int(np.ceil(float(video_num) / batch_size))
# Execute inference
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
audio_feature_batch = pe(whisper_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
# Pad generated images to original video size
print("Padding generated images to original video size")
for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
except:
continue
# Merge results
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
# Save prediction results
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
print("Video generation command:", cmd_img2video)
os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
print("Audio combination command:", cmd_combine_audio)
os.system(cmd_combine_audio)
# Clean up temporary files
shutil.rmtree(result_img_save_path)
os.remove(temp_vid_path)
shutil.rmtree(save_dir_full)
if not args.saved_coord:
os.remove(crop_coord_save_path)
print(f"Results saved to {output_vid_name}")
except Exception as e:
print("Error occurred during processing:", e)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
args = parser.parse_args()
main(args)

View File

@@ -10,24 +10,20 @@ import sys
from tqdm import tqdm from tqdm import tqdm
import copy import copy
import json import json
from musetalk.utils.utils import get_file_type,get_video_fps,datagen from transformers import WhisperModel
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
from musetalk.utils.utils import load_all_model
import shutil
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.utils import datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs
from musetalk.utils.blending import get_image_prepare_material, get_image_blending
from musetalk.utils.utils import load_all_model
from musetalk.utils.audio_processor import AudioProcessor
import shutil
import threading import threading
import queue import queue
import time import time
# load model weights
audio_processor, vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000): def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
cap = cv2.VideoCapture(vid_path) cap = cv2.VideoCapture(vid_path)
@@ -42,6 +38,7 @@ def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
else: else:
break break
def osmakedirs(path_list): def osmakedirs(path_list):
for path in path_list: for path in path_list:
os.makedirs(path) if not os.path.exists(path) else None os.makedirs(path) if not os.path.exists(path) else None
@@ -53,7 +50,13 @@ class Avatar:
self.avatar_id = avatar_id self.avatar_id = avatar_id
self.video_path = video_path self.video_path = video_path
self.bbox_shift = bbox_shift self.bbox_shift = bbox_shift
self.avatar_path = f"./results/avatars/{avatar_id}" # 根据版本设置不同的基础路径
if args.version == "v15":
self.base_path = f"./results/{args.version}/avatars/{avatar_id}"
else: # v1
self.base_path = f"./results/avatars/{avatar_id}"
self.avatar_path = self.base_path
self.full_imgs_path = f"{self.avatar_path}/full_imgs" self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl" self.coords_path = f"{self.avatar_path}/coords.pkl"
self.latents_out_path = f"{self.avatar_path}/latents.pt" self.latents_out_path = f"{self.avatar_path}/latents.pt"
@@ -64,7 +67,8 @@ class Avatar:
self.avatar_info = { self.avatar_info = {
"avatar_id": avatar_id, "avatar_id": avatar_id,
"video_path": video_path, "video_path": video_path,
"bbox_shift":bbox_shift "bbox_shift": bbox_shift,
"version": args.version
} }
self.preparation = preparation self.preparation = preparation
self.batch_size = batch_size self.batch_size = batch_size
@@ -159,6 +163,10 @@ class Avatar:
if bbox == coord_placeholder: if bbox == coord_placeholder:
continue continue
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
if args.version == "v15":
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
coord_list[idx] = [x1, y1, x2, y2] # 更新coord_list中的bbox
crop_frame = frame[y1:y2, x1:x2] crop_frame = frame[y1:y2, x1:x2]
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(resized_crop_frame) latents = vae.get_latents_for_unet(resized_crop_frame)
@@ -173,8 +181,13 @@ class Avatar:
for i, frame in enumerate(tqdm(self.frame_list_cycle)): for i, frame in enumerate(tqdm(self.frame_list_cycle)):
cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png", frame) cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png", frame)
face_box = self.coord_list_cycle[i] x1, y1, x2, y2 = self.coord_list_cycle[i]
mask,crop_box = get_image_prepare_material(frame,face_box) if args.version == "v15":
mode = args.parsing_mode
else:
mode = "raw"
mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp, mode=mode)
cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png", mask) cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png", mask)
self.mask_coords_list_cycle += [crop_box] self.mask_coords_list_cycle += [crop_box]
self.mask_list_cycle.append(mask) self.mask_list_cycle.append(mask)
@@ -186,12 +199,8 @@ class Avatar:
pickle.dump(self.coord_list_cycle, f) pickle.dump(self.coord_list_cycle, f)
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path)) torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
#
def process_frames(self, def process_frames(self, res_frame_queue, video_len, skip_save_images):
res_frame_queue,
video_len,
skip_save_images):
print(video_len) print(video_len)
while True: while True:
if self.idx >= video_len - 1: if self.idx >= video_len - 1:
@@ -211,30 +220,35 @@ class Avatar:
continue continue
mask = self.mask_list_cycle[self.idx % (len(self.mask_list_cycle))] mask = self.mask_list_cycle[self.idx % (len(self.mask_list_cycle))]
mask_crop_box = self.mask_coords_list_cycle[self.idx % (len(self.mask_coords_list_cycle))] mask_crop_box = self.mask_coords_list_cycle[self.idx % (len(self.mask_coords_list_cycle))]
#combine_frame = get_image(ori_frame,res_frame,bbox)
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
if skip_save_images is False: if skip_save_images is False:
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png", combine_frame) cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png", combine_frame)
self.idx = self.idx + 1 self.idx = self.idx + 1
def inference(self, def inference(self, audio_path, out_vid_name, fps, skip_save_images):
audio_path,
out_vid_name,
fps,
skip_save_images):
os.makedirs(self.avatar_path + '/tmp', exist_ok=True) os.makedirs(self.avatar_path + '/tmp', exist_ok=True)
print("start inference") print("start inference")
############################################## extract audio feature ############################################## ############################################## extract audio feature ##############################################
start_time = time.time() start_time = time.time()
whisper_feature = audio_processor.audio2feat(audio_path) # Extract audio features
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path, weight_dtype=weight_dtype)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
)
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms") print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
############################################## inference batch by batch ############################################## ############################################## inference batch by batch ##############################################
video_num = len(whisper_chunks) video_num = len(whisper_chunks)
res_frame_queue = queue.Queue() res_frame_queue = queue.Queue()
self.idx = 0 self.idx = 0
# # Create a sub-thread and start it # Create a sub-thread and start it
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images)) process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
process_thread.start() process_thread.start()
@@ -245,15 +259,13 @@ class Avatar:
res_frame_list = [] res_frame_list = []
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=int(np.ceil(float(video_num) / self.batch_size)))): for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=int(np.ceil(float(video_num) / self.batch_size)))):
audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = pe(whisper_batch.to(device))
audio_feature_batch = audio_feature_batch.to(device=unet.device, latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, pred_latents = unet.model(latent_batch,
timesteps, timesteps,
encoder_hidden_states=audio_feature_batch).sample encoder_hidden_states=audio_feature_batch).sample
pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
recon = vae.decode_latents(pred_latents) recon = vae.decode_latents(pred_latents)
for res_frame in recon: for res_frame in recon:
res_frame_queue.put(res_frame) res_frame_queue.put(res_frame)
@@ -271,7 +283,7 @@ class Avatar:
if out_vid_name is not None and args.skip_save_images is False: if out_vid_name is not None and args.skip_save_images is False:
# optional # optional
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4" cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
print(cmd_img2video) print(cmd_img2video)
os.system(cmd_img2video) os.system(cmd_img2video)
@@ -292,18 +304,27 @@ if __name__ == "__main__":
''' '''
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--inference_config", parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Version of MuseTalk: v1 or v15")
type=str, parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
default="configs/inference/realtime.yaml", parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
) parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
parser.add_argument("--fps", parser.add_argument("--unet_config", type=str, default="./models/musetalk/musetalk.json", help="Path to UNet configuration file")
type=int, parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
default=25, parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
) parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml")
parser.add_argument("--batch_size", parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
type=int, parser.add_argument("--result_dir", default='./results', help="Directory for output results")
default=4, parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
) parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=25, help="Batch size for inference")
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
parser.add_argument("--skip_save_images", parser.add_argument("--skip_save_images",
action="store_true", action="store_true",
help="Whether skip saving images for better generation speed calculation", help="Whether skip saving images for better generation speed calculation",
@@ -311,13 +332,47 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Set computing device
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
# Load model weights
vae, unet, pe = load_all_model(
unet_model_path=args.unet_model_path,
vae_type=args.vae_type,
unet_config=args.unet_config,
device=device
)
timesteps = torch.tensor([0], device=device)
pe = pe.half().to(device)
vae.vae = vae.vae.half().to(device)
unet.model = unet.model.half().to(device)
# Initialize audio processor and Whisper model
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
weight_dtype = unet.model.dtype
whisper = WhisperModel.from_pretrained(args.whisper_dir)
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
# Initialize face parser with configurable parameters based on version
if args.version == "v15":
fp = FaceParsing(
left_cheek_width=args.left_cheek_width,
right_cheek_width=args.right_cheek_width
)
else: # v1
fp = FaceParsing()
inference_config = OmegaConf.load(args.inference_config) inference_config = OmegaConf.load(args.inference_config)
print(inference_config) print(inference_config)
for avatar_id in inference_config: for avatar_id in inference_config:
data_preparation = inference_config[avatar_id]["preparation"] data_preparation = inference_config[avatar_id]["preparation"]
video_path = inference_config[avatar_id]["video_path"] video_path = inference_config[avatar_id]["video_path"]
if args.version == "v15":
bbox_shift = 0
else:
bbox_shift = inference_config[avatar_id]["bbox_shift"] bbox_shift = inference_config[avatar_id]["bbox_shift"]
avatar = Avatar( avatar = Avatar(
avatar_id=avatar_id, avatar_id=avatar_id,