diff --git a/README.md b/README.md index 80cc4d6..2f4664d 100644 --- a/README.md +++ b/README.md @@ -130,9 +130,8 @@ https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde - [x] codes for real-time inference. - [x] [technical report](https://arxiv.org/abs/2410.10122v2). - [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122). +- [x] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon). - [ ] training and dataloader code (Expected completion on 04/04/2025). -- [ ] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon). - # Getting Started @@ -220,21 +219,52 @@ We provide inference scripts for both versions of MuseTalk: #### MuseTalk 1.5 (Recommended) ```bash -sh inference.sh v1.5 +# Run MuseTalk 1.5 inference +sh inference.sh v1.5 normal ``` -This inference script supports both MuseTalk 1.5 and 1.0 models: -- For MuseTalk 1.5: Use the command above with the V1.5 model path -- For MuseTalk 1.0: Use the same script but point to the V1.0 model path - -configs/inference/test.yaml is the path to the inference configuration file, including video_path and audio_path. -The video_path should be either a video file, an image file or a directory of images. #### MuseTalk 1.0 ```bash -sh inference.sh v1.0 +# Run MuseTalk 1.0 inference +sh inference.sh v1.0 normal ``` -You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg. -
+ +The inference script supports both MuseTalk 1.5 and 1.0 models: +- For MuseTalk 1.5: Use the command above with the V1.5 model path +- For MuseTalk 1.0: Use the same script but point to the V1.0 model path + +The configuration file `configs/inference/test.yaml` contains the inference settings, including: +- `video_path`: Path to the input video, image file, or directory of images +- `audio_path`: Path to the input audio file + +Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg. + +#### Real-time Inference +For real-time inference, use the following command: +```bash +# Run real-time inference +sh inference.sh v1.5 realtime # For MuseTalk 1.5 +# or +sh inference.sh v1.0 realtime # For MuseTalk 1.0 +``` + +The real-time inference configuration is in `configs/inference/realtime.yaml`, which includes: +- `preparation`: Set to `True` for new avatar preparation +- `video_path`: Path to the input video +- `bbox_shift`: Adjustable parameter for mouth region control +- `audio_clips`: List of audio clips for generation + +Important notes for real-time inference: +1. Set `preparation` to `True` when processing a new avatar +2. After preparation, the avatar will generate videos using audio clips from `audio_clips` +3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100 +4. Set `preparation` to `False` for generating more videos with the same avatar + +For faster generation without saving images, you can use: +```bash +python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images +``` + ## TestCases For 1.0 @@ -332,39 +362,11 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo ``` :pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md). - - #### Combining MuseV and MuseTalk As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference). -#### Real-time inference - -
-Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time. - -``` -python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --batch_size 4 -``` -configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`. - -1. Set `preparation` to `True` in `realtime.yaml` to prepare the materials for a new `avatar`. (If the `bbox_shift` has changed, you also need to re-prepare the materials.) -1. After that, the `avatar` will use an audio clip selected from `audio_clips` to generate video. - ``` - Inferring using: data/audio/yongen.wav - ``` -1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100. -1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar. - -##### Note for Real-time inference -1. If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process. -1. In the previous script, the generation time is also limited by I/O (e.g. saving images). If you just want to test the generation speed without saving the images, you can run -``` -python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images -``` -
- # Acknowledgement 1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch). 1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings). diff --git a/configs/inference/realtime.yaml b/configs/inference/realtime.yaml index d4092ac..9319e98 100644 --- a/configs/inference/realtime.yaml +++ b/configs/inference/realtime.yaml @@ -1,10 +1,10 @@ avator_1: - preparation: False + preparation: True # your can set it to False if you want to use the existing avator, it will save time bbox_shift: 5 - video_path: "data/video/sun.mp4" + video_path: "data/video/yongen.mp4" audio_clips: audio_0: "data/audio/yongen.wav" - audio_1: "data/audio/sun.wav" + audio_1: "data/audio/eng.wav" diff --git a/configs/inference/test.yaml b/configs/inference/test.yaml index 4ec8aca..4a64f3b 100644 --- a/configs/inference/test.yaml +++ b/configs/inference/test.yaml @@ -3,8 +3,8 @@ task_0: audio_path: "data/audio/yongen.wav" task_1: - video_path: "data/video/sun.mp4" - audio_path: "data/audio/sun.wav" + video_path: "data/video/yongen.mp4" + audio_path: "data/audio/eng.wav" bbox_shift: -7 diff --git a/data/audio/eng.wav b/data/audio/eng.wav new file mode 100755 index 0000000..61c55b6 Binary files /dev/null and b/data/audio/eng.wav differ diff --git a/inference.sh b/inference.sh index d7d8426..355cfdd 100644 --- a/inference.sh +++ b/inference.sh @@ -1,46 +1,72 @@ #!/bin/bash -# This script runs inference based on the version specified by the user. +# This script runs inference based on the version and mode specified by the user. # Usage: -# To run v1.0 inference: sh inference.sh v1.0 -# To run v1.5 inference: sh inference.sh v1.5 +# To run v1.0 inference: sh inference.sh v1.0 [normal|realtime] +# To run v1.5 inference: sh inference.sh v1.5 [normal|realtime] # Check if the correct number of arguments is provided -if [ "$#" -ne 1 ]; then - echo "Usage: $0 " - echo "Example: $0 v1.0 or $0 v1.5" +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + echo "Example: $0 v1.0 normal or $0 v1.5 realtime" exit 1 fi -# Get the version from the user input +# Get the version and mode from the user input version=$1 -config_path="./configs/inference/test.yaml" +mode=$2 + +# Validate mode +if [ "$mode" != "normal" ] && [ "$mode" != "realtime" ]; then + echo "Invalid mode specified. Please use 'normal' or 'realtime'." + exit 1 +fi + +# Set config path based on mode +if [ "$mode" = "normal" ]; then + config_path="./configs/inference/test.yaml" + result_dir="./results/test" +else + config_path="./configs/inference/realtime.yaml" + result_dir="./results/realtime" +fi # Define the model paths based on the version if [ "$version" = "v1.0" ]; then model_dir="./models/musetalk" unet_model_path="$model_dir/pytorch_model.bin" unet_config="$model_dir/musetalk.json" + version_arg="v1" elif [ "$version" = "v1.5" ]; then model_dir="./models/musetalkV15" unet_model_path="$model_dir/unet.pth" unet_config="$model_dir/musetalk.json" + version_arg="v15" else echo "Invalid version specified. Please use v1.0 or v1.5." exit 1 fi -# Run inference based on the version -if [ "$version" = "v1.0" ]; then - python3 -m scripts.inference \ - --inference_config "$config_path" \ - --result_dir "./results/test" \ - --unet_model_path "$unet_model_path" \ - --unet_config "$unet_config" -elif [ "$version" = "v1.5" ]; then - python3 -m scripts.inference_alpha \ - --inference_config "$config_path" \ - --result_dir "./results/test" \ - --unet_model_path "$unet_model_path" \ - --unet_config "$unet_config" -fi \ No newline at end of file +# Set script name based on mode +if [ "$mode" = "normal" ]; then + script_name="scripts.inference" +else + script_name="scripts.realtime_inference" +fi + +# Base command arguments +cmd_args="--inference_config $config_path \ + --result_dir $result_dir \ + --unet_model_path $unet_model_path \ + --unet_config $unet_config \ + --version $version_arg \ + +# Add realtime-specific arguments if in realtime mode +if [ "$mode" = "realtime" ]; then + cmd_args="$cmd_args \ + --fps 25 \ + --version $version_arg \ +fi + +# Run inference +python3 -m $script_name $cmd_args \ No newline at end of file diff --git a/musetalk/utils/audio_processor.py b/musetalk/utils/audio_processor.py index 740aacb..1c41ceb 100755 --- a/musetalk/utils/audio_processor.py +++ b/musetalk/utils/audio_processor.py @@ -11,7 +11,7 @@ class AudioProcessor: def __init__(self, feature_extractor_path="openai/whisper-tiny/"): self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path) - def get_audio_feature(self, wav_path, start_index=0): + def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None): if not os.path.exists(wav_path): return None librosa_output, sampling_rate = librosa.load(wav_path, sr=16000) @@ -27,6 +27,8 @@ class AudioProcessor: return_tensors="pt", sampling_rate=sampling_rate ).input_features + if weight_dtype is not None: + audio_feature = audio_feature.to(dtype=weight_dtype) features.append(audio_feature) return features, len(librosa_output) diff --git a/musetalk/utils/blending.py b/musetalk/utils/blending.py index c68a972..fa3effc 100755 --- a/musetalk/utils/blending.py +++ b/musetalk/utils/blending.py @@ -3,6 +3,7 @@ import numpy as np import cv2 import copy + def get_crop_box(box, expand): x, y, x1, y1 = box x_c, y_c = (x+x1)//2, (y+y1)//2 @@ -11,7 +12,8 @@ def get_crop_box(box, expand): crop_box = [x_c-s, y_c-s, x_c+s, y_c+s] return crop_box, s -def face_seg(image, mode="jaw", fp=None): + +def face_seg(image, mode="raw", fp=None): """ 对图像进行面部解析,生成面部区域的掩码。 @@ -86,14 +88,12 @@ def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode= body.paste(face_large, crop_box[:2], mask_image) - # 不用掩码,完全用infer - #face_large.save("debug/checkpoint_6_face_large.png") - body = np.array(body) # 将 PIL 图像转换回 numpy 数组 return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB) -def get_image_blending(image,face,face_box,mask_array,crop_box): + +def get_image_blending(image, face, face_box, mask_array, crop_box): body = Image.fromarray(image[:,:,::-1]) face = Image.fromarray(face[:,:,::-1]) @@ -108,7 +108,8 @@ def get_image_blending(image,face,face_box,mask_array,crop_box): body = np.array(body) return body[:,:,::-1] -def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2): + +def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"): body = Image.fromarray(image[:,:,::-1]) x, y, x1, y1 = face_box @@ -119,7 +120,7 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand= face_large = body.crop(crop_box) ori_shape = face_large.size - mask_image = face_seg(face_large) + mask_image = face_seg(face_large, mode=mode, fp=fp) mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s)) mask_image = Image.new('L', ori_shape, 0) mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s)) @@ -132,4 +133,4 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand= blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1 mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) - return mask_array,crop_box + return mask_array, crop_box diff --git a/musetalk/utils/face_parsing/__init__.py b/musetalk/utils/face_parsing/__init__.py index a33eda5..09c1c02 100755 --- a/musetalk/utils/face_parsing/__init__.py +++ b/musetalk/utils/face_parsing/__init__.py @@ -74,7 +74,7 @@ class FaceParsing(): transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) - def __call__(self, image, size=(512, 512), mode="jaw"): + def __call__(self, image, size=(512, 512), mode="raw"): if isinstance(image, str): image = Image.open(image) diff --git a/scripts/inference.py b/scripts/inference.py index 671dc43..41da8dd 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -1,8 +1,9 @@ import os import cv2 +import math import copy -import glob import torch +import glob import shutil import pickle import argparse @@ -17,18 +18,16 @@ from musetalk.utils.audio_processor import AudioProcessor from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder - - @torch.no_grad() def main(args): # Configure ffmpeg path if args.ffmpeg_path not in os.getenv('PATH'): print("Adding ffmpeg to PATH") os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}" - + # Set computing device device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu") - + # Load model weights vae, unet, pe = load_all_model( unet_model_path=args.unet_model_path, @@ -37,164 +36,229 @@ def main(args): device=device ) timesteps = torch.tensor([0], device=device) - - - if args.use_float16 is True: + + # Convert models to half precision if float16 is enabled + if args.use_float16: pe = pe.half() vae.vae = vae.vae.half() unet.model = unet.model.half() + + # Move models to specified device + pe = pe.to(device) + vae.vae = vae.vae.to(device) + unet.model = unet.model.to(device) - # Initialize audio processor and Whisper model + # Initialize audio processor and Whisper model audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir) weight_dtype = unet.model.dtype whisper = WhisperModel.from_pretrained(args.whisper_dir) whisper = whisper.to(device=device, dtype=weight_dtype).eval() whisper.requires_grad_(False) - # Initialize face parser - fp = FaceParsing() - - inference_config = OmegaConf.load(args.inference_config) - print(inference_config) - for task_id in inference_config: - video_path = inference_config[task_id]["video_path"] - audio_path = inference_config[task_id]["audio_path"] - bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) - - input_basename = os.path.basename(video_path).split('.')[0] - audio_basename = os.path.basename(audio_path).split('.')[0] - output_basename = f"{input_basename}_{audio_basename}" - result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs - crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input - os.makedirs(result_img_save_path,exist_ok =True) - - if args.output_vid_name is None: - output_vid_name = os.path.join(args.result_dir, output_basename+".mp4") - else: - output_vid_name = os.path.join(args.result_dir, args.output_vid_name) - ############################################## extract frames from source video ############################################## - if get_file_type(video_path)=="video": - save_dir_full = os.path.join(args.result_dir, input_basename) - os.makedirs(save_dir_full,exist_ok = True) - cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" - os.system(cmd) - input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) - fps = get_video_fps(video_path) - elif get_file_type(video_path)=="image": - input_img_list = [video_path, ] - fps = args.fps - elif os.path.isdir(video_path): # input img folder - input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) - input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) - fps = args.fps - else: - raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") - - ############################################## extract audio feature ############################################## - # Extract audio features - whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) - whisper_chunks = audio_processor.get_whisper_chunk( - whisper_input_features, - device, - weight_dtype, - whisper, - librosa_length, - fps=fps, - audio_padding_length_left=args.audio_padding_length_left, - audio_padding_length_right=args.audio_padding_length_right, + # Initialize face parser with configurable parameters based on version + if args.version == "v15": + fp = FaceParsing( + left_cheek_width=args.left_cheek_width, + right_cheek_width=args.right_cheek_width ) - - ############################################## preprocess input image ############################################## - if os.path.exists(crop_coord_save_path) and args.use_saved_coord: - print("using extracted coordinates") - with open(crop_coord_save_path,'rb') as f: - coord_list = pickle.load(f) - frame_list = read_imgs(input_img_list) - else: - print("extracting landmarks...time consuming") - coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) - with open(crop_coord_save_path, 'wb') as f: - pickle.dump(coord_list, f) - - i = 0 - input_latent_list = [] - for bbox, frame in zip(coord_list, frame_list): - if bbox == coord_placeholder: - continue - x1, y1, x2, y2 = bbox - crop_frame = frame[y1:y2, x1:x2] - crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) - latents = vae.get_latents_for_unet(crop_frame) - input_latent_list.append(latents) + else: # v1 + fp = FaceParsing() - # to smooth the first and the last frame - frame_list_cycle = frame_list + frame_list[::-1] - coord_list_cycle = coord_list + coord_list[::-1] - input_latent_list_cycle = input_latent_list + input_latent_list[::-1] - ############################################## inference batch by batch ############################################## - print("start inference") - video_num = len(whisper_chunks) - batch_size = args.batch_size - gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) - res_frame_list = [] - for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): - audio_feature_batch = pe(whisper_batch) - latent_batch = latent_batch.to(dtype=unet.model.dtype) + # Load inference configuration + inference_config = OmegaConf.load(args.inference_config) + print("Loaded inference config:", inference_config) + + # Process each task + for task_id in inference_config: + try: + # Get task configuration + video_path = inference_config[task_id]["video_path"] + audio_path = inference_config[task_id]["audio_path"] + if "result_name" in inference_config[task_id]: + args.output_vid_name = inference_config[task_id]["result_name"] - pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample - recon = vae.decode_latents(pred_latents) - for res_frame in recon: - res_frame_list.append(res_frame) - - ############################################## pad to full image ############################################## - print("pad talking image to original video") - for i, res_frame in enumerate(tqdm(res_frame_list)): - bbox = coord_list_cycle[i%(len(coord_list_cycle))] - ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) - x1, y1, x2, y2 = bbox - try: - res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) - except: - continue + # Set bbox_shift based on version + if args.version == "v15": + bbox_shift = 0 # v15 uses fixed bbox_shift + else: + bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) # v1 uses config or default - # Merge results - combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp) - cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) + # Set output paths + input_basename = os.path.basename(video_path).split('.')[0] + audio_basename = os.path.basename(audio_path).split('.')[0] + output_basename = f"{input_basename}_{audio_basename}" + + # Create temporary directories + temp_dir = os.path.join(args.result_dir, f"{args.version}") + os.makedirs(temp_dir, exist_ok=True) + + # Set result save paths + result_img_save_path = os.path.join(temp_dir, output_basename) + crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl") + os.makedirs(result_img_save_path, exist_ok=True) + + # Set output video paths + if args.output_vid_name is None: + output_vid_name = os.path.join(temp_dir, output_basename + ".mp4") + else: + output_vid_name = os.path.join(temp_dir, args.output_vid_name) + output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4") + + # Extract frames from source video + if get_file_type(video_path) == "video": + save_dir_full = os.path.join(temp_dir, input_basename) + os.makedirs(save_dir_full, exist_ok=True) + cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" + os.system(cmd) + input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) + fps = get_video_fps(video_path) + elif get_file_type(video_path) == "image": + input_img_list = [video_path] + fps = args.fps + elif os.path.isdir(video_path): + input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + fps = args.fps + else: + raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") - cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" - print(cmd_img2video) - os.system(cmd_img2video) + # Extract audio features + whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) + whisper_chunks = audio_processor.get_whisper_chunk( + whisper_input_features, + device, + weight_dtype, + whisper, + librosa_length, + fps=fps, + audio_padding_length_left=args.audio_padding_length_left, + audio_padding_length_right=args.audio_padding_length_right, + ) + + # Preprocess input images + if os.path.exists(crop_coord_save_path) and args.use_saved_coord: + print("Using saved coordinates") + with open(crop_coord_save_path, 'rb') as f: + coord_list = pickle.load(f) + frame_list = read_imgs(input_img_list) + else: + print("Extracting landmarks... time-consuming operation") + coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) + with open(crop_coord_save_path, 'wb') as f: + pickle.dump(coord_list, f) + + print(f"Number of frames: {len(frame_list)}") + + # Process each frame + input_latent_list = [] + for bbox, frame in zip(coord_list, frame_list): + if bbox == coord_placeholder: + continue + x1, y1, x2, y2 = bbox + if args.version == "v15": + y2 = y2 + args.extra_margin + y2 = min(y2, frame.shape[0]) + crop_frame = frame[y1:y2, x1:x2] + crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4) + latents = vae.get_latents_for_unet(crop_frame) + input_latent_list.append(latents) - cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" - print(cmd_combine_audio) - os.system(cmd_combine_audio) - - os.remove("temp.mp4") - shutil.rmtree(result_img_save_path) - print(f"result is save to {output_vid_name}") + # Smooth first and last frames + frame_list_cycle = frame_list + frame_list[::-1] + coord_list_cycle = coord_list + coord_list[::-1] + input_latent_list_cycle = input_latent_list + input_latent_list[::-1] + + # Batch inference + print("Starting inference") + video_num = len(whisper_chunks) + batch_size = args.batch_size + gen = datagen( + whisper_chunks=whisper_chunks, + vae_encode_latents=input_latent_list_cycle, + batch_size=batch_size, + delay_frame=0, + device=device, + ) + + res_frame_list = [] + total = int(np.ceil(float(video_num) / batch_size)) + + # Execute inference + for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)): + audio_feature_batch = pe(whisper_batch) + latent_batch = latent_batch.to(dtype=unet.model.dtype) + + pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample + recon = vae.decode_latents(pred_latents) + for res_frame in recon: + res_frame_list.append(res_frame) + + # Pad generated images to original video size + print("Padding generated images to original video size") + for i, res_frame in enumerate(tqdm(res_frame_list)): + bbox = coord_list_cycle[i%(len(coord_list_cycle))] + ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) + x1, y1, x2, y2 = bbox + if args.version == "v15": + y2 = y2 + args.extra_margin + y2 = min(y2, frame.shape[0]) + try: + res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) + except: + continue + + # Merge results with version-specific parameters + if args.version == "v15": + combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp) + else: + combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp) + cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame) + + # Save prediction results + temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4" + cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}" + print("Video generation command:", cmd_img2video) + os.system(cmd_img2video) + + cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}" + print("Audio combination command:", cmd_combine_audio) + os.system(cmd_combine_audio) + + # Clean up temporary files + shutil.rmtree(result_img_save_path) + os.remove(temp_vid_path) + + shutil.rmtree(save_dir_full) + if not args.saved_coord: + os.remove(crop_coord_save_path) + + print(f"Results saved to {output_vid_name}") + except Exception as e: + print("Error occurred during processing:", e) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable") - parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml") - parser.add_argument("--bbox_shift", type=int, default=0) - parser.add_argument("--result_dir", default='./results', help="path to output") parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use") - parser.add_argument("--batch_size", type=int, default=8) - parser.add_argument("--output_vid_name", type=str, default=None) - parser.add_argument("--use_saved_coord", - action="store_true", - help='use saved coordinate to save time') - parser.add_argument("--use_float16", - action="store_true", - help="Whether use float16 to speed up inference", - ) - parser.add_argument("--fps", type=int, default=25, help="Video frames per second") - parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights") parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model") parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file") + parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights") parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model") + parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file") + parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value") + parser.add_argument("--result_dir", default='./results', help="Directory for output results") + parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping") + parser.add_argument("--fps", type=int, default=25, help="Video frames per second") parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio") parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio") + parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference") + parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file") + parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time') + parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use') + parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference") + parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode") + parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region") + parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region") + parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use") args = parser.parse_args() main(args) diff --git a/scripts/inference_alpha.py b/scripts/inference_alpha.py deleted file mode 100644 index 498be0f..0000000 --- a/scripts/inference_alpha.py +++ /dev/null @@ -1,252 +0,0 @@ -import os -import cv2 -import math -import copy -import torch -import glob -import shutil -import pickle -import argparse -import subprocess -import numpy as np -from tqdm import tqdm -from omegaconf import OmegaConf -from transformers import WhisperModel - -from musetalk.utils.blending import get_image -from musetalk.utils.face_parsing import FaceParsing -from musetalk.utils.audio_processor import AudioProcessor -from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model -from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder - - -@torch.no_grad() -def main(args): - # Configure ffmpeg path - if args.ffmpeg_path not in os.getenv('PATH'): - print("Adding ffmpeg to PATH") - os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}" - - # Set computing device - device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu") - - # Load model weights - vae, unet, pe = load_all_model( - unet_model_path=args.unet_model_path, - vae_type=args.vae_type, - unet_config=args.unet_config, - device=device - ) - timesteps = torch.tensor([0], device=device) - - # Convert models to half precision if float16 is enabled - if args.use_float16: - pe = pe.half() - vae.vae = vae.vae.half() - unet.model = unet.model.half() - - # Move models to specified device - pe = pe.to(device) - vae.vae = vae.vae.to(device) - unet.model = unet.model.to(device) - - # Initialize audio processor and Whisper model - audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir) - weight_dtype = unet.model.dtype - whisper = WhisperModel.from_pretrained(args.whisper_dir) - whisper = whisper.to(device=device, dtype=weight_dtype).eval() - whisper.requires_grad_(False) - - # Initialize face parser - fp = FaceParsing(left_cheek_width=args.left_cheek_width, right_cheek_width=args.right_cheek_width) - - # Load inference configuration - inference_config = OmegaConf.load(args.inference_config) - print("Loaded inference config:", inference_config) - - # Process each task - for task_id in inference_config: - try: - # Get task configuration - video_path = inference_config[task_id]["video_path"] - audio_path = inference_config[task_id]["audio_path"] - if "result_name" in inference_config[task_id]: - args.output_vid_name = inference_config[task_id]["result_name"] - bbox_shift = args.bbox_shift - # Set output paths - input_basename = os.path.basename(video_path).split('.')[0] - audio_basename = os.path.basename(audio_path).split('.')[0] - output_basename = f"{input_basename}_{audio_basename}" - - # Create temporary directories - temp_dir = os.path.join(args.result_dir, "frames_result") - os.makedirs(temp_dir, exist_ok=True) - - # Set result save paths - result_img_save_path = os.path.join(temp_dir, output_basename) # related to video & audio inputs - crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl") # only related to video input - os.makedirs(result_img_save_path, exist_ok=True) - # Set output video paths - if args.output_vid_name is None: - output_vid_name = os.path.join(temp_dir, output_basename + ".mp4") - else: - output_vid_name = os.path.join(temp_dir, args.output_vid_name) - output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4") - - # Skip if output file already exists - if os.path.exists(output_vid_name): - print(f"{output_vid_name} already exists, skipping!") - continue - - # Extract frames from source video - if get_file_type(video_path) == "video": - save_dir_full = os.path.join(temp_dir, input_basename) - os.makedirs(save_dir_full, exist_ok=True) - cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" - os.system(cmd) - input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) - fps = get_video_fps(video_path) - elif get_file_type(video_path) == "image": - input_img_list = [video_path] - fps = args.fps - elif os.path.isdir(video_path): - input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) - input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) - fps = args.fps - else: - raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") - - # Extract audio features - whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) - whisper_chunks = audio_processor.get_whisper_chunk( - whisper_input_features, - device, - weight_dtype, - whisper, - librosa_length, - fps=fps, - audio_padding_length_left=args.audio_padding_length_left, - audio_padding_length_right=args.audio_padding_length_right, - ) - - # Preprocess input images - if os.path.exists(crop_coord_save_path) and args.use_saved_coord: - print("Using saved coordinates") - with open(crop_coord_save_path, 'rb') as f: - coord_list = pickle.load(f) - frame_list = read_imgs(input_img_list) - else: - print("Extracting landmarks... time-consuming operation") - coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) - with open(crop_coord_save_path, 'wb') as f: - pickle.dump(coord_list, f) - - print(f"Number of frames: {len(frame_list)}") - - # Process each frame - input_latent_list = [] - for bbox, frame in zip(coord_list, frame_list): - if bbox == coord_placeholder: - continue - x1, y1, x2, y2 = bbox - y2 = y2 + args.extra_margin - y2 = min(y2, frame.shape[0]) - crop_frame = frame[y1:y2, x1:x2] - crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4) - latents = vae.get_latents_for_unet(crop_frame) - input_latent_list.append(latents) - - # Smooth first and last frames - frame_list_cycle = frame_list + frame_list[::-1] - coord_list_cycle = coord_list + coord_list[::-1] - input_latent_list_cycle = input_latent_list + input_latent_list[::-1] - - # Batch inference - print("Starting inference") - video_num = len(whisper_chunks) - batch_size = args.batch_size - gen = datagen( - whisper_chunks=whisper_chunks, - vae_encode_latents=input_latent_list_cycle, - batch_size=batch_size, - delay_frame=0, - device=device, - ) - - res_frame_list = [] - total = int(np.ceil(float(video_num) / batch_size)) - - # Execute inference - for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)): - audio_feature_batch = pe(whisper_batch) - latent_batch = latent_batch.to(dtype=unet.model.dtype) - - pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample - recon = vae.decode_latents(pred_latents) - for res_frame in recon: - res_frame_list.append(res_frame) - - # Pad generated images to original video size - print("Padding generated images to original video size") - for i, res_frame in enumerate(tqdm(res_frame_list)): - bbox = coord_list_cycle[i%(len(coord_list_cycle))] - ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) - x1, y1, x2, y2 = bbox - y2 = y2 + args.extra_margin - y2 = min(y2, frame.shape[0]) - try: - res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) - except: - continue - - # Merge results - combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp) - cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame) - - # Save prediction results - temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4" - cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}" - print("Video generation command:", cmd_img2video) - os.system(cmd_img2video) - - cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}" - print("Audio combination command:", cmd_combine_audio) - os.system(cmd_combine_audio) - - # Clean up temporary files - shutil.rmtree(result_img_save_path) - os.remove(temp_vid_path) - - shutil.rmtree(save_dir_full) - if not args.saved_coord: - os.remove(crop_coord_save_path) - - print(f"Results saved to {output_vid_name}") - except Exception as e: - print("Error occurred during processing:", e) - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable") - parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use") - parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model") - parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file") - parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights") - parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model") - parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file") - parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value") - parser.add_argument("--result_dir", default='./results', help="Directory for output results") - parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping") - parser.add_argument("--fps", type=int, default=25, help="Video frames per second") - parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio") - parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio") - parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference") - parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file") - parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time') - parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use') - parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference") - parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode") - parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region") - parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region") - args = parser.parse_args() - main(args) diff --git a/scripts/realtime_inference.py b/scripts/realtime_inference.py index 18bb856..52560c5 100644 --- a/scripts/realtime_inference.py +++ b/scripts/realtime_inference.py @@ -10,26 +10,22 @@ import sys from tqdm import tqdm import copy import json -from musetalk.utils.utils import get_file_type,get_video_fps,datagen -from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder -from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending -from musetalk.utils.utils import load_all_model -import shutil +from transformers import WhisperModel +from musetalk.utils.face_parsing import FaceParsing +from musetalk.utils.utils import datagen +from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs +from musetalk.utils.blending import get_image_prepare_material, get_image_blending +from musetalk.utils.utils import load_all_model +from musetalk.utils.audio_processor import AudioProcessor + +import shutil import threading import queue - import time -# load model weights -audio_processor, vae, unet, pe = load_all_model() -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -timesteps = torch.tensor([0], device=device) -pe = pe.half() -vae.vae = vae.vae.half() -unet.model = unet.model.half() -def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000): +def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000): cap = cv2.VideoCapture(vid_path) count = 0 while True: @@ -42,35 +38,43 @@ def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000): else: break + def osmakedirs(path_list): for path in path_list: os.makedirs(path) if not os.path.exists(path) else None - -@torch.no_grad() + +@torch.no_grad() class Avatar: def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation): self.avatar_id = avatar_id self.video_path = video_path self.bbox_shift = bbox_shift - self.avatar_path = f"./results/avatars/{avatar_id}" - self.full_imgs_path = f"{self.avatar_path}/full_imgs" + # 根据版本设置不同的基础路径 + if args.version == "v15": + self.base_path = f"./results/{args.version}/avatars/{avatar_id}" + else: # v1 + self.base_path = f"./results/avatars/{avatar_id}" + + self.avatar_path = self.base_path + self.full_imgs_path = f"{self.avatar_path}/full_imgs" self.coords_path = f"{self.avatar_path}/coords.pkl" - self.latents_out_path= f"{self.avatar_path}/latents.pt" + self.latents_out_path = f"{self.avatar_path}/latents.pt" self.video_out_path = f"{self.avatar_path}/vid_output/" - self.mask_out_path =f"{self.avatar_path}/mask" - self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl" + self.mask_out_path = f"{self.avatar_path}/mask" + self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl" self.avatar_info_path = f"{self.avatar_path}/avator_info.json" self.avatar_info = { - "avatar_id":avatar_id, - "video_path":video_path, - "bbox_shift":bbox_shift + "avatar_id": avatar_id, + "video_path": video_path, + "bbox_shift": bbox_shift, + "version": args.version } self.preparation = preparation self.batch_size = batch_size self.idx = 0 self.init() - + def init(self): if self.preparation: if os.path.exists(self.avatar_path): @@ -80,7 +84,7 @@ class Avatar: print("*********************************") print(f" creating avator: {self.avatar_id}") print("*********************************") - osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) + osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path]) self.prepare_material() else: self.input_latent_list_cycle = torch.load(self.latents_out_path) @@ -98,16 +102,16 @@ class Avatar: print("*********************************") print(f" creating avator: {self.avatar_id}") print("*********************************") - osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) + osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path]) self.prepare_material() - else: + else: if not os.path.exists(self.avatar_path): print(f"{self.avatar_id} does not exist, you should set preparation to True") sys.exit() with open(self.avatar_info_path, "r") as f: avatar_info = json.load(f) - + if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']: response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)") if response.lower() == "c": @@ -115,11 +119,11 @@ class Avatar: print("*********************************") print(f" creating avator: {self.avatar_id}") print("*********************************") - osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) + osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path]) self.prepare_material() else: sys.exit() - else: + else: self.input_latent_list_cycle = torch.load(self.latents_out_path) with open(self.coords_path, 'rb') as f: self.coord_list_cycle = pickle.load(f) @@ -131,36 +135,40 @@ class Avatar: input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) self.mask_list_cycle = read_imgs(input_mask_list) - + def prepare_material(self): print("preparing data materials ... ...") with open(self.avatar_info_path, "w") as f: json.dump(self.avatar_info, f) - + if os.path.isfile(self.video_path): - video2imgs(self.video_path, self.full_imgs_path, ext = 'png') + video2imgs(self.video_path, self.full_imgs_path, ext='png') else: print(f"copy files in {self.video_path}") files = os.listdir(self.video_path) files.sort() - files = [file for file in files if file.split(".")[-1]=="png"] + files = [file for file in files if file.split(".")[-1] == "png"] for filename in files: shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}") input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))) - + print("extracting landmarks...") coord_list, frame_list = get_landmark_and_bbox(input_img_list, self.bbox_shift) input_latent_list = [] idx = -1 - # maker if the bbox is not sufficient - coord_placeholder = (0.0,0.0,0.0,0.0) + # maker if the bbox is not sufficient + coord_placeholder = (0.0, 0.0, 0.0, 0.0) for bbox, frame in zip(coord_list, frame_list): idx = idx + 1 if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox + if args.version == "v15": + y2 = y2 + args.extra_margin + y2 = min(y2, frame.shape[0]) + coord_list[idx] = [x1, y1, x2, y2] # 更新coord_list中的bbox crop_frame = frame[y1:y2, x1:x2] - resized_crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) + resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) latents = vae.get_latents_for_unet(resized_crop_frame) input_latent_list.append(latents) @@ -170,112 +178,116 @@ class Avatar: self.mask_coords_list_cycle = [] self.mask_list_cycle = [] - for i,frame in enumerate(tqdm(self.frame_list_cycle)): - cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png",frame) - - face_box = self.coord_list_cycle[i] - mask,crop_box = get_image_prepare_material(frame,face_box) - cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png",mask) + for i, frame in enumerate(tqdm(self.frame_list_cycle)): + cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png", frame) + + x1, y1, x2, y2 = self.coord_list_cycle[i] + if args.version == "v15": + mode = args.parsing_mode + else: + mode = "raw" + mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp, mode=mode) + + cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png", mask) self.mask_coords_list_cycle += [crop_box] self.mask_list_cycle.append(mask) - + with open(self.mask_coords_path, 'wb') as f: pickle.dump(self.mask_coords_list_cycle, f) with open(self.coords_path, 'wb') as f: pickle.dump(self.coord_list_cycle, f) - - torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path)) - # - - def process_frames(self, - res_frame_queue, - video_len, - skip_save_images): + + torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path)) + + def process_frames(self, res_frame_queue, video_len, skip_save_images): print(video_len) while True: - if self.idx>=video_len-1: + if self.idx >= video_len - 1: break try: start = time.time() res_frame = res_frame_queue.get(block=True, timeout=1) except queue.Empty: continue - - bbox = self.coord_list_cycle[self.idx%(len(self.coord_list_cycle))] - ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx%(len(self.frame_list_cycle))]) + + bbox = self.coord_list_cycle[self.idx % (len(self.coord_list_cycle))] + ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx % (len(self.frame_list_cycle))]) x1, y1, x2, y2 = bbox try: - res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) + res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) except: continue - mask = self.mask_list_cycle[self.idx%(len(self.mask_list_cycle))] - mask_crop_box = self.mask_coords_list_cycle[self.idx%(len(self.mask_coords_list_cycle))] - #combine_frame = get_image(ori_frame,res_frame,bbox) + mask = self.mask_list_cycle[self.idx % (len(self.mask_list_cycle))] + mask_crop_box = self.mask_coords_list_cycle[self.idx % (len(self.mask_coords_list_cycle))] combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) if skip_save_images is False: - cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame) + cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png", combine_frame) self.idx = self.idx + 1 - def inference(self, - audio_path, - out_vid_name, - fps, - skip_save_images): - os.makedirs(self.avatar_path+'/tmp',exist_ok =True) + def inference(self, audio_path, out_vid_name, fps, skip_save_images): + os.makedirs(self.avatar_path + '/tmp', exist_ok=True) print("start inference") ############################################## extract audio feature ############################################## start_time = time.time() - whisper_feature = audio_processor.audio2feat(audio_path) - whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) + # Extract audio features + whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path, weight_dtype=weight_dtype) + whisper_chunks = audio_processor.get_whisper_chunk( + whisper_input_features, + device, + weight_dtype, + whisper, + librosa_length, + fps=fps, + audio_padding_length_left=args.audio_padding_length_left, + audio_padding_length_right=args.audio_padding_length_right, + ) print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms") ############################################## inference batch by batch ############################################## - video_num = len(whisper_chunks) + video_num = len(whisper_chunks) res_frame_queue = queue.Queue() self.idx = 0 - # # Create a sub-thread and start it + # Create a sub-thread and start it process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images)) process_thread.start() gen = datagen(whisper_chunks, - self.input_latent_list_cycle, - self.batch_size) + self.input_latent_list_cycle, + self.batch_size) start_time = time.time() res_frame_list = [] - - for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))): - audio_feature_batch = torch.from_numpy(whisper_batch) - audio_feature_batch = audio_feature_batch.to(device=unet.device, - dtype=unet.model.dtype) - audio_feature_batch = pe(audio_feature_batch) - latent_batch = latent_batch.to(dtype=unet.model.dtype) - pred_latents = unet.model(latent_batch, - timesteps, - encoder_hidden_states=audio_feature_batch).sample + for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=int(np.ceil(float(video_num) / self.batch_size)))): + audio_feature_batch = pe(whisper_batch.to(device)) + latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype) + + pred_latents = unet.model(latent_batch, + timesteps, + encoder_hidden_states=audio_feature_batch).sample + pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype) recon = vae.decode_latents(pred_latents) for res_frame in recon: res_frame_queue.put(res_frame) # Close the queue and sub-thread after all tasks are completed process_thread.join() - + if args.skip_save_images is True: print('Total process time of {} frames without saving images = {}s'.format( - video_num, - time.time()-start_time)) + video_num, + time.time() - start_time)) else: print('Total process time of {} frames including saving images = {}s'.format( - video_num, - time.time()-start_time)) + video_num, + time.time() - start_time)) - if out_vid_name is not None and args.skip_save_images is False: + if out_vid_name is not None and args.skip_save_images is False: # optional - cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4" + cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {self.avatar_path}/temp.mp4" print(cmd_img2video) os.system(cmd_img2video) - output_vid = os.path.join(self.video_out_path, out_vid_name+".mp4") # on + output_vid = os.path.join(self.video_out_path, out_vid_name + ".mp4") # on cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {self.avatar_path}/temp.mp4 {output_vid}" print(cmd_combine_audio) os.system(cmd_combine_audio) @@ -284,52 +296,95 @@ class Avatar: shutil.rmtree(f"{self.avatar_path}/tmp") print(f"result is save to {output_vid}") print("\n") - + if __name__ == "__main__": ''' This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time. ''' - + parser = argparse.ArgumentParser() - parser.add_argument("--inference_config", - type=str, - default="configs/inference/realtime.yaml", - ) - parser.add_argument("--fps", - type=int, - default=25, - ) - parser.add_argument("--batch_size", - type=int, - default=4, - ) + parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Version of MuseTalk: v1 or v15") + parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable") + parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use") + parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model") + parser.add_argument("--unet_config", type=str, default="./models/musetalk/musetalk.json", help="Path to UNet configuration file") + parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights") + parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model") + parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml") + parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value") + parser.add_argument("--result_dir", default='./results', help="Directory for output results") + parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping") + parser.add_argument("--fps", type=int, default=25, help="Video frames per second") + parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio") + parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio") + parser.add_argument("--batch_size", type=int, default=25, help="Batch size for inference") + parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file") + parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time') + parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use') + parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode") + parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region") + parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region") parser.add_argument("--skip_save_images", - action="store_true", - help="Whether skip saving images for better generation speed calculation", - ) + action="store_true", + help="Whether skip saving images for better generation speed calculation", + ) args = parser.parse_args() - + + # Set computing device + device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu") + + # Load model weights + vae, unet, pe = load_all_model( + unet_model_path=args.unet_model_path, + vae_type=args.vae_type, + unet_config=args.unet_config, + device=device + ) + timesteps = torch.tensor([0], device=device) + + pe = pe.half().to(device) + vae.vae = vae.vae.half().to(device) + unet.model = unet.model.half().to(device) + + # Initialize audio processor and Whisper model + audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir) + weight_dtype = unet.model.dtype + whisper = WhisperModel.from_pretrained(args.whisper_dir) + whisper = whisper.to(device=device, dtype=weight_dtype).eval() + whisper.requires_grad_(False) + + # Initialize face parser with configurable parameters based on version + if args.version == "v15": + fp = FaceParsing( + left_cheek_width=args.left_cheek_width, + right_cheek_width=args.right_cheek_width + ) + else: # v1 + fp = FaceParsing() + inference_config = OmegaConf.load(args.inference_config) print(inference_config) - - + for avatar_id in inference_config: data_preparation = inference_config[avatar_id]["preparation"] video_path = inference_config[avatar_id]["video_path"] - bbox_shift = inference_config[avatar_id]["bbox_shift"] + if args.version == "v15": + bbox_shift = 0 + else: + bbox_shift = inference_config[avatar_id]["bbox_shift"] avatar = Avatar( - avatar_id = avatar_id, - video_path = video_path, - bbox_shift = bbox_shift, - batch_size = args.batch_size, - preparation= data_preparation) - + avatar_id=avatar_id, + video_path=video_path, + bbox_shift=bbox_shift, + batch_size=args.batch_size, + preparation=data_preparation) + audio_clips = inference_config[avatar_id]["audio_clips"] for audio_num, audio_path in audio_clips.items(): - print("Inferring using:",audio_path) - avatar.inference(audio_path, - audio_num, - args.fps, - args.skip_save_images) + print("Inferring using:", audio_path) + avatar.inference(audio_path, + audio_num, + args.fps, + args.skip_save_images)