Merge pull request #281 from zzzweakman/fix/infer

fix: infer bug
This commit is contained in:
Octpus
2025-03-31 21:32:52 +08:00
committed by GitHub
6 changed files with 126 additions and 42 deletions

7
.gitignore vendored
View File

@@ -4,7 +4,8 @@
.vscode/
*.pyc
.ipynb_checkpoints
models
results/
data/audio/*.wav
data/video/*.mp4
./models
**/__pycache__/
*.py[cod]
*$py.class

View File

@@ -177,7 +177,7 @@ You can download weights manually as follows:
2. Download the weights of other components:
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
- [whisper](https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt)
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
@@ -201,7 +201,10 @@ Finally, these weights should be organized in `models` as follows:
│ ├── config.json
│ └── diffusion_pytorch_model.bin
└── whisper
── tiny.pt
── config.json
├── pytorch_model.bin
└── preprocessor_config.json
```
## Quickstart
@@ -210,7 +213,7 @@ We provide inference scripts for both versions of MuseTalk:
#### MuseTalk 1.5 (Recommended)
```bash
python3 -m scripts.inference_alpha --inference_config configs/inference/test.yaml --unet_model_path ./models/musetalkV15/unet.pth
sh inference.sh v1.5
```
This inference script supports both MuseTalk 1.5 and 1.0 models:
- For MuseTalk 1.5: Use the command above with the V1.5 model path
@@ -221,7 +224,7 @@ The video_path should be either a video file, an image file or a directory of im
#### MuseTalk 1.0
```bash
python3 -m scripts.inference --inference_config configs/inference/test.yaml
sh inference.sh v1.0
```
You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
<details close>

42
inference.sh Normal file
View File

@@ -0,0 +1,42 @@
#!/bin/bash
# This script runs inference based on the version specified by the user.
# Usage:
# To run v1.0 inference: sh inference.sh v1.0
# To run v1.5 inference: sh inference.sh v1.5
# Check if the correct number of arguments is provided
if [ "$#" -ne 1 ]; then
echo "Usage: $0 <version>"
echo "Example: $0 v1.0 or $0 v1.5"
exit 1
fi
# Get the version from the user input
version=$1
config_path="./configs/inference/test.yaml"
# Define the model paths based on the version
if [ "$version" = "v1.0" ]; then
model_dir="./models/musetalk"
unet_model_path="$model_dir/pytorch_model.bin"
elif [ "$version" = "v1.5" ]; then
model_dir="./models/musetalkV15"
unet_model_path="$model_dir/unet.pth"
else
echo "Invalid version specified. Please use v1.0 or v1.5."
exit 1
fi
# Run inference based on the version
if [ "$version" = "v1.0" ]; then
python3 -m scripts.inference \
--inference_config "$config_path" \
--result_dir "./results/test" \
--unet_model_path "$unet_model_path"
elif [ "$version" = "v1.5" ]; then
python3 -m scripts.inference_alpha \
--inference_config "$config_path" \
--result_dir "./results/test" \
--unet_model_path "$unet_model_path"
fi

View File

@@ -91,7 +91,7 @@ class AudioProcessor:
if __name__ == "__main__":
audio_processor = AudioProcessor()
wav_path = "/cfs-workspace/users/gozhong/codes/musetalk_opensource2/data/audio/2.wav"
wav_path = "./2.wav"
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
print("Audio Feature shape:", audio_feature.shape)
print("librosa_feature_length:", librosa_feature_length)

View File

@@ -1,32 +1,58 @@
import argparse
import os
from omegaconf import OmegaConf
import numpy as np
import cv2
import torch
import glob
import pickle
from tqdm import tqdm
import copy
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model
import glob
import torch
import shutil
import pickle
import argparse
import numpy as np
from tqdm import tqdm
from omegaconf import OmegaConf
from transformers import WhisperModel
from musetalk.utils.blending import get_image
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
# load model weights
audio_processor, vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
@torch.no_grad()
def main(args):
global pe
# Configure ffmpeg path
if args.ffmpeg_path not in os.getenv('PATH'):
print("Adding ffmpeg to PATH")
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
# Set computing device
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
# Load model weights
vae, unet, pe = load_all_model(
unet_model_path=args.unet_model_path,
vae_type=args.vae_type,
unet_config=args.unet_config,
device=device
)
timesteps = torch.tensor([0], device=device)
if args.use_float16 is True:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
# Initialize audio processor and Whisper model
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
weight_dtype = unet.model.dtype
whisper = WhisperModel.from_pretrained(args.whisper_dir)
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
# Initialize face parser
fp = FaceParsing()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
@@ -64,10 +90,20 @@ def main(args):
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
#print(input_img_list)
############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
# Extract audio features
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
)
############################################## preprocess input image ##############################################
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates")
@@ -102,10 +138,7 @@ def main(args):
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype) # torch, B, 5*N,384
audio_feature_batch = pe(audio_feature_batch)
audio_feature_batch = pe(whisper_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
@@ -122,10 +155,10 @@ def main(args):
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
# print(bbox)
continue
combine_frame = get_image(ori_frame,res_frame,bbox)
# Merge results
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
@@ -142,11 +175,11 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
parser.add_argument("--bbox_shift", type=int, default=0)
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--output_vid_name", type=str, default=None)
parser.add_argument("--use_saved_coord",
@@ -156,6 +189,12 @@ if __name__ == "__main__":
action="store_true",
help="Whether use float16 to speed up inference",
)
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
args = parser.parse_args()
main(args)

View File

@@ -72,8 +72,7 @@ def main(args):
audio_path = inference_config[task_id]["audio_path"]
if "result_name" in inference_config[task_id]:
args.output_vid_name = inference_config[task_id]["result_name"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
bbox_shift = args.bbox_shift
# Set output paths
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
@@ -228,12 +227,12 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_path", type=str, default="/cfs-workspace/users/gozhong/ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
parser.add_argument("--unet_model_path", type=str, default="/cfs-datasets/users/gozhong/codes/musetalk_exp/exp_out/stage1_bs40/unet-20000.pth", help="Path to UNet model weights")
parser.add_argument("--whisper_dir", type=str, default="/cfs-datasets/public_models/whisper-tiny", help="Directory containing Whisper model")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
parser.add_argument("--result_dir", default='./results', help="Directory for output results")