feat: real-time infer (#286)

* feat: realtime infer

* cchore: infer script
This commit is contained in:
Zhizhou Zhong
2025-04-02 19:13:18 +08:00
committed by GitHub
parent fbe6a97dff
commit 39ccf69f36
11 changed files with 490 additions and 592 deletions

View File

@@ -1,8 +1,9 @@
import os
import cv2
import math
import copy
import glob
import torch
import glob
import shutil
import pickle
import argparse
@@ -17,18 +18,16 @@ from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
@torch.no_grad()
def main(args):
# Configure ffmpeg path
if args.ffmpeg_path not in os.getenv('PATH'):
print("Adding ffmpeg to PATH")
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
# Set computing device
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
# Load model weights
vae, unet, pe = load_all_model(
unet_model_path=args.unet_model_path,
@@ -37,164 +36,229 @@ def main(args):
device=device
)
timesteps = torch.tensor([0], device=device)
if args.use_float16 is True:
# Convert models to half precision if float16 is enabled
if args.use_float16:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
# Move models to specified device
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
# Initialize audio processor and Whisper model
# Initialize audio processor and Whisper model
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
weight_dtype = unet.model.dtype
whisper = WhisperModel.from_pretrained(args.whisper_dir)
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
# Initialize face parser
fp = FaceParsing()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
for task_id in inference_config:
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path,exist_ok =True)
if args.output_vid_name is None:
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
else:
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
############################################## extract frames from source video ##############################################
if get_file_type(video_path)=="video":
save_dir_full = os.path.join(args.result_dir, input_basename)
os.makedirs(save_dir_full,exist_ok = True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path)=="image":
input_img_list = [video_path, ]
fps = args.fps
elif os.path.isdir(video_path): # input img folder
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
############################################## extract audio feature ##############################################
# Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
# Initialize face parser with configurable parameters based on version
if args.version == "v15":
fp = FaceParsing(
left_cheek_width=args.left_cheek_width,
right_cheek_width=args.right_cheek_width
)
############################################## preprocess input image ##############################################
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates")
with open(crop_coord_save_path,'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("extracting landmarks...time consuming")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
i = 0
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
else: # v1
fp = FaceParsing()
# to smooth the first and the last frame
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
############################################## inference batch by batch ##############################################
print("start inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
audio_feature_batch = pe(whisper_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
# Load inference configuration
inference_config = OmegaConf.load(args.inference_config)
print("Loaded inference config:", inference_config)
# Process each task
for task_id in inference_config:
try:
# Get task configuration
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
if "result_name" in inference_config[task_id]:
args.output_vid_name = inference_config[task_id]["result_name"]
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
############################################## pad to full image ##############################################
print("pad talking image to original video")
for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
continue
# Set bbox_shift based on version
if args.version == "v15":
bbox_shift = 0 # v15 uses fixed bbox_shift
else:
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) # v1 uses config or default
# Merge results
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
# Set output paths
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
# Create temporary directories
temp_dir = os.path.join(args.result_dir, f"{args.version}")
os.makedirs(temp_dir, exist_ok=True)
# Set result save paths
result_img_save_path = os.path.join(temp_dir, output_basename)
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
os.makedirs(result_img_save_path, exist_ok=True)
# Set output video paths
if args.output_vid_name is None:
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
else:
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
# Extract frames from source video
if get_file_type(video_path) == "video":
save_dir_full = os.path.join(temp_dir, input_basename)
os.makedirs(save_dir_full, exist_ok=True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path) == "image":
input_img_list = [video_path]
fps = args.fps
elif os.path.isdir(video_path):
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
print(cmd_img2video)
os.system(cmd_img2video)
# Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
)
# Preprocess input images
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("Using saved coordinates")
with open(crop_coord_save_path, 'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("Extracting landmarks... time-consuming operation")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
print(f"Number of frames: {len(frame_list)}")
# Process each frame
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
if args.version == "v15":
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
print(cmd_combine_audio)
os.system(cmd_combine_audio)
os.remove("temp.mp4")
shutil.rmtree(result_img_save_path)
print(f"result is save to {output_vid_name}")
# Smooth first and last frames
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
# Batch inference
print("Starting inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=device,
)
res_frame_list = []
total = int(np.ceil(float(video_num) / batch_size))
# Execute inference
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
audio_feature_batch = pe(whisper_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
# Pad generated images to original video size
print("Padding generated images to original video size")
for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
if args.version == "v15":
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
except:
continue
# Merge results with version-specific parameters
if args.version == "v15":
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
else:
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
# Save prediction results
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
print("Video generation command:", cmd_img2video)
os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
print("Audio combination command:", cmd_combine_audio)
os.system(cmd_combine_audio)
# Clean up temporary files
shutil.rmtree(result_img_save_path)
os.remove(temp_vid_path)
shutil.rmtree(save_dir_full)
if not args.saved_coord:
os.remove(crop_coord_save_path)
print(f"Results saved to {output_vid_name}")
except Exception as e:
print("Error occurred during processing:", e)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
parser.add_argument("--bbox_shift", type=int, default=0)
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--output_vid_name", type=str, default=None)
parser.add_argument("--use_saved_coord",
action="store_true",
help='use saved coordinate to save time')
parser.add_argument("--use_float16",
action="store_true",
help="Whether use float16 to speed up inference",
)
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use")
args = parser.parse_args()
main(args)