diff --git a/.gitignore b/.gitignore index 8f69ccd..b0f4f3c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,8 @@ .vscode/ *.pyc .ipynb_checkpoints -models results/ -data/audio/*.wav -data/video/*.mp4 +./models +**/__pycache__/ +*.py[cod] +*$py.class \ No newline at end of file diff --git a/README.md b/README.md index 9d11e1c..4a281d5 100644 --- a/README.md +++ b/README.md @@ -177,7 +177,7 @@ You can download weights manually as follows: 2. Download the weights of other components: - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) - - [whisper](https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt) + - [whisper](https://huggingface.co/openai/whisper-tiny/tree/main) - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main) - [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch) - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth) @@ -201,7 +201,10 @@ Finally, these weights should be organized in `models` as follows: │ ├── config.json │ └── diffusion_pytorch_model.bin └── whisper - └── tiny.pt + ├── config.json + ├── pytorch_model.bin + └── preprocessor_config.json + ``` ## Quickstart @@ -210,7 +213,7 @@ We provide inference scripts for both versions of MuseTalk: #### MuseTalk 1.5 (Recommended) ```bash -python3 -m scripts.inference_alpha --inference_config configs/inference/test.yaml --unet_model_path ./models/musetalkV15/unet.pth +sh inference.sh v1.5 ``` This inference script supports both MuseTalk 1.5 and 1.0 models: - For MuseTalk 1.5: Use the command above with the V1.5 model path @@ -221,7 +224,7 @@ The video_path should be either a video file, an image file or a directory of im #### MuseTalk 1.0 ```bash -python3 -m scripts.inference --inference_config configs/inference/test.yaml +sh inference.sh v1.0 ``` You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
diff --git a/inference.sh b/inference.sh new file mode 100644 index 0000000..e2d6f4b --- /dev/null +++ b/inference.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# This script runs inference based on the version specified by the user. +# Usage: +# To run v1.0 inference: sh inference.sh v1.0 +# To run v1.5 inference: sh inference.sh v1.5 + +# Check if the correct number of arguments is provided +if [ "$#" -ne 1 ]; then + echo "Usage: $0 " + echo "Example: $0 v1.0 or $0 v1.5" + exit 1 +fi + +# Get the version from the user input +version=$1 +config_path="./configs/inference/test.yaml" + +# Define the model paths based on the version +if [ "$version" = "v1.0" ]; then + model_dir="./models/musetalk" + unet_model_path="$model_dir/pytorch_model.bin" +elif [ "$version" = "v1.5" ]; then + model_dir="./models/musetalkV15" + unet_model_path="$model_dir/unet.pth" +else + echo "Invalid version specified. Please use v1.0 or v1.5." + exit 1 +fi + +# Run inference based on the version +if [ "$version" = "v1.0" ]; then + python3 -m scripts.inference \ + --inference_config "$config_path" \ + --result_dir "./results/test" \ + --unet_model_path "$unet_model_path" +elif [ "$version" = "v1.5" ]; then + python3 -m scripts.inference_alpha \ + --inference_config "$config_path" \ + --result_dir "./results/test" \ + --unet_model_path "$unet_model_path" +fi \ No newline at end of file diff --git a/musetalk/utils/audio_processor.py b/musetalk/utils/audio_processor.py index 13e46c1..740aacb 100755 --- a/musetalk/utils/audio_processor.py +++ b/musetalk/utils/audio_processor.py @@ -91,7 +91,7 @@ class AudioProcessor: if __name__ == "__main__": audio_processor = AudioProcessor() - wav_path = "/cfs-workspace/users/gozhong/codes/musetalk_opensource2/data/audio/2.wav" + wav_path = "./2.wav" audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path) print("Audio Feature shape:", audio_feature.shape) print("librosa_feature_length:", librosa_feature_length) diff --git a/scripts/inference.py b/scripts/inference.py index fe2a234..671dc43 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -1,32 +1,58 @@ -import argparse import os -from omegaconf import OmegaConf -import numpy as np import cv2 -import torch -import glob -import pickle -from tqdm import tqdm import copy - -from musetalk.utils.utils import get_file_type,get_video_fps,datagen -from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder -from musetalk.utils.blending import get_image -from musetalk.utils.utils import load_all_model +import glob +import torch import shutil +import pickle +import argparse +import numpy as np +from tqdm import tqdm +from omegaconf import OmegaConf +from transformers import WhisperModel + +from musetalk.utils.blending import get_image +from musetalk.utils.face_parsing import FaceParsing +from musetalk.utils.audio_processor import AudioProcessor +from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model +from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder + -# load model weights -audio_processor, vae, unet, pe = load_all_model() -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -timesteps = torch.tensor([0], device=device) @torch.no_grad() def main(args): - global pe + # Configure ffmpeg path + if args.ffmpeg_path not in os.getenv('PATH'): + print("Adding ffmpeg to PATH") + os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}" + + # Set computing device + device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu") + + # Load model weights + vae, unet, pe = load_all_model( + unet_model_path=args.unet_model_path, + vae_type=args.vae_type, + unet_config=args.unet_config, + device=device + ) + timesteps = torch.tensor([0], device=device) + + if args.use_float16 is True: pe = pe.half() vae.vae = vae.vae.half() unet.model = unet.model.half() + + # Initialize audio processor and Whisper model + audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir) + weight_dtype = unet.model.dtype + whisper = WhisperModel.from_pretrained(args.whisper_dir) + whisper = whisper.to(device=device, dtype=weight_dtype).eval() + whisper.requires_grad_(False) + + # Initialize face parser + fp = FaceParsing() inference_config = OmegaConf.load(args.inference_config) print(inference_config) @@ -64,10 +90,20 @@ def main(args): else: raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") - #print(input_img_list) ############################################## extract audio feature ############################################## - whisper_feature = audio_processor.audio2feat(audio_path) - whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) + # Extract audio features + whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) + whisper_chunks = audio_processor.get_whisper_chunk( + whisper_input_features, + device, + weight_dtype, + whisper, + librosa_length, + fps=fps, + audio_padding_length_left=args.audio_padding_length_left, + audio_padding_length_right=args.audio_padding_length_right, + ) + ############################################## preprocess input image ############################################## if os.path.exists(crop_coord_save_path) and args.use_saved_coord: print("using extracted coordinates") @@ -102,10 +138,7 @@ def main(args): gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) res_frame_list = [] for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): - audio_feature_batch = torch.from_numpy(whisper_batch) - audio_feature_batch = audio_feature_batch.to(device=unet.device, - dtype=unet.model.dtype) # torch, B, 5*N,384 - audio_feature_batch = pe(audio_feature_batch) + audio_feature_batch = pe(whisper_batch) latent_batch = latent_batch.to(dtype=unet.model.dtype) pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample @@ -122,10 +155,10 @@ def main(args): try: res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) except: -# print(bbox) continue - combine_frame = get_image(ori_frame,res_frame,bbox) + # Merge results + combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp) cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" @@ -142,11 +175,11 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable") parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml") parser.add_argument("--bbox_shift", type=int, default=0) parser.add_argument("--result_dir", default='./results', help="path to output") - - parser.add_argument("--fps", type=int, default=25) + parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use") parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--output_vid_name", type=str, default=None) parser.add_argument("--use_saved_coord", @@ -156,6 +189,12 @@ if __name__ == "__main__": action="store_true", help="Whether use float16 to speed up inference", ) - + parser.add_argument("--fps", type=int, default=25, help="Video frames per second") + parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights") + parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model") + parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file") + parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model") + parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio") + parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio") args = parser.parse_args() main(args) diff --git a/scripts/inference_alpha.py b/scripts/inference_alpha.py index b48f3ea..498be0f 100644 --- a/scripts/inference_alpha.py +++ b/scripts/inference_alpha.py @@ -72,8 +72,7 @@ def main(args): audio_path = inference_config[task_id]["audio_path"] if "result_name" in inference_config[task_id]: args.output_vid_name = inference_config[task_id]["result_name"] - bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) - + bbox_shift = args.bbox_shift # Set output paths input_basename = os.path.basename(video_path).split('.')[0] audio_basename = os.path.basename(audio_path).split('.')[0] @@ -228,12 +227,12 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--ffmpeg_path", type=str, default="/cfs-workspace/users/gozhong/ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable") + parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable") parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use") parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model") parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file") - parser.add_argument("--unet_model_path", type=str, default="/cfs-datasets/users/gozhong/codes/musetalk_exp/exp_out/stage1_bs40/unet-20000.pth", help="Path to UNet model weights") - parser.add_argument("--whisper_dir", type=str, default="/cfs-datasets/public_models/whisper-tiny", help="Directory containing Whisper model") + parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights") + parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model") parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file") parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value") parser.add_argument("--result_dir", default='./results', help="Directory for output results")