mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 09:29:20 +08:00
feat: real-time infer (#286)
* feat: realtime infer * cchore: infer script
This commit is contained in:
82
README.md
82
README.md
@@ -130,9 +130,8 @@ https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
|
||||
- [x] codes for real-time inference.
|
||||
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
|
||||
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
|
||||
- [x] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
|
||||
- [ ] training and dataloader code (Expected completion on 04/04/2025).
|
||||
- [ ] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
|
||||
|
||||
|
||||
|
||||
# Getting Started
|
||||
@@ -220,21 +219,52 @@ We provide inference scripts for both versions of MuseTalk:
|
||||
|
||||
#### MuseTalk 1.5 (Recommended)
|
||||
```bash
|
||||
sh inference.sh v1.5
|
||||
# Run MuseTalk 1.5 inference
|
||||
sh inference.sh v1.5 normal
|
||||
```
|
||||
This inference script supports both MuseTalk 1.5 and 1.0 models:
|
||||
- For MuseTalk 1.5: Use the command above with the V1.5 model path
|
||||
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
|
||||
|
||||
configs/inference/test.yaml is the path to the inference configuration file, including video_path and audio_path.
|
||||
The video_path should be either a video file, an image file or a directory of images.
|
||||
|
||||
#### MuseTalk 1.0
|
||||
```bash
|
||||
sh inference.sh v1.0
|
||||
# Run MuseTalk 1.0 inference
|
||||
sh inference.sh v1.0 normal
|
||||
```
|
||||
You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
|
||||
<details close>
|
||||
|
||||
The inference script supports both MuseTalk 1.5 and 1.0 models:
|
||||
- For MuseTalk 1.5: Use the command above with the V1.5 model path
|
||||
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
|
||||
|
||||
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
|
||||
- `video_path`: Path to the input video, image file, or directory of images
|
||||
- `audio_path`: Path to the input audio file
|
||||
|
||||
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
|
||||
|
||||
#### Real-time Inference
|
||||
For real-time inference, use the following command:
|
||||
```bash
|
||||
# Run real-time inference
|
||||
sh inference.sh v1.5 realtime # For MuseTalk 1.5
|
||||
# or
|
||||
sh inference.sh v1.0 realtime # For MuseTalk 1.0
|
||||
```
|
||||
|
||||
The real-time inference configuration is in `configs/inference/realtime.yaml`, which includes:
|
||||
- `preparation`: Set to `True` for new avatar preparation
|
||||
- `video_path`: Path to the input video
|
||||
- `bbox_shift`: Adjustable parameter for mouth region control
|
||||
- `audio_clips`: List of audio clips for generation
|
||||
|
||||
Important notes for real-time inference:
|
||||
1. Set `preparation` to `True` when processing a new avatar
|
||||
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
|
||||
3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100
|
||||
4. Set `preparation` to `False` for generating more videos with the same avatar
|
||||
|
||||
For faster generation without saving images, you can use:
|
||||
```bash
|
||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
||||
```
|
||||
|
||||
## TestCases For 1.0
|
||||
<table class="center">
|
||||
<tr style="font-weight: bolder;text-align:center;">
|
||||
@@ -332,39 +362,11 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
|
||||
```
|
||||
:pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
#### Combining MuseV and MuseTalk
|
||||
|
||||
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
|
||||
|
||||
#### Real-time inference
|
||||
|
||||
<details close>
|
||||
Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
|
||||
|
||||
```
|
||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --batch_size 4
|
||||
```
|
||||
configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
|
||||
|
||||
1. Set `preparation` to `True` in `realtime.yaml` to prepare the materials for a new `avatar`. (If the `bbox_shift` has changed, you also need to re-prepare the materials.)
|
||||
1. After that, the `avatar` will use an audio clip selected from `audio_clips` to generate video.
|
||||
```
|
||||
Inferring using: data/audio/yongen.wav
|
||||
```
|
||||
1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100.
|
||||
1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
|
||||
|
||||
##### Note for Real-time inference
|
||||
1. If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
|
||||
1. In the previous script, the generation time is also limited by I/O (e.g. saving images). If you just want to test the generation speed without saving the images, you can run
|
||||
```
|
||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
||||
```
|
||||
</details>
|
||||
|
||||
# Acknowledgement
|
||||
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
|
||||
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
avator_1:
|
||||
preparation: False
|
||||
preparation: True # your can set it to False if you want to use the existing avator, it will save time
|
||||
bbox_shift: 5
|
||||
video_path: "data/video/sun.mp4"
|
||||
video_path: "data/video/yongen.mp4"
|
||||
audio_clips:
|
||||
audio_0: "data/audio/yongen.wav"
|
||||
audio_1: "data/audio/sun.wav"
|
||||
audio_1: "data/audio/eng.wav"
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ task_0:
|
||||
audio_path: "data/audio/yongen.wav"
|
||||
|
||||
task_1:
|
||||
video_path: "data/video/sun.mp4"
|
||||
audio_path: "data/audio/sun.wav"
|
||||
video_path: "data/video/yongen.mp4"
|
||||
audio_path: "data/audio/eng.wav"
|
||||
bbox_shift: -7
|
||||
|
||||
|
||||
|
||||
BIN
data/audio/eng.wav
Executable file
BIN
data/audio/eng.wav
Executable file
Binary file not shown.
66
inference.sh
66
inference.sh
@@ -1,46 +1,72 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script runs inference based on the version specified by the user.
|
||||
# This script runs inference based on the version and mode specified by the user.
|
||||
# Usage:
|
||||
# To run v1.0 inference: sh inference.sh v1.0
|
||||
# To run v1.5 inference: sh inference.sh v1.5
|
||||
# To run v1.0 inference: sh inference.sh v1.0 [normal|realtime]
|
||||
# To run v1.5 inference: sh inference.sh v1.5 [normal|realtime]
|
||||
|
||||
# Check if the correct number of arguments is provided
|
||||
if [ "$#" -ne 1 ]; then
|
||||
echo "Usage: $0 <version>"
|
||||
echo "Example: $0 v1.0 or $0 v1.5"
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 <version> <mode>"
|
||||
echo "Example: $0 v1.0 normal or $0 v1.5 realtime"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the version from the user input
|
||||
# Get the version and mode from the user input
|
||||
version=$1
|
||||
mode=$2
|
||||
|
||||
# Validate mode
|
||||
if [ "$mode" != "normal" ] && [ "$mode" != "realtime" ]; then
|
||||
echo "Invalid mode specified. Please use 'normal' or 'realtime'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set config path based on mode
|
||||
if [ "$mode" = "normal" ]; then
|
||||
config_path="./configs/inference/test.yaml"
|
||||
result_dir="./results/test"
|
||||
else
|
||||
config_path="./configs/inference/realtime.yaml"
|
||||
result_dir="./results/realtime"
|
||||
fi
|
||||
|
||||
# Define the model paths based on the version
|
||||
if [ "$version" = "v1.0" ]; then
|
||||
model_dir="./models/musetalk"
|
||||
unet_model_path="$model_dir/pytorch_model.bin"
|
||||
unet_config="$model_dir/musetalk.json"
|
||||
version_arg="v1"
|
||||
elif [ "$version" = "v1.5" ]; then
|
||||
model_dir="./models/musetalkV15"
|
||||
unet_model_path="$model_dir/unet.pth"
|
||||
unet_config="$model_dir/musetalk.json"
|
||||
version_arg="v15"
|
||||
else
|
||||
echo "Invalid version specified. Please use v1.0 or v1.5."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run inference based on the version
|
||||
if [ "$version" = "v1.0" ]; then
|
||||
python3 -m scripts.inference \
|
||||
--inference_config "$config_path" \
|
||||
--result_dir "./results/test" \
|
||||
--unet_model_path "$unet_model_path" \
|
||||
--unet_config "$unet_config"
|
||||
elif [ "$version" = "v1.5" ]; then
|
||||
python3 -m scripts.inference_alpha \
|
||||
--inference_config "$config_path" \
|
||||
--result_dir "./results/test" \
|
||||
--unet_model_path "$unet_model_path" \
|
||||
--unet_config "$unet_config"
|
||||
# Set script name based on mode
|
||||
if [ "$mode" = "normal" ]; then
|
||||
script_name="scripts.inference"
|
||||
else
|
||||
script_name="scripts.realtime_inference"
|
||||
fi
|
||||
|
||||
# Base command arguments
|
||||
cmd_args="--inference_config $config_path \
|
||||
--result_dir $result_dir \
|
||||
--unet_model_path $unet_model_path \
|
||||
--unet_config $unet_config \
|
||||
--version $version_arg \
|
||||
|
||||
# Add realtime-specific arguments if in realtime mode
|
||||
if [ "$mode" = "realtime" ]; then
|
||||
cmd_args="$cmd_args \
|
||||
--fps 25 \
|
||||
--version $version_arg \
|
||||
fi
|
||||
|
||||
# Run inference
|
||||
python3 -m $script_name $cmd_args
|
||||
@@ -11,7 +11,7 @@ class AudioProcessor:
|
||||
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
|
||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
|
||||
|
||||
def get_audio_feature(self, wav_path, start_index=0):
|
||||
def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
|
||||
if not os.path.exists(wav_path):
|
||||
return None
|
||||
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
|
||||
@@ -27,6 +27,8 @@ class AudioProcessor:
|
||||
return_tensors="pt",
|
||||
sampling_rate=sampling_rate
|
||||
).input_features
|
||||
if weight_dtype is not None:
|
||||
audio_feature = audio_feature.to(dtype=weight_dtype)
|
||||
features.append(audio_feature)
|
||||
|
||||
return features, len(librosa_output)
|
||||
|
||||
@@ -3,6 +3,7 @@ import numpy as np
|
||||
import cv2
|
||||
import copy
|
||||
|
||||
|
||||
def get_crop_box(box, expand):
|
||||
x, y, x1, y1 = box
|
||||
x_c, y_c = (x+x1)//2, (y+y1)//2
|
||||
@@ -11,7 +12,8 @@ def get_crop_box(box, expand):
|
||||
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
|
||||
return crop_box, s
|
||||
|
||||
def face_seg(image, mode="jaw", fp=None):
|
||||
|
||||
def face_seg(image, mode="raw", fp=None):
|
||||
"""
|
||||
对图像进行面部解析,生成面部区域的掩码。
|
||||
|
||||
@@ -86,13 +88,11 @@ def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode=
|
||||
|
||||
body.paste(face_large, crop_box[:2], mask_image)
|
||||
|
||||
# 不用掩码,完全用infer
|
||||
#face_large.save("debug/checkpoint_6_face_large.png")
|
||||
|
||||
body = np.array(body) # 将 PIL 图像转换回 numpy 数组
|
||||
|
||||
return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
|
||||
|
||||
|
||||
def get_image_blending(image, face, face_box, mask_array, crop_box):
|
||||
body = Image.fromarray(image[:,:,::-1])
|
||||
face = Image.fromarray(face[:,:,::-1])
|
||||
@@ -108,7 +108,8 @@ def get_image_blending(image,face,face_box,mask_array,crop_box):
|
||||
body = np.array(body)
|
||||
return body[:,:,::-1]
|
||||
|
||||
def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
|
||||
|
||||
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
|
||||
body = Image.fromarray(image[:,:,::-1])
|
||||
|
||||
x, y, x1, y1 = face_box
|
||||
@@ -119,7 +120,7 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=
|
||||
face_large = body.crop(crop_box)
|
||||
ori_shape = face_large.size
|
||||
|
||||
mask_image = face_seg(face_large)
|
||||
mask_image = face_seg(face_large, mode=mode, fp=fp)
|
||||
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
mask_image = Image.new('L', ori_shape, 0)
|
||||
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
|
||||
@@ -74,7 +74,7 @@ class FaceParsing():
|
||||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
def __call__(self, image, size=(512, 512), mode="jaw"):
|
||||
def __call__(self, image, size=(512, 512), mode="raw"):
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import os
|
||||
import cv2
|
||||
import math
|
||||
import copy
|
||||
import glob
|
||||
import torch
|
||||
import glob
|
||||
import shutil
|
||||
import pickle
|
||||
import argparse
|
||||
@@ -17,8 +18,6 @@ from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
# Configure ffmpeg path
|
||||
@@ -38,12 +37,17 @@ def main(args):
|
||||
)
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
|
||||
if args.use_float16 is True:
|
||||
# Convert models to half precision if float16 is enabled
|
||||
if args.use_float16:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
# Move models to specified device
|
||||
pe = pe.to(device)
|
||||
vae.vae = vae.vae.to(device)
|
||||
unet.model = unet.model.to(device)
|
||||
|
||||
# Initialize audio processor and Whisper model
|
||||
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
||||
weight_dtype = unet.model.dtype
|
||||
@@ -51,46 +55,73 @@ def main(args):
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
# Initialize face parser
|
||||
# Initialize face parser with configurable parameters based on version
|
||||
if args.version == "v15":
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
else: # v1
|
||||
fp = FaceParsing()
|
||||
|
||||
# Load inference configuration
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
print("Loaded inference config:", inference_config)
|
||||
|
||||
# Process each task
|
||||
for task_id in inference_config:
|
||||
try:
|
||||
# Get task configuration
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||
if "result_name" in inference_config[task_id]:
|
||||
args.output_vid_name = inference_config[task_id]["result_name"]
|
||||
|
||||
# Set bbox_shift based on version
|
||||
if args.version == "v15":
|
||||
bbox_shift = 0 # v15 uses fixed bbox_shift
|
||||
else:
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) # v1 uses config or default
|
||||
|
||||
# Set output paths
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||
|
||||
# Create temporary directories
|
||||
temp_dir = os.path.join(args.result_dir, f"{args.version}")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Set result save paths
|
||||
result_img_save_path = os.path.join(temp_dir, output_basename)
|
||||
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
|
||||
os.makedirs(result_img_save_path, exist_ok=True)
|
||||
|
||||
# Set output video paths
|
||||
if args.output_vid_name is None:
|
||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||
############################################## extract frames from source video ##############################################
|
||||
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
||||
output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
|
||||
|
||||
# Extract frames from source video
|
||||
if get_file_type(video_path) == "video":
|
||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||
save_dir_full = os.path.join(temp_dir, input_basename)
|
||||
os.makedirs(save_dir_full, exist_ok=True)
|
||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
os.system(cmd)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
elif get_file_type(video_path) == "image":
|
||||
input_img_list = [video_path, ]
|
||||
input_img_list = [video_path]
|
||||
fps = args.fps
|
||||
elif os.path.isdir(video_path): # input img folder
|
||||
elif os.path.isdir(video_path):
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
|
||||
############################################## extract audio feature ##############################################
|
||||
# Extract audio features
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
@@ -104,40 +135,56 @@ def main(args):
|
||||
audio_padding_length_right=args.audio_padding_length_right,
|
||||
)
|
||||
|
||||
############################################## preprocess input image ##############################################
|
||||
# Preprocess input images
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
print("Using saved coordinates")
|
||||
with open(crop_coord_save_path, 'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("extracting landmarks...time consuming")
|
||||
print("Extracting landmarks... time-consuming operation")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
i = 0
|
||||
print(f"Number of frames: {len(frame_list)}")
|
||||
|
||||
# Process each frame
|
||||
input_latent_list = []
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
if args.version == "v15":
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
|
||||
# to smooth the first and the last frame
|
||||
# Smooth first and last frames
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
print("start inference")
|
||||
|
||||
# Batch inference
|
||||
print("Starting inference")
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=device,
|
||||
)
|
||||
|
||||
res_frame_list = []
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
total = int(np.ceil(float(video_num) / batch_size))
|
||||
|
||||
# Execute inference
|
||||
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
|
||||
audio_feature_batch = pe(whisper_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
|
||||
@@ -146,55 +193,72 @@ def main(args):
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
############################################## pad to full image ##############################################
|
||||
print("pad talking image to original video")
|
||||
# Pad generated images to original video size
|
||||
print("Padding generated images to original video size")
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
if args.version == "v15":
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
||||
except:
|
||||
continue
|
||||
|
||||
# Merge results
|
||||
# Merge results with version-specific parameters
|
||||
if args.version == "v15":
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
else:
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
print(cmd_img2video)
|
||||
# Save prediction results
|
||||
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
|
||||
print("Video generation command:", cmd_img2video)
|
||||
os.system(cmd_img2video)
|
||||
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
print(cmd_combine_audio)
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
|
||||
print("Audio combination command:", cmd_combine_audio)
|
||||
os.system(cmd_combine_audio)
|
||||
|
||||
os.remove("temp.mp4")
|
||||
# Clean up temporary files
|
||||
shutil.rmtree(result_img_save_path)
|
||||
print(f"result is save to {output_vid_name}")
|
||||
os.remove(temp_vid_path)
|
||||
|
||||
shutil.rmtree(save_dir_full)
|
||||
if not args.saved_coord:
|
||||
os.remove(crop_coord_save_path)
|
||||
|
||||
print(f"Results saved to {output_vid_name}")
|
||||
except Exception as e:
|
||||
print("Error occurred during processing:", e)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0)
|
||||
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--output_vid_name", type=str, default=None)
|
||||
parser.add_argument("--use_saved_coord",
|
||||
action="store_true",
|
||||
help='use saved coordinate to save time')
|
||||
parser.add_argument("--use_float16",
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
||||
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
|
||||
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
||||
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
|
||||
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
|
||||
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
||||
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
||||
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
|
||||
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
||||
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
||||
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
||||
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
|
||||
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
||||
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
||||
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
||||
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
||||
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
|
||||
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
|
||||
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
|
||||
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -1,252 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
import math
|
||||
import copy
|
||||
import torch
|
||||
import glob
|
||||
import shutil
|
||||
import pickle
|
||||
import argparse
|
||||
import subprocess
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import WhisperModel
|
||||
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
# Configure ffmpeg path
|
||||
if args.ffmpeg_path not in os.getenv('PATH'):
|
||||
print("Adding ffmpeg to PATH")
|
||||
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
|
||||
|
||||
# Set computing device
|
||||
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Load model weights
|
||||
vae, unet, pe = load_all_model(
|
||||
unet_model_path=args.unet_model_path,
|
||||
vae_type=args.vae_type,
|
||||
unet_config=args.unet_config,
|
||||
device=device
|
||||
)
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
# Convert models to half precision if float16 is enabled
|
||||
if args.use_float16:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
# Move models to specified device
|
||||
pe = pe.to(device)
|
||||
vae.vae = vae.vae.to(device)
|
||||
unet.model = unet.model.to(device)
|
||||
|
||||
# Initialize audio processor and Whisper model
|
||||
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
||||
weight_dtype = unet.model.dtype
|
||||
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
# Initialize face parser
|
||||
fp = FaceParsing(left_cheek_width=args.left_cheek_width, right_cheek_width=args.right_cheek_width)
|
||||
|
||||
# Load inference configuration
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print("Loaded inference config:", inference_config)
|
||||
|
||||
# Process each task
|
||||
for task_id in inference_config:
|
||||
try:
|
||||
# Get task configuration
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
if "result_name" in inference_config[task_id]:
|
||||
args.output_vid_name = inference_config[task_id]["result_name"]
|
||||
bbox_shift = args.bbox_shift
|
||||
# Set output paths
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
|
||||
# Create temporary directories
|
||||
temp_dir = os.path.join(args.result_dir, "frames_result")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Set result save paths
|
||||
result_img_save_path = os.path.join(temp_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl") # only related to video input
|
||||
os.makedirs(result_img_save_path, exist_ok=True)
|
||||
# Set output video paths
|
||||
if args.output_vid_name is None:
|
||||
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
||||
output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
|
||||
|
||||
# Skip if output file already exists
|
||||
if os.path.exists(output_vid_name):
|
||||
print(f"{output_vid_name} already exists, skipping!")
|
||||
continue
|
||||
|
||||
# Extract frames from source video
|
||||
if get_file_type(video_path) == "video":
|
||||
save_dir_full = os.path.join(temp_dir, input_basename)
|
||||
os.makedirs(save_dir_full, exist_ok=True)
|
||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
os.system(cmd)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
elif get_file_type(video_path) == "image":
|
||||
input_img_list = [video_path]
|
||||
fps = args.fps
|
||||
elif os.path.isdir(video_path):
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
|
||||
# Extract audio features
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=args.audio_padding_length_left,
|
||||
audio_padding_length_right=args.audio_padding_length_right,
|
||||
)
|
||||
|
||||
# Preprocess input images
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("Using saved coordinates")
|
||||
with open(crop_coord_save_path, 'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("Extracting landmarks... time-consuming operation")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
print(f"Number of frames: {len(frame_list)}")
|
||||
|
||||
# Process each frame
|
||||
input_latent_list = []
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
|
||||
# Smooth first and last frames
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
|
||||
# Batch inference
|
||||
print("Starting inference")
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=device,
|
||||
)
|
||||
|
||||
res_frame_list = []
|
||||
total = int(np.ceil(float(video_num) / batch_size))
|
||||
|
||||
# Execute inference
|
||||
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
|
||||
audio_feature_batch = pe(whisper_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
# Pad generated images to original video size
|
||||
print("Padding generated images to original video size")
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
||||
except:
|
||||
continue
|
||||
|
||||
# Merge results
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
|
||||
|
||||
# Save prediction results
|
||||
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
|
||||
print("Video generation command:", cmd_img2video)
|
||||
os.system(cmd_img2video)
|
||||
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
|
||||
print("Audio combination command:", cmd_combine_audio)
|
||||
os.system(cmd_combine_audio)
|
||||
|
||||
# Clean up temporary files
|
||||
shutil.rmtree(result_img_save_path)
|
||||
os.remove(temp_vid_path)
|
||||
|
||||
shutil.rmtree(save_dir_full)
|
||||
if not args.saved_coord:
|
||||
os.remove(crop_coord_save_path)
|
||||
|
||||
print(f"Results saved to {output_vid_name}")
|
||||
except Exception as e:
|
||||
print("Error occurred during processing:", e)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
||||
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
||||
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
||||
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
|
||||
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
|
||||
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
||||
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
||||
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
|
||||
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
||||
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
||||
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
||||
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
|
||||
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
||||
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
||||
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
||||
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
||||
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
|
||||
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
|
||||
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -10,24 +10,20 @@ import sys
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
import json
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
from transformers import WhisperModel
|
||||
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.utils import datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs
|
||||
from musetalk.utils.blending import get_image_prepare_material, get_image_blending
|
||||
from musetalk.utils.utils import load_all_model
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
|
||||
import shutil
|
||||
import threading
|
||||
import queue
|
||||
|
||||
import time
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
||||
cap = cv2.VideoCapture(vid_path)
|
||||
@@ -42,6 +38,7 @@ def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def osmakedirs(path_list):
|
||||
for path in path_list:
|
||||
os.makedirs(path) if not os.path.exists(path) else None
|
||||
@@ -53,7 +50,13 @@ class Avatar:
|
||||
self.avatar_id = avatar_id
|
||||
self.video_path = video_path
|
||||
self.bbox_shift = bbox_shift
|
||||
self.avatar_path = f"./results/avatars/{avatar_id}"
|
||||
# 根据版本设置不同的基础路径
|
||||
if args.version == "v15":
|
||||
self.base_path = f"./results/{args.version}/avatars/{avatar_id}"
|
||||
else: # v1
|
||||
self.base_path = f"./results/avatars/{avatar_id}"
|
||||
|
||||
self.avatar_path = self.base_path
|
||||
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
|
||||
self.coords_path = f"{self.avatar_path}/coords.pkl"
|
||||
self.latents_out_path = f"{self.avatar_path}/latents.pt"
|
||||
@@ -64,7 +67,8 @@ class Avatar:
|
||||
self.avatar_info = {
|
||||
"avatar_id": avatar_id,
|
||||
"video_path": video_path,
|
||||
"bbox_shift":bbox_shift
|
||||
"bbox_shift": bbox_shift,
|
||||
"version": args.version
|
||||
}
|
||||
self.preparation = preparation
|
||||
self.batch_size = batch_size
|
||||
@@ -159,6 +163,10 @@ class Avatar:
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
if args.version == "v15":
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
coord_list[idx] = [x1, y1, x2, y2] # 更新coord_list中的bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(resized_crop_frame)
|
||||
@@ -173,8 +181,13 @@ class Avatar:
|
||||
for i, frame in enumerate(tqdm(self.frame_list_cycle)):
|
||||
cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png", frame)
|
||||
|
||||
face_box = self.coord_list_cycle[i]
|
||||
mask,crop_box = get_image_prepare_material(frame,face_box)
|
||||
x1, y1, x2, y2 = self.coord_list_cycle[i]
|
||||
if args.version == "v15":
|
||||
mode = args.parsing_mode
|
||||
else:
|
||||
mode = "raw"
|
||||
mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp, mode=mode)
|
||||
|
||||
cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png", mask)
|
||||
self.mask_coords_list_cycle += [crop_box]
|
||||
self.mask_list_cycle.append(mask)
|
||||
@@ -186,12 +199,8 @@ class Avatar:
|
||||
pickle.dump(self.coord_list_cycle, f)
|
||||
|
||||
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
|
||||
#
|
||||
|
||||
def process_frames(self,
|
||||
res_frame_queue,
|
||||
video_len,
|
||||
skip_save_images):
|
||||
def process_frames(self, res_frame_queue, video_len, skip_save_images):
|
||||
print(video_len)
|
||||
while True:
|
||||
if self.idx >= video_len - 1:
|
||||
@@ -211,30 +220,35 @@ class Avatar:
|
||||
continue
|
||||
mask = self.mask_list_cycle[self.idx % (len(self.mask_list_cycle))]
|
||||
mask_crop_box = self.mask_coords_list_cycle[self.idx % (len(self.mask_coords_list_cycle))]
|
||||
#combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
|
||||
|
||||
if skip_save_images is False:
|
||||
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png", combine_frame)
|
||||
self.idx = self.idx + 1
|
||||
|
||||
def inference(self,
|
||||
audio_path,
|
||||
out_vid_name,
|
||||
fps,
|
||||
skip_save_images):
|
||||
def inference(self, audio_path, out_vid_name, fps, skip_save_images):
|
||||
os.makedirs(self.avatar_path + '/tmp', exist_ok=True)
|
||||
print("start inference")
|
||||
############################################## extract audio feature ##############################################
|
||||
start_time = time.time()
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
# Extract audio features
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path, weight_dtype=weight_dtype)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=args.audio_padding_length_left,
|
||||
audio_padding_length_right=args.audio_padding_length_right,
|
||||
)
|
||||
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
|
||||
############################################## inference batch by batch ##############################################
|
||||
video_num = len(whisper_chunks)
|
||||
res_frame_queue = queue.Queue()
|
||||
self.idx = 0
|
||||
# # Create a sub-thread and start it
|
||||
# Create a sub-thread and start it
|
||||
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
|
||||
process_thread.start()
|
||||
|
||||
@@ -245,15 +259,13 @@ class Avatar:
|
||||
res_frame_list = []
|
||||
|
||||
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=int(np.ceil(float(video_num) / self.batch_size)))):
|
||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||
dtype=unet.model.dtype)
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
audio_feature_batch = pe(whisper_batch.to(device))
|
||||
latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch,
|
||||
timesteps,
|
||||
encoder_hidden_states=audio_feature_batch).sample
|
||||
pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_queue.put(res_frame)
|
||||
@@ -271,7 +283,7 @@ class Avatar:
|
||||
|
||||
if out_vid_name is not None and args.skip_save_images is False:
|
||||
# optional
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
|
||||
print(cmd_img2video)
|
||||
os.system(cmd_img2video)
|
||||
|
||||
@@ -292,18 +304,27 @@ if __name__ == "__main__":
|
||||
'''
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--inference_config",
|
||||
type=str,
|
||||
default="configs/inference/realtime.yaml",
|
||||
)
|
||||
parser.add_argument("--fps",
|
||||
type=int,
|
||||
default=25,
|
||||
)
|
||||
parser.add_argument("--batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Version of MuseTalk: v1 or v15")
|
||||
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
||||
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
||||
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
||||
parser.add_argument("--unet_config", type=str, default="./models/musetalk/musetalk.json", help="Path to UNet configuration file")
|
||||
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
|
||||
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
||||
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
||||
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
|
||||
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
||||
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
||||
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
||||
parser.add_argument("--batch_size", type=int, default=25, help="Batch size for inference")
|
||||
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
||||
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
||||
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
||||
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
|
||||
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
|
||||
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
|
||||
parser.add_argument("--skip_save_images",
|
||||
action="store_true",
|
||||
help="Whether skip saving images for better generation speed calculation",
|
||||
@@ -311,13 +332,47 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set computing device
|
||||
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Load model weights
|
||||
vae, unet, pe = load_all_model(
|
||||
unet_model_path=args.unet_model_path,
|
||||
vae_type=args.vae_type,
|
||||
unet_config=args.unet_config,
|
||||
device=device
|
||||
)
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
pe = pe.half().to(device)
|
||||
vae.vae = vae.vae.half().to(device)
|
||||
unet.model = unet.model.half().to(device)
|
||||
|
||||
# Initialize audio processor and Whisper model
|
||||
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
||||
weight_dtype = unet.model.dtype
|
||||
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
# Initialize face parser with configurable parameters based on version
|
||||
if args.version == "v15":
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
else: # v1
|
||||
fp = FaceParsing()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
|
||||
|
||||
for avatar_id in inference_config:
|
||||
data_preparation = inference_config[avatar_id]["preparation"]
|
||||
video_path = inference_config[avatar_id]["video_path"]
|
||||
if args.version == "v15":
|
||||
bbox_shift = 0
|
||||
else:
|
||||
bbox_shift = inference_config[avatar_id]["bbox_shift"]
|
||||
avatar = Avatar(
|
||||
avatar_id=avatar_id,
|
||||
|
||||
Reference in New Issue
Block a user