feat: real-time infer (#286)

* feat: realtime infer

* cchore: infer script
This commit is contained in:
Zhizhou Zhong
2025-04-02 19:13:18 +08:00
committed by GitHub
parent fbe6a97dff
commit 39ccf69f36
11 changed files with 490 additions and 592 deletions

View File

@@ -130,9 +130,8 @@ https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
- [x] codes for real-time inference.
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
- [x] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
- [ ] training and dataloader code (Expected completion on 04/04/2025).
- [ ] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
# Getting Started
@@ -220,21 +219,52 @@ We provide inference scripts for both versions of MuseTalk:
#### MuseTalk 1.5 (Recommended)
```bash
sh inference.sh v1.5
# Run MuseTalk 1.5 inference
sh inference.sh v1.5 normal
```
This inference script supports both MuseTalk 1.5 and 1.0 models:
- For MuseTalk 1.5: Use the command above with the V1.5 model path
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
configs/inference/test.yaml is the path to the inference configuration file, including video_path and audio_path.
The video_path should be either a video file, an image file or a directory of images.
#### MuseTalk 1.0
```bash
sh inference.sh v1.0
# Run MuseTalk 1.0 inference
sh inference.sh v1.0 normal
```
You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
<details close>
The inference script supports both MuseTalk 1.5 and 1.0 models:
- For MuseTalk 1.5: Use the command above with the V1.5 model path
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
- `video_path`: Path to the input video, image file, or directory of images
- `audio_path`: Path to the input audio file
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
#### Real-time Inference
For real-time inference, use the following command:
```bash
# Run real-time inference
sh inference.sh v1.5 realtime # For MuseTalk 1.5
# or
sh inference.sh v1.0 realtime # For MuseTalk 1.0
```
The real-time inference configuration is in `configs/inference/realtime.yaml`, which includes:
- `preparation`: Set to `True` for new avatar preparation
- `video_path`: Path to the input video
- `bbox_shift`: Adjustable parameter for mouth region control
- `audio_clips`: List of audio clips for generation
Important notes for real-time inference:
1. Set `preparation` to `True` when processing a new avatar
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100
4. Set `preparation` to `False` for generating more videos with the same avatar
For faster generation without saving images, you can use:
```bash
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
```
## TestCases For 1.0
<table class="center">
<tr style="font-weight: bolder;text-align:center;">
@@ -332,39 +362,11 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
```
:pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
</details>
#### Combining MuseV and MuseTalk
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
#### Real-time inference
<details close>
Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
```
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --batch_size 4
```
configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
1. Set `preparation` to `True` in `realtime.yaml` to prepare the materials for a new `avatar`. (If the `bbox_shift` has changed, you also need to re-prepare the materials.)
1. After that, the `avatar` will use an audio clip selected from `audio_clips` to generate video.
```
Inferring using: data/audio/yongen.wav
```
1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100.
1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
##### Note for Real-time inference
1. If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
1. In the previous script, the generation time is also limited by I/O (e.g. saving images). If you just want to test the generation speed without saving the images, you can run
```
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
```
</details>
# Acknowledgement
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).

View File

@@ -1,10 +1,10 @@
avator_1:
preparation: False
preparation: True # your can set it to False if you want to use the existing avator, it will save time
bbox_shift: 5
video_path: "data/video/sun.mp4"
video_path: "data/video/yongen.mp4"
audio_clips:
audio_0: "data/audio/yongen.wav"
audio_1: "data/audio/sun.wav"
audio_1: "data/audio/eng.wav"

View File

@@ -3,8 +3,8 @@ task_0:
audio_path: "data/audio/yongen.wav"
task_1:
video_path: "data/video/sun.mp4"
audio_path: "data/audio/sun.wav"
video_path: "data/video/yongen.mp4"
audio_path: "data/audio/eng.wav"
bbox_shift: -7

BIN
data/audio/eng.wav Executable file

Binary file not shown.

View File

@@ -1,46 +1,72 @@
#!/bin/bash
# This script runs inference based on the version specified by the user.
# This script runs inference based on the version and mode specified by the user.
# Usage:
# To run v1.0 inference: sh inference.sh v1.0
# To run v1.5 inference: sh inference.sh v1.5
# To run v1.0 inference: sh inference.sh v1.0 [normal|realtime]
# To run v1.5 inference: sh inference.sh v1.5 [normal|realtime]
# Check if the correct number of arguments is provided
if [ "$#" -ne 1 ]; then
echo "Usage: $0 <version>"
echo "Example: $0 v1.0 or $0 v1.5"
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <version> <mode>"
echo "Example: $0 v1.0 normal or $0 v1.5 realtime"
exit 1
fi
# Get the version from the user input
# Get the version and mode from the user input
version=$1
config_path="./configs/inference/test.yaml"
mode=$2
# Validate mode
if [ "$mode" != "normal" ] && [ "$mode" != "realtime" ]; then
echo "Invalid mode specified. Please use 'normal' or 'realtime'."
exit 1
fi
# Set config path based on mode
if [ "$mode" = "normal" ]; then
config_path="./configs/inference/test.yaml"
result_dir="./results/test"
else
config_path="./configs/inference/realtime.yaml"
result_dir="./results/realtime"
fi
# Define the model paths based on the version
if [ "$version" = "v1.0" ]; then
model_dir="./models/musetalk"
unet_model_path="$model_dir/pytorch_model.bin"
unet_config="$model_dir/musetalk.json"
version_arg="v1"
elif [ "$version" = "v1.5" ]; then
model_dir="./models/musetalkV15"
unet_model_path="$model_dir/unet.pth"
unet_config="$model_dir/musetalk.json"
version_arg="v15"
else
echo "Invalid version specified. Please use v1.0 or v1.5."
exit 1
fi
# Run inference based on the version
if [ "$version" = "v1.0" ]; then
python3 -m scripts.inference \
--inference_config "$config_path" \
--result_dir "./results/test" \
--unet_model_path "$unet_model_path" \
--unet_config "$unet_config"
elif [ "$version" = "v1.5" ]; then
python3 -m scripts.inference_alpha \
--inference_config "$config_path" \
--result_dir "./results/test" \
--unet_model_path "$unet_model_path" \
--unet_config "$unet_config"
fi
# Set script name based on mode
if [ "$mode" = "normal" ]; then
script_name="scripts.inference"
else
script_name="scripts.realtime_inference"
fi
# Base command arguments
cmd_args="--inference_config $config_path \
--result_dir $result_dir \
--unet_model_path $unet_model_path \
--unet_config $unet_config \
--version $version_arg \
# Add realtime-specific arguments if in realtime mode
if [ "$mode" = "realtime" ]; then
cmd_args="$cmd_args \
--fps 25 \
--version $version_arg \
fi
# Run inference
python3 -m $script_name $cmd_args

View File

@@ -11,7 +11,7 @@ class AudioProcessor:
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
def get_audio_feature(self, wav_path, start_index=0):
def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
if not os.path.exists(wav_path):
return None
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
@@ -27,6 +27,8 @@ class AudioProcessor:
return_tensors="pt",
sampling_rate=sampling_rate
).input_features
if weight_dtype is not None:
audio_feature = audio_feature.to(dtype=weight_dtype)
features.append(audio_feature)
return features, len(librosa_output)

View File

@@ -3,6 +3,7 @@ import numpy as np
import cv2
import copy
def get_crop_box(box, expand):
x, y, x1, y1 = box
x_c, y_c = (x+x1)//2, (y+y1)//2
@@ -11,7 +12,8 @@ def get_crop_box(box, expand):
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
return crop_box, s
def face_seg(image, mode="jaw", fp=None):
def face_seg(image, mode="raw", fp=None):
"""
对图像进行面部解析,生成面部区域的掩码。
@@ -86,14 +88,12 @@ def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode=
body.paste(face_large, crop_box[:2], mask_image)
# 不用掩码完全用infer
#face_large.save("debug/checkpoint_6_face_large.png")
body = np.array(body) # 将 PIL 图像转换回 numpy 数组
return body[:, :, ::-1] # 返回处理后的图像BGR 转 RGB
def get_image_blending(image,face,face_box,mask_array,crop_box):
def get_image_blending(image, face, face_box, mask_array, crop_box):
body = Image.fromarray(image[:,:,::-1])
face = Image.fromarray(face[:,:,::-1])
@@ -108,7 +108,8 @@ def get_image_blending(image,face,face_box,mask_array,crop_box):
body = np.array(body)
return body[:,:,::-1]
def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
body = Image.fromarray(image[:,:,::-1])
x, y, x1, y1 = face_box
@@ -119,7 +120,7 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=
face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_image = face_seg(face_large, mode=mode, fp=fp)
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
@@ -132,4 +133,4 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
return mask_array,crop_box
return mask_array, crop_box

View File

@@ -74,7 +74,7 @@ class FaceParsing():
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def __call__(self, image, size=(512, 512), mode="jaw"):
def __call__(self, image, size=(512, 512), mode="raw"):
if isinstance(image, str):
image = Image.open(image)

View File

@@ -1,8 +1,9 @@
import os
import cv2
import math
import copy
import glob
import torch
import glob
import shutil
import pickle
import argparse
@@ -17,18 +18,16 @@ from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
@torch.no_grad()
def main(args):
# Configure ffmpeg path
if args.ffmpeg_path not in os.getenv('PATH'):
print("Adding ffmpeg to PATH")
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
# Set computing device
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
# Load model weights
vae, unet, pe = load_all_model(
unet_model_path=args.unet_model_path,
@@ -37,164 +36,229 @@ def main(args):
device=device
)
timesteps = torch.tensor([0], device=device)
if args.use_float16 is True:
# Convert models to half precision if float16 is enabled
if args.use_float16:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
# Move models to specified device
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
# Initialize audio processor and Whisper model
# Initialize audio processor and Whisper model
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
weight_dtype = unet.model.dtype
whisper = WhisperModel.from_pretrained(args.whisper_dir)
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
# Initialize face parser
fp = FaceParsing()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
for task_id in inference_config:
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path,exist_ok =True)
if args.output_vid_name is None:
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
else:
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
############################################## extract frames from source video ##############################################
if get_file_type(video_path)=="video":
save_dir_full = os.path.join(args.result_dir, input_basename)
os.makedirs(save_dir_full,exist_ok = True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path)=="image":
input_img_list = [video_path, ]
fps = args.fps
elif os.path.isdir(video_path): # input img folder
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
############################################## extract audio feature ##############################################
# Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
# Initialize face parser with configurable parameters based on version
if args.version == "v15":
fp = FaceParsing(
left_cheek_width=args.left_cheek_width,
right_cheek_width=args.right_cheek_width
)
############################################## preprocess input image ##############################################
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates")
with open(crop_coord_save_path,'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("extracting landmarks...time consuming")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
i = 0
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
else: # v1
fp = FaceParsing()
# to smooth the first and the last frame
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
############################################## inference batch by batch ##############################################
print("start inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
audio_feature_batch = pe(whisper_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
# Load inference configuration
inference_config = OmegaConf.load(args.inference_config)
print("Loaded inference config:", inference_config)
# Process each task
for task_id in inference_config:
try:
# Get task configuration
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
if "result_name" in inference_config[task_id]:
args.output_vid_name = inference_config[task_id]["result_name"]
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
############################################## pad to full image ##############################################
print("pad talking image to original video")
for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
continue
# Set bbox_shift based on version
if args.version == "v15":
bbox_shift = 0 # v15 uses fixed bbox_shift
else:
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) # v1 uses config or default
# Merge results
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
# Set output paths
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
# Create temporary directories
temp_dir = os.path.join(args.result_dir, f"{args.version}")
os.makedirs(temp_dir, exist_ok=True)
# Set result save paths
result_img_save_path = os.path.join(temp_dir, output_basename)
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
os.makedirs(result_img_save_path, exist_ok=True)
# Set output video paths
if args.output_vid_name is None:
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
else:
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
# Extract frames from source video
if get_file_type(video_path) == "video":
save_dir_full = os.path.join(temp_dir, input_basename)
os.makedirs(save_dir_full, exist_ok=True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path) == "image":
input_img_list = [video_path]
fps = args.fps
elif os.path.isdir(video_path):
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
print(cmd_img2video)
os.system(cmd_img2video)
# Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
)
# Preprocess input images
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("Using saved coordinates")
with open(crop_coord_save_path, 'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("Extracting landmarks... time-consuming operation")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
print(f"Number of frames: {len(frame_list)}")
# Process each frame
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
if args.version == "v15":
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
print(cmd_combine_audio)
os.system(cmd_combine_audio)
os.remove("temp.mp4")
shutil.rmtree(result_img_save_path)
print(f"result is save to {output_vid_name}")
# Smooth first and last frames
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
# Batch inference
print("Starting inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=device,
)
res_frame_list = []
total = int(np.ceil(float(video_num) / batch_size))
# Execute inference
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
audio_feature_batch = pe(whisper_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
# Pad generated images to original video size
print("Padding generated images to original video size")
for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
if args.version == "v15":
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
except:
continue
# Merge results with version-specific parameters
if args.version == "v15":
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
else:
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
# Save prediction results
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
print("Video generation command:", cmd_img2video)
os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
print("Audio combination command:", cmd_combine_audio)
os.system(cmd_combine_audio)
# Clean up temporary files
shutil.rmtree(result_img_save_path)
os.remove(temp_vid_path)
shutil.rmtree(save_dir_full)
if not args.saved_coord:
os.remove(crop_coord_save_path)
print(f"Results saved to {output_vid_name}")
except Exception as e:
print("Error occurred during processing:", e)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
parser.add_argument("--bbox_shift", type=int, default=0)
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--output_vid_name", type=str, default=None)
parser.add_argument("--use_saved_coord",
action="store_true",
help='use saved coordinate to save time')
parser.add_argument("--use_float16",
action="store_true",
help="Whether use float16 to speed up inference",
)
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use")
args = parser.parse_args()
main(args)

View File

@@ -1,252 +0,0 @@
import os
import cv2
import math
import copy
import torch
import glob
import shutil
import pickle
import argparse
import subprocess
import numpy as np
from tqdm import tqdm
from omegaconf import OmegaConf
from transformers import WhisperModel
from musetalk.utils.blending import get_image
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
@torch.no_grad()
def main(args):
# Configure ffmpeg path
if args.ffmpeg_path not in os.getenv('PATH'):
print("Adding ffmpeg to PATH")
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
# Set computing device
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
# Load model weights
vae, unet, pe = load_all_model(
unet_model_path=args.unet_model_path,
vae_type=args.vae_type,
unet_config=args.unet_config,
device=device
)
timesteps = torch.tensor([0], device=device)
# Convert models to half precision if float16 is enabled
if args.use_float16:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
# Move models to specified device
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
# Initialize audio processor and Whisper model
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
weight_dtype = unet.model.dtype
whisper = WhisperModel.from_pretrained(args.whisper_dir)
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
# Initialize face parser
fp = FaceParsing(left_cheek_width=args.left_cheek_width, right_cheek_width=args.right_cheek_width)
# Load inference configuration
inference_config = OmegaConf.load(args.inference_config)
print("Loaded inference config:", inference_config)
# Process each task
for task_id in inference_config:
try:
# Get task configuration
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
if "result_name" in inference_config[task_id]:
args.output_vid_name = inference_config[task_id]["result_name"]
bbox_shift = args.bbox_shift
# Set output paths
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
# Create temporary directories
temp_dir = os.path.join(args.result_dir, "frames_result")
os.makedirs(temp_dir, exist_ok=True)
# Set result save paths
result_img_save_path = os.path.join(temp_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path, exist_ok=True)
# Set output video paths
if args.output_vid_name is None:
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
else:
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
# Skip if output file already exists
if os.path.exists(output_vid_name):
print(f"{output_vid_name} already exists, skipping!")
continue
# Extract frames from source video
if get_file_type(video_path) == "video":
save_dir_full = os.path.join(temp_dir, input_basename)
os.makedirs(save_dir_full, exist_ok=True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path) == "image":
input_img_list = [video_path]
fps = args.fps
elif os.path.isdir(video_path):
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
# Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
)
# Preprocess input images
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("Using saved coordinates")
with open(crop_coord_save_path, 'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("Extracting landmarks... time-consuming operation")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
print(f"Number of frames: {len(frame_list)}")
# Process each frame
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
# Smooth first and last frames
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
# Batch inference
print("Starting inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=device,
)
res_frame_list = []
total = int(np.ceil(float(video_num) / batch_size))
# Execute inference
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
audio_feature_batch = pe(whisper_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
# Pad generated images to original video size
print("Padding generated images to original video size")
for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
except:
continue
# Merge results
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
# Save prediction results
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
print("Video generation command:", cmd_img2video)
os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
print("Audio combination command:", cmd_combine_audio)
os.system(cmd_combine_audio)
# Clean up temporary files
shutil.rmtree(result_img_save_path)
os.remove(temp_vid_path)
shutil.rmtree(save_dir_full)
if not args.saved_coord:
os.remove(crop_coord_save_path)
print(f"Results saved to {output_vid_name}")
except Exception as e:
print("Error occurred during processing:", e)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
args = parser.parse_args()
main(args)

View File

@@ -10,26 +10,22 @@ import sys
from tqdm import tqdm
import copy
import json
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
from musetalk.utils.utils import load_all_model
import shutil
from transformers import WhisperModel
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.utils import datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs
from musetalk.utils.blending import get_image_prepare_material, get_image_blending
from musetalk.utils.utils import load_all_model
from musetalk.utils.audio_processor import AudioProcessor
import shutil
import threading
import queue
import time
# load model weights
audio_processor, vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
cap = cv2.VideoCapture(vid_path)
count = 0
while True:
@@ -42,35 +38,43 @@ def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
else:
break
def osmakedirs(path_list):
for path in path_list:
os.makedirs(path) if not os.path.exists(path) else None
@torch.no_grad()
@torch.no_grad()
class Avatar:
def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation):
self.avatar_id = avatar_id
self.video_path = video_path
self.bbox_shift = bbox_shift
self.avatar_path = f"./results/avatars/{avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
# 根据版本设置不同的基础路径
if args.version == "v15":
self.base_path = f"./results/{args.version}/avatars/{avatar_id}"
else: # v1
self.base_path = f"./results/avatars/{avatar_id}"
self.avatar_path = self.base_path
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl"
self.latents_out_path= f"{self.avatar_path}/latents.pt"
self.latents_out_path = f"{self.avatar_path}/latents.pt"
self.video_out_path = f"{self.avatar_path}/vid_output/"
self.mask_out_path =f"{self.avatar_path}/mask"
self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl"
self.mask_out_path = f"{self.avatar_path}/mask"
self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl"
self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
self.avatar_info = {
"avatar_id":avatar_id,
"video_path":video_path,
"bbox_shift":bbox_shift
"avatar_id": avatar_id,
"video_path": video_path,
"bbox_shift": bbox_shift,
"version": args.version
}
self.preparation = preparation
self.batch_size = batch_size
self.idx = 0
self.init()
def init(self):
if self.preparation:
if os.path.exists(self.avatar_path):
@@ -80,7 +84,7 @@ class Avatar:
print("*********************************")
print(f" creating avator: {self.avatar_id}")
print("*********************************")
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
self.prepare_material()
else:
self.input_latent_list_cycle = torch.load(self.latents_out_path)
@@ -98,16 +102,16 @@ class Avatar:
print("*********************************")
print(f" creating avator: {self.avatar_id}")
print("*********************************")
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
self.prepare_material()
else:
else:
if not os.path.exists(self.avatar_path):
print(f"{self.avatar_id} does not exist, you should set preparation to True")
sys.exit()
with open(self.avatar_info_path, "r") as f:
avatar_info = json.load(f)
if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']:
response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)")
if response.lower() == "c":
@@ -115,11 +119,11 @@ class Avatar:
print("*********************************")
print(f" creating avator: {self.avatar_id}")
print("*********************************")
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
self.prepare_material()
else:
sys.exit()
else:
else:
self.input_latent_list_cycle = torch.load(self.latents_out_path)
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
@@ -131,36 +135,40 @@ class Avatar:
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.mask_list_cycle = read_imgs(input_mask_list)
def prepare_material(self):
print("preparing data materials ... ...")
with open(self.avatar_info_path, "w") as f:
json.dump(self.avatar_info, f)
if os.path.isfile(self.video_path):
video2imgs(self.video_path, self.full_imgs_path, ext = 'png')
video2imgs(self.video_path, self.full_imgs_path, ext='png')
else:
print(f"copy files in {self.video_path}")
files = os.listdir(self.video_path)
files.sort()
files = [file for file in files if file.split(".")[-1]=="png"]
files = [file for file in files if file.split(".")[-1] == "png"]
for filename in files:
shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}")
input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
print("extracting landmarks...")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, self.bbox_shift)
input_latent_list = []
idx = -1
# maker if the bbox is not sufficient
coord_placeholder = (0.0,0.0,0.0,0.0)
# maker if the bbox is not sufficient
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
for bbox, frame in zip(coord_list, frame_list):
idx = idx + 1
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
if args.version == "v15":
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
coord_list[idx] = [x1, y1, x2, y2] # 更新coord_list中的bbox
crop_frame = frame[y1:y2, x1:x2]
resized_crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(resized_crop_frame)
input_latent_list.append(latents)
@@ -170,112 +178,116 @@ class Avatar:
self.mask_coords_list_cycle = []
self.mask_list_cycle = []
for i,frame in enumerate(tqdm(self.frame_list_cycle)):
cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png",frame)
face_box = self.coord_list_cycle[i]
mask,crop_box = get_image_prepare_material(frame,face_box)
cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png",mask)
for i, frame in enumerate(tqdm(self.frame_list_cycle)):
cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png", frame)
x1, y1, x2, y2 = self.coord_list_cycle[i]
if args.version == "v15":
mode = args.parsing_mode
else:
mode = "raw"
mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp, mode=mode)
cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png", mask)
self.mask_coords_list_cycle += [crop_box]
self.mask_list_cycle.append(mask)
with open(self.mask_coords_path, 'wb') as f:
pickle.dump(self.mask_coords_list_cycle, f)
with open(self.coords_path, 'wb') as f:
pickle.dump(self.coord_list_cycle, f)
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
#
def process_frames(self,
res_frame_queue,
video_len,
skip_save_images):
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
def process_frames(self, res_frame_queue, video_len, skip_save_images):
print(video_len)
while True:
if self.idx>=video_len-1:
if self.idx >= video_len - 1:
break
try:
start = time.time()
res_frame = res_frame_queue.get(block=True, timeout=1)
except queue.Empty:
continue
bbox = self.coord_list_cycle[self.idx%(len(self.coord_list_cycle))]
ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx%(len(self.frame_list_cycle))])
bbox = self.coord_list_cycle[self.idx % (len(self.coord_list_cycle))]
ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx % (len(self.frame_list_cycle))])
x1, y1, x2, y2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
except:
continue
mask = self.mask_list_cycle[self.idx%(len(self.mask_list_cycle))]
mask_crop_box = self.mask_coords_list_cycle[self.idx%(len(self.mask_coords_list_cycle))]
#combine_frame = get_image(ori_frame,res_frame,bbox)
mask = self.mask_list_cycle[self.idx % (len(self.mask_list_cycle))]
mask_crop_box = self.mask_coords_list_cycle[self.idx % (len(self.mask_coords_list_cycle))]
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
if skip_save_images is False:
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png", combine_frame)
self.idx = self.idx + 1
def inference(self,
audio_path,
out_vid_name,
fps,
skip_save_images):
os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
def inference(self, audio_path, out_vid_name, fps, skip_save_images):
os.makedirs(self.avatar_path + '/tmp', exist_ok=True)
print("start inference")
############################################## extract audio feature ##############################################
start_time = time.time()
whisper_feature = audio_processor.audio2feat(audio_path)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
# Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path, weight_dtype=weight_dtype)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
)
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
############################################## inference batch by batch ##############################################
video_num = len(whisper_chunks)
video_num = len(whisper_chunks)
res_frame_queue = queue.Queue()
self.idx = 0
# # Create a sub-thread and start it
# Create a sub-thread and start it
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
process_thread.start()
gen = datagen(whisper_chunks,
self.input_latent_list_cycle,
self.batch_size)
self.input_latent_list_cycle,
self.batch_size)
start_time = time.time()
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=int(np.ceil(float(video_num) / self.batch_size)))):
audio_feature_batch = pe(whisper_batch.to(device))
latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_queue.put(res_frame)
# Close the queue and sub-thread after all tasks are completed
process_thread.join()
if args.skip_save_images is True:
print('Total process time of {} frames without saving images = {}s'.format(
video_num,
time.time()-start_time))
video_num,
time.time() - start_time))
else:
print('Total process time of {} frames including saving images = {}s'.format(
video_num,
time.time()-start_time))
video_num,
time.time() - start_time))
if out_vid_name is not None and args.skip_save_images is False:
if out_vid_name is not None and args.skip_save_images is False:
# optional
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
print(cmd_img2video)
os.system(cmd_img2video)
output_vid = os.path.join(self.video_out_path, out_vid_name+".mp4") # on
output_vid = os.path.join(self.video_out_path, out_vid_name + ".mp4") # on
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {self.avatar_path}/temp.mp4 {output_vid}"
print(cmd_combine_audio)
os.system(cmd_combine_audio)
@@ -284,52 +296,95 @@ class Avatar:
shutil.rmtree(f"{self.avatar_path}/tmp")
print(f"result is save to {output_vid}")
print("\n")
if __name__ == "__main__":
'''
This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
'''
parser = argparse.ArgumentParser()
parser.add_argument("--inference_config",
type=str,
default="configs/inference/realtime.yaml",
)
parser.add_argument("--fps",
type=int,
default=25,
)
parser.add_argument("--batch_size",
type=int,
default=4,
)
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Version of MuseTalk: v1 or v15")
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
parser.add_argument("--unet_config", type=str, default="./models/musetalk/musetalk.json", help="Path to UNet configuration file")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml")
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=25, help="Batch size for inference")
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
parser.add_argument("--skip_save_images",
action="store_true",
help="Whether skip saving images for better generation speed calculation",
)
action="store_true",
help="Whether skip saving images for better generation speed calculation",
)
args = parser.parse_args()
# Set computing device
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
# Load model weights
vae, unet, pe = load_all_model(
unet_model_path=args.unet_model_path,
vae_type=args.vae_type,
unet_config=args.unet_config,
device=device
)
timesteps = torch.tensor([0], device=device)
pe = pe.half().to(device)
vae.vae = vae.vae.half().to(device)
unet.model = unet.model.half().to(device)
# Initialize audio processor and Whisper model
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
weight_dtype = unet.model.dtype
whisper = WhisperModel.from_pretrained(args.whisper_dir)
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
# Initialize face parser with configurable parameters based on version
if args.version == "v15":
fp = FaceParsing(
left_cheek_width=args.left_cheek_width,
right_cheek_width=args.right_cheek_width
)
else: # v1
fp = FaceParsing()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
for avatar_id in inference_config:
data_preparation = inference_config[avatar_id]["preparation"]
video_path = inference_config[avatar_id]["video_path"]
bbox_shift = inference_config[avatar_id]["bbox_shift"]
if args.version == "v15":
bbox_shift = 0
else:
bbox_shift = inference_config[avatar_id]["bbox_shift"]
avatar = Avatar(
avatar_id = avatar_id,
video_path = video_path,
bbox_shift = bbox_shift,
batch_size = args.batch_size,
preparation= data_preparation)
avatar_id=avatar_id,
video_path=video_path,
bbox_shift=bbox_shift,
batch_size=args.batch_size,
preparation=data_preparation)
audio_clips = inference_config[avatar_id]["audio_clips"]
for audio_num, audio_path in audio_clips.items():
print("Inferring using:",audio_path)
avatar.inference(audio_path,
audio_num,
args.fps,
args.skip_save_images)
print("Inferring using:", audio_path)
avatar.inference(audio_path,
audio_num,
args.fps,
args.skip_save_images)