diff --git a/.gitignore b/.gitignore index aa31084..c83750f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,11 +5,14 @@ *.pyc .ipynb_checkpoints results/ -./models +models/ **/__pycache__/ *.py[cod] *$py.class dataset/ ffmpeg* +ffmprobe* +ffplay* debug -exp_out \ No newline at end of file +exp_out +.gradio \ No newline at end of file diff --git a/README.md b/README.md index 6bf3549..49e76c4 100644 --- a/README.md +++ b/README.md @@ -146,50 +146,87 @@ We also hope you note that we have not verified, maintained, or updated third-pa ## Installation To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below: -### Build environment -We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows: +### Build environment +We recommend Python 3.10 and CUDA 11.7. Set up your environment as follows: + +```shell +conda create -n MuseTalk python==3.10 +conda activate MuseTalk +``` + +### Install PyTorch 2.0.1 +Choose one of the following installation methods: + +```shell +# Option 1: Using pip +pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 + +# Option 2: Using conda +conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia +``` + +### Install Dependencies +Install the remaining required packages: ```shell pip install -r requirements.txt ``` -### mmlab packages +### Install MMLab Packages +Install the MMLab ecosystem packages: + ```bash -pip install --no-cache-dir -U openmim -mim install mmengine -mim install "mmcv>=2.0.1" -mim install "mmdet>=3.1.0" -mim install "mmpose>=1.1.0" +pip install --no-cache-dir -U openmim +mim install mmengine +mim install "mmcv==2.0.1" +mim install "mmdet==3.1.0" +mim install "mmpose==1.1.0" ``` -### Download ffmpeg-static -Download the ffmpeg-static and -``` +### Setup FFmpeg +1. [Download](https://github.com/BtbN/FFmpeg-Builds/releases) the ffmpeg-static package + +2. Configure FFmpeg based on your operating system: + +For Linux: +```bash export FFMPEG_PATH=/path/to/ffmpeg -``` -for example: -``` +# Example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static ``` -### Download weights -You can download weights manually as follows: -1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk). +For Windows: +Add the `ffmpeg-xxx\bin` directory to your system's PATH environment variable. Verify the installation by running `ffmpeg -version` in the command prompt - it should display the ffmpeg version information. + +### Download weights +You can download weights in two ways: + +#### Option 1: Using Download Scripts +We provide two scripts for automatic downloading: + +For Linux: ```bash -# !pip install -U "huggingface_hub[cli]" -export HF_ENDPOINT=https://hf-mirror.com -huggingface-cli download TMElyralab/MuseTalk --local-dir models/ +sh ./download_weights.sh ``` +For Windows: +```batch +# Run the script +download_weights.bat +``` + +#### Option 2: Manual Download +You can also download the weights manually from the following links: + +1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk/tree/main) 2. Download the weights of other components: - - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) + - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main) - [whisper](https://huggingface.co/openai/whisper-tiny/tree/main) - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main) - - [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch) - - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth) - [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main) - + - [face-parse-bisent](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view?pli=1) + - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth) Finally, these weights should be organized in `models` as follows: ``` @@ -207,7 +244,7 @@ Finally, these weights should be organized in `models` as follows: ├── face-parse-bisent │ ├── 79999_iter.pth │ └── resnet18-5c106cde.pth -├── sd-vae-ft-mse +├── sd-vae │ ├── config.json │ └── diffusion_pytorch_model.bin └── whisper @@ -221,21 +258,60 @@ Finally, these weights should be organized in `models` as follows: ### Inference We provide inference scripts for both versions of MuseTalk: -#### MuseTalk 1.5 (Recommended) +#### Prerequisites +Before running inference, please ensure ffmpeg is installed and accessible: ```bash -# Run MuseTalk 1.5 inference -sh inference.sh v1.5 normal +# Check ffmpeg installation +ffmpeg -version ``` +If ffmpeg is not found, please install it first: +- Windows: Download from [ffmpeg-static](https://github.com/BtbN/FFmpeg-Builds/releases) and add to PATH +- Linux: `sudo apt-get install ffmpeg` -#### MuseTalk 1.0 +#### Normal Inference +##### Linux Environment ```bash -# Run MuseTalk 1.0 inference +# MuseTalk 1.5 (Recommended) +sh inference.sh v1.5 normal + +# MuseTalk 1.0 sh inference.sh v1.0 normal ``` -The inference script supports both MuseTalk 1.5 and 1.0 models: -- For MuseTalk 1.5: Use the command above with the V1.5 model path -- For MuseTalk 1.0: Use the same script but point to the V1.0 model path +##### Windows Environment + +Please ensure that you set the `ffmpeg_path` to match the actual location of your FFmpeg installation. + +```bash +# MuseTalk 1.5 (Recommended) +python -m scripts.inference --inference_config configs\inference\test.yaml --result_dir results\test --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin + +# For MuseTalk 1.0, change: +# - models\musetalkV15 -> models\musetalk +# - unet.pth -> pytorch_model.bin +# - --version v15 -> --version v1 +``` + +#### Real-time Inference +##### Linux Environment +```bash +# MuseTalk 1.5 (Recommended) +sh inference.sh v1.5 realtime + +# MuseTalk 1.0 +sh inference.sh v1.0 realtime +``` + +##### Windows Environment +```bash +# MuseTalk 1.5 (Recommended) +python -m scripts.realtime_inference --inference_config configs\inference\realtime.yaml --result_dir results\realtime --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --fps 25 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin + +# For MuseTalk 1.0, change: +# - models\musetalkV15 -> models\musetalk +# - unet.pth -> pytorch_model.bin +# - --version v15 -> --version v1 +``` The configuration file `configs/inference/test.yaml` contains the inference settings, including: - `video_path`: Path to the input video, image file, or directory of images @@ -243,21 +319,6 @@ The configuration file `configs/inference/test.yaml` contains the inference sett Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg. -#### Real-time Inference -For real-time inference, use the following command: -```bash -# Run real-time inference -sh inference.sh v1.5 realtime # For MuseTalk 1.5 -# or -sh inference.sh v1.0 realtime # For MuseTalk 1.0 -``` - -The real-time inference configuration is in `configs/inference/realtime.yaml`, which includes: -- `preparation`: Set to `True` for new avatar preparation -- `video_path`: Path to the input video -- `bbox_shift`: Adjustable parameter for mouth region control -- `audio_clips`: List of audio clips for generation - Important notes for real-time inference: 1. Set `preparation` to `True` when processing a new avatar 2. After preparation, the avatar will generate videos using audio clips from `audio_clips` @@ -269,6 +330,18 @@ For faster generation without saving images, you can use: python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images ``` +## Gradio Demo +We provide an intuitive web interface through Gradio for users to easily adjust input parameters. To optimize inference time, users can generate only the **first frame** to fine-tune the best lip-sync parameters, which helps reduce facial artifacts in the final output. + +For minimum hardware requirements, we tested the system on a Windows environment using an NVIDIA GeForce RTX 3050 Ti Laptop GPU with 4GB VRAM. In fp16 mode, generating an 8-second video takes approximately 5 minutes.  + +Both Linux and Windows users can launch the demo using the following command. Please ensure that the `ffmpeg_path` parameter matches your actual FFmpeg installation path: + +```bash +# You can remove --use_float16 for better quality, but it will increase VRAM usage and inference time +python app.py --use_float16 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin +``` + ## Training ### Data Preparation diff --git a/app.py b/app.py index feec239..448e641 100644 --- a/app.py +++ b/app.py @@ -4,7 +4,6 @@ import pdb import re import gradio as gr -import spaces import numpy as np import sys import subprocess @@ -28,11 +27,101 @@ import gdown import imageio import ffmpeg from moviepy.editor import * - +from transformers import WhisperModel ProjectDir = os.path.abspath(os.path.dirname(__file__)) CheckpointsDir = os.path.join(ProjectDir, "models") +@torch.no_grad() +def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw", + left_cheek_width=90, right_cheek_width=90): + """Debug inpainting parameters, only process the first frame""" + # Set default parameters + args_dict = { + "result_dir": './results/debug', + "fps": 25, + "batch_size": 1, + "output_vid_name": '', + "use_saved_coord": False, + "audio_padding_length_left": 2, + "audio_padding_length_right": 2, + "version": "v15", + "extra_margin": extra_margin, + "parsing_mode": parsing_mode, + "left_cheek_width": left_cheek_width, + "right_cheek_width": right_cheek_width + } + args = Namespace(**args_dict) + + # Create debug directory + os.makedirs(args.result_dir, exist_ok=True) + + # Read first frame + if get_file_type(video_path) == "video": + reader = imageio.get_reader(video_path) + first_frame = reader.get_data(0) + reader.close() + else: + first_frame = cv2.imread(video_path) + first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) + + # Save first frame + debug_frame_path = os.path.join(args.result_dir, "debug_frame.png") + cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR)) + + # Get face coordinates + coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift) + bbox = coord_list[0] + frame = frame_list[0] + + if bbox == coord_placeholder: + return None, "No face detected, please adjust bbox_shift parameter" + + # Initialize face parser + fp = FaceParsing( + left_cheek_width=args.left_cheek_width, + right_cheek_width=args.right_cheek_width + ) + + # Process first frame + x1, y1, x2, y2 = bbox + y2 = y2 + args.extra_margin + y2 = min(y2, frame.shape[0]) + crop_frame = frame[y1:y2, x1:x2] + crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) + + # Generate random audio features + random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype) + audio_feature = pe(random_audio) + + # Get latents + latents = vae.get_latents_for_unet(crop_frame) + latents = latents.to(dtype=weight_dtype) + + # Generate prediction results + pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample + recon = vae.decode_latents(pred_latents) + + # Inpaint back to original image + res_frame = recon[0] + res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) + combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp) + + # Save results (no need to convert color space again since get_image already returns RGB format) + debug_result_path = os.path.join(args.result_dir, "debug_result.png") + cv2.imwrite(debug_result_path, combine_frame) + + # Create information text + info_text = f"Parameter information:\n" + \ + f"bbox_shift: {bbox_shift}\n" + \ + f"extra_margin: {extra_margin}\n" + \ + f"parsing_mode: {parsing_mode}\n" + \ + f"left_cheek_width: {left_cheek_width}\n" + \ + f"right_cheek_width: {right_cheek_width}\n" + \ + f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]" + + return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text + def print_directory_contents(path): for child in os.listdir(path): child_path = os.path.join(path, child) @@ -40,119 +129,107 @@ def print_directory_contents(path): print(child_path) def download_model(): - if not os.path.exists(CheckpointsDir): - os.makedirs(CheckpointsDir) - print("Checkpoint Not Downloaded, start downloading...") - tic = time.time() - snapshot_download( - repo_id="TMElyralab/MuseTalk", - local_dir=CheckpointsDir, - max_workers=8, - local_dir_use_symlinks=True, - force_download=True, resume_download=False - ) - # weight - os.makedirs(f"{CheckpointsDir}/sd-vae-ft-mse/") - snapshot_download( - repo_id="stabilityai/sd-vae-ft-mse", - local_dir=CheckpointsDir+'/sd-vae-ft-mse', - max_workers=8, - local_dir_use_symlinks=True, - force_download=True, resume_download=False - ) - #dwpose - os.makedirs(f"{CheckpointsDir}/dwpose/") - snapshot_download( - repo_id="yzd-v/DWPose", - local_dir=CheckpointsDir+'/dwpose', - max_workers=8, - local_dir_use_symlinks=True, - force_download=True, resume_download=False - ) - #vae - url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt" - response = requests.get(url) - # 确保请求成功 - if response.status_code == 200: - # 指定文件保存的位置 - file_path = f"{CheckpointsDir}/whisper/tiny.pt" - os.makedirs(f"{CheckpointsDir}/whisper/") - # 将文件内容写入指定位置 - with open(file_path, "wb") as f: - f.write(response.content) + # 检查必需的模型文件是否存在 + required_models = { + "MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth", + "MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json", + "SD VAE": f"{CheckpointsDir}/sd-vae/config.json", + "Whisper": f"{CheckpointsDir}/whisper/config.json", + "DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth", + "SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt", + "Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth", + "ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth" + } + + missing_models = [] + for model_name, model_path in required_models.items(): + if not os.path.exists(model_path): + missing_models.append(model_name) + + if missing_models: + # 全用英文 + print("The following required model files are missing:") + for model in missing_models: + print(f"- {model}") + print("\nPlease run the download script to download the missing models:") + if sys.platform == "win32": + print("Windows: Run download_weights.bat") else: - print(f"请求失败,状态码:{response.status_code}") - #gdown face parse - url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812" - os.makedirs(f"{CheckpointsDir}/face-parse-bisent/") - file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth" - gdown.download(url, file_path, quiet=False) - #resnet - url = "https://download.pytorch.org/models/resnet18-5c106cde.pth" - response = requests.get(url) - # 确保请求成功 - if response.status_code == 200: - # 指定文件保存的位置 - file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth" - # 将文件内容写入指定位置 - with open(file_path, "wb") as f: - f.write(response.content) - else: - print(f"请求失败,状态码:{response.status_code}") - - - toc = time.time() - - print(f"download cost {toc-tic} seconds") - print_directory_contents(CheckpointsDir) - + print("Linux/Mac: Run ./download_weights.sh") + sys.exit(1) else: - print("Already download the model.") - + print("All required model files exist.") download_model() # for huggingface deployment. - -from musetalk.utils.utils import get_file_type,get_video_fps,datagen -from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder,get_bbox_range from musetalk.utils.blending import get_image -from musetalk.utils.utils import load_all_model +from musetalk.utils.face_parsing import FaceParsing +from musetalk.utils.audio_processor import AudioProcessor +from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model +from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range +def fast_check_ffmpeg(): + try: + subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) + return True + except: + return False - - -@spaces.GPU(duration=600) @torch.no_grad() -def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=True)): - args_dict={"result_dir":'./results/output', "fps":25, "batch_size":8, "output_vid_name":'', "use_saved_coord":False}#same with inferenece script +def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw", + left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)): + # Set default parameters, aligned with inference.py + args_dict = { + "result_dir": './results/output', + "fps": 25, + "batch_size": 8, + "output_vid_name": '', + "use_saved_coord": False, + "audio_padding_length_left": 2, + "audio_padding_length_right": 2, + "version": "v15", # Fixed use v15 version + "extra_margin": extra_margin, + "parsing_mode": parsing_mode, + "left_cheek_width": left_cheek_width, + "right_cheek_width": right_cheek_width + } args = Namespace(**args_dict) - input_basename = os.path.basename(video_path).split('.')[0] - audio_basename = os.path.basename(audio_path).split('.')[0] - output_basename = f"{input_basename}_{audio_basename}" - result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs - crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input - os.makedirs(result_img_save_path,exist_ok =True) + # Check ffmpeg + if not fast_check_ffmpeg(): + print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed") - if args.output_vid_name=="": - output_vid_name = os.path.join(args.result_dir, output_basename+".mp4") + input_basename = os.path.basename(video_path).split('.')[0] + audio_basename = os.path.basename(audio_path).split('.')[0] + output_basename = f"{input_basename}_{audio_basename}" + + # Create temporary directory + temp_dir = os.path.join(args.result_dir, f"{args.version}") + os.makedirs(temp_dir, exist_ok=True) + + # Set result save path + result_img_save_path = os.path.join(temp_dir, output_basename) + crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl") + os.makedirs(result_img_save_path, exist_ok=True) + + if args.output_vid_name == "": + output_vid_name = os.path.join(temp_dir, output_basename+".mp4") else: - output_vid_name = os.path.join(args.result_dir, args.output_vid_name) + output_vid_name = os.path.join(temp_dir, args.output_vid_name) + ############################################## extract frames from source video ############################################## - if get_file_type(video_path)=="video": - save_dir_full = os.path.join(args.result_dir, input_basename) - os.makedirs(save_dir_full,exist_ok = True) - # cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" - # os.system(cmd) - # 读取视频 + if get_file_type(video_path) == "video": + save_dir_full = os.path.join(temp_dir, input_basename) + os.makedirs(save_dir_full, exist_ok=True) + # Read video reader = imageio.get_reader(video_path) - # 保存图片 + # Save images for i, im in enumerate(reader): imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im) input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) @@ -161,10 +238,21 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) fps = args.fps - #print(input_img_list) + ############################################## extract audio feature ############################################## - whisper_feature = audio_processor.audio2feat(audio_path) - whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) + # Extract audio features + whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) + whisper_chunks = audio_processor.get_whisper_chunk( + whisper_input_features, + device, + weight_dtype, + whisper, + librosa_length, + fps=fps, + audio_padding_length_left=args.audio_padding_length_left, + audio_padding_length_right=args.audio_padding_length_right, + ) + ############################################## preprocess input image ############################################## if os.path.exists(crop_coord_save_path) and args.use_saved_coord: print("using extracted coordinates") @@ -176,13 +264,22 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) with open(crop_coord_save_path, 'wb') as f: pickle.dump(coord_list, f) - bbox_shift_text=get_bbox_range(input_img_list, bbox_shift) + bbox_shift_text = get_bbox_range(input_img_list, bbox_shift) + + # Initialize face parser + fp = FaceParsing( + left_cheek_width=args.left_cheek_width, + right_cheek_width=args.right_cheek_width + ) + i = 0 input_latent_list = [] for bbox, frame in zip(coord_list, frame_list): if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox + y2 = y2 + args.extra_margin + y2 = min(y2, frame.shape[0]) crop_frame = frame[y1:y2, x1:x2] crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) latents = vae.get_latents_for_unet(crop_frame) @@ -192,17 +289,23 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] + ############################################## inference batch by batch ############################################## print("start inference") video_num = len(whisper_chunks) batch_size = args.batch_size - gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) + gen = datagen( + whisper_chunks=whisper_chunks, + vae_encode_latents=input_latent_list_cycle, + batch_size=batch_size, + delay_frame=0, + device=device, + ) res_frame_list = [] for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): - - tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch] - audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384 - audio_feature_batch = pe(audio_feature_batch) + audio_feature_batch = pe(whisper_batch) + # Ensure latent_batch is consistent with model weight type + latent_batch = latent_batch.to(dtype=weight_dtype) pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample recon = vae.decode_latents(pred_latents) @@ -215,25 +318,24 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T bbox = coord_list_cycle[i%(len(coord_list_cycle))] ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) x1, y1, x2, y2 = bbox + y2 = y2 + args.extra_margin + y2 = min(y2, frame.shape[0]) try: res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) except: - # print(bbox) continue - combine_frame = get_image(ori_frame,res_frame,bbox) + # Use v15 version blending + combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp) + cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) - # cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p temp.mp4" - # print(cmd_img2video) - # os.system(cmd_img2video) - # 帧率 + # Frame rate fps = 25 - # 图片路径 - # 输出视频路径 + # Output video path output_video = 'temp.mp4' - # 读取图片 + # Read images def is_valid_image(file): pattern = re.compile(r'\d{8}\.png') return pattern.match(file) @@ -247,13 +349,9 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T images.append(imageio.imread(filename)) - # 保存视频 + # Save video imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p') - # cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}" - # print(cmd_combine_audio) - # os.system(cmd_combine_audio) - input_video = './temp.mp4' # Check if the input_video and audio_path exist if not os.path.exists(input_video): @@ -261,40 +359,15 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") - # 读取视频 + # Read video reader = imageio.get_reader(input_video) - fps = reader.get_meta_data()['fps'] # 获取原视频的帧率 - reader.close() # 否则在win11上会报错:PermissionError: [WinError 32] 另一个程序正在使用此文件,进程无法访问。: 'temp.mp4' - # 将帧存储在列表中 + fps = reader.get_meta_data()['fps'] # Get original video frame rate + reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4' + # Store frames in list frames = images - - # 保存视频并添加音频 - # imageio.mimwrite(output_vid_name, frames, 'FFMPEG', fps=fps, codec='libx264', audio_codec='aac', input_params=['-i', audio_path]) - - # input_video = ffmpeg.input(input_video) - - # input_audio = ffmpeg.input(audio_path) print(len(frames)) - # imageio.mimwrite( - # output_video, - # frames, - # 'FFMPEG', - # fps=25, - # codec='libx264', - # audio_codec='aac', - # input_params=['-i', audio_path], - # output_params=['-y'], # Add the '-y' flag to overwrite the output file if it exists - # ) - # writer = imageio.get_writer(output_vid_name, fps = 25, codec='libx264', quality=10, pixelformat='yuvj444p') - # for im in frames: - # writer.append_data(im) - # writer.close() - - - - # Load the video video_clip = VideoFileClip(input_video) @@ -315,11 +388,45 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T # load model weights -audio_processor,vae,unet,pe = load_all_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +vae, unet, pe = load_all_model( + unet_model_path="./models/musetalkV15/unet.pth", + vae_type="sd-vae", + unet_config="./models/musetalkV15/musetalk.json", + device=device +) + +# Parse command line arguments +parser = argparse.ArgumentParser() +parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable") +parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to") +parser.add_argument("--port", type=int, default=7860, help="Port to bind to") +parser.add_argument("--share", action="store_true", help="Create a public link") +parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference") +args = parser.parse_args() + +# Set data type +if args.use_float16: + # Convert models to half precision for better performance + pe = pe.half() + vae.vae = vae.vae.half() + unet.model = unet.model.half() + weight_dtype = torch.float16 +else: + weight_dtype = torch.float32 + +# Move models to specified device +pe = pe.to(device) +vae.vae = vae.vae.to(device) +unet.model = unet.model.to(device) + timesteps = torch.tensor([0], device=device) - +# Initialize audio processor and Whisper model +audio_processor = AudioProcessor(feature_extractor_path="./models/whisper") +whisper = WhisperModel.from_pretrained("./models/whisper") +whisper = whisper.to(device=device, dtype=weight_dtype).eval() +whisper.requires_grad_(False) def check_video(video): @@ -340,9 +447,6 @@ def check_video(video): output_video = os.path.join('./results/input', output_file_name) - # # Run the ffmpeg command to change the frame rate to 25fps - # command = f"ffmpeg -i {video} -r 25 -vcodec libx264 -vtag hvc1 -pix_fmt yuv420p crf 18 {output_video} -y" - # read video reader = imageio.get_reader(video) fps = reader.get_meta_data()['fps'] # get fps from original video @@ -374,34 +478,45 @@ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024p with gr.Blocks(css=css) as demo: gr.Markdown( - "