mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 09:29:20 +08:00
feat: windows infer & gradio (#312)
* fix: windows infer * docs: update readme * docs: update readme * feat: v1.5 gradio for windows&linux * fix: dependencies * feat: windows infer & gradio --------- Co-authored-by: NeRF-Factory <zzhizhou66@gmail.com>
This commit is contained in:
476
app.py
476
app.py
@@ -4,7 +4,6 @@ import pdb
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
import spaces
|
||||
import numpy as np
|
||||
import sys
|
||||
import subprocess
|
||||
@@ -28,11 +27,101 @@ import gdown
|
||||
import imageio
|
||||
import ffmpeg
|
||||
from moviepy.editor import *
|
||||
|
||||
from transformers import WhisperModel
|
||||
|
||||
ProjectDir = os.path.abspath(os.path.dirname(__file__))
|
||||
CheckpointsDir = os.path.join(ProjectDir, "models")
|
||||
|
||||
@torch.no_grad()
|
||||
def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
||||
left_cheek_width=90, right_cheek_width=90):
|
||||
"""Debug inpainting parameters, only process the first frame"""
|
||||
# Set default parameters
|
||||
args_dict = {
|
||||
"result_dir": './results/debug',
|
||||
"fps": 25,
|
||||
"batch_size": 1,
|
||||
"output_vid_name": '',
|
||||
"use_saved_coord": False,
|
||||
"audio_padding_length_left": 2,
|
||||
"audio_padding_length_right": 2,
|
||||
"version": "v15",
|
||||
"extra_margin": extra_margin,
|
||||
"parsing_mode": parsing_mode,
|
||||
"left_cheek_width": left_cheek_width,
|
||||
"right_cheek_width": right_cheek_width
|
||||
}
|
||||
args = Namespace(**args_dict)
|
||||
|
||||
# Create debug directory
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
|
||||
# Read first frame
|
||||
if get_file_type(video_path) == "video":
|
||||
reader = imageio.get_reader(video_path)
|
||||
first_frame = reader.get_data(0)
|
||||
reader.close()
|
||||
else:
|
||||
first_frame = cv2.imread(video_path)
|
||||
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Save first frame
|
||||
debug_frame_path = os.path.join(args.result_dir, "debug_frame.png")
|
||||
cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
|
||||
|
||||
# Get face coordinates
|
||||
coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift)
|
||||
bbox = coord_list[0]
|
||||
frame = frame_list[0]
|
||||
|
||||
if bbox == coord_placeholder:
|
||||
return None, "No face detected, please adjust bbox_shift parameter"
|
||||
|
||||
# Initialize face parser
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
|
||||
# Process first frame
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
|
||||
# Generate random audio features
|
||||
random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype)
|
||||
audio_feature = pe(random_audio)
|
||||
|
||||
# Get latents
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
latents = latents.to(dtype=weight_dtype)
|
||||
|
||||
# Generate prediction results
|
||||
pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
|
||||
# Inpaint back to original image
|
||||
res_frame = recon[0]
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
|
||||
# Save results (no need to convert color space again since get_image already returns RGB format)
|
||||
debug_result_path = os.path.join(args.result_dir, "debug_result.png")
|
||||
cv2.imwrite(debug_result_path, combine_frame)
|
||||
|
||||
# Create information text
|
||||
info_text = f"Parameter information:\n" + \
|
||||
f"bbox_shift: {bbox_shift}\n" + \
|
||||
f"extra_margin: {extra_margin}\n" + \
|
||||
f"parsing_mode: {parsing_mode}\n" + \
|
||||
f"left_cheek_width: {left_cheek_width}\n" + \
|
||||
f"right_cheek_width: {right_cheek_width}\n" + \
|
||||
f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]"
|
||||
|
||||
return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text
|
||||
|
||||
def print_directory_contents(path):
|
||||
for child in os.listdir(path):
|
||||
child_path = os.path.join(path, child)
|
||||
@@ -40,119 +129,107 @@ def print_directory_contents(path):
|
||||
print(child_path)
|
||||
|
||||
def download_model():
|
||||
if not os.path.exists(CheckpointsDir):
|
||||
os.makedirs(CheckpointsDir)
|
||||
print("Checkpoint Not Downloaded, start downloading...")
|
||||
tic = time.time()
|
||||
snapshot_download(
|
||||
repo_id="TMElyralab/MuseTalk",
|
||||
local_dir=CheckpointsDir,
|
||||
max_workers=8,
|
||||
local_dir_use_symlinks=True,
|
||||
force_download=True, resume_download=False
|
||||
)
|
||||
# weight
|
||||
os.makedirs(f"{CheckpointsDir}/sd-vae-ft-mse/")
|
||||
snapshot_download(
|
||||
repo_id="stabilityai/sd-vae-ft-mse",
|
||||
local_dir=CheckpointsDir+'/sd-vae-ft-mse',
|
||||
max_workers=8,
|
||||
local_dir_use_symlinks=True,
|
||||
force_download=True, resume_download=False
|
||||
)
|
||||
#dwpose
|
||||
os.makedirs(f"{CheckpointsDir}/dwpose/")
|
||||
snapshot_download(
|
||||
repo_id="yzd-v/DWPose",
|
||||
local_dir=CheckpointsDir+'/dwpose',
|
||||
max_workers=8,
|
||||
local_dir_use_symlinks=True,
|
||||
force_download=True, resume_download=False
|
||||
)
|
||||
#vae
|
||||
url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
|
||||
response = requests.get(url)
|
||||
# 确保请求成功
|
||||
if response.status_code == 200:
|
||||
# 指定文件保存的位置
|
||||
file_path = f"{CheckpointsDir}/whisper/tiny.pt"
|
||||
os.makedirs(f"{CheckpointsDir}/whisper/")
|
||||
# 将文件内容写入指定位置
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
# 检查必需的模型文件是否存在
|
||||
required_models = {
|
||||
"MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth",
|
||||
"MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json",
|
||||
"SD VAE": f"{CheckpointsDir}/sd-vae/config.json",
|
||||
"Whisper": f"{CheckpointsDir}/whisper/config.json",
|
||||
"DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth",
|
||||
"SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt",
|
||||
"Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth",
|
||||
"ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
|
||||
}
|
||||
|
||||
missing_models = []
|
||||
for model_name, model_path in required_models.items():
|
||||
if not os.path.exists(model_path):
|
||||
missing_models.append(model_name)
|
||||
|
||||
if missing_models:
|
||||
# 全用英文
|
||||
print("The following required model files are missing:")
|
||||
for model in missing_models:
|
||||
print(f"- {model}")
|
||||
print("\nPlease run the download script to download the missing models:")
|
||||
if sys.platform == "win32":
|
||||
print("Windows: Run download_weights.bat")
|
||||
else:
|
||||
print(f"请求失败,状态码:{response.status_code}")
|
||||
#gdown face parse
|
||||
url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812"
|
||||
os.makedirs(f"{CheckpointsDir}/face-parse-bisent/")
|
||||
file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth"
|
||||
gdown.download(url, file_path, quiet=False)
|
||||
#resnet
|
||||
url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
|
||||
response = requests.get(url)
|
||||
# 确保请求成功
|
||||
if response.status_code == 200:
|
||||
# 指定文件保存的位置
|
||||
file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
|
||||
# 将文件内容写入指定位置
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
print(f"请求失败,状态码:{response.status_code}")
|
||||
|
||||
|
||||
toc = time.time()
|
||||
|
||||
print(f"download cost {toc-tic} seconds")
|
||||
print_directory_contents(CheckpointsDir)
|
||||
|
||||
print("Linux/Mac: Run ./download_weights.sh")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("Already download the model.")
|
||||
|
||||
print("All required model files exist.")
|
||||
|
||||
|
||||
|
||||
|
||||
download_model() # for huggingface deployment.
|
||||
|
||||
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder,get_bbox_range
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
@spaces.GPU(duration=600)
|
||||
@torch.no_grad()
|
||||
def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=True)):
|
||||
args_dict={"result_dir":'./results/output', "fps":25, "batch_size":8, "output_vid_name":'', "use_saved_coord":False}#same with inferenece script
|
||||
def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
||||
left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)):
|
||||
# Set default parameters, aligned with inference.py
|
||||
args_dict = {
|
||||
"result_dir": './results/output',
|
||||
"fps": 25,
|
||||
"batch_size": 8,
|
||||
"output_vid_name": '',
|
||||
"use_saved_coord": False,
|
||||
"audio_padding_length_left": 2,
|
||||
"audio_padding_length_right": 2,
|
||||
"version": "v15", # Fixed use v15 version
|
||||
"extra_margin": extra_margin,
|
||||
"parsing_mode": parsing_mode,
|
||||
"left_cheek_width": left_cheek_width,
|
||||
"right_cheek_width": right_cheek_width
|
||||
}
|
||||
args = Namespace(**args_dict)
|
||||
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||
os.makedirs(result_img_save_path,exist_ok =True)
|
||||
# Check ffmpeg
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
if args.output_vid_name=="":
|
||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
|
||||
# Create temporary directory
|
||||
temp_dir = os.path.join(args.result_dir, f"{args.version}")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Set result save path
|
||||
result_img_save_path = os.path.join(temp_dir, output_basename)
|
||||
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
|
||||
os.makedirs(result_img_save_path, exist_ok=True)
|
||||
|
||||
if args.output_vid_name == "":
|
||||
output_vid_name = os.path.join(temp_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
||||
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path)=="video":
|
||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||
os.makedirs(save_dir_full,exist_ok = True)
|
||||
# cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
# os.system(cmd)
|
||||
# 读取视频
|
||||
if get_file_type(video_path) == "video":
|
||||
save_dir_full = os.path.join(temp_dir, input_basename)
|
||||
os.makedirs(save_dir_full, exist_ok=True)
|
||||
# Read video
|
||||
reader = imageio.get_reader(video_path)
|
||||
|
||||
# 保存图片
|
||||
# Save images
|
||||
for i, im in enumerate(reader):
|
||||
imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
@@ -161,10 +238,21 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
#print(input_img_list)
|
||||
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
# Extract audio features
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=args.audio_padding_length_left,
|
||||
audio_padding_length_right=args.audio_padding_length_right,
|
||||
)
|
||||
|
||||
############################################## preprocess input image ##############################################
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
@@ -176,13 +264,22 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
bbox_shift_text=get_bbox_range(input_img_list, bbox_shift)
|
||||
bbox_shift_text = get_bbox_range(input_img_list, bbox_shift)
|
||||
|
||||
# Initialize face parser
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
@@ -192,17 +289,23 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
|
||||
############################################## inference batch by batch ##############################################
|
||||
print("start inference")
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=device,
|
||||
)
|
||||
res_frame_list = []
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
|
||||
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
|
||||
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
audio_feature_batch = pe(whisper_batch)
|
||||
# Ensure latent_batch is consistent with model weight type
|
||||
latent_batch = latent_batch.to(dtype=weight_dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
@@ -215,25 +318,24 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
except:
|
||||
# print(bbox)
|
||||
continue
|
||||
|
||||
combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
# Use v15 version blending
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||
|
||||
# cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p temp.mp4"
|
||||
# print(cmd_img2video)
|
||||
# os.system(cmd_img2video)
|
||||
# 帧率
|
||||
# Frame rate
|
||||
fps = 25
|
||||
# 图片路径
|
||||
# 输出视频路径
|
||||
# Output video path
|
||||
output_video = 'temp.mp4'
|
||||
|
||||
# 读取图片
|
||||
# Read images
|
||||
def is_valid_image(file):
|
||||
pattern = re.compile(r'\d{8}\.png')
|
||||
return pattern.match(file)
|
||||
@@ -247,13 +349,9 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
||||
images.append(imageio.imread(filename))
|
||||
|
||||
|
||||
# 保存视频
|
||||
# Save video
|
||||
imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
|
||||
|
||||
# cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
# print(cmd_combine_audio)
|
||||
# os.system(cmd_combine_audio)
|
||||
|
||||
input_video = './temp.mp4'
|
||||
# Check if the input_video and audio_path exist
|
||||
if not os.path.exists(input_video):
|
||||
@@ -261,40 +359,15 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
||||
if not os.path.exists(audio_path):
|
||||
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
||||
|
||||
# 读取视频
|
||||
# Read video
|
||||
reader = imageio.get_reader(input_video)
|
||||
fps = reader.get_meta_data()['fps'] # 获取原视频的帧率
|
||||
reader.close() # 否则在win11上会报错:PermissionError: [WinError 32] 另一个程序正在使用此文件,进程无法访问。: 'temp.mp4'
|
||||
# 将帧存储在列表中
|
||||
fps = reader.get_meta_data()['fps'] # Get original video frame rate
|
||||
reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4'
|
||||
# Store frames in list
|
||||
frames = images
|
||||
|
||||
# 保存视频并添加音频
|
||||
# imageio.mimwrite(output_vid_name, frames, 'FFMPEG', fps=fps, codec='libx264', audio_codec='aac', input_params=['-i', audio_path])
|
||||
|
||||
# input_video = ffmpeg.input(input_video)
|
||||
|
||||
# input_audio = ffmpeg.input(audio_path)
|
||||
|
||||
print(len(frames))
|
||||
|
||||
# imageio.mimwrite(
|
||||
# output_video,
|
||||
# frames,
|
||||
# 'FFMPEG',
|
||||
# fps=25,
|
||||
# codec='libx264',
|
||||
# audio_codec='aac',
|
||||
# input_params=['-i', audio_path],
|
||||
# output_params=['-y'], # Add the '-y' flag to overwrite the output file if it exists
|
||||
# )
|
||||
# writer = imageio.get_writer(output_vid_name, fps = 25, codec='libx264', quality=10, pixelformat='yuvj444p')
|
||||
# for im in frames:
|
||||
# writer.append_data(im)
|
||||
# writer.close()
|
||||
|
||||
|
||||
|
||||
|
||||
# Load the video
|
||||
video_clip = VideoFileClip(input_video)
|
||||
|
||||
@@ -315,11 +388,45 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
||||
|
||||
|
||||
# load model weights
|
||||
audio_processor,vae,unet,pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
vae, unet, pe = load_all_model(
|
||||
unet_model_path="./models/musetalkV15/unet.pth",
|
||||
vae_type="sd-vae",
|
||||
unet_config="./models/musetalkV15/musetalk.json",
|
||||
device=device
|
||||
)
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable")
|
||||
parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
|
||||
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
|
||||
parser.add_argument("--share", action="store_true", help="Create a public link")
|
||||
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set data type
|
||||
if args.use_float16:
|
||||
# Convert models to half precision for better performance
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
weight_dtype = torch.float16
|
||||
else:
|
||||
weight_dtype = torch.float32
|
||||
|
||||
# Move models to specified device
|
||||
pe = pe.to(device)
|
||||
vae.vae = vae.vae.to(device)
|
||||
unet.model = unet.model.to(device)
|
||||
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
|
||||
# Initialize audio processor and Whisper model
|
||||
audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
|
||||
whisper = WhisperModel.from_pretrained("./models/whisper")
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
|
||||
def check_video(video):
|
||||
@@ -340,9 +447,6 @@ def check_video(video):
|
||||
output_video = os.path.join('./results/input', output_file_name)
|
||||
|
||||
|
||||
# # Run the ffmpeg command to change the frame rate to 25fps
|
||||
# command = f"ffmpeg -i {video} -r 25 -vcodec libx264 -vtag hvc1 -pix_fmt yuv420p crf 18 {output_video} -y"
|
||||
|
||||
# read video
|
||||
reader = imageio.get_reader(video)
|
||||
fps = reader.get_meta_data()['fps'] # get fps from original video
|
||||
@@ -374,34 +478,45 @@ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024p
|
||||
|
||||
with gr.Blocks(css=css) as demo:
|
||||
gr.Markdown(
|
||||
"<div align='center'> <h1>MuseTalk: Real-Time High Quality Lip Synchronization with Latent Space Inpainting </span> </h1> \
|
||||
"""<div align='center'> <h1>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</h1> \
|
||||
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
||||
</br>\
|
||||
Yue Zhang <sup>\*</sup>,\
|
||||
Minhao Liu<sup>\*</sup>,\
|
||||
Yue Zhang <sup>*</sup>,\
|
||||
Zhizhou Zhong <sup>*</sup>,\
|
||||
Minhao Liu<sup>*</sup>,\
|
||||
Zhaokang Chen,\
|
||||
Bin Wu<sup>†</sup>,\
|
||||
Yubin Zeng,\
|
||||
Chao Zhang,\
|
||||
Yingjie He,\
|
||||
Chao Zhan,\
|
||||
Wenjiang Zhou\
|
||||
Junxin Huang,\
|
||||
Wenjiang Zhou <br>\
|
||||
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
|
||||
Lyra Lab, Tencent Music Entertainment\
|
||||
</h2> \
|
||||
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
|
||||
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
|
||||
<a style='font-size:18px;color: #000000' href=''> [Technical report(Coming Soon)] </a>\
|
||||
<a style='font-size:18px;color: #000000' href=''> [Project Page(Coming Soon)] </a> </div>"
|
||||
<a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2410.10122'> [Technical report] </a>"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
audio = gr.Audio(label="Driven Audio",type="filepath")
|
||||
audio = gr.Audio(label="Drving Audio",type="filepath")
|
||||
video = gr.Video(label="Reference Video",sources=['upload'])
|
||||
bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
|
||||
bbox_shift_scale = gr.Textbox(label="BBox_shift recommend value lower bound,The corresponding bbox range is generated after the initial result is generated. \n If the result is not good, it can be adjusted according to this reference value", value="",interactive=False)
|
||||
extra_margin = gr.Slider(label="Extra Margin", minimum=0, maximum=40, value=10, step=1)
|
||||
parsing_mode = gr.Radio(label="Parsing Mode", choices=["jaw", "raw"], value="jaw")
|
||||
left_cheek_width = gr.Slider(label="Left Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
||||
right_cheek_width = gr.Slider(label="Right Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
||||
bbox_shift_scale = gr.Textbox(label="'left_cheek_width' and 'right_cheek_width' parameters determine the range of left and right cheeks editing when parsing model is 'jaw'. The 'extra_margin' parameter determines the movement range of the jaw. Users can freely adjust these three parameters to obtain better inpainting results.")
|
||||
|
||||
btn = gr.Button("Generate")
|
||||
out1 = gr.Video()
|
||||
with gr.Row():
|
||||
debug_btn = gr.Button("1. Test Inpainting ")
|
||||
btn = gr.Button("2. Generate")
|
||||
with gr.Column():
|
||||
debug_image = gr.Image(label="Test Inpainting Result (First Frame)")
|
||||
debug_info = gr.Textbox(label="Parameter Information", lines=5)
|
||||
out1 = gr.Video()
|
||||
|
||||
video.change(
|
||||
fn=check_video, inputs=[video], outputs=[video]
|
||||
@@ -412,15 +527,44 @@ with gr.Blocks(css=css) as demo:
|
||||
audio,
|
||||
video,
|
||||
bbox_shift,
|
||||
extra_margin,
|
||||
parsing_mode,
|
||||
left_cheek_width,
|
||||
right_cheek_width
|
||||
],
|
||||
outputs=[out1,bbox_shift_scale]
|
||||
)
|
||||
debug_btn.click(
|
||||
fn=debug_inpainting,
|
||||
inputs=[
|
||||
video,
|
||||
bbox_shift,
|
||||
extra_margin,
|
||||
parsing_mode,
|
||||
left_cheek_width,
|
||||
right_cheek_width
|
||||
],
|
||||
outputs=[debug_image, debug_info]
|
||||
)
|
||||
|
||||
# Set the IP and port
|
||||
ip_address = "0.0.0.0" # Replace with your desired IP address
|
||||
port_number = 7860 # Replace with your desired port number
|
||||
# Check ffmpeg and add to PATH
|
||||
if not fast_check_ffmpeg():
|
||||
print(f"Adding ffmpeg to PATH: {args.ffmpeg_path}")
|
||||
# According to operating system, choose path separator
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
# Solve asynchronous IO issues on Windows
|
||||
if sys.platform == 'win32':
|
||||
import asyncio
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
# Start Gradio application
|
||||
demo.queue().launch(
|
||||
share=False , debug=True, server_name=ip_address, server_port=port_number
|
||||
share=args.share,
|
||||
debug=True,
|
||||
server_name=args.ip,
|
||||
server_port=args.port
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user