mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 09:29:20 +08:00
feat: real-time infer (#286)
* feat: realtime infer * cchore: infer script
This commit is contained in:
82
README.md
82
README.md
@@ -130,9 +130,8 @@ https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
|
|||||||
- [x] codes for real-time inference.
|
- [x] codes for real-time inference.
|
||||||
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
|
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
|
||||||
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
|
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
|
||||||
|
- [x] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
|
||||||
- [ ] training and dataloader code (Expected completion on 04/04/2025).
|
- [ ] training and dataloader code (Expected completion on 04/04/2025).
|
||||||
- [ ] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Getting Started
|
# Getting Started
|
||||||
@@ -220,21 +219,52 @@ We provide inference scripts for both versions of MuseTalk:
|
|||||||
|
|
||||||
#### MuseTalk 1.5 (Recommended)
|
#### MuseTalk 1.5 (Recommended)
|
||||||
```bash
|
```bash
|
||||||
sh inference.sh v1.5
|
# Run MuseTalk 1.5 inference
|
||||||
|
sh inference.sh v1.5 normal
|
||||||
```
|
```
|
||||||
This inference script supports both MuseTalk 1.5 and 1.0 models:
|
|
||||||
- For MuseTalk 1.5: Use the command above with the V1.5 model path
|
|
||||||
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
|
|
||||||
|
|
||||||
configs/inference/test.yaml is the path to the inference configuration file, including video_path and audio_path.
|
|
||||||
The video_path should be either a video file, an image file or a directory of images.
|
|
||||||
|
|
||||||
#### MuseTalk 1.0
|
#### MuseTalk 1.0
|
||||||
```bash
|
```bash
|
||||||
sh inference.sh v1.0
|
# Run MuseTalk 1.0 inference
|
||||||
|
sh inference.sh v1.0 normal
|
||||||
```
|
```
|
||||||
You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
|
|
||||||
<details close>
|
The inference script supports both MuseTalk 1.5 and 1.0 models:
|
||||||
|
- For MuseTalk 1.5: Use the command above with the V1.5 model path
|
||||||
|
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
|
||||||
|
|
||||||
|
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
|
||||||
|
- `video_path`: Path to the input video, image file, or directory of images
|
||||||
|
- `audio_path`: Path to the input audio file
|
||||||
|
|
||||||
|
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
|
||||||
|
|
||||||
|
#### Real-time Inference
|
||||||
|
For real-time inference, use the following command:
|
||||||
|
```bash
|
||||||
|
# Run real-time inference
|
||||||
|
sh inference.sh v1.5 realtime # For MuseTalk 1.5
|
||||||
|
# or
|
||||||
|
sh inference.sh v1.0 realtime # For MuseTalk 1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
The real-time inference configuration is in `configs/inference/realtime.yaml`, which includes:
|
||||||
|
- `preparation`: Set to `True` for new avatar preparation
|
||||||
|
- `video_path`: Path to the input video
|
||||||
|
- `bbox_shift`: Adjustable parameter for mouth region control
|
||||||
|
- `audio_clips`: List of audio clips for generation
|
||||||
|
|
||||||
|
Important notes for real-time inference:
|
||||||
|
1. Set `preparation` to `True` when processing a new avatar
|
||||||
|
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
|
||||||
|
3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100
|
||||||
|
4. Set `preparation` to `False` for generating more videos with the same avatar
|
||||||
|
|
||||||
|
For faster generation without saving images, you can use:
|
||||||
|
```bash
|
||||||
|
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
||||||
|
```
|
||||||
|
|
||||||
## TestCases For 1.0
|
## TestCases For 1.0
|
||||||
<table class="center">
|
<table class="center">
|
||||||
<tr style="font-weight: bolder;text-align:center;">
|
<tr style="font-weight: bolder;text-align:center;">
|
||||||
@@ -332,39 +362,11 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
|
|||||||
```
|
```
|
||||||
:pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
|
:pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||
#### Combining MuseV and MuseTalk
|
#### Combining MuseV and MuseTalk
|
||||||
|
|
||||||
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
|
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
|
||||||
|
|
||||||
#### Real-time inference
|
|
||||||
|
|
||||||
<details close>
|
|
||||||
Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
|
|
||||||
|
|
||||||
```
|
|
||||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --batch_size 4
|
|
||||||
```
|
|
||||||
configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
|
|
||||||
|
|
||||||
1. Set `preparation` to `True` in `realtime.yaml` to prepare the materials for a new `avatar`. (If the `bbox_shift` has changed, you also need to re-prepare the materials.)
|
|
||||||
1. After that, the `avatar` will use an audio clip selected from `audio_clips` to generate video.
|
|
||||||
```
|
|
||||||
Inferring using: data/audio/yongen.wav
|
|
||||||
```
|
|
||||||
1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100.
|
|
||||||
1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
|
|
||||||
|
|
||||||
##### Note for Real-time inference
|
|
||||||
1. If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
|
|
||||||
1. In the previous script, the generation time is also limited by I/O (e.g. saving images). If you just want to test the generation speed without saving the images, you can run
|
|
||||||
```
|
|
||||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
|
||||||
```
|
|
||||||
</details>
|
|
||||||
|
|
||||||
# Acknowledgement
|
# Acknowledgement
|
||||||
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
|
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
|
||||||
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
|
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
avator_1:
|
avator_1:
|
||||||
preparation: False
|
preparation: True # your can set it to False if you want to use the existing avator, it will save time
|
||||||
bbox_shift: 5
|
bbox_shift: 5
|
||||||
video_path: "data/video/sun.mp4"
|
video_path: "data/video/yongen.mp4"
|
||||||
audio_clips:
|
audio_clips:
|
||||||
audio_0: "data/audio/yongen.wav"
|
audio_0: "data/audio/yongen.wav"
|
||||||
audio_1: "data/audio/sun.wav"
|
audio_1: "data/audio/eng.wav"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ task_0:
|
|||||||
audio_path: "data/audio/yongen.wav"
|
audio_path: "data/audio/yongen.wav"
|
||||||
|
|
||||||
task_1:
|
task_1:
|
||||||
video_path: "data/video/sun.mp4"
|
video_path: "data/video/yongen.mp4"
|
||||||
audio_path: "data/audio/sun.wav"
|
audio_path: "data/audio/eng.wav"
|
||||||
bbox_shift: -7
|
bbox_shift: -7
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
BIN
data/audio/eng.wav
Executable file
BIN
data/audio/eng.wav
Executable file
Binary file not shown.
70
inference.sh
70
inference.sh
@@ -1,46 +1,72 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# This script runs inference based on the version specified by the user.
|
# This script runs inference based on the version and mode specified by the user.
|
||||||
# Usage:
|
# Usage:
|
||||||
# To run v1.0 inference: sh inference.sh v1.0
|
# To run v1.0 inference: sh inference.sh v1.0 [normal|realtime]
|
||||||
# To run v1.5 inference: sh inference.sh v1.5
|
# To run v1.5 inference: sh inference.sh v1.5 [normal|realtime]
|
||||||
|
|
||||||
# Check if the correct number of arguments is provided
|
# Check if the correct number of arguments is provided
|
||||||
if [ "$#" -ne 1 ]; then
|
if [ "$#" -ne 2 ]; then
|
||||||
echo "Usage: $0 <version>"
|
echo "Usage: $0 <version> <mode>"
|
||||||
echo "Example: $0 v1.0 or $0 v1.5"
|
echo "Example: $0 v1.0 normal or $0 v1.5 realtime"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Get the version from the user input
|
# Get the version and mode from the user input
|
||||||
version=$1
|
version=$1
|
||||||
config_path="./configs/inference/test.yaml"
|
mode=$2
|
||||||
|
|
||||||
|
# Validate mode
|
||||||
|
if [ "$mode" != "normal" ] && [ "$mode" != "realtime" ]; then
|
||||||
|
echo "Invalid mode specified. Please use 'normal' or 'realtime'."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set config path based on mode
|
||||||
|
if [ "$mode" = "normal" ]; then
|
||||||
|
config_path="./configs/inference/test.yaml"
|
||||||
|
result_dir="./results/test"
|
||||||
|
else
|
||||||
|
config_path="./configs/inference/realtime.yaml"
|
||||||
|
result_dir="./results/realtime"
|
||||||
|
fi
|
||||||
|
|
||||||
# Define the model paths based on the version
|
# Define the model paths based on the version
|
||||||
if [ "$version" = "v1.0" ]; then
|
if [ "$version" = "v1.0" ]; then
|
||||||
model_dir="./models/musetalk"
|
model_dir="./models/musetalk"
|
||||||
unet_model_path="$model_dir/pytorch_model.bin"
|
unet_model_path="$model_dir/pytorch_model.bin"
|
||||||
unet_config="$model_dir/musetalk.json"
|
unet_config="$model_dir/musetalk.json"
|
||||||
|
version_arg="v1"
|
||||||
elif [ "$version" = "v1.5" ]; then
|
elif [ "$version" = "v1.5" ]; then
|
||||||
model_dir="./models/musetalkV15"
|
model_dir="./models/musetalkV15"
|
||||||
unet_model_path="$model_dir/unet.pth"
|
unet_model_path="$model_dir/unet.pth"
|
||||||
unet_config="$model_dir/musetalk.json"
|
unet_config="$model_dir/musetalk.json"
|
||||||
|
version_arg="v15"
|
||||||
else
|
else
|
||||||
echo "Invalid version specified. Please use v1.0 or v1.5."
|
echo "Invalid version specified. Please use v1.0 or v1.5."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Run inference based on the version
|
# Set script name based on mode
|
||||||
if [ "$version" = "v1.0" ]; then
|
if [ "$mode" = "normal" ]; then
|
||||||
python3 -m scripts.inference \
|
script_name="scripts.inference"
|
||||||
--inference_config "$config_path" \
|
else
|
||||||
--result_dir "./results/test" \
|
script_name="scripts.realtime_inference"
|
||||||
--unet_model_path "$unet_model_path" \
|
fi
|
||||||
--unet_config "$unet_config"
|
|
||||||
elif [ "$version" = "v1.5" ]; then
|
# Base command arguments
|
||||||
python3 -m scripts.inference_alpha \
|
cmd_args="--inference_config $config_path \
|
||||||
--inference_config "$config_path" \
|
--result_dir $result_dir \
|
||||||
--result_dir "./results/test" \
|
--unet_model_path $unet_model_path \
|
||||||
--unet_model_path "$unet_model_path" \
|
--unet_config $unet_config \
|
||||||
--unet_config "$unet_config"
|
--version $version_arg \
|
||||||
fi
|
|
||||||
|
# Add realtime-specific arguments if in realtime mode
|
||||||
|
if [ "$mode" = "realtime" ]; then
|
||||||
|
cmd_args="$cmd_args \
|
||||||
|
--fps 25 \
|
||||||
|
--version $version_arg \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
python3 -m $script_name $cmd_args
|
||||||
@@ -11,7 +11,7 @@ class AudioProcessor:
|
|||||||
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
|
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
|
||||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
|
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
|
||||||
|
|
||||||
def get_audio_feature(self, wav_path, start_index=0):
|
def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
|
||||||
if not os.path.exists(wav_path):
|
if not os.path.exists(wav_path):
|
||||||
return None
|
return None
|
||||||
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
|
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
|
||||||
@@ -27,6 +27,8 @@ class AudioProcessor:
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
sampling_rate=sampling_rate
|
sampling_rate=sampling_rate
|
||||||
).input_features
|
).input_features
|
||||||
|
if weight_dtype is not None:
|
||||||
|
audio_feature = audio_feature.to(dtype=weight_dtype)
|
||||||
features.append(audio_feature)
|
features.append(audio_feature)
|
||||||
|
|
||||||
return features, len(librosa_output)
|
return features, len(librosa_output)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import numpy as np
|
|||||||
import cv2
|
import cv2
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
|
||||||
def get_crop_box(box, expand):
|
def get_crop_box(box, expand):
|
||||||
x, y, x1, y1 = box
|
x, y, x1, y1 = box
|
||||||
x_c, y_c = (x+x1)//2, (y+y1)//2
|
x_c, y_c = (x+x1)//2, (y+y1)//2
|
||||||
@@ -11,7 +12,8 @@ def get_crop_box(box, expand):
|
|||||||
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
|
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
|
||||||
return crop_box, s
|
return crop_box, s
|
||||||
|
|
||||||
def face_seg(image, mode="jaw", fp=None):
|
|
||||||
|
def face_seg(image, mode="raw", fp=None):
|
||||||
"""
|
"""
|
||||||
对图像进行面部解析,生成面部区域的掩码。
|
对图像进行面部解析,生成面部区域的掩码。
|
||||||
|
|
||||||
@@ -86,14 +88,12 @@ def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode=
|
|||||||
|
|
||||||
body.paste(face_large, crop_box[:2], mask_image)
|
body.paste(face_large, crop_box[:2], mask_image)
|
||||||
|
|
||||||
# 不用掩码,完全用infer
|
|
||||||
#face_large.save("debug/checkpoint_6_face_large.png")
|
|
||||||
|
|
||||||
body = np.array(body) # 将 PIL 图像转换回 numpy 数组
|
body = np.array(body) # 将 PIL 图像转换回 numpy 数组
|
||||||
|
|
||||||
return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
|
return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
|
||||||
|
|
||||||
def get_image_blending(image,face,face_box,mask_array,crop_box):
|
|
||||||
|
def get_image_blending(image, face, face_box, mask_array, crop_box):
|
||||||
body = Image.fromarray(image[:,:,::-1])
|
body = Image.fromarray(image[:,:,::-1])
|
||||||
face = Image.fromarray(face[:,:,::-1])
|
face = Image.fromarray(face[:,:,::-1])
|
||||||
|
|
||||||
@@ -108,7 +108,8 @@ def get_image_blending(image,face,face_box,mask_array,crop_box):
|
|||||||
body = np.array(body)
|
body = np.array(body)
|
||||||
return body[:,:,::-1]
|
return body[:,:,::-1]
|
||||||
|
|
||||||
def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
|
|
||||||
|
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
|
||||||
body = Image.fromarray(image[:,:,::-1])
|
body = Image.fromarray(image[:,:,::-1])
|
||||||
|
|
||||||
x, y, x1, y1 = face_box
|
x, y, x1, y1 = face_box
|
||||||
@@ -119,7 +120,7 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=
|
|||||||
face_large = body.crop(crop_box)
|
face_large = body.crop(crop_box)
|
||||||
ori_shape = face_large.size
|
ori_shape = face_large.size
|
||||||
|
|
||||||
mask_image = face_seg(face_large)
|
mask_image = face_seg(face_large, mode=mode, fp=fp)
|
||||||
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
|
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||||
mask_image = Image.new('L', ori_shape, 0)
|
mask_image = Image.new('L', ori_shape, 0)
|
||||||
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||||
@@ -132,4 +133,4 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=
|
|||||||
|
|
||||||
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
||||||
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
||||||
return mask_array,crop_box
|
return mask_array, crop_box
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class FaceParsing():
|
|||||||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||||
])
|
])
|
||||||
|
|
||||||
def __call__(self, image, size=(512, 512), mode="jaw"):
|
def __call__(self, image, size=(512, 512), mode="raw"):
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image)
|
image = Image.open(image)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
|
import math
|
||||||
import copy
|
import copy
|
||||||
import glob
|
|
||||||
import torch
|
import torch
|
||||||
|
import glob
|
||||||
import shutil
|
import shutil
|
||||||
import pickle
|
import pickle
|
||||||
import argparse
|
import argparse
|
||||||
@@ -17,18 +18,16 @@ from musetalk.utils.audio_processor import AudioProcessor
|
|||||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main(args):
|
def main(args):
|
||||||
# Configure ffmpeg path
|
# Configure ffmpeg path
|
||||||
if args.ffmpeg_path not in os.getenv('PATH'):
|
if args.ffmpeg_path not in os.getenv('PATH'):
|
||||||
print("Adding ffmpeg to PATH")
|
print("Adding ffmpeg to PATH")
|
||||||
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
|
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
|
||||||
|
|
||||||
# Set computing device
|
# Set computing device
|
||||||
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Load model weights
|
# Load model weights
|
||||||
vae, unet, pe = load_all_model(
|
vae, unet, pe = load_all_model(
|
||||||
unet_model_path=args.unet_model_path,
|
unet_model_path=args.unet_model_path,
|
||||||
@@ -37,164 +36,229 @@ def main(args):
|
|||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
timesteps = torch.tensor([0], device=device)
|
timesteps = torch.tensor([0], device=device)
|
||||||
|
|
||||||
|
# Convert models to half precision if float16 is enabled
|
||||||
if args.use_float16 is True:
|
if args.use_float16:
|
||||||
pe = pe.half()
|
pe = pe.half()
|
||||||
vae.vae = vae.vae.half()
|
vae.vae = vae.vae.half()
|
||||||
unet.model = unet.model.half()
|
unet.model = unet.model.half()
|
||||||
|
|
||||||
|
# Move models to specified device
|
||||||
|
pe = pe.to(device)
|
||||||
|
vae.vae = vae.vae.to(device)
|
||||||
|
unet.model = unet.model.to(device)
|
||||||
|
|
||||||
# Initialize audio processor and Whisper model
|
# Initialize audio processor and Whisper model
|
||||||
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
||||||
weight_dtype = unet.model.dtype
|
weight_dtype = unet.model.dtype
|
||||||
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
||||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||||
whisper.requires_grad_(False)
|
whisper.requires_grad_(False)
|
||||||
|
|
||||||
# Initialize face parser
|
# Initialize face parser with configurable parameters based on version
|
||||||
fp = FaceParsing()
|
if args.version == "v15":
|
||||||
|
fp = FaceParsing(
|
||||||
inference_config = OmegaConf.load(args.inference_config)
|
left_cheek_width=args.left_cheek_width,
|
||||||
print(inference_config)
|
right_cheek_width=args.right_cheek_width
|
||||||
for task_id in inference_config:
|
|
||||||
video_path = inference_config[task_id]["video_path"]
|
|
||||||
audio_path = inference_config[task_id]["audio_path"]
|
|
||||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
|
||||||
|
|
||||||
input_basename = os.path.basename(video_path).split('.')[0]
|
|
||||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
|
||||||
output_basename = f"{input_basename}_{audio_basename}"
|
|
||||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
|
||||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
|
||||||
os.makedirs(result_img_save_path,exist_ok =True)
|
|
||||||
|
|
||||||
if args.output_vid_name is None:
|
|
||||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
|
||||||
else:
|
|
||||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
|
||||||
############################################## extract frames from source video ##############################################
|
|
||||||
if get_file_type(video_path)=="video":
|
|
||||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
|
||||||
os.makedirs(save_dir_full,exist_ok = True)
|
|
||||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
|
||||||
os.system(cmd)
|
|
||||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
|
||||||
fps = get_video_fps(video_path)
|
|
||||||
elif get_file_type(video_path)=="image":
|
|
||||||
input_img_list = [video_path, ]
|
|
||||||
fps = args.fps
|
|
||||||
elif os.path.isdir(video_path): # input img folder
|
|
||||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
|
||||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
|
||||||
fps = args.fps
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
|
||||||
|
|
||||||
############################################## extract audio feature ##############################################
|
|
||||||
# Extract audio features
|
|
||||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
|
||||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
|
||||||
whisper_input_features,
|
|
||||||
device,
|
|
||||||
weight_dtype,
|
|
||||||
whisper,
|
|
||||||
librosa_length,
|
|
||||||
fps=fps,
|
|
||||||
audio_padding_length_left=args.audio_padding_length_left,
|
|
||||||
audio_padding_length_right=args.audio_padding_length_right,
|
|
||||||
)
|
)
|
||||||
|
else: # v1
|
||||||
############################################## preprocess input image ##############################################
|
fp = FaceParsing()
|
||||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
|
||||||
print("using extracted coordinates")
|
|
||||||
with open(crop_coord_save_path,'rb') as f:
|
|
||||||
coord_list = pickle.load(f)
|
|
||||||
frame_list = read_imgs(input_img_list)
|
|
||||||
else:
|
|
||||||
print("extracting landmarks...time consuming")
|
|
||||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
|
||||||
with open(crop_coord_save_path, 'wb') as f:
|
|
||||||
pickle.dump(coord_list, f)
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
input_latent_list = []
|
|
||||||
for bbox, frame in zip(coord_list, frame_list):
|
|
||||||
if bbox == coord_placeholder:
|
|
||||||
continue
|
|
||||||
x1, y1, x2, y2 = bbox
|
|
||||||
crop_frame = frame[y1:y2, x1:x2]
|
|
||||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
|
||||||
latents = vae.get_latents_for_unet(crop_frame)
|
|
||||||
input_latent_list.append(latents)
|
|
||||||
|
|
||||||
# to smooth the first and the last frame
|
# Load inference configuration
|
||||||
frame_list_cycle = frame_list + frame_list[::-1]
|
inference_config = OmegaConf.load(args.inference_config)
|
||||||
coord_list_cycle = coord_list + coord_list[::-1]
|
print("Loaded inference config:", inference_config)
|
||||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
|
||||||
############################################## inference batch by batch ##############################################
|
# Process each task
|
||||||
print("start inference")
|
for task_id in inference_config:
|
||||||
video_num = len(whisper_chunks)
|
try:
|
||||||
batch_size = args.batch_size
|
# Get task configuration
|
||||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
video_path = inference_config[task_id]["video_path"]
|
||||||
res_frame_list = []
|
audio_path = inference_config[task_id]["audio_path"]
|
||||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
if "result_name" in inference_config[task_id]:
|
||||||
audio_feature_batch = pe(whisper_batch)
|
args.output_vid_name = inference_config[task_id]["result_name"]
|
||||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
|
||||||
|
|
||||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
# Set bbox_shift based on version
|
||||||
recon = vae.decode_latents(pred_latents)
|
if args.version == "v15":
|
||||||
for res_frame in recon:
|
bbox_shift = 0 # v15 uses fixed bbox_shift
|
||||||
res_frame_list.append(res_frame)
|
else:
|
||||||
|
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) # v1 uses config or default
|
||||||
############################################## pad to full image ##############################################
|
|
||||||
print("pad talking image to original video")
|
|
||||||
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
|
||||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
|
||||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
|
||||||
x1, y1, x2, y2 = bbox
|
|
||||||
try:
|
|
||||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Merge results
|
# Set output paths
|
||||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
|
input_basename = os.path.basename(video_path).split('.')[0]
|
||||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||||
|
output_basename = f"{input_basename}_{audio_basename}"
|
||||||
|
|
||||||
|
# Create temporary directories
|
||||||
|
temp_dir = os.path.join(args.result_dir, f"{args.version}")
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Set result save paths
|
||||||
|
result_img_save_path = os.path.join(temp_dir, output_basename)
|
||||||
|
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
|
||||||
|
os.makedirs(result_img_save_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Set output video paths
|
||||||
|
if args.output_vid_name is None:
|
||||||
|
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
|
||||||
|
else:
|
||||||
|
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
||||||
|
output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
|
||||||
|
|
||||||
|
# Extract frames from source video
|
||||||
|
if get_file_type(video_path) == "video":
|
||||||
|
save_dir_full = os.path.join(temp_dir, input_basename)
|
||||||
|
os.makedirs(save_dir_full, exist_ok=True)
|
||||||
|
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||||
|
os.system(cmd)
|
||||||
|
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||||
|
fps = get_video_fps(video_path)
|
||||||
|
elif get_file_type(video_path) == "image":
|
||||||
|
input_img_list = [video_path]
|
||||||
|
fps = args.fps
|
||||||
|
elif os.path.isdir(video_path):
|
||||||
|
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
|
fps = args.fps
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||||
|
|
||||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
# Extract audio features
|
||||||
print(cmd_img2video)
|
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||||
os.system(cmd_img2video)
|
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||||
|
whisper_input_features,
|
||||||
|
device,
|
||||||
|
weight_dtype,
|
||||||
|
whisper,
|
||||||
|
librosa_length,
|
||||||
|
fps=fps,
|
||||||
|
audio_padding_length_left=args.audio_padding_length_left,
|
||||||
|
audio_padding_length_right=args.audio_padding_length_right,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Preprocess input images
|
||||||
|
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||||
|
print("Using saved coordinates")
|
||||||
|
with open(crop_coord_save_path, 'rb') as f:
|
||||||
|
coord_list = pickle.load(f)
|
||||||
|
frame_list = read_imgs(input_img_list)
|
||||||
|
else:
|
||||||
|
print("Extracting landmarks... time-consuming operation")
|
||||||
|
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||||
|
with open(crop_coord_save_path, 'wb') as f:
|
||||||
|
pickle.dump(coord_list, f)
|
||||||
|
|
||||||
|
print(f"Number of frames: {len(frame_list)}")
|
||||||
|
|
||||||
|
# Process each frame
|
||||||
|
input_latent_list = []
|
||||||
|
for bbox, frame in zip(coord_list, frame_list):
|
||||||
|
if bbox == coord_placeholder:
|
||||||
|
continue
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
if args.version == "v15":
|
||||||
|
y2 = y2 + args.extra_margin
|
||||||
|
y2 = min(y2, frame.shape[0])
|
||||||
|
crop_frame = frame[y1:y2, x1:x2]
|
||||||
|
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
|
||||||
|
latents = vae.get_latents_for_unet(crop_frame)
|
||||||
|
input_latent_list.append(latents)
|
||||||
|
|
||||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
# Smooth first and last frames
|
||||||
print(cmd_combine_audio)
|
frame_list_cycle = frame_list + frame_list[::-1]
|
||||||
os.system(cmd_combine_audio)
|
coord_list_cycle = coord_list + coord_list[::-1]
|
||||||
|
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||||
os.remove("temp.mp4")
|
|
||||||
shutil.rmtree(result_img_save_path)
|
# Batch inference
|
||||||
print(f"result is save to {output_vid_name}")
|
print("Starting inference")
|
||||||
|
video_num = len(whisper_chunks)
|
||||||
|
batch_size = args.batch_size
|
||||||
|
gen = datagen(
|
||||||
|
whisper_chunks=whisper_chunks,
|
||||||
|
vae_encode_latents=input_latent_list_cycle,
|
||||||
|
batch_size=batch_size,
|
||||||
|
delay_frame=0,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
res_frame_list = []
|
||||||
|
total = int(np.ceil(float(video_num) / batch_size))
|
||||||
|
|
||||||
|
# Execute inference
|
||||||
|
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
|
||||||
|
audio_feature_batch = pe(whisper_batch)
|
||||||
|
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||||
|
|
||||||
|
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||||
|
recon = vae.decode_latents(pred_latents)
|
||||||
|
for res_frame in recon:
|
||||||
|
res_frame_list.append(res_frame)
|
||||||
|
|
||||||
|
# Pad generated images to original video size
|
||||||
|
print("Padding generated images to original video size")
|
||||||
|
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
||||||
|
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||||
|
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
if args.version == "v15":
|
||||||
|
y2 = y2 + args.extra_margin
|
||||||
|
y2 = min(y2, frame.shape[0])
|
||||||
|
try:
|
||||||
|
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Merge results with version-specific parameters
|
||||||
|
if args.version == "v15":
|
||||||
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||||
|
else:
|
||||||
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
|
||||||
|
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
|
||||||
|
|
||||||
|
# Save prediction results
|
||||||
|
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
|
||||||
|
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
|
||||||
|
print("Video generation command:", cmd_img2video)
|
||||||
|
os.system(cmd_img2video)
|
||||||
|
|
||||||
|
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
|
||||||
|
print("Audio combination command:", cmd_combine_audio)
|
||||||
|
os.system(cmd_combine_audio)
|
||||||
|
|
||||||
|
# Clean up temporary files
|
||||||
|
shutil.rmtree(result_img_save_path)
|
||||||
|
os.remove(temp_vid_path)
|
||||||
|
|
||||||
|
shutil.rmtree(save_dir_full)
|
||||||
|
if not args.saved_coord:
|
||||||
|
os.remove(crop_coord_save_path)
|
||||||
|
|
||||||
|
print(f"Results saved to {output_vid_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print("Error occurred during processing:", e)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
||||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
|
||||||
parser.add_argument("--bbox_shift", type=int, default=0)
|
|
||||||
parser.add_argument("--result_dir", default='./results', help="path to output")
|
|
||||||
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
||||||
parser.add_argument("--batch_size", type=int, default=8)
|
|
||||||
parser.add_argument("--output_vid_name", type=str, default=None)
|
|
||||||
parser.add_argument("--use_saved_coord",
|
|
||||||
action="store_true",
|
|
||||||
help='use saved coordinate to save time')
|
|
||||||
parser.add_argument("--use_float16",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether use float16 to speed up inference",
|
|
||||||
)
|
|
||||||
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
|
||||||
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
|
|
||||||
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
||||||
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
|
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
|
||||||
|
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
|
||||||
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
||||||
|
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
|
||||||
|
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
||||||
|
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
||||||
|
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
|
||||||
|
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
||||||
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
||||||
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
|
||||||
|
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
||||||
|
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
||||||
|
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
||||||
|
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
||||||
|
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
|
||||||
|
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
|
||||||
|
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
|
||||||
|
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -1,252 +0,0 @@
|
|||||||
import os
|
|
||||||
import cv2
|
|
||||||
import math
|
|
||||||
import copy
|
|
||||||
import torch
|
|
||||||
import glob
|
|
||||||
import shutil
|
|
||||||
import pickle
|
|
||||||
import argparse
|
|
||||||
import subprocess
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from transformers import WhisperModel
|
|
||||||
|
|
||||||
from musetalk.utils.blending import get_image
|
|
||||||
from musetalk.utils.face_parsing import FaceParsing
|
|
||||||
from musetalk.utils.audio_processor import AudioProcessor
|
|
||||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
|
||||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main(args):
|
|
||||||
# Configure ffmpeg path
|
|
||||||
if args.ffmpeg_path not in os.getenv('PATH'):
|
|
||||||
print("Adding ffmpeg to PATH")
|
|
||||||
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
|
|
||||||
|
|
||||||
# Set computing device
|
|
||||||
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
# Load model weights
|
|
||||||
vae, unet, pe = load_all_model(
|
|
||||||
unet_model_path=args.unet_model_path,
|
|
||||||
vae_type=args.vae_type,
|
|
||||||
unet_config=args.unet_config,
|
|
||||||
device=device
|
|
||||||
)
|
|
||||||
timesteps = torch.tensor([0], device=device)
|
|
||||||
|
|
||||||
# Convert models to half precision if float16 is enabled
|
|
||||||
if args.use_float16:
|
|
||||||
pe = pe.half()
|
|
||||||
vae.vae = vae.vae.half()
|
|
||||||
unet.model = unet.model.half()
|
|
||||||
|
|
||||||
# Move models to specified device
|
|
||||||
pe = pe.to(device)
|
|
||||||
vae.vae = vae.vae.to(device)
|
|
||||||
unet.model = unet.model.to(device)
|
|
||||||
|
|
||||||
# Initialize audio processor and Whisper model
|
|
||||||
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
|
||||||
weight_dtype = unet.model.dtype
|
|
||||||
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
|
||||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
|
||||||
whisper.requires_grad_(False)
|
|
||||||
|
|
||||||
# Initialize face parser
|
|
||||||
fp = FaceParsing(left_cheek_width=args.left_cheek_width, right_cheek_width=args.right_cheek_width)
|
|
||||||
|
|
||||||
# Load inference configuration
|
|
||||||
inference_config = OmegaConf.load(args.inference_config)
|
|
||||||
print("Loaded inference config:", inference_config)
|
|
||||||
|
|
||||||
# Process each task
|
|
||||||
for task_id in inference_config:
|
|
||||||
try:
|
|
||||||
# Get task configuration
|
|
||||||
video_path = inference_config[task_id]["video_path"]
|
|
||||||
audio_path = inference_config[task_id]["audio_path"]
|
|
||||||
if "result_name" in inference_config[task_id]:
|
|
||||||
args.output_vid_name = inference_config[task_id]["result_name"]
|
|
||||||
bbox_shift = args.bbox_shift
|
|
||||||
# Set output paths
|
|
||||||
input_basename = os.path.basename(video_path).split('.')[0]
|
|
||||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
|
||||||
output_basename = f"{input_basename}_{audio_basename}"
|
|
||||||
|
|
||||||
# Create temporary directories
|
|
||||||
temp_dir = os.path.join(args.result_dir, "frames_result")
|
|
||||||
os.makedirs(temp_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Set result save paths
|
|
||||||
result_img_save_path = os.path.join(temp_dir, output_basename) # related to video & audio inputs
|
|
||||||
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl") # only related to video input
|
|
||||||
os.makedirs(result_img_save_path, exist_ok=True)
|
|
||||||
# Set output video paths
|
|
||||||
if args.output_vid_name is None:
|
|
||||||
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
|
|
||||||
else:
|
|
||||||
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
|
||||||
output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
|
|
||||||
|
|
||||||
# Skip if output file already exists
|
|
||||||
if os.path.exists(output_vid_name):
|
|
||||||
print(f"{output_vid_name} already exists, skipping!")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Extract frames from source video
|
|
||||||
if get_file_type(video_path) == "video":
|
|
||||||
save_dir_full = os.path.join(temp_dir, input_basename)
|
|
||||||
os.makedirs(save_dir_full, exist_ok=True)
|
|
||||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
|
||||||
os.system(cmd)
|
|
||||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
|
||||||
fps = get_video_fps(video_path)
|
|
||||||
elif get_file_type(video_path) == "image":
|
|
||||||
input_img_list = [video_path]
|
|
||||||
fps = args.fps
|
|
||||||
elif os.path.isdir(video_path):
|
|
||||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
|
||||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
|
||||||
fps = args.fps
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
|
||||||
|
|
||||||
# Extract audio features
|
|
||||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
|
||||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
|
||||||
whisper_input_features,
|
|
||||||
device,
|
|
||||||
weight_dtype,
|
|
||||||
whisper,
|
|
||||||
librosa_length,
|
|
||||||
fps=fps,
|
|
||||||
audio_padding_length_left=args.audio_padding_length_left,
|
|
||||||
audio_padding_length_right=args.audio_padding_length_right,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Preprocess input images
|
|
||||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
|
||||||
print("Using saved coordinates")
|
|
||||||
with open(crop_coord_save_path, 'rb') as f:
|
|
||||||
coord_list = pickle.load(f)
|
|
||||||
frame_list = read_imgs(input_img_list)
|
|
||||||
else:
|
|
||||||
print("Extracting landmarks... time-consuming operation")
|
|
||||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
|
||||||
with open(crop_coord_save_path, 'wb') as f:
|
|
||||||
pickle.dump(coord_list, f)
|
|
||||||
|
|
||||||
print(f"Number of frames: {len(frame_list)}")
|
|
||||||
|
|
||||||
# Process each frame
|
|
||||||
input_latent_list = []
|
|
||||||
for bbox, frame in zip(coord_list, frame_list):
|
|
||||||
if bbox == coord_placeholder:
|
|
||||||
continue
|
|
||||||
x1, y1, x2, y2 = bbox
|
|
||||||
y2 = y2 + args.extra_margin
|
|
||||||
y2 = min(y2, frame.shape[0])
|
|
||||||
crop_frame = frame[y1:y2, x1:x2]
|
|
||||||
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
|
|
||||||
latents = vae.get_latents_for_unet(crop_frame)
|
|
||||||
input_latent_list.append(latents)
|
|
||||||
|
|
||||||
# Smooth first and last frames
|
|
||||||
frame_list_cycle = frame_list + frame_list[::-1]
|
|
||||||
coord_list_cycle = coord_list + coord_list[::-1]
|
|
||||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
|
||||||
|
|
||||||
# Batch inference
|
|
||||||
print("Starting inference")
|
|
||||||
video_num = len(whisper_chunks)
|
|
||||||
batch_size = args.batch_size
|
|
||||||
gen = datagen(
|
|
||||||
whisper_chunks=whisper_chunks,
|
|
||||||
vae_encode_latents=input_latent_list_cycle,
|
|
||||||
batch_size=batch_size,
|
|
||||||
delay_frame=0,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
res_frame_list = []
|
|
||||||
total = int(np.ceil(float(video_num) / batch_size))
|
|
||||||
|
|
||||||
# Execute inference
|
|
||||||
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
|
|
||||||
audio_feature_batch = pe(whisper_batch)
|
|
||||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
|
||||||
|
|
||||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
|
||||||
recon = vae.decode_latents(pred_latents)
|
|
||||||
for res_frame in recon:
|
|
||||||
res_frame_list.append(res_frame)
|
|
||||||
|
|
||||||
# Pad generated images to original video size
|
|
||||||
print("Padding generated images to original video size")
|
|
||||||
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
|
||||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
|
||||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
|
||||||
x1, y1, x2, y2 = bbox
|
|
||||||
y2 = y2 + args.extra_margin
|
|
||||||
y2 = min(y2, frame.shape[0])
|
|
||||||
try:
|
|
||||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Merge results
|
|
||||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
|
||||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
|
|
||||||
|
|
||||||
# Save prediction results
|
|
||||||
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
|
|
||||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
|
|
||||||
print("Video generation command:", cmd_img2video)
|
|
||||||
os.system(cmd_img2video)
|
|
||||||
|
|
||||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
|
|
||||||
print("Audio combination command:", cmd_combine_audio)
|
|
||||||
os.system(cmd_combine_audio)
|
|
||||||
|
|
||||||
# Clean up temporary files
|
|
||||||
shutil.rmtree(result_img_save_path)
|
|
||||||
os.remove(temp_vid_path)
|
|
||||||
|
|
||||||
shutil.rmtree(save_dir_full)
|
|
||||||
if not args.saved_coord:
|
|
||||||
os.remove(crop_coord_save_path)
|
|
||||||
|
|
||||||
print(f"Results saved to {output_vid_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print("Error occurred during processing:", e)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
|
||||||
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
|
||||||
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
|
||||||
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
|
|
||||||
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
|
|
||||||
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
|
||||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
|
|
||||||
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
|
||||||
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
|
||||||
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
|
|
||||||
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
|
||||||
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
|
||||||
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
|
||||||
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
|
|
||||||
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
|
||||||
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
|
||||||
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
|
||||||
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
|
||||||
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
|
|
||||||
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
|
|
||||||
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
||||||
@@ -10,26 +10,22 @@ import sys
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
from transformers import WhisperModel
|
||||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
|
||||||
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
|
|
||||||
from musetalk.utils.utils import load_all_model
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
|
from musetalk.utils.face_parsing import FaceParsing
|
||||||
|
from musetalk.utils.utils import datagen
|
||||||
|
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs
|
||||||
|
from musetalk.utils.blending import get_image_prepare_material, get_image_blending
|
||||||
|
from musetalk.utils.utils import load_all_model
|
||||||
|
from musetalk.utils.audio_processor import AudioProcessor
|
||||||
|
|
||||||
|
import shutil
|
||||||
import threading
|
import threading
|
||||||
import queue
|
import queue
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
# load model weights
|
|
||||||
audio_processor, vae, unet, pe = load_all_model()
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
timesteps = torch.tensor([0], device=device)
|
|
||||||
pe = pe.half()
|
|
||||||
vae.vae = vae.vae.half()
|
|
||||||
unet.model = unet.model.half()
|
|
||||||
|
|
||||||
def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
|
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
||||||
cap = cv2.VideoCapture(vid_path)
|
cap = cv2.VideoCapture(vid_path)
|
||||||
count = 0
|
count = 0
|
||||||
while True:
|
while True:
|
||||||
@@ -42,35 +38,43 @@ def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def osmakedirs(path_list):
|
def osmakedirs(path_list):
|
||||||
for path in path_list:
|
for path in path_list:
|
||||||
os.makedirs(path) if not os.path.exists(path) else None
|
os.makedirs(path) if not os.path.exists(path) else None
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
|
@torch.no_grad()
|
||||||
class Avatar:
|
class Avatar:
|
||||||
def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation):
|
def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation):
|
||||||
self.avatar_id = avatar_id
|
self.avatar_id = avatar_id
|
||||||
self.video_path = video_path
|
self.video_path = video_path
|
||||||
self.bbox_shift = bbox_shift
|
self.bbox_shift = bbox_shift
|
||||||
self.avatar_path = f"./results/avatars/{avatar_id}"
|
# 根据版本设置不同的基础路径
|
||||||
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
|
if args.version == "v15":
|
||||||
|
self.base_path = f"./results/{args.version}/avatars/{avatar_id}"
|
||||||
|
else: # v1
|
||||||
|
self.base_path = f"./results/avatars/{avatar_id}"
|
||||||
|
|
||||||
|
self.avatar_path = self.base_path
|
||||||
|
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
|
||||||
self.coords_path = f"{self.avatar_path}/coords.pkl"
|
self.coords_path = f"{self.avatar_path}/coords.pkl"
|
||||||
self.latents_out_path= f"{self.avatar_path}/latents.pt"
|
self.latents_out_path = f"{self.avatar_path}/latents.pt"
|
||||||
self.video_out_path = f"{self.avatar_path}/vid_output/"
|
self.video_out_path = f"{self.avatar_path}/vid_output/"
|
||||||
self.mask_out_path =f"{self.avatar_path}/mask"
|
self.mask_out_path = f"{self.avatar_path}/mask"
|
||||||
self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl"
|
self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl"
|
||||||
self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
|
self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
|
||||||
self.avatar_info = {
|
self.avatar_info = {
|
||||||
"avatar_id":avatar_id,
|
"avatar_id": avatar_id,
|
||||||
"video_path":video_path,
|
"video_path": video_path,
|
||||||
"bbox_shift":bbox_shift
|
"bbox_shift": bbox_shift,
|
||||||
|
"version": args.version
|
||||||
}
|
}
|
||||||
self.preparation = preparation
|
self.preparation = preparation
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.idx = 0
|
self.idx = 0
|
||||||
self.init()
|
self.init()
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
if self.preparation:
|
if self.preparation:
|
||||||
if os.path.exists(self.avatar_path):
|
if os.path.exists(self.avatar_path):
|
||||||
@@ -80,7 +84,7 @@ class Avatar:
|
|||||||
print("*********************************")
|
print("*********************************")
|
||||||
print(f" creating avator: {self.avatar_id}")
|
print(f" creating avator: {self.avatar_id}")
|
||||||
print("*********************************")
|
print("*********************************")
|
||||||
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
|
osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
|
||||||
self.prepare_material()
|
self.prepare_material()
|
||||||
else:
|
else:
|
||||||
self.input_latent_list_cycle = torch.load(self.latents_out_path)
|
self.input_latent_list_cycle = torch.load(self.latents_out_path)
|
||||||
@@ -98,16 +102,16 @@ class Avatar:
|
|||||||
print("*********************************")
|
print("*********************************")
|
||||||
print(f" creating avator: {self.avatar_id}")
|
print(f" creating avator: {self.avatar_id}")
|
||||||
print("*********************************")
|
print("*********************************")
|
||||||
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
|
osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
|
||||||
self.prepare_material()
|
self.prepare_material()
|
||||||
else:
|
else:
|
||||||
if not os.path.exists(self.avatar_path):
|
if not os.path.exists(self.avatar_path):
|
||||||
print(f"{self.avatar_id} does not exist, you should set preparation to True")
|
print(f"{self.avatar_id} does not exist, you should set preparation to True")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
with open(self.avatar_info_path, "r") as f:
|
with open(self.avatar_info_path, "r") as f:
|
||||||
avatar_info = json.load(f)
|
avatar_info = json.load(f)
|
||||||
|
|
||||||
if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']:
|
if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']:
|
||||||
response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)")
|
response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)")
|
||||||
if response.lower() == "c":
|
if response.lower() == "c":
|
||||||
@@ -115,11 +119,11 @@ class Avatar:
|
|||||||
print("*********************************")
|
print("*********************************")
|
||||||
print(f" creating avator: {self.avatar_id}")
|
print(f" creating avator: {self.avatar_id}")
|
||||||
print("*********************************")
|
print("*********************************")
|
||||||
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
|
osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
|
||||||
self.prepare_material()
|
self.prepare_material()
|
||||||
else:
|
else:
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
self.input_latent_list_cycle = torch.load(self.latents_out_path)
|
self.input_latent_list_cycle = torch.load(self.latents_out_path)
|
||||||
with open(self.coords_path, 'rb') as f:
|
with open(self.coords_path, 'rb') as f:
|
||||||
self.coord_list_cycle = pickle.load(f)
|
self.coord_list_cycle = pickle.load(f)
|
||||||
@@ -131,36 +135,40 @@ class Avatar:
|
|||||||
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
|
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
self.mask_list_cycle = read_imgs(input_mask_list)
|
self.mask_list_cycle = read_imgs(input_mask_list)
|
||||||
|
|
||||||
def prepare_material(self):
|
def prepare_material(self):
|
||||||
print("preparing data materials ... ...")
|
print("preparing data materials ... ...")
|
||||||
with open(self.avatar_info_path, "w") as f:
|
with open(self.avatar_info_path, "w") as f:
|
||||||
json.dump(self.avatar_info, f)
|
json.dump(self.avatar_info, f)
|
||||||
|
|
||||||
if os.path.isfile(self.video_path):
|
if os.path.isfile(self.video_path):
|
||||||
video2imgs(self.video_path, self.full_imgs_path, ext = 'png')
|
video2imgs(self.video_path, self.full_imgs_path, ext='png')
|
||||||
else:
|
else:
|
||||||
print(f"copy files in {self.video_path}")
|
print(f"copy files in {self.video_path}")
|
||||||
files = os.listdir(self.video_path)
|
files = os.listdir(self.video_path)
|
||||||
files.sort()
|
files.sort()
|
||||||
files = [file for file in files if file.split(".")[-1]=="png"]
|
files = [file for file in files if file.split(".")[-1] == "png"]
|
||||||
for filename in files:
|
for filename in files:
|
||||||
shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}")
|
shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}")
|
||||||
input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
|
input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
|
||||||
|
|
||||||
print("extracting landmarks...")
|
print("extracting landmarks...")
|
||||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, self.bbox_shift)
|
coord_list, frame_list = get_landmark_and_bbox(input_img_list, self.bbox_shift)
|
||||||
input_latent_list = []
|
input_latent_list = []
|
||||||
idx = -1
|
idx = -1
|
||||||
# maker if the bbox is not sufficient
|
# maker if the bbox is not sufficient
|
||||||
coord_placeholder = (0.0,0.0,0.0,0.0)
|
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
|
||||||
for bbox, frame in zip(coord_list, frame_list):
|
for bbox, frame in zip(coord_list, frame_list):
|
||||||
idx = idx + 1
|
idx = idx + 1
|
||||||
if bbox == coord_placeholder:
|
if bbox == coord_placeholder:
|
||||||
continue
|
continue
|
||||||
x1, y1, x2, y2 = bbox
|
x1, y1, x2, y2 = bbox
|
||||||
|
if args.version == "v15":
|
||||||
|
y2 = y2 + args.extra_margin
|
||||||
|
y2 = min(y2, frame.shape[0])
|
||||||
|
coord_list[idx] = [x1, y1, x2, y2] # 更新coord_list中的bbox
|
||||||
crop_frame = frame[y1:y2, x1:x2]
|
crop_frame = frame[y1:y2, x1:x2]
|
||||||
resized_crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
||||||
latents = vae.get_latents_for_unet(resized_crop_frame)
|
latents = vae.get_latents_for_unet(resized_crop_frame)
|
||||||
input_latent_list.append(latents)
|
input_latent_list.append(latents)
|
||||||
|
|
||||||
@@ -170,112 +178,116 @@ class Avatar:
|
|||||||
self.mask_coords_list_cycle = []
|
self.mask_coords_list_cycle = []
|
||||||
self.mask_list_cycle = []
|
self.mask_list_cycle = []
|
||||||
|
|
||||||
for i,frame in enumerate(tqdm(self.frame_list_cycle)):
|
for i, frame in enumerate(tqdm(self.frame_list_cycle)):
|
||||||
cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png",frame)
|
cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png", frame)
|
||||||
|
|
||||||
face_box = self.coord_list_cycle[i]
|
x1, y1, x2, y2 = self.coord_list_cycle[i]
|
||||||
mask,crop_box = get_image_prepare_material(frame,face_box)
|
if args.version == "v15":
|
||||||
cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png",mask)
|
mode = args.parsing_mode
|
||||||
|
else:
|
||||||
|
mode = "raw"
|
||||||
|
mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp, mode=mode)
|
||||||
|
|
||||||
|
cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png", mask)
|
||||||
self.mask_coords_list_cycle += [crop_box]
|
self.mask_coords_list_cycle += [crop_box]
|
||||||
self.mask_list_cycle.append(mask)
|
self.mask_list_cycle.append(mask)
|
||||||
|
|
||||||
with open(self.mask_coords_path, 'wb') as f:
|
with open(self.mask_coords_path, 'wb') as f:
|
||||||
pickle.dump(self.mask_coords_list_cycle, f)
|
pickle.dump(self.mask_coords_list_cycle, f)
|
||||||
|
|
||||||
with open(self.coords_path, 'wb') as f:
|
with open(self.coords_path, 'wb') as f:
|
||||||
pickle.dump(self.coord_list_cycle, f)
|
pickle.dump(self.coord_list_cycle, f)
|
||||||
|
|
||||||
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
|
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
|
||||||
#
|
|
||||||
|
def process_frames(self, res_frame_queue, video_len, skip_save_images):
|
||||||
def process_frames(self,
|
|
||||||
res_frame_queue,
|
|
||||||
video_len,
|
|
||||||
skip_save_images):
|
|
||||||
print(video_len)
|
print(video_len)
|
||||||
while True:
|
while True:
|
||||||
if self.idx>=video_len-1:
|
if self.idx >= video_len - 1:
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
res_frame = res_frame_queue.get(block=True, timeout=1)
|
res_frame = res_frame_queue.get(block=True, timeout=1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
bbox = self.coord_list_cycle[self.idx%(len(self.coord_list_cycle))]
|
bbox = self.coord_list_cycle[self.idx % (len(self.coord_list_cycle))]
|
||||||
ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx%(len(self.frame_list_cycle))])
|
ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx % (len(self.frame_list_cycle))])
|
||||||
x1, y1, x2, y2 = bbox
|
x1, y1, x2, y2 = bbox
|
||||||
try:
|
try:
|
||||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
||||||
except:
|
except:
|
||||||
continue
|
continue
|
||||||
mask = self.mask_list_cycle[self.idx%(len(self.mask_list_cycle))]
|
mask = self.mask_list_cycle[self.idx % (len(self.mask_list_cycle))]
|
||||||
mask_crop_box = self.mask_coords_list_cycle[self.idx%(len(self.mask_coords_list_cycle))]
|
mask_crop_box = self.mask_coords_list_cycle[self.idx % (len(self.mask_coords_list_cycle))]
|
||||||
#combine_frame = get_image(ori_frame,res_frame,bbox)
|
|
||||||
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
|
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
|
||||||
|
|
||||||
if skip_save_images is False:
|
if skip_save_images is False:
|
||||||
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
|
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png", combine_frame)
|
||||||
self.idx = self.idx + 1
|
self.idx = self.idx + 1
|
||||||
|
|
||||||
def inference(self,
|
def inference(self, audio_path, out_vid_name, fps, skip_save_images):
|
||||||
audio_path,
|
os.makedirs(self.avatar_path + '/tmp', exist_ok=True)
|
||||||
out_vid_name,
|
|
||||||
fps,
|
|
||||||
skip_save_images):
|
|
||||||
os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
|
|
||||||
print("start inference")
|
print("start inference")
|
||||||
############################################## extract audio feature ##############################################
|
############################################## extract audio feature ##############################################
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
# Extract audio features
|
||||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path, weight_dtype=weight_dtype)
|
||||||
|
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||||
|
whisper_input_features,
|
||||||
|
device,
|
||||||
|
weight_dtype,
|
||||||
|
whisper,
|
||||||
|
librosa_length,
|
||||||
|
fps=fps,
|
||||||
|
audio_padding_length_left=args.audio_padding_length_left,
|
||||||
|
audio_padding_length_right=args.audio_padding_length_right,
|
||||||
|
)
|
||||||
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
|
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
|
||||||
############################################## inference batch by batch ##############################################
|
############################################## inference batch by batch ##############################################
|
||||||
video_num = len(whisper_chunks)
|
video_num = len(whisper_chunks)
|
||||||
res_frame_queue = queue.Queue()
|
res_frame_queue = queue.Queue()
|
||||||
self.idx = 0
|
self.idx = 0
|
||||||
# # Create a sub-thread and start it
|
# Create a sub-thread and start it
|
||||||
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
|
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
|
||||||
process_thread.start()
|
process_thread.start()
|
||||||
|
|
||||||
gen = datagen(whisper_chunks,
|
gen = datagen(whisper_chunks,
|
||||||
self.input_latent_list_cycle,
|
self.input_latent_list_cycle,
|
||||||
self.batch_size)
|
self.batch_size)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
res_frame_list = []
|
res_frame_list = []
|
||||||
|
|
||||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
|
|
||||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
|
||||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
|
||||||
dtype=unet.model.dtype)
|
|
||||||
audio_feature_batch = pe(audio_feature_batch)
|
|
||||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
|
||||||
|
|
||||||
pred_latents = unet.model(latent_batch,
|
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=int(np.ceil(float(video_num) / self.batch_size)))):
|
||||||
timesteps,
|
audio_feature_batch = pe(whisper_batch.to(device))
|
||||||
encoder_hidden_states=audio_feature_batch).sample
|
latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
|
||||||
|
|
||||||
|
pred_latents = unet.model(latent_batch,
|
||||||
|
timesteps,
|
||||||
|
encoder_hidden_states=audio_feature_batch).sample
|
||||||
|
pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
|
||||||
recon = vae.decode_latents(pred_latents)
|
recon = vae.decode_latents(pred_latents)
|
||||||
for res_frame in recon:
|
for res_frame in recon:
|
||||||
res_frame_queue.put(res_frame)
|
res_frame_queue.put(res_frame)
|
||||||
# Close the queue and sub-thread after all tasks are completed
|
# Close the queue and sub-thread after all tasks are completed
|
||||||
process_thread.join()
|
process_thread.join()
|
||||||
|
|
||||||
if args.skip_save_images is True:
|
if args.skip_save_images is True:
|
||||||
print('Total process time of {} frames without saving images = {}s'.format(
|
print('Total process time of {} frames without saving images = {}s'.format(
|
||||||
video_num,
|
video_num,
|
||||||
time.time()-start_time))
|
time.time() - start_time))
|
||||||
else:
|
else:
|
||||||
print('Total process time of {} frames including saving images = {}s'.format(
|
print('Total process time of {} frames including saving images = {}s'.format(
|
||||||
video_num,
|
video_num,
|
||||||
time.time()-start_time))
|
time.time() - start_time))
|
||||||
|
|
||||||
if out_vid_name is not None and args.skip_save_images is False:
|
if out_vid_name is not None and args.skip_save_images is False:
|
||||||
# optional
|
# optional
|
||||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
|
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
|
||||||
print(cmd_img2video)
|
print(cmd_img2video)
|
||||||
os.system(cmd_img2video)
|
os.system(cmd_img2video)
|
||||||
|
|
||||||
output_vid = os.path.join(self.video_out_path, out_vid_name+".mp4") # on
|
output_vid = os.path.join(self.video_out_path, out_vid_name + ".mp4") # on
|
||||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {self.avatar_path}/temp.mp4 {output_vid}"
|
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {self.avatar_path}/temp.mp4 {output_vid}"
|
||||||
print(cmd_combine_audio)
|
print(cmd_combine_audio)
|
||||||
os.system(cmd_combine_audio)
|
os.system(cmd_combine_audio)
|
||||||
@@ -284,52 +296,95 @@ class Avatar:
|
|||||||
shutil.rmtree(f"{self.avatar_path}/tmp")
|
shutil.rmtree(f"{self.avatar_path}/tmp")
|
||||||
print(f"result is save to {output_vid}")
|
print(f"result is save to {output_vid}")
|
||||||
print("\n")
|
print("\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
'''
|
'''
|
||||||
This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
|
This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--inference_config",
|
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Version of MuseTalk: v1 or v15")
|
||||||
type=str,
|
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
||||||
default="configs/inference/realtime.yaml",
|
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
||||||
)
|
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
||||||
parser.add_argument("--fps",
|
parser.add_argument("--unet_config", type=str, default="./models/musetalk/musetalk.json", help="Path to UNet configuration file")
|
||||||
type=int,
|
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
|
||||||
default=25,
|
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
||||||
)
|
parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml")
|
||||||
parser.add_argument("--batch_size",
|
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
||||||
type=int,
|
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
||||||
default=4,
|
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
|
||||||
)
|
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
||||||
|
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
||||||
|
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=25, help="Batch size for inference")
|
||||||
|
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
||||||
|
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
||||||
|
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
||||||
|
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
|
||||||
|
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
|
||||||
|
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
|
||||||
parser.add_argument("--skip_save_images",
|
parser.add_argument("--skip_save_images",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether skip saving images for better generation speed calculation",
|
help="Whether skip saving images for better generation speed calculation",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set computing device
|
||||||
|
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Load model weights
|
||||||
|
vae, unet, pe = load_all_model(
|
||||||
|
unet_model_path=args.unet_model_path,
|
||||||
|
vae_type=args.vae_type,
|
||||||
|
unet_config=args.unet_config,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
timesteps = torch.tensor([0], device=device)
|
||||||
|
|
||||||
|
pe = pe.half().to(device)
|
||||||
|
vae.vae = vae.vae.half().to(device)
|
||||||
|
unet.model = unet.model.half().to(device)
|
||||||
|
|
||||||
|
# Initialize audio processor and Whisper model
|
||||||
|
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
||||||
|
weight_dtype = unet.model.dtype
|
||||||
|
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
||||||
|
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||||
|
whisper.requires_grad_(False)
|
||||||
|
|
||||||
|
# Initialize face parser with configurable parameters based on version
|
||||||
|
if args.version == "v15":
|
||||||
|
fp = FaceParsing(
|
||||||
|
left_cheek_width=args.left_cheek_width,
|
||||||
|
right_cheek_width=args.right_cheek_width
|
||||||
|
)
|
||||||
|
else: # v1
|
||||||
|
fp = FaceParsing()
|
||||||
|
|
||||||
inference_config = OmegaConf.load(args.inference_config)
|
inference_config = OmegaConf.load(args.inference_config)
|
||||||
print(inference_config)
|
print(inference_config)
|
||||||
|
|
||||||
|
|
||||||
for avatar_id in inference_config:
|
for avatar_id in inference_config:
|
||||||
data_preparation = inference_config[avatar_id]["preparation"]
|
data_preparation = inference_config[avatar_id]["preparation"]
|
||||||
video_path = inference_config[avatar_id]["video_path"]
|
video_path = inference_config[avatar_id]["video_path"]
|
||||||
bbox_shift = inference_config[avatar_id]["bbox_shift"]
|
if args.version == "v15":
|
||||||
|
bbox_shift = 0
|
||||||
|
else:
|
||||||
|
bbox_shift = inference_config[avatar_id]["bbox_shift"]
|
||||||
avatar = Avatar(
|
avatar = Avatar(
|
||||||
avatar_id = avatar_id,
|
avatar_id=avatar_id,
|
||||||
video_path = video_path,
|
video_path=video_path,
|
||||||
bbox_shift = bbox_shift,
|
bbox_shift=bbox_shift,
|
||||||
batch_size = args.batch_size,
|
batch_size=args.batch_size,
|
||||||
preparation= data_preparation)
|
preparation=data_preparation)
|
||||||
|
|
||||||
audio_clips = inference_config[avatar_id]["audio_clips"]
|
audio_clips = inference_config[avatar_id]["audio_clips"]
|
||||||
for audio_num, audio_path in audio_clips.items():
|
for audio_num, audio_path in audio_clips.items():
|
||||||
print("Inferring using:",audio_path)
|
print("Inferring using:", audio_path)
|
||||||
avatar.inference(audio_path,
|
avatar.inference(audio_path,
|
||||||
audio_num,
|
audio_num,
|
||||||
args.fps,
|
args.fps,
|
||||||
args.skip_save_images)
|
args.skip_save_images)
|
||||||
|
|||||||
Reference in New Issue
Block a user