<enhance>: modified inference codes

1. Can set bbox_shift in configs/inference/test.yaml
2. Do not need to pip install whisper now
This commit is contained in:
zkangchen
2024-04-03 14:35:55 +08:00
parent dde2ee49ef
commit bc1379abad
18 changed files with 28 additions and 96 deletions

4
.gitignore vendored
View File

@@ -4,7 +4,7 @@
.vscode/
*.pyc
.ipynb_checkpoints
models/
models
results/
data/audio/*.WAV
data/audio/*.wav
data/video/*.mp4

View File

@@ -175,11 +175,6 @@ We recommend a python version >=3.10 and cuda version =11.7. Then build environm
```shell
pip install -r requirements.txt
```
### whisper
install whisper to extract audio feature (only encoder)
```
pip install --editable ./musetalk/whisper
```
### mmlab packages
```bash
@@ -256,13 +251,13 @@ As a complete solution to virtual human generation, you are suggested to first a
# Note
If you want to launch online video chats, you are suggested to generate videos using MuseV and apply necessary pre-processing such as face detection in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
If you want to launch online video chats, you are suggested to generate videos using MuseV and apply necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
# Acknowledgement
1. We thank open-source components like [whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers).
1. MuseTalk has been built on `HDTF` datasets.
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.
Thanks for open-sourcing!

View File

@@ -1,9 +1,10 @@
task_0:
video_path: "data/video/monalisa.mp4"
audio_path: "data/audio/monalisa.wav"
video_path: "data/video/yongen.mp4"
audio_path: "data/audio/yongen.wav"
task_1:
video_path: "data/video/sun.mp4"
audio_path: "data/audio/sun.wav"
bbox_shift: -7

Binary file not shown.

BIN
data/audio/yongen.wav Normal file

Binary file not shown.

Binary file not shown.

BIN
data/video/yongen.mp4 Normal file

Binary file not shown.

View File

@@ -52,7 +52,6 @@ def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2):
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
mask_image = Image.fromarray(mask_array)
mask_image.save("./debug_mask.png")
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
body.paste(face_large, crop_box[:2], mask_image)

View File

@@ -1,7 +1,5 @@
import os
#import whisper
from whisper import load_model
#import whisper.whispher as whiisper
from .whisper import load_model
import soundfile as sf
import numpy as np
import time
@@ -9,11 +7,12 @@ import sys
sys.path.append("..")
class Audio2Feature():
def __init__(self, whisper_model_type="tiny",model_path="./checkpoints/wisper_tiny.pt"):
def __init__(self,
whisper_model_type="tiny",
model_path="./models/whisper/tiny.pt"):
self.whisper_model_type = whisper_model_type
self.model = load_model(model_path) #
def get_sliced_feature(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
"""
Get sliced features based on a given index

View File

@@ -1,6 +0,0 @@
numpy
torch
tqdm
more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0

View File

@@ -1,24 +0,0 @@
import os
import pkg_resources
from setuptools import setup, find_packages
setup(
name="whisper",
py_modules=["whisper"],
version="1.0",
description="",
author="OpenAI",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
entry_points = {
'console_scripts': ['whisper=whisper.transcribe:cli'],
},
include_package_data=True,
extras_require={'dev': ['pytest']},
)

View File

@@ -1,5 +0,0 @@
Metadata-Version: 2.1
Name: whisper
Version: 1.0
Author: OpenAI
Provides-Extra: dev

View File

@@ -1,18 +0,0 @@
setup.py
whisper/__init__.py
whisper/__main__.py
whisper/audio.py
whisper/decoding.py
whisper/model.py
whisper/tokenizer.py
whisper/transcribe.py
whisper/utils.py
whisper.egg-info/PKG-INFO
whisper.egg-info/SOURCES.txt
whisper.egg-info/dependency_links.txt
whisper.egg-info/entry_points.txt
whisper.egg-info/requires.txt
whisper.egg-info/top_level.txt
whisper/normalizers/__init__.py
whisper/normalizers/basic.py
whisper/normalizers/english.py

View File

@@ -1,2 +0,0 @@
[console_scripts]
whisper = whisper.transcribe:cli

View File

@@ -1,9 +0,0 @@
numpy
torch
tqdm
more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0
[dev]
pytest

View File

@@ -1 +0,0 @@
whisper

View File

@@ -13,6 +13,7 @@ from musetalk.utils.utils import get_file_type,get_video_fps,datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model
import shutil
# load model weights
audio_processor,vae,unet,pe = load_all_model()
@@ -26,6 +27,7 @@ def main(args):
for task_id in inference_config:
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
@@ -42,7 +44,7 @@ def main(args):
if get_file_type(video_path)=="video":
save_dir_full = os.path.join(args.result_dir, input_basename)
os.makedirs(save_dir_full,exist_ok = True)
cmd = f"ffmpeg -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
@@ -62,7 +64,7 @@ def main(args):
frame_list = read_imgs(input_img_list)
else:
print("extracting landmarks...time consuming")
coord_list, frame_list = get_landmark_and_bbox(input_img_list,args.bbox_shift)
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
@@ -117,24 +119,26 @@ def main(args):
print(cmd_img2video)
os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -i {audio_path} -i temp.mp4 {output_vid_name} -y"
cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
print(cmd_combine_audio)
os.system(cmd_combine_audio)
os.system("rm temp.mp4")
os.system(f"rm -rf {result_img_save_path}")
os.remove("temp.mp4")
shutil.rmtree(result_img_save_path)
print(f"result is save to {output_vid_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--inference_config",type=str, default="configs/inference/test_img.yaml")
parser.add_argument("--bbox_shift",type=int, default=0)
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
parser.add_argument("--bbox_shift", type=int, default=0)
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--fps",type=int, default=25)
parser.add_argument("--batch_size",type=int, default=8)
parser.add_argument("--output_vid_name",type=str,default='')
parser.add_argument("--use_saved_coord",action="store_true", help='use saved coordinate to save time')
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--output_vid_name", type=str,default='')
parser.add_argument("--use_saved_coord",
action="store_true",
help='use saved coordinate to save time')
args = parser.parse_args()