feat: v1.5 gradio for windows&linux

This commit is contained in:
zzzweakman
2025-04-11 02:43:04 +08:00
parent 2e5b74a257
commit b9b459a119
8 changed files with 330 additions and 169 deletions

3
.gitignore vendored
View File

@@ -14,4 +14,5 @@ ffmpeg*
ffmprobe* ffmprobe*
ffplay* ffplay*
debug debug
exp_out exp_out
.gradio

View File

@@ -330,6 +330,18 @@ For faster generation without saving images, you can use:
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
``` ```
## Gradio Demo
We provide an intuitive web interface through Gradio for users to easily adjust input parameters. To optimize inference time, users can generate only the **first frame** to fine-tune the best lip-sync parameters, which helps reduce facial artifacts in the final output.
![para](assets/figs/gradio_2.png)
For minimum hardware requirements, we tested the system on a Windows environment using an NVIDIA GeForce RTX 3050 Ti Laptop GPU with 4GB VRAM. In fp16 mode, generating an 8-second video takes approximately 5 minutes. ![speed](assets/figs/gradio.png)
Both Linux and Windows users can launch the demo using the following command. Please ensure that the `ffmpeg_path` parameter matches your actual FFmpeg installation path:
```bash
# You can remove --use_float16 for better quality, but it will increase VRAM usage and inference time
python app.py --use_float16 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
```
## Training ## Training
### Data Preparation ### Data Preparation

474
app.py
View File

@@ -28,11 +28,101 @@ import gdown
import imageio import imageio
import ffmpeg import ffmpeg
from moviepy.editor import * from moviepy.editor import *
from transformers import WhisperModel
ProjectDir = os.path.abspath(os.path.dirname(__file__)) ProjectDir = os.path.abspath(os.path.dirname(__file__))
CheckpointsDir = os.path.join(ProjectDir, "models") CheckpointsDir = os.path.join(ProjectDir, "models")
@torch.no_grad()
def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
left_cheek_width=90, right_cheek_width=90):
"""Debug inpainting parameters, only process the first frame"""
# Set default parameters
args_dict = {
"result_dir": './results/debug',
"fps": 25,
"batch_size": 1,
"output_vid_name": '',
"use_saved_coord": False,
"audio_padding_length_left": 2,
"audio_padding_length_right": 2,
"version": "v15",
"extra_margin": extra_margin,
"parsing_mode": parsing_mode,
"left_cheek_width": left_cheek_width,
"right_cheek_width": right_cheek_width
}
args = Namespace(**args_dict)
# Create debug directory
os.makedirs(args.result_dir, exist_ok=True)
# Read first frame
if get_file_type(video_path) == "video":
reader = imageio.get_reader(video_path)
first_frame = reader.get_data(0)
reader.close()
else:
first_frame = cv2.imread(video_path)
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
# Save first frame
debug_frame_path = os.path.join(args.result_dir, "debug_frame.png")
cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
# Get face coordinates
coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift)
bbox = coord_list[0]
frame = frame_list[0]
if bbox == coord_placeholder:
return None, "No face detected, please adjust bbox_shift parameter"
# Initialize face parser
fp = FaceParsing(
left_cheek_width=args.left_cheek_width,
right_cheek_width=args.right_cheek_width
)
# Process first frame
x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
# Generate random audio features
random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype)
audio_feature = pe(random_audio)
# Get latents
latents = vae.get_latents_for_unet(crop_frame)
latents = latents.to(dtype=weight_dtype)
# Generate prediction results
pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample
recon = vae.decode_latents(pred_latents)
# Inpaint back to original image
res_frame = recon[0]
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
# Save results (no need to convert color space again since get_image already returns RGB format)
debug_result_path = os.path.join(args.result_dir, "debug_result.png")
cv2.imwrite(debug_result_path, combine_frame)
# Create information text
info_text = f"Parameter information:\n" + \
f"bbox_shift: {bbox_shift}\n" + \
f"extra_margin: {extra_margin}\n" + \
f"parsing_mode: {parsing_mode}\n" + \
f"left_cheek_width: {left_cheek_width}\n" + \
f"right_cheek_width: {right_cheek_width}\n" + \
f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]"
return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text
def print_directory_contents(path): def print_directory_contents(path):
for child in os.listdir(path): for child in os.listdir(path):
child_path = os.path.join(path, child) child_path = os.path.join(path, child)
@@ -40,119 +130,108 @@ def print_directory_contents(path):
print(child_path) print(child_path)
def download_model(): def download_model():
if not os.path.exists(CheckpointsDir): # 检查必需的模型文件是否存在
os.makedirs(CheckpointsDir) required_models = {
print("Checkpoint Not Downloaded, start downloading...") "MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth",
tic = time.time() "MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json",
snapshot_download( "SD VAE": f"{CheckpointsDir}/sd-vae/config.json",
repo_id="TMElyralab/MuseTalk", "Whisper": f"{CheckpointsDir}/whisper/config.json",
local_dir=CheckpointsDir, "DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth",
max_workers=8, "SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt",
local_dir_use_symlinks=True, "Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth",
force_download=True, resume_download=False "ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
) }
# weight
os.makedirs(f"{CheckpointsDir}/sd-vae-ft-mse/") missing_models = []
snapshot_download( for model_name, model_path in required_models.items():
repo_id="stabilityai/sd-vae-ft-mse", if not os.path.exists(model_path):
local_dir=CheckpointsDir+'/sd-vae-ft-mse', missing_models.append(model_name)
max_workers=8,
local_dir_use_symlinks=True, if missing_models:
force_download=True, resume_download=False # 全用英文
) print("The following required model files are missing:")
#dwpose for model in missing_models:
os.makedirs(f"{CheckpointsDir}/dwpose/") print(f"- {model}")
snapshot_download( print("\nPlease run the download script to download the missing models:")
repo_id="yzd-v/DWPose", if sys.platform == "win32":
local_dir=CheckpointsDir+'/dwpose', print("Windows: Run download_weights.bat")
max_workers=8,
local_dir_use_symlinks=True,
force_download=True, resume_download=False
)
#vae
url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
response = requests.get(url)
# 确保请求成功
if response.status_code == 200:
# 指定文件保存的位置
file_path = f"{CheckpointsDir}/whisper/tiny.pt"
os.makedirs(f"{CheckpointsDir}/whisper/")
# 将文件内容写入指定位置
with open(file_path, "wb") as f:
f.write(response.content)
else: else:
print(f"请求失败,状态码:{response.status_code}") print("Linux/Mac: Run ./download_weights.sh")
#gdown face parse sys.exit(1)
url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812"
os.makedirs(f"{CheckpointsDir}/face-parse-bisent/")
file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth"
gdown.download(url, file_path, quiet=False)
#resnet
url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
response = requests.get(url)
# 确保请求成功
if response.status_code == 200:
# 指定文件保存的位置
file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
# 将文件内容写入指定位置
with open(file_path, "wb") as f:
f.write(response.content)
else:
print(f"请求失败,状态码:{response.status_code}")
toc = time.time()
print(f"download cost {toc-tic} seconds")
print_directory_contents(CheckpointsDir)
else: else:
print("Already download the model.") print("All required model files exist.")
download_model() # for huggingface deployment. download_model() # for huggingface deployment.
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder,get_bbox_range
from musetalk.utils.blending import get_image from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
@spaces.GPU(duration=600) @spaces.GPU(duration=600)
@torch.no_grad() @torch.no_grad()
def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=True)): def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
args_dict={"result_dir":'./results/output', "fps":25, "batch_size":8, "output_vid_name":'', "use_saved_coord":False}#same with inferenece script left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)):
# Set default parameters, aligned with inference.py
args_dict = {
"result_dir": './results/output',
"fps": 25,
"batch_size": 8,
"output_vid_name": '',
"use_saved_coord": False,
"audio_padding_length_left": 2,
"audio_padding_length_right": 2,
"version": "v15", # Fixed use v15 version
"extra_margin": extra_margin,
"parsing_mode": parsing_mode,
"left_cheek_width": left_cheek_width,
"right_cheek_width": right_cheek_width
}
args = Namespace(**args_dict) args = Namespace(**args_dict)
input_basename = os.path.basename(video_path).split('.')[0] # Check ffmpeg
audio_basename = os.path.basename(audio_path).split('.')[0] if not fast_check_ffmpeg():
output_basename = f"{input_basename}_{audio_basename}" print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path,exist_ok =True)
if args.output_vid_name=="": input_basename = os.path.basename(video_path).split('.')[0]
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4") audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
# Create temporary directory
temp_dir = os.path.join(args.result_dir, f"{args.version}")
os.makedirs(temp_dir, exist_ok=True)
# Set result save path
result_img_save_path = os.path.join(temp_dir, output_basename)
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
os.makedirs(result_img_save_path, exist_ok=True)
if args.output_vid_name == "":
output_vid_name = os.path.join(temp_dir, output_basename+".mp4")
else: else:
output_vid_name = os.path.join(args.result_dir, args.output_vid_name) output_vid_name = os.path.join(temp_dir, args.output_vid_name)
############################################## extract frames from source video ############################################## ############################################## extract frames from source video ##############################################
if get_file_type(video_path)=="video": if get_file_type(video_path) == "video":
save_dir_full = os.path.join(args.result_dir, input_basename) save_dir_full = os.path.join(temp_dir, input_basename)
os.makedirs(save_dir_full,exist_ok = True) os.makedirs(save_dir_full, exist_ok=True)
# cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" # Read video
# os.system(cmd)
# 读取视频
reader = imageio.get_reader(video_path) reader = imageio.get_reader(video_path)
# 保存图片 # Save images
for i, im in enumerate(reader): for i, im in enumerate(reader):
imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im) imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
@@ -161,10 +240,21 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps fps = args.fps
#print(input_img_list)
############################################## extract audio feature ############################################## ############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path) # Extract audio features
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
)
############################################## preprocess input image ############################################## ############################################## preprocess input image ##############################################
if os.path.exists(crop_coord_save_path) and args.use_saved_coord: if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates") print("using extracted coordinates")
@@ -176,13 +266,22 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f: with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f) pickle.dump(coord_list, f)
bbox_shift_text=get_bbox_range(input_img_list, bbox_shift) bbox_shift_text = get_bbox_range(input_img_list, bbox_shift)
# Initialize face parser
fp = FaceParsing(
left_cheek_width=args.left_cheek_width,
right_cheek_width=args.right_cheek_width
)
i = 0 i = 0
input_latent_list = [] input_latent_list = []
for bbox, frame in zip(coord_list, frame_list): for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder: if bbox == coord_placeholder:
continue continue
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2] crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame) latents = vae.get_latents_for_unet(crop_frame)
@@ -192,17 +291,23 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
frame_list_cycle = frame_list + frame_list[::-1] frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1] coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
############################################## inference batch by batch ############################################## ############################################## inference batch by batch ##############################################
print("start inference") print("start inference")
video_num = len(whisper_chunks) video_num = len(whisper_chunks)
batch_size = args.batch_size batch_size = args.batch_size
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=device,
)
res_frame_list = [] res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
audio_feature_batch = pe(whisper_batch)
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch] # Ensure latent_batch is consistent with model weight type
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384 latent_batch = latent_batch.to(dtype=weight_dtype)
audio_feature_batch = pe(audio_feature_batch)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents) recon = vae.decode_latents(pred_latents)
@@ -215,25 +320,24 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
bbox = coord_list_cycle[i%(len(coord_list_cycle))] bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
try: try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except: except:
# print(bbox)
continue continue
combine_frame = get_image(ori_frame,res_frame,bbox) # Use v15 version blending
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
# cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p temp.mp4" # Frame rate
# print(cmd_img2video)
# os.system(cmd_img2video)
# 帧率
fps = 25 fps = 25
# 图片路径 # Output video path
# 输出视频路径
output_video = 'temp.mp4' output_video = 'temp.mp4'
# 读取图片 # Read images
def is_valid_image(file): def is_valid_image(file):
pattern = re.compile(r'\d{8}\.png') pattern = re.compile(r'\d{8}\.png')
return pattern.match(file) return pattern.match(file)
@@ -247,13 +351,9 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
images.append(imageio.imread(filename)) images.append(imageio.imread(filename))
# 保存视频 # Save video
imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p') imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
# cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
# print(cmd_combine_audio)
# os.system(cmd_combine_audio)
input_video = './temp.mp4' input_video = './temp.mp4'
# Check if the input_video and audio_path exist # Check if the input_video and audio_path exist
if not os.path.exists(input_video): if not os.path.exists(input_video):
@@ -261,40 +361,15 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
if not os.path.exists(audio_path): if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}") raise FileNotFoundError(f"Audio file not found: {audio_path}")
# 读取视频 # Read video
reader = imageio.get_reader(input_video) reader = imageio.get_reader(input_video)
fps = reader.get_meta_data()['fps'] # 获取原视频的帧率 fps = reader.get_meta_data()['fps'] # Get original video frame rate
reader.close() # 否则在win11上会报错PermissionError: [WinError 32] 另一个程序正在使用此文件,进程无法访问。: 'temp.mp4' reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4'
# 将帧存储在列表中 # Store frames in list
frames = images frames = images
# 保存视频并添加音频
# imageio.mimwrite(output_vid_name, frames, 'FFMPEG', fps=fps, codec='libx264', audio_codec='aac', input_params=['-i', audio_path])
# input_video = ffmpeg.input(input_video)
# input_audio = ffmpeg.input(audio_path)
print(len(frames)) print(len(frames))
# imageio.mimwrite(
# output_video,
# frames,
# 'FFMPEG',
# fps=25,
# codec='libx264',
# audio_codec='aac',
# input_params=['-i', audio_path],
# output_params=['-y'], # Add the '-y' flag to overwrite the output file if it exists
# )
# writer = imageio.get_writer(output_vid_name, fps = 25, codec='libx264', quality=10, pixelformat='yuvj444p')
# for im in frames:
# writer.append_data(im)
# writer.close()
# Load the video # Load the video
video_clip = VideoFileClip(input_video) video_clip = VideoFileClip(input_video)
@@ -315,11 +390,45 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
# load model weights # load model weights
audio_processor,vae,unet,pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae, unet, pe = load_all_model(
unet_model_path="./models/musetalkV15/unet.pth",
vae_type="sd-vae",
unet_config="./models/musetalkV15/musetalk.json",
device=device
)
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable")
parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
parser.add_argument("--share", action="store_true", help="Create a public link")
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
args = parser.parse_args()
# Set data type
if args.use_float16:
# Convert models to half precision for better performance
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
weight_dtype = torch.float16
else:
weight_dtype = torch.float32
# Move models to specified device
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
timesteps = torch.tensor([0], device=device) timesteps = torch.tensor([0], device=device)
# Initialize audio processor and Whisper model
audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
whisper = WhisperModel.from_pretrained("./models/whisper")
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
def check_video(video): def check_video(video):
@@ -340,9 +449,6 @@ def check_video(video):
output_video = os.path.join('./results/input', output_file_name) output_video = os.path.join('./results/input', output_file_name)
# # Run the ffmpeg command to change the frame rate to 25fps
# command = f"ffmpeg -i {video} -r 25 -vcodec libx264 -vtag hvc1 -pix_fmt yuv420p crf 18 {output_video} -y"
# read video # read video
reader = imageio.get_reader(video) reader = imageio.get_reader(video)
fps = reader.get_meta_data()['fps'] # get fps from original video fps = reader.get_meta_data()['fps'] # get fps from original video
@@ -374,34 +480,45 @@ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024p
with gr.Blocks(css=css) as demo: with gr.Blocks(css=css) as demo:
gr.Markdown( gr.Markdown(
"<div align='center'> <h1>MuseTalk: Real-Time High Quality Lip Synchronization with Latent Space Inpainting </span> </h1> \ """<div align='center'> <h1>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</h1> \
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\ <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
</br>\ </br>\
Yue Zhang <sup>\*</sup>,\ Yue Zhang <sup>*</sup>,\
Minhao Liu<sup>\*</sup>,\ Zhizhou Zhong <sup>*</sup>,\
Minhao Liu<sup>*</sup>,\
Zhaokang Chen,\ Zhaokang Chen,\
Bin Wu<sup>†</sup>,\ Bin Wu<sup>†</sup>,\
Yubin Zeng,\
Chao Zhang,\
Yingjie He,\ Yingjie He,\
Chao Zhan,\ Junxin Huang,\
Wenjiang Zhou\ Wenjiang Zhou <br>\
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\ (<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
Lyra Lab, Tencent Music Entertainment\ Lyra Lab, Tencent Music Entertainment\
</h2> \ </h2> \
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\ <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\ <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
<a style='font-size:18px;color: #000000' href=''> [Technical report(Coming Soon)] </a>\ <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2410.10122'> [Technical report] </a>"""
<a style='font-size:18px;color: #000000' href=''> [Project Page(Coming Soon)] </a> </div>"
) )
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
audio = gr.Audio(label="Driven Audio",type="filepath") audio = gr.Audio(label="Drving Audio",type="filepath")
video = gr.Video(label="Reference Video",sources=['upload']) video = gr.Video(label="Reference Video",sources=['upload'])
bbox_shift = gr.Number(label="BBox_shift value, px", value=0) bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
bbox_shift_scale = gr.Textbox(label="BBox_shift recommend value lower bound,The corresponding bbox range is generated after the initial result is generated. \n If the result is not good, it can be adjusted according to this reference value", value="",interactive=False) extra_margin = gr.Slider(label="Extra Margin", minimum=0, maximum=40, value=10, step=1)
parsing_mode = gr.Radio(label="Parsing Mode", choices=["jaw", "raw"], value="jaw")
left_cheek_width = gr.Slider(label="Left Cheek Width", minimum=20, maximum=160, value=90, step=5)
right_cheek_width = gr.Slider(label="Right Cheek Width", minimum=20, maximum=160, value=90, step=5)
bbox_shift_scale = gr.Textbox(label="'left_cheek_width' and 'right_cheek_width' parameters determine the range of left and right cheeks editing when parsing model is 'jaw'. The 'extra_margin' parameter determines the movement range of the jaw. Users can freely adjust these three parameters to obtain better inpainting results.")
btn = gr.Button("Generate") with gr.Row():
out1 = gr.Video() debug_btn = gr.Button("1. Test Inpainting ")
btn = gr.Button("2. Generate")
with gr.Column():
debug_image = gr.Image(label="Test Inpainting Result (First Frame)")
debug_info = gr.Textbox(label="Parameter Information", lines=5)
out1 = gr.Video()
video.change( video.change(
fn=check_video, inputs=[video], outputs=[video] fn=check_video, inputs=[video], outputs=[video]
@@ -412,15 +529,44 @@ with gr.Blocks(css=css) as demo:
audio, audio,
video, video,
bbox_shift, bbox_shift,
extra_margin,
parsing_mode,
left_cheek_width,
right_cheek_width
], ],
outputs=[out1,bbox_shift_scale] outputs=[out1,bbox_shift_scale]
) )
debug_btn.click(
fn=debug_inpainting,
inputs=[
video,
bbox_shift,
extra_margin,
parsing_mode,
left_cheek_width,
right_cheek_width
],
outputs=[debug_image, debug_info]
)
# Set the IP and port # Check ffmpeg and add to PATH
ip_address = "0.0.0.0" # Replace with your desired IP address if not fast_check_ffmpeg():
port_number = 7860 # Replace with your desired port number print(f"Adding ffmpeg to PATH: {args.ffmpeg_path}")
# According to operating system, choose path separator
path_separator = ';' if sys.platform == 'win32' else ':'
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
if not fast_check_ffmpeg():
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
# Solve asynchronous IO issues on Windows
if sys.platform == 'win32':
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# Start Gradio application
demo.queue().launch( demo.queue().launch(
share=False , debug=True, server_name=ip_address, server_port=port_number share=args.share,
debug=True,
server_name=args.ip,
server_port=args.port
) )

BIN
assets/figs/gradio.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

BIN
assets/figs/gradio_2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 73 KiB

View File

@@ -49,8 +49,9 @@ class AudioProcessor:
whisper_feature = [] whisper_feature = []
# Process multiple 30s mel input features # Process multiple 30s mel input features
for input_feature in whisper_input_features: for input_feature in whisper_input_features:
audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states input_feature = input_feature.to(device).to(weight_dtype)
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2)
whisper_feature.append(audio_feats) whisper_feature.append(audio_feats)
whisper_feature = torch.cat(whisper_feature, dim=1) whisper_feature = torch.cat(whisper_feature, dim=1)

View File

@@ -13,9 +13,9 @@ from musetalk.models.unet import UNet,PositionalEncoding
def load_all_model( def load_all_model(
unet_model_path=os.path.join("models", "musetalk", "pytorch_model.bin"), unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
vae_type="sd-vae", vae_type="sd-vae",
unet_config=os.path.join("models", "musetalk", "musetalk.json"), unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
device=None, device=None,
): ):
vae = VAE( vae = VAE(

View File

@@ -15,6 +15,7 @@ requests
imageio[ffmpeg] imageio[ffmpeg]
gradio gradio
spaces
omegaconf omegaconf
ffmpeg-python ffmpeg-python
moviepy moviepy