Add codes for real time inference

This commit is contained in:
czk32611
2024-04-18 12:05:22 +08:00
parent 955ca416ea
commit 0387c39a93
4 changed files with 373 additions and 5 deletions

View File

@@ -11,7 +11,7 @@ Chao Zhan,
Wenjiang Zhou
(<sup>*</sup>Equal Contribution, <sup></sup>Corresponding Author, benbinwu@tencent.com)
**[github](https://github.com/TMElyralab/MuseTalk)** **[huggingface](https://huggingface.co/TMElyralab/MuseTalk)** **[gradio](https://huggingface.co/spaces/TMElyralab/MuseTalk)** **Project (comming soon)** **Technical report (comming soon)**
**[github](https://github.com/TMElyralab/MuseTalk)** **[huggingface](https://huggingface.co/TMElyralab/MuseTalk)** **[space](https://huggingface.co/spaces/TMElyralab/MuseTalk)** **Project (comming soon)** **Technical report (comming soon)**
We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+ on an NVIDIA Tesla V100). MuseTalk can be applied with input videos, e.g., generated by [MuseV](https://github.com/TMElyralab/MuseV), as a complete virtual human solution.
@@ -28,12 +28,13 @@ We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+
# News
- [04/02/2024] Release MuseTalk project and pretrained models.
- [04/16/2024] Release Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk) on HuggingFace Spaces (thanks to HF team for their community grant)
- [04/17/2024] :mega: We release a pipeline that utilizes MuseTalk for real-time inference.
## Model
![Model Structure](assets/figs/musetalk_arc.jpg)
MuseTalk was trained in latent spaces, where the images were encoded by a freezed VAE. The audio was encoded by a freezed `whisper-tiny` model. The architecture of the generation network was borrowed from the UNet of the `stable-diffusion-v1-4`, where the audio embeddings were fused to the image embeddings by cross-attention.
Note that although we use a very similar architecture as Stable Diffusion, MuseTalk is distinct in that it is `Not` a diffusion model. Instead, MuseTalk operates by inpainting in the latent space with `a single step`.
Note that although we use a very similar architecture as Stable Diffusion, MuseTalk is distinct in that it is **NOT** a diffusion model. Instead, MuseTalk operates by inpainting in the latent space with a single step.
## Cases
### MuseV + MuseTalk make human photos alive
@@ -162,7 +163,7 @@ Note that although we use a very similar architecture as Stable Diffusion, MuseT
# TODO:
- [x] trained models and inference codes.
- [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk).
- [ ] codes for real-time inference.
- [x] codes for real-time inference.
- [ ] technical report.
- [ ] training codes.
- [ ] a better model (may take longer).
@@ -262,9 +263,30 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
# Note
#### :new: Real-time inference
If you want to launch online video chats, you are suggested to generate videos using MuseV and apply necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
```
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml
```
configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
1. Set `preparation` to `True` in `realtime.yaml` to prepare the materials for a new `avatar`. (If the `bbox_shift` has changed, you also need to re-prepare the materials.)
1. After that, the `avatar` will use an audio clip selected from `audio_clips` to generate video.
```
Inferring using: data/audio/yongen.wav
```
1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve up to 50fps on an NVIDIA Tesla V100.
```
2%|██▍ | 3/141 [00:00<00:32, 4.30it/s] # inference process
Generating the 6-th frame with FPS: 48.58 # playing process
Generating the 7-th frame with FPS: 48.74
Generating the 8-th frame with FPS: 49.17
3%|███▎ | 4/141 [00:00<00:32, 4.21it/s]
```
1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
# Acknowledgement

View File

@@ -0,0 +1,10 @@
avator_1:
preparation: False
bbox_shift: 5
video_path: "data/video/sun.mp4"
audio_clips:
audio_0: "data/audio/yongen.wav"
audio_1: "data/audio/sun.wav"

View File

@@ -57,3 +57,44 @@ def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2):
body.paste(face_large, crop_box[:2], mask_image)
body = np.array(body)
return body[:,:,::-1]
def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
body = Image.fromarray(image[:,:,::-1])
x, y, x1, y1 = face_box
#print(x1-x,y1-y)
crop_box, s = get_crop_box(face_box, expand)
x_s, y_s, x_e, y_e = crop_box
face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
# keep upper_boundary_ratio of talking area
width, height = mask_image.size
top_boundary = int(height * upper_boundary_ratio)
modified_mask_image = Image.new('L', ori_shape, 0)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
return mask_array,crop_box
def get_image_blending(image,face,face_box,mask_array,crop_box):
body = Image.fromarray(image[:,:,::-1])
face = Image.fromarray(face[:,:,::-1])
x, y, x1, y1 = face_box
x_s, y_s, x_e, y_e = crop_box
face_large = body.crop(crop_box)
mask_image = Image.fromarray(mask_array)
mask_image = mask_image.convert("L")
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
body.paste(face_large, crop_box[:2], mask_image)
body = np.array(body)
return body[:,:,::-1]

View File

@@ -0,0 +1,295 @@
import argparse
import os
from omegaconf import OmegaConf
import numpy as np
import cv2
import torch
import glob
import pickle
import sys
from tqdm import tqdm
import copy
import json
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
from musetalk.utils.utils import load_all_model
import shutil
import threading
import queue
import time
# load model weights
audio_processor,vae,unet,pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
cap = cv2.VideoCapture(vid_path)
count = 0
while True:
if count > cut_frame:
break
ret, frame = cap.read()
if ret:
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
count += 1
else:
break
def osmakedirs(path_list):
for path in path_list:
os.makedirs(path) if not os.path.exists(path) else None
@torch.no_grad()
class Avatar:
def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation):
self.avatar_id = avatar_id
self.video_path = video_path
self.bbox_shift = bbox_shift
self.avatar_path = f"./results/avatars/{avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl"
self.latents_out_path= f"{self.avatar_path}/latents.pt"
self.video_out_path = f"{self.avatar_path}/vid_output/"
self.mask_out_path =f"{self.avatar_path}/mask"
self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl"
self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
self.avatar_info = {
"avatar_id":avatar_id,
"video_path":video_path,
"bbox_shift":bbox_shift
}
self.preparation = preparation
self.batch_size = batch_size
self.idx = 0
self.init()
def init(self):
if self.preparation:
if os.path.exists(self.avatar_path):
response = input(f"{self.avatar_id} exists, Do you want to re-create it ? (y/n)")
if response.lower() == "y":
shutil.rmtree(self.avatar_path)
print("*********************************")
print(f" creating avator: {self.avatar_id}")
print("*********************************")
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
self.prepare_material()
else:
self.input_latent_list_cycle = torch.load(self.latents_out_path)
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.frame_list_cycle = read_imgs(input_img_list)
with open(self.mask_coords_path, 'rb') as f:
self.mask_coords_list_cycle = pickle.load(f)
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.mask_list_cycle = read_imgs(input_mask_list)
else:
print("*********************************")
print(f" creating avator: {self.avatar_id}")
print("*********************************")
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
self.prepare_material()
else:
with open(self.avatar_info_path, "r") as f:
avatar_info = json.load(f)
if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']:
response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)")
if response.lower() == "c":
shutil.rmtree(self.avatar_path)
print("*********************************")
print(f" creating avator: {self.avatar_id}")
print("*********************************")
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
self.prepare_material()
else:
sys.exit()
else:
self.input_latent_list_cycle = torch.load(self.latents_out_path)
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.frame_list_cycle = read_imgs(input_img_list)
with open(self.mask_coords_path, 'rb') as f:
self.mask_coords_list_cycle = pickle.load(f)
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.mask_list_cycle = read_imgs(input_mask_list)
def prepare_material(self):
print("preparing data materials ... ...")
with open(self.avatar_info_path, "w") as f:
json.dump(self.avatar_info, f)
if os.path.isfile(self.video_path):
video2imgs(self.video_path, self.full_imgs_path, ext = 'png')
else:
print(f"copy files in {self.video_path}")
files = os.listdir(self.video_path)
files.sort()
files = [file for file in files if file.split(".")[-1]=="png"]
for filename in files:
shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}")
input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
print("extracting landmarks...")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, self.bbox_shift)
input_latent_list = []
idx = -1
# maker if the bbox is not sufficient
coord_placeholder = (0.0,0.0,0.0,0.0)
for bbox, frame in zip(coord_list, frame_list):
idx = idx + 1
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
resized_crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(resized_crop_frame)
input_latent_list.append(latents)
self.frame_list_cycle = frame_list + frame_list[::-1]
self.coord_list_cycle = coord_list + coord_list[::-1]
self.input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
self.mask_coords_list_cycle = []
self.mask_list_cycle = []
for i,frame in enumerate(tqdm(self.frame_list_cycle)):
cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png",frame)
face_box = self.coord_list_cycle[i]
mask,crop_box = get_image_prepare_material(frame,face_box)
cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png",mask)
self.mask_coords_list_cycle += [crop_box]
self.mask_list_cycle.append(mask)
with open(self.mask_coords_path, 'wb') as f:
pickle.dump(self.mask_coords_list_cycle, f)
with open(self.coords_path, 'wb') as f:
pickle.dump(self.coord_list_cycle, f)
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
#
def process_frames(self, res_frame_queue,video_len):
print(video_len)
while True:
if self.idx>=video_len-1:
break
try:
start = time.time()
res_frame = res_frame_queue.get(block=True, timeout=1)
except queue.Empty:
continue
bbox = self.coord_list_cycle[self.idx%(len(self.coord_list_cycle))]
ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx%(len(self.frame_list_cycle))])
x1, y1, x2, y2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
continue
mask = self.mask_list_cycle[self.idx%(len(self.mask_list_cycle))]
mask_crop_box = self.mask_coords_list_cycle[self.idx%(len(self.mask_coords_list_cycle))]
#combine_frame = get_image(ori_frame,res_frame,bbox)
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
fps = 1/(time.time()-start)
print(f"Generating the {self.idx}-th frame with FPS: {fps:.2f}")
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
self.idx = self.idx + 1
def inference(self, audio_path, out_vid_name, fps):
os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
############################################## inference batch by batch ##############################################
video_num = len(whisper_chunks)
print("start inference")
res_frame_queue = queue.Queue()
self.idx = 0
# # Create a sub-thread and start it
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue,video_num))
process_thread.start()
start_time = time.time()
gen = datagen(whisper_chunks,self.input_latent_list_cycle, self.batch_size)
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
start_time = time.time()
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
start_time = time.time()
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
audio_feature_batch = pe(audio_feature_batch)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_queue.put(res_frame)
# Close the queue and sub-thread after all tasks are completed
process_thread.join()
if out_vid_name is not None:
# optional
cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
print(cmd_img2video)
os.system(cmd_img2video)
output_vid = os.path.join(self.video_out_path, out_vid_name+".mp4") # on
cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i {self.avatar_path}/temp.mp4 {output_vid}"
print(cmd_combine_audio)
os.system(cmd_combine_audio)
os.remove(f"{self.avatar_path}/temp.mp4")
shutil.rmtree(f"{self.avatar_path}/tmp")
print(f"result is save to {output_vid}")
if __name__ == "__main__":
'''
This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
'''
parser = argparse.ArgumentParser()
parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml")
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--batch_size", type=int, default=4)
args = parser.parse_args()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
for avatar_id in inference_config:
data_preparation = inference_config[avatar_id]["preparation"]
video_path = inference_config[avatar_id]["video_path"]
bbox_shift = inference_config[avatar_id]["bbox_shift"]
avatar = Avatar(
avatar_id = avatar_id,
video_path = video_path,
bbox_shift = bbox_shift,
batch_size = args.batch_size,
preparation= data_preparation)
audio_clips = inference_config[avatar_id]["audio_clips"]
for audio_num, audio_path in audio_clips.items():
print("Inferring using:",audio_path)
avatar.inference(audio_path, audio_num, args.fps)