mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 09:29:20 +08:00
feat: windows infer & gradio (#312)
* fix: windows infer * docs: update readme * docs: update readme * feat: v1.5 gradio for windows&linux * fix: dependencies * feat: windows infer & gradio --------- Co-authored-by: NeRF-Factory <zzhizhou66@gmail.com>
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -5,11 +5,14 @@
|
|||||||
*.pyc
|
*.pyc
|
||||||
.ipynb_checkpoints
|
.ipynb_checkpoints
|
||||||
results/
|
results/
|
||||||
./models
|
models/
|
||||||
**/__pycache__/
|
**/__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
dataset/
|
dataset/
|
||||||
ffmpeg*
|
ffmpeg*
|
||||||
|
ffmprobe*
|
||||||
|
ffplay*
|
||||||
debug
|
debug
|
||||||
exp_out
|
exp_out
|
||||||
|
.gradio
|
||||||
165
README.md
165
README.md
@@ -146,50 +146,87 @@ We also hope you note that we have not verified, maintained, or updated third-pa
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
|
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
|
||||||
### Build environment
|
|
||||||
|
|
||||||
We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows:
|
### Build environment
|
||||||
|
We recommend Python 3.10 and CUDA 11.7. Set up your environment as follows:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
conda create -n MuseTalk python==3.10
|
||||||
|
conda activate MuseTalk
|
||||||
|
```
|
||||||
|
|
||||||
|
### Install PyTorch 2.0.1
|
||||||
|
Choose one of the following installation methods:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Option 1: Using pip
|
||||||
|
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
|
||||||
|
|
||||||
|
# Option 2: Using conda
|
||||||
|
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||||
|
```
|
||||||
|
|
||||||
|
### Install Dependencies
|
||||||
|
Install the remaining required packages:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### mmlab packages
|
### Install MMLab Packages
|
||||||
|
Install the MMLab ecosystem packages:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install --no-cache-dir -U openmim
|
pip install --no-cache-dir -U openmim
|
||||||
mim install mmengine
|
mim install mmengine
|
||||||
mim install "mmcv>=2.0.1"
|
mim install "mmcv==2.0.1"
|
||||||
mim install "mmdet>=3.1.0"
|
mim install "mmdet==3.1.0"
|
||||||
mim install "mmpose>=1.1.0"
|
mim install "mmpose==1.1.0"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Download ffmpeg-static
|
### Setup FFmpeg
|
||||||
Download the ffmpeg-static and
|
1. [Download](https://github.com/BtbN/FFmpeg-Builds/releases) the ffmpeg-static package
|
||||||
```
|
|
||||||
|
2. Configure FFmpeg based on your operating system:
|
||||||
|
|
||||||
|
For Linux:
|
||||||
|
```bash
|
||||||
export FFMPEG_PATH=/path/to/ffmpeg
|
export FFMPEG_PATH=/path/to/ffmpeg
|
||||||
```
|
# Example:
|
||||||
for example:
|
|
||||||
```
|
|
||||||
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
|
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
|
||||||
```
|
```
|
||||||
### Download weights
|
|
||||||
You can download weights manually as follows:
|
|
||||||
|
|
||||||
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk).
|
For Windows:
|
||||||
|
Add the `ffmpeg-xxx\bin` directory to your system's PATH environment variable. Verify the installation by running `ffmpeg -version` in the command prompt - it should display the ffmpeg version information.
|
||||||
|
|
||||||
|
### Download weights
|
||||||
|
You can download weights in two ways:
|
||||||
|
|
||||||
|
#### Option 1: Using Download Scripts
|
||||||
|
We provide two scripts for automatic downloading:
|
||||||
|
|
||||||
|
For Linux:
|
||||||
```bash
|
```bash
|
||||||
# !pip install -U "huggingface_hub[cli]"
|
sh ./download_weights.sh
|
||||||
export HF_ENDPOINT=https://hf-mirror.com
|
|
||||||
huggingface-cli download TMElyralab/MuseTalk --local-dir models/
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For Windows:
|
||||||
|
```batch
|
||||||
|
# Run the script
|
||||||
|
download_weights.bat
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Option 2: Manual Download
|
||||||
|
You can also download the weights manually from the following links:
|
||||||
|
|
||||||
|
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk/tree/main)
|
||||||
2. Download the weights of other components:
|
2. Download the weights of other components:
|
||||||
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
|
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)
|
||||||
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
|
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
|
||||||
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
||||||
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
|
|
||||||
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
|
||||||
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
|
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
|
||||||
|
- [face-parse-bisent](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view?pli=1)
|
||||||
|
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
||||||
|
|
||||||
Finally, these weights should be organized in `models` as follows:
|
Finally, these weights should be organized in `models` as follows:
|
||||||
```
|
```
|
||||||
@@ -207,7 +244,7 @@ Finally, these weights should be organized in `models` as follows:
|
|||||||
├── face-parse-bisent
|
├── face-parse-bisent
|
||||||
│ ├── 79999_iter.pth
|
│ ├── 79999_iter.pth
|
||||||
│ └── resnet18-5c106cde.pth
|
│ └── resnet18-5c106cde.pth
|
||||||
├── sd-vae-ft-mse
|
├── sd-vae
|
||||||
│ ├── config.json
|
│ ├── config.json
|
||||||
│ └── diffusion_pytorch_model.bin
|
│ └── diffusion_pytorch_model.bin
|
||||||
└── whisper
|
└── whisper
|
||||||
@@ -221,21 +258,60 @@ Finally, these weights should be organized in `models` as follows:
|
|||||||
### Inference
|
### Inference
|
||||||
We provide inference scripts for both versions of MuseTalk:
|
We provide inference scripts for both versions of MuseTalk:
|
||||||
|
|
||||||
#### MuseTalk 1.5 (Recommended)
|
#### Prerequisites
|
||||||
|
Before running inference, please ensure ffmpeg is installed and accessible:
|
||||||
```bash
|
```bash
|
||||||
# Run MuseTalk 1.5 inference
|
# Check ffmpeg installation
|
||||||
sh inference.sh v1.5 normal
|
ffmpeg -version
|
||||||
```
|
```
|
||||||
|
If ffmpeg is not found, please install it first:
|
||||||
|
- Windows: Download from [ffmpeg-static](https://github.com/BtbN/FFmpeg-Builds/releases) and add to PATH
|
||||||
|
- Linux: `sudo apt-get install ffmpeg`
|
||||||
|
|
||||||
#### MuseTalk 1.0
|
#### Normal Inference
|
||||||
|
##### Linux Environment
|
||||||
```bash
|
```bash
|
||||||
# Run MuseTalk 1.0 inference
|
# MuseTalk 1.5 (Recommended)
|
||||||
|
sh inference.sh v1.5 normal
|
||||||
|
|
||||||
|
# MuseTalk 1.0
|
||||||
sh inference.sh v1.0 normal
|
sh inference.sh v1.0 normal
|
||||||
```
|
```
|
||||||
|
|
||||||
The inference script supports both MuseTalk 1.5 and 1.0 models:
|
##### Windows Environment
|
||||||
- For MuseTalk 1.5: Use the command above with the V1.5 model path
|
|
||||||
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
|
Please ensure that you set the `ffmpeg_path` to match the actual location of your FFmpeg installation.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# MuseTalk 1.5 (Recommended)
|
||||||
|
python -m scripts.inference --inference_config configs\inference\test.yaml --result_dir results\test --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||||
|
|
||||||
|
# For MuseTalk 1.0, change:
|
||||||
|
# - models\musetalkV15 -> models\musetalk
|
||||||
|
# - unet.pth -> pytorch_model.bin
|
||||||
|
# - --version v15 -> --version v1
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Real-time Inference
|
||||||
|
##### Linux Environment
|
||||||
|
```bash
|
||||||
|
# MuseTalk 1.5 (Recommended)
|
||||||
|
sh inference.sh v1.5 realtime
|
||||||
|
|
||||||
|
# MuseTalk 1.0
|
||||||
|
sh inference.sh v1.0 realtime
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Windows Environment
|
||||||
|
```bash
|
||||||
|
# MuseTalk 1.5 (Recommended)
|
||||||
|
python -m scripts.realtime_inference --inference_config configs\inference\realtime.yaml --result_dir results\realtime --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --fps 25 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||||
|
|
||||||
|
# For MuseTalk 1.0, change:
|
||||||
|
# - models\musetalkV15 -> models\musetalk
|
||||||
|
# - unet.pth -> pytorch_model.bin
|
||||||
|
# - --version v15 -> --version v1
|
||||||
|
```
|
||||||
|
|
||||||
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
|
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
|
||||||
- `video_path`: Path to the input video, image file, or directory of images
|
- `video_path`: Path to the input video, image file, or directory of images
|
||||||
@@ -243,21 +319,6 @@ The configuration file `configs/inference/test.yaml` contains the inference sett
|
|||||||
|
|
||||||
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
|
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
|
||||||
|
|
||||||
#### Real-time Inference
|
|
||||||
For real-time inference, use the following command:
|
|
||||||
```bash
|
|
||||||
# Run real-time inference
|
|
||||||
sh inference.sh v1.5 realtime # For MuseTalk 1.5
|
|
||||||
# or
|
|
||||||
sh inference.sh v1.0 realtime # For MuseTalk 1.0
|
|
||||||
```
|
|
||||||
|
|
||||||
The real-time inference configuration is in `configs/inference/realtime.yaml`, which includes:
|
|
||||||
- `preparation`: Set to `True` for new avatar preparation
|
|
||||||
- `video_path`: Path to the input video
|
|
||||||
- `bbox_shift`: Adjustable parameter for mouth region control
|
|
||||||
- `audio_clips`: List of audio clips for generation
|
|
||||||
|
|
||||||
Important notes for real-time inference:
|
Important notes for real-time inference:
|
||||||
1. Set `preparation` to `True` when processing a new avatar
|
1. Set `preparation` to `True` when processing a new avatar
|
||||||
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
|
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
|
||||||
@@ -269,6 +330,18 @@ For faster generation without saving images, you can use:
|
|||||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Gradio Demo
|
||||||
|
We provide an intuitive web interface through Gradio for users to easily adjust input parameters. To optimize inference time, users can generate only the **first frame** to fine-tune the best lip-sync parameters, which helps reduce facial artifacts in the final output.
|
||||||
|

|
||||||
|
For minimum hardware requirements, we tested the system on a Windows environment using an NVIDIA GeForce RTX 3050 Ti Laptop GPU with 4GB VRAM. In fp16 mode, generating an 8-second video takes approximately 5 minutes. 
|
||||||
|
|
||||||
|
Both Linux and Windows users can launch the demo using the following command. Please ensure that the `ffmpeg_path` parameter matches your actual FFmpeg installation path:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# You can remove --use_float16 for better quality, but it will increase VRAM usage and inference time
|
||||||
|
python app.py --use_float16 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||||
|
```
|
||||||
|
|
||||||
## Training
|
## Training
|
||||||
|
|
||||||
### Data Preparation
|
### Data Preparation
|
||||||
|
|||||||
476
app.py
476
app.py
@@ -4,7 +4,6 @@ import pdb
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import spaces
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -28,11 +27,101 @@ import gdown
|
|||||||
import imageio
|
import imageio
|
||||||
import ffmpeg
|
import ffmpeg
|
||||||
from moviepy.editor import *
|
from moviepy.editor import *
|
||||||
|
from transformers import WhisperModel
|
||||||
|
|
||||||
ProjectDir = os.path.abspath(os.path.dirname(__file__))
|
ProjectDir = os.path.abspath(os.path.dirname(__file__))
|
||||||
CheckpointsDir = os.path.join(ProjectDir, "models")
|
CheckpointsDir = os.path.join(ProjectDir, "models")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
||||||
|
left_cheek_width=90, right_cheek_width=90):
|
||||||
|
"""Debug inpainting parameters, only process the first frame"""
|
||||||
|
# Set default parameters
|
||||||
|
args_dict = {
|
||||||
|
"result_dir": './results/debug',
|
||||||
|
"fps": 25,
|
||||||
|
"batch_size": 1,
|
||||||
|
"output_vid_name": '',
|
||||||
|
"use_saved_coord": False,
|
||||||
|
"audio_padding_length_left": 2,
|
||||||
|
"audio_padding_length_right": 2,
|
||||||
|
"version": "v15",
|
||||||
|
"extra_margin": extra_margin,
|
||||||
|
"parsing_mode": parsing_mode,
|
||||||
|
"left_cheek_width": left_cheek_width,
|
||||||
|
"right_cheek_width": right_cheek_width
|
||||||
|
}
|
||||||
|
args = Namespace(**args_dict)
|
||||||
|
|
||||||
|
# Create debug directory
|
||||||
|
os.makedirs(args.result_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Read first frame
|
||||||
|
if get_file_type(video_path) == "video":
|
||||||
|
reader = imageio.get_reader(video_path)
|
||||||
|
first_frame = reader.get_data(0)
|
||||||
|
reader.close()
|
||||||
|
else:
|
||||||
|
first_frame = cv2.imread(video_path)
|
||||||
|
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# Save first frame
|
||||||
|
debug_frame_path = os.path.join(args.result_dir, "debug_frame.png")
|
||||||
|
cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
|
||||||
|
|
||||||
|
# Get face coordinates
|
||||||
|
coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift)
|
||||||
|
bbox = coord_list[0]
|
||||||
|
frame = frame_list[0]
|
||||||
|
|
||||||
|
if bbox == coord_placeholder:
|
||||||
|
return None, "No face detected, please adjust bbox_shift parameter"
|
||||||
|
|
||||||
|
# Initialize face parser
|
||||||
|
fp = FaceParsing(
|
||||||
|
left_cheek_width=args.left_cheek_width,
|
||||||
|
right_cheek_width=args.right_cheek_width
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process first frame
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
y2 = y2 + args.extra_margin
|
||||||
|
y2 = min(y2, frame.shape[0])
|
||||||
|
crop_frame = frame[y1:y2, x1:x2]
|
||||||
|
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||||
|
|
||||||
|
# Generate random audio features
|
||||||
|
random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype)
|
||||||
|
audio_feature = pe(random_audio)
|
||||||
|
|
||||||
|
# Get latents
|
||||||
|
latents = vae.get_latents_for_unet(crop_frame)
|
||||||
|
latents = latents.to(dtype=weight_dtype)
|
||||||
|
|
||||||
|
# Generate prediction results
|
||||||
|
pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample
|
||||||
|
recon = vae.decode_latents(pred_latents)
|
||||||
|
|
||||||
|
# Inpaint back to original image
|
||||||
|
res_frame = recon[0]
|
||||||
|
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||||
|
combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||||
|
|
||||||
|
# Save results (no need to convert color space again since get_image already returns RGB format)
|
||||||
|
debug_result_path = os.path.join(args.result_dir, "debug_result.png")
|
||||||
|
cv2.imwrite(debug_result_path, combine_frame)
|
||||||
|
|
||||||
|
# Create information text
|
||||||
|
info_text = f"Parameter information:\n" + \
|
||||||
|
f"bbox_shift: {bbox_shift}\n" + \
|
||||||
|
f"extra_margin: {extra_margin}\n" + \
|
||||||
|
f"parsing_mode: {parsing_mode}\n" + \
|
||||||
|
f"left_cheek_width: {left_cheek_width}\n" + \
|
||||||
|
f"right_cheek_width: {right_cheek_width}\n" + \
|
||||||
|
f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]"
|
||||||
|
|
||||||
|
return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text
|
||||||
|
|
||||||
def print_directory_contents(path):
|
def print_directory_contents(path):
|
||||||
for child in os.listdir(path):
|
for child in os.listdir(path):
|
||||||
child_path = os.path.join(path, child)
|
child_path = os.path.join(path, child)
|
||||||
@@ -40,119 +129,107 @@ def print_directory_contents(path):
|
|||||||
print(child_path)
|
print(child_path)
|
||||||
|
|
||||||
def download_model():
|
def download_model():
|
||||||
if not os.path.exists(CheckpointsDir):
|
# 检查必需的模型文件是否存在
|
||||||
os.makedirs(CheckpointsDir)
|
required_models = {
|
||||||
print("Checkpoint Not Downloaded, start downloading...")
|
"MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth",
|
||||||
tic = time.time()
|
"MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json",
|
||||||
snapshot_download(
|
"SD VAE": f"{CheckpointsDir}/sd-vae/config.json",
|
||||||
repo_id="TMElyralab/MuseTalk",
|
"Whisper": f"{CheckpointsDir}/whisper/config.json",
|
||||||
local_dir=CheckpointsDir,
|
"DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth",
|
||||||
max_workers=8,
|
"SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt",
|
||||||
local_dir_use_symlinks=True,
|
"Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth",
|
||||||
force_download=True, resume_download=False
|
"ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
|
||||||
)
|
}
|
||||||
# weight
|
|
||||||
os.makedirs(f"{CheckpointsDir}/sd-vae-ft-mse/")
|
missing_models = []
|
||||||
snapshot_download(
|
for model_name, model_path in required_models.items():
|
||||||
repo_id="stabilityai/sd-vae-ft-mse",
|
if not os.path.exists(model_path):
|
||||||
local_dir=CheckpointsDir+'/sd-vae-ft-mse',
|
missing_models.append(model_name)
|
||||||
max_workers=8,
|
|
||||||
local_dir_use_symlinks=True,
|
if missing_models:
|
||||||
force_download=True, resume_download=False
|
# 全用英文
|
||||||
)
|
print("The following required model files are missing:")
|
||||||
#dwpose
|
for model in missing_models:
|
||||||
os.makedirs(f"{CheckpointsDir}/dwpose/")
|
print(f"- {model}")
|
||||||
snapshot_download(
|
print("\nPlease run the download script to download the missing models:")
|
||||||
repo_id="yzd-v/DWPose",
|
if sys.platform == "win32":
|
||||||
local_dir=CheckpointsDir+'/dwpose',
|
print("Windows: Run download_weights.bat")
|
||||||
max_workers=8,
|
|
||||||
local_dir_use_symlinks=True,
|
|
||||||
force_download=True, resume_download=False
|
|
||||||
)
|
|
||||||
#vae
|
|
||||||
url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
|
|
||||||
response = requests.get(url)
|
|
||||||
# 确保请求成功
|
|
||||||
if response.status_code == 200:
|
|
||||||
# 指定文件保存的位置
|
|
||||||
file_path = f"{CheckpointsDir}/whisper/tiny.pt"
|
|
||||||
os.makedirs(f"{CheckpointsDir}/whisper/")
|
|
||||||
# 将文件内容写入指定位置
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
else:
|
else:
|
||||||
print(f"请求失败,状态码:{response.status_code}")
|
print("Linux/Mac: Run ./download_weights.sh")
|
||||||
#gdown face parse
|
sys.exit(1)
|
||||||
url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812"
|
|
||||||
os.makedirs(f"{CheckpointsDir}/face-parse-bisent/")
|
|
||||||
file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth"
|
|
||||||
gdown.download(url, file_path, quiet=False)
|
|
||||||
#resnet
|
|
||||||
url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
|
|
||||||
response = requests.get(url)
|
|
||||||
# 确保请求成功
|
|
||||||
if response.status_code == 200:
|
|
||||||
# 指定文件保存的位置
|
|
||||||
file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
|
|
||||||
# 将文件内容写入指定位置
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
else:
|
|
||||||
print(f"请求失败,状态码:{response.status_code}")
|
|
||||||
|
|
||||||
|
|
||||||
toc = time.time()
|
|
||||||
|
|
||||||
print(f"download cost {toc-tic} seconds")
|
|
||||||
print_directory_contents(CheckpointsDir)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("Already download the model.")
|
print("All required model files exist.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
download_model() # for huggingface deployment.
|
download_model() # for huggingface deployment.
|
||||||
|
|
||||||
|
|
||||||
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
|
||||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder,get_bbox_range
|
|
||||||
from musetalk.utils.blending import get_image
|
from musetalk.utils.blending import get_image
|
||||||
from musetalk.utils.utils import load_all_model
|
from musetalk.utils.face_parsing import FaceParsing
|
||||||
|
from musetalk.utils.audio_processor import AudioProcessor
|
||||||
|
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||||
|
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range
|
||||||
|
|
||||||
|
|
||||||
|
def fast_check_ffmpeg():
|
||||||
|
try:
|
||||||
|
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@spaces.GPU(duration=600)
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=True)):
|
def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
||||||
args_dict={"result_dir":'./results/output', "fps":25, "batch_size":8, "output_vid_name":'', "use_saved_coord":False}#same with inferenece script
|
left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)):
|
||||||
|
# Set default parameters, aligned with inference.py
|
||||||
|
args_dict = {
|
||||||
|
"result_dir": './results/output',
|
||||||
|
"fps": 25,
|
||||||
|
"batch_size": 8,
|
||||||
|
"output_vid_name": '',
|
||||||
|
"use_saved_coord": False,
|
||||||
|
"audio_padding_length_left": 2,
|
||||||
|
"audio_padding_length_right": 2,
|
||||||
|
"version": "v15", # Fixed use v15 version
|
||||||
|
"extra_margin": extra_margin,
|
||||||
|
"parsing_mode": parsing_mode,
|
||||||
|
"left_cheek_width": left_cheek_width,
|
||||||
|
"right_cheek_width": right_cheek_width
|
||||||
|
}
|
||||||
args = Namespace(**args_dict)
|
args = Namespace(**args_dict)
|
||||||
|
|
||||||
input_basename = os.path.basename(video_path).split('.')[0]
|
# Check ffmpeg
|
||||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
if not fast_check_ffmpeg():
|
||||||
output_basename = f"{input_basename}_{audio_basename}"
|
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
|
||||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
|
||||||
os.makedirs(result_img_save_path,exist_ok =True)
|
|
||||||
|
|
||||||
if args.output_vid_name=="":
|
input_basename = os.path.basename(video_path).split('.')[0]
|
||||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||||
|
output_basename = f"{input_basename}_{audio_basename}"
|
||||||
|
|
||||||
|
# Create temporary directory
|
||||||
|
temp_dir = os.path.join(args.result_dir, f"{args.version}")
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Set result save path
|
||||||
|
result_img_save_path = os.path.join(temp_dir, output_basename)
|
||||||
|
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
|
||||||
|
os.makedirs(result_img_save_path, exist_ok=True)
|
||||||
|
|
||||||
|
if args.output_vid_name == "":
|
||||||
|
output_vid_name = os.path.join(temp_dir, output_basename+".mp4")
|
||||||
else:
|
else:
|
||||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
||||||
|
|
||||||
############################################## extract frames from source video ##############################################
|
############################################## extract frames from source video ##############################################
|
||||||
if get_file_type(video_path)=="video":
|
if get_file_type(video_path) == "video":
|
||||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
save_dir_full = os.path.join(temp_dir, input_basename)
|
||||||
os.makedirs(save_dir_full,exist_ok = True)
|
os.makedirs(save_dir_full, exist_ok=True)
|
||||||
# cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
# Read video
|
||||||
# os.system(cmd)
|
|
||||||
# 读取视频
|
|
||||||
reader = imageio.get_reader(video_path)
|
reader = imageio.get_reader(video_path)
|
||||||
|
|
||||||
# 保存图片
|
# Save images
|
||||||
for i, im in enumerate(reader):
|
for i, im in enumerate(reader):
|
||||||
imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
|
imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
|
||||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||||
@@ -161,10 +238,21 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|||||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
fps = args.fps
|
fps = args.fps
|
||||||
#print(input_img_list)
|
|
||||||
############################################## extract audio feature ##############################################
|
############################################## extract audio feature ##############################################
|
||||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
# Extract audio features
|
||||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||||
|
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||||
|
whisper_input_features,
|
||||||
|
device,
|
||||||
|
weight_dtype,
|
||||||
|
whisper,
|
||||||
|
librosa_length,
|
||||||
|
fps=fps,
|
||||||
|
audio_padding_length_left=args.audio_padding_length_left,
|
||||||
|
audio_padding_length_right=args.audio_padding_length_right,
|
||||||
|
)
|
||||||
|
|
||||||
############################################## preprocess input image ##############################################
|
############################################## preprocess input image ##############################################
|
||||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||||
print("using extracted coordinates")
|
print("using extracted coordinates")
|
||||||
@@ -176,13 +264,22 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|||||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||||
with open(crop_coord_save_path, 'wb') as f:
|
with open(crop_coord_save_path, 'wb') as f:
|
||||||
pickle.dump(coord_list, f)
|
pickle.dump(coord_list, f)
|
||||||
bbox_shift_text=get_bbox_range(input_img_list, bbox_shift)
|
bbox_shift_text = get_bbox_range(input_img_list, bbox_shift)
|
||||||
|
|
||||||
|
# Initialize face parser
|
||||||
|
fp = FaceParsing(
|
||||||
|
left_cheek_width=args.left_cheek_width,
|
||||||
|
right_cheek_width=args.right_cheek_width
|
||||||
|
)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
input_latent_list = []
|
input_latent_list = []
|
||||||
for bbox, frame in zip(coord_list, frame_list):
|
for bbox, frame in zip(coord_list, frame_list):
|
||||||
if bbox == coord_placeholder:
|
if bbox == coord_placeholder:
|
||||||
continue
|
continue
|
||||||
x1, y1, x2, y2 = bbox
|
x1, y1, x2, y2 = bbox
|
||||||
|
y2 = y2 + args.extra_margin
|
||||||
|
y2 = min(y2, frame.shape[0])
|
||||||
crop_frame = frame[y1:y2, x1:x2]
|
crop_frame = frame[y1:y2, x1:x2]
|
||||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||||
latents = vae.get_latents_for_unet(crop_frame)
|
latents = vae.get_latents_for_unet(crop_frame)
|
||||||
@@ -192,17 +289,23 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|||||||
frame_list_cycle = frame_list + frame_list[::-1]
|
frame_list_cycle = frame_list + frame_list[::-1]
|
||||||
coord_list_cycle = coord_list + coord_list[::-1]
|
coord_list_cycle = coord_list + coord_list[::-1]
|
||||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||||
|
|
||||||
############################################## inference batch by batch ##############################################
|
############################################## inference batch by batch ##############################################
|
||||||
print("start inference")
|
print("start inference")
|
||||||
video_num = len(whisper_chunks)
|
video_num = len(whisper_chunks)
|
||||||
batch_size = args.batch_size
|
batch_size = args.batch_size
|
||||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
gen = datagen(
|
||||||
|
whisper_chunks=whisper_chunks,
|
||||||
|
vae_encode_latents=input_latent_list_cycle,
|
||||||
|
batch_size=batch_size,
|
||||||
|
delay_frame=0,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
res_frame_list = []
|
res_frame_list = []
|
||||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||||
|
audio_feature_batch = pe(whisper_batch)
|
||||||
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
|
# Ensure latent_batch is consistent with model weight type
|
||||||
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
|
latent_batch = latent_batch.to(dtype=weight_dtype)
|
||||||
audio_feature_batch = pe(audio_feature_batch)
|
|
||||||
|
|
||||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||||
recon = vae.decode_latents(pred_latents)
|
recon = vae.decode_latents(pred_latents)
|
||||||
@@ -215,25 +318,24 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|||||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||||
x1, y1, x2, y2 = bbox
|
x1, y1, x2, y2 = bbox
|
||||||
|
y2 = y2 + args.extra_margin
|
||||||
|
y2 = min(y2, frame.shape[0])
|
||||||
try:
|
try:
|
||||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||||
except:
|
except:
|
||||||
# print(bbox)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
combine_frame = get_image(ori_frame,res_frame,bbox)
|
# Use v15 version blending
|
||||||
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||||
|
|
||||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||||
|
|
||||||
# cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p temp.mp4"
|
# Frame rate
|
||||||
# print(cmd_img2video)
|
|
||||||
# os.system(cmd_img2video)
|
|
||||||
# 帧率
|
|
||||||
fps = 25
|
fps = 25
|
||||||
# 图片路径
|
# Output video path
|
||||||
# 输出视频路径
|
|
||||||
output_video = 'temp.mp4'
|
output_video = 'temp.mp4'
|
||||||
|
|
||||||
# 读取图片
|
# Read images
|
||||||
def is_valid_image(file):
|
def is_valid_image(file):
|
||||||
pattern = re.compile(r'\d{8}\.png')
|
pattern = re.compile(r'\d{8}\.png')
|
||||||
return pattern.match(file)
|
return pattern.match(file)
|
||||||
@@ -247,13 +349,9 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|||||||
images.append(imageio.imread(filename))
|
images.append(imageio.imread(filename))
|
||||||
|
|
||||||
|
|
||||||
# 保存视频
|
# Save video
|
||||||
imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
|
imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
|
||||||
|
|
||||||
# cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
|
|
||||||
# print(cmd_combine_audio)
|
|
||||||
# os.system(cmd_combine_audio)
|
|
||||||
|
|
||||||
input_video = './temp.mp4'
|
input_video = './temp.mp4'
|
||||||
# Check if the input_video and audio_path exist
|
# Check if the input_video and audio_path exist
|
||||||
if not os.path.exists(input_video):
|
if not os.path.exists(input_video):
|
||||||
@@ -261,40 +359,15 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|||||||
if not os.path.exists(audio_path):
|
if not os.path.exists(audio_path):
|
||||||
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
||||||
|
|
||||||
# 读取视频
|
# Read video
|
||||||
reader = imageio.get_reader(input_video)
|
reader = imageio.get_reader(input_video)
|
||||||
fps = reader.get_meta_data()['fps'] # 获取原视频的帧率
|
fps = reader.get_meta_data()['fps'] # Get original video frame rate
|
||||||
reader.close() # 否则在win11上会报错:PermissionError: [WinError 32] 另一个程序正在使用此文件,进程无法访问。: 'temp.mp4'
|
reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4'
|
||||||
# 将帧存储在列表中
|
# Store frames in list
|
||||||
frames = images
|
frames = images
|
||||||
|
|
||||||
# 保存视频并添加音频
|
|
||||||
# imageio.mimwrite(output_vid_name, frames, 'FFMPEG', fps=fps, codec='libx264', audio_codec='aac', input_params=['-i', audio_path])
|
|
||||||
|
|
||||||
# input_video = ffmpeg.input(input_video)
|
|
||||||
|
|
||||||
# input_audio = ffmpeg.input(audio_path)
|
|
||||||
|
|
||||||
print(len(frames))
|
print(len(frames))
|
||||||
|
|
||||||
# imageio.mimwrite(
|
|
||||||
# output_video,
|
|
||||||
# frames,
|
|
||||||
# 'FFMPEG',
|
|
||||||
# fps=25,
|
|
||||||
# codec='libx264',
|
|
||||||
# audio_codec='aac',
|
|
||||||
# input_params=['-i', audio_path],
|
|
||||||
# output_params=['-y'], # Add the '-y' flag to overwrite the output file if it exists
|
|
||||||
# )
|
|
||||||
# writer = imageio.get_writer(output_vid_name, fps = 25, codec='libx264', quality=10, pixelformat='yuvj444p')
|
|
||||||
# for im in frames:
|
|
||||||
# writer.append_data(im)
|
|
||||||
# writer.close()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Load the video
|
# Load the video
|
||||||
video_clip = VideoFileClip(input_video)
|
video_clip = VideoFileClip(input_video)
|
||||||
|
|
||||||
@@ -315,11 +388,45 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|||||||
|
|
||||||
|
|
||||||
# load model weights
|
# load model weights
|
||||||
audio_processor,vae,unet,pe = load_all_model()
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
vae, unet, pe = load_all_model(
|
||||||
|
unet_model_path="./models/musetalkV15/unet.pth",
|
||||||
|
vae_type="sd-vae",
|
||||||
|
unet_config="./models/musetalkV15/musetalk.json",
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable")
|
||||||
|
parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
|
||||||
|
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
|
||||||
|
parser.add_argument("--share", action="store_true", help="Create a public link")
|
||||||
|
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set data type
|
||||||
|
if args.use_float16:
|
||||||
|
# Convert models to half precision for better performance
|
||||||
|
pe = pe.half()
|
||||||
|
vae.vae = vae.vae.half()
|
||||||
|
unet.model = unet.model.half()
|
||||||
|
weight_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
weight_dtype = torch.float32
|
||||||
|
|
||||||
|
# Move models to specified device
|
||||||
|
pe = pe.to(device)
|
||||||
|
vae.vae = vae.vae.to(device)
|
||||||
|
unet.model = unet.model.to(device)
|
||||||
|
|
||||||
timesteps = torch.tensor([0], device=device)
|
timesteps = torch.tensor([0], device=device)
|
||||||
|
|
||||||
|
# Initialize audio processor and Whisper model
|
||||||
|
audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
|
||||||
|
whisper = WhisperModel.from_pretrained("./models/whisper")
|
||||||
|
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||||
|
whisper.requires_grad_(False)
|
||||||
|
|
||||||
|
|
||||||
def check_video(video):
|
def check_video(video):
|
||||||
@@ -340,9 +447,6 @@ def check_video(video):
|
|||||||
output_video = os.path.join('./results/input', output_file_name)
|
output_video = os.path.join('./results/input', output_file_name)
|
||||||
|
|
||||||
|
|
||||||
# # Run the ffmpeg command to change the frame rate to 25fps
|
|
||||||
# command = f"ffmpeg -i {video} -r 25 -vcodec libx264 -vtag hvc1 -pix_fmt yuv420p crf 18 {output_video} -y"
|
|
||||||
|
|
||||||
# read video
|
# read video
|
||||||
reader = imageio.get_reader(video)
|
reader = imageio.get_reader(video)
|
||||||
fps = reader.get_meta_data()['fps'] # get fps from original video
|
fps = reader.get_meta_data()['fps'] # get fps from original video
|
||||||
@@ -374,34 +478,45 @@ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024p
|
|||||||
|
|
||||||
with gr.Blocks(css=css) as demo:
|
with gr.Blocks(css=css) as demo:
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
"<div align='center'> <h1>MuseTalk: Real-Time High Quality Lip Synchronization with Latent Space Inpainting </span> </h1> \
|
"""<div align='center'> <h1>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</h1> \
|
||||||
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
||||||
</br>\
|
</br>\
|
||||||
Yue Zhang <sup>\*</sup>,\
|
Yue Zhang <sup>*</sup>,\
|
||||||
Minhao Liu<sup>\*</sup>,\
|
Zhizhou Zhong <sup>*</sup>,\
|
||||||
|
Minhao Liu<sup>*</sup>,\
|
||||||
Zhaokang Chen,\
|
Zhaokang Chen,\
|
||||||
Bin Wu<sup>†</sup>,\
|
Bin Wu<sup>†</sup>,\
|
||||||
|
Yubin Zeng,\
|
||||||
|
Chao Zhang,\
|
||||||
Yingjie He,\
|
Yingjie He,\
|
||||||
Chao Zhan,\
|
Junxin Huang,\
|
||||||
Wenjiang Zhou\
|
Wenjiang Zhou <br>\
|
||||||
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
|
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
|
||||||
Lyra Lab, Tencent Music Entertainment\
|
Lyra Lab, Tencent Music Entertainment\
|
||||||
</h2> \
|
</h2> \
|
||||||
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
|
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
|
||||||
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
|
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
|
||||||
<a style='font-size:18px;color: #000000' href=''> [Technical report(Coming Soon)] </a>\
|
<a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2410.10122'> [Technical report] </a>"""
|
||||||
<a style='font-size:18px;color: #000000' href=''> [Project Page(Coming Soon)] </a> </div>"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
audio = gr.Audio(label="Driven Audio",type="filepath")
|
audio = gr.Audio(label="Drving Audio",type="filepath")
|
||||||
video = gr.Video(label="Reference Video",sources=['upload'])
|
video = gr.Video(label="Reference Video",sources=['upload'])
|
||||||
bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
|
bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
|
||||||
bbox_shift_scale = gr.Textbox(label="BBox_shift recommend value lower bound,The corresponding bbox range is generated after the initial result is generated. \n If the result is not good, it can be adjusted according to this reference value", value="",interactive=False)
|
extra_margin = gr.Slider(label="Extra Margin", minimum=0, maximum=40, value=10, step=1)
|
||||||
|
parsing_mode = gr.Radio(label="Parsing Mode", choices=["jaw", "raw"], value="jaw")
|
||||||
|
left_cheek_width = gr.Slider(label="Left Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
||||||
|
right_cheek_width = gr.Slider(label="Right Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
||||||
|
bbox_shift_scale = gr.Textbox(label="'left_cheek_width' and 'right_cheek_width' parameters determine the range of left and right cheeks editing when parsing model is 'jaw'. The 'extra_margin' parameter determines the movement range of the jaw. Users can freely adjust these three parameters to obtain better inpainting results.")
|
||||||
|
|
||||||
btn = gr.Button("Generate")
|
with gr.Row():
|
||||||
out1 = gr.Video()
|
debug_btn = gr.Button("1. Test Inpainting ")
|
||||||
|
btn = gr.Button("2. Generate")
|
||||||
|
with gr.Column():
|
||||||
|
debug_image = gr.Image(label="Test Inpainting Result (First Frame)")
|
||||||
|
debug_info = gr.Textbox(label="Parameter Information", lines=5)
|
||||||
|
out1 = gr.Video()
|
||||||
|
|
||||||
video.change(
|
video.change(
|
||||||
fn=check_video, inputs=[video], outputs=[video]
|
fn=check_video, inputs=[video], outputs=[video]
|
||||||
@@ -412,15 +527,44 @@ with gr.Blocks(css=css) as demo:
|
|||||||
audio,
|
audio,
|
||||||
video,
|
video,
|
||||||
bbox_shift,
|
bbox_shift,
|
||||||
|
extra_margin,
|
||||||
|
parsing_mode,
|
||||||
|
left_cheek_width,
|
||||||
|
right_cheek_width
|
||||||
],
|
],
|
||||||
outputs=[out1,bbox_shift_scale]
|
outputs=[out1,bbox_shift_scale]
|
||||||
)
|
)
|
||||||
|
debug_btn.click(
|
||||||
|
fn=debug_inpainting,
|
||||||
|
inputs=[
|
||||||
|
video,
|
||||||
|
bbox_shift,
|
||||||
|
extra_margin,
|
||||||
|
parsing_mode,
|
||||||
|
left_cheek_width,
|
||||||
|
right_cheek_width
|
||||||
|
],
|
||||||
|
outputs=[debug_image, debug_info]
|
||||||
|
)
|
||||||
|
|
||||||
# Set the IP and port
|
# Check ffmpeg and add to PATH
|
||||||
ip_address = "0.0.0.0" # Replace with your desired IP address
|
if not fast_check_ffmpeg():
|
||||||
port_number = 7860 # Replace with your desired port number
|
print(f"Adding ffmpeg to PATH: {args.ffmpeg_path}")
|
||||||
|
# According to operating system, choose path separator
|
||||||
|
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||||
|
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||||
|
if not fast_check_ffmpeg():
|
||||||
|
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||||
|
|
||||||
|
# Solve asynchronous IO issues on Windows
|
||||||
|
if sys.platform == 'win32':
|
||||||
|
import asyncio
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
# Start Gradio application
|
||||||
demo.queue().launch(
|
demo.queue().launch(
|
||||||
share=False , debug=True, server_name=ip_address, server_port=port_number
|
share=args.share,
|
||||||
|
debug=True,
|
||||||
|
server_name=args.ip,
|
||||||
|
server_port=args.port
|
||||||
)
|
)
|
||||||
|
|||||||
BIN
assets/figs/gradio.png
Normal file
BIN
assets/figs/gradio.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
BIN
assets/figs/gradio_2.png
Normal file
BIN
assets/figs/gradio_2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
45
download_weights.bat
Normal file
45
download_weights.bat
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
@echo off
|
||||||
|
setlocal
|
||||||
|
|
||||||
|
:: Set the checkpoints directory
|
||||||
|
set CheckpointsDir=models
|
||||||
|
|
||||||
|
:: Create necessary directories
|
||||||
|
mkdir %CheckpointsDir%\musetalk
|
||||||
|
mkdir %CheckpointsDir%\musetalkV15
|
||||||
|
mkdir %CheckpointsDir%\syncnet
|
||||||
|
mkdir %CheckpointsDir%\dwpose
|
||||||
|
mkdir %CheckpointsDir%\face-parse-bisent
|
||||||
|
mkdir %CheckpointsDir%\sd-vae-ft-mse
|
||||||
|
mkdir %CheckpointsDir%\whisper
|
||||||
|
|
||||||
|
:: Install required packages
|
||||||
|
pip install -U "huggingface_hub[cli]"
|
||||||
|
pip install gdown
|
||||||
|
|
||||||
|
:: Set HuggingFace endpoint
|
||||||
|
set HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
|
||||||
|
:: Download MuseTalk weights
|
||||||
|
huggingface-cli download TMElyralab/MuseTalk --local-dir %CheckpointsDir%
|
||||||
|
|
||||||
|
:: Download SD VAE weights
|
||||||
|
huggingface-cli download stabilityai/sd-vae-ft-mse --local-dir %CheckpointsDir%\sd-vae --include "config.json" "diffusion_pytorch_model.bin"
|
||||||
|
|
||||||
|
:: Download Whisper weights
|
||||||
|
huggingface-cli download openai/whisper-tiny --local-dir %CheckpointsDir%\whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
|
||||||
|
|
||||||
|
:: Download DWPose weights
|
||||||
|
huggingface-cli download yzd-v/DWPose --local-dir %CheckpointsDir%\dwpose --include "dw-ll_ucoco_384.pth"
|
||||||
|
|
||||||
|
:: Download SyncNet weights
|
||||||
|
huggingface-cli download ByteDance/LatentSync --local-dir %CheckpointsDir%\syncnet --include "latentsync_syncnet.pt"
|
||||||
|
|
||||||
|
:: Download Face Parse Bisent weights (using gdown)
|
||||||
|
gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O %CheckpointsDir%\face-parse-bisent\79999_iter.pth
|
||||||
|
|
||||||
|
:: Download ResNet weights
|
||||||
|
curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o %CheckpointsDir%\face-parse-bisent\resnet18-5c106cde.pth
|
||||||
|
|
||||||
|
echo All weights have been downloaded successfully!
|
||||||
|
endlocal
|
||||||
37
download_weights.sh
Normal file
37
download_weights.sh
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Set the checkpoints directory
|
||||||
|
CheckpointsDir="models"
|
||||||
|
|
||||||
|
# Create necessary directories
|
||||||
|
mkdir -p $CheckpointsDir/{musetalk,musetalkV15,syncnet,dwpose,face-parse-bisent,sd-vae-ft-mse,whisper}
|
||||||
|
|
||||||
|
# Install required packages
|
||||||
|
pip install -U "huggingface_hub[cli]"
|
||||||
|
pip install gdown
|
||||||
|
|
||||||
|
# Set HuggingFace endpoint
|
||||||
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
|
||||||
|
# Download MuseTalk weights
|
||||||
|
huggingface-cli download TMElyralab/MuseTalk --local-dir $CheckpointsDir
|
||||||
|
|
||||||
|
# Download SD VAE weights
|
||||||
|
huggingface-cli download stabilityai/sd-vae-ft-mse --local-dir $CheckpointsDir/sd-vae --include "config.json" "diffusion_pytorch_model.bin"
|
||||||
|
|
||||||
|
# Download Whisper weights
|
||||||
|
huggingface-cli download openai/whisper-tiny --local-dir $CheckpointsDir/whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
|
||||||
|
|
||||||
|
# Download DWPose weights
|
||||||
|
huggingface-cli download yzd-v/DWPose --local-dir $CheckpointsDir/dwpose --include "dw-ll_ucoco_384.pth"
|
||||||
|
|
||||||
|
# Download SyncNet weights
|
||||||
|
huggingface-cli download ByteDance/LatentSync --local-dir $CheckpointsDir/syncnet --include "latentsync_syncnet.pt"
|
||||||
|
|
||||||
|
# Download Face Parse Bisent weights (using gdown)
|
||||||
|
gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O $CheckpointsDir/face-parse-bisent/79999_iter.pth
|
||||||
|
|
||||||
|
# Download ResNet weights
|
||||||
|
curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o $CheckpointsDir/face-parse-bisent/resnet18-5c106cde.pth
|
||||||
|
|
||||||
|
echo "All weights have been downloaded successfully!"
|
||||||
@@ -49,8 +49,9 @@ class AudioProcessor:
|
|||||||
whisper_feature = []
|
whisper_feature = []
|
||||||
# Process multiple 30s mel input features
|
# Process multiple 30s mel input features
|
||||||
for input_feature in whisper_input_features:
|
for input_feature in whisper_input_features:
|
||||||
audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states
|
input_feature = input_feature.to(device).to(weight_dtype)
|
||||||
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)
|
audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
|
||||||
|
audio_feats = torch.stack(audio_feats, dim=2)
|
||||||
whisper_feature.append(audio_feats)
|
whisper_feature.append(audio_feats)
|
||||||
|
|
||||||
whisper_feature = torch.cat(whisper_feature, dim=1)
|
whisper_feature = torch.cat(whisper_feature, dim=1)
|
||||||
|
|||||||
@@ -8,26 +8,18 @@ from einops import rearrange
|
|||||||
import shutil
|
import shutil
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
ffmpeg_path = os.getenv('FFMPEG_PATH')
|
|
||||||
if ffmpeg_path is None:
|
|
||||||
print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
|
|
||||||
elif ffmpeg_path not in os.getenv('PATH'):
|
|
||||||
print("add ffmpeg to path")
|
|
||||||
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
|
|
||||||
|
|
||||||
|
|
||||||
from musetalk.models.vae import VAE
|
from musetalk.models.vae import VAE
|
||||||
from musetalk.models.unet import UNet,PositionalEncoding
|
from musetalk.models.unet import UNet,PositionalEncoding
|
||||||
|
|
||||||
|
|
||||||
def load_all_model(
|
def load_all_model(
|
||||||
unet_model_path="./models/musetalk/pytorch_model.bin",
|
unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
|
||||||
vae_type="sd-vae-ft-mse",
|
vae_type="sd-vae",
|
||||||
unet_config="./models/musetalk/musetalk.json",
|
unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
|
||||||
device=None,
|
device=None,
|
||||||
):
|
):
|
||||||
vae = VAE(
|
vae = VAE(
|
||||||
model_path = f"./models/{vae_type}/",
|
model_path = os.path.join("models", vae_type),
|
||||||
)
|
)
|
||||||
print(f"load unet model from {unet_model_path}")
|
print(f"load unet model from {unet_model_path}")
|
||||||
unet = UNet(
|
unet = UNet(
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
diffusers==0.30.2
|
||||||
torch==2.0.1
|
|
||||||
torchvision==0.15.2
|
|
||||||
torchaudio==2.0.2
|
|
||||||
diffusers==0.27.2
|
|
||||||
accelerate==0.28.0
|
accelerate==0.28.0
|
||||||
|
numpy==1.23.5
|
||||||
tensorflow==2.12.0
|
tensorflow==2.12.0
|
||||||
tensorboard==2.12.0
|
tensorboard==2.12.0
|
||||||
opencv-python==4.9.0.80
|
opencv-python==4.9.0.80
|
||||||
soundfile==0.12.1
|
soundfile==0.12.1
|
||||||
transformers==4.39.2
|
transformers==4.39.2
|
||||||
huggingface_hub==0.25.0
|
huggingface_hub==0.30.2
|
||||||
|
librosa==0.11.0
|
||||||
|
einops==0.8.1
|
||||||
|
gradio==5.24.0
|
||||||
|
|
||||||
gdown
|
gdown
|
||||||
requests
|
requests
|
||||||
@@ -17,6 +17,4 @@ imageio[ffmpeg]
|
|||||||
|
|
||||||
omegaconf
|
omegaconf
|
||||||
ffmpeg-python
|
ffmpeg-python
|
||||||
gradio
|
|
||||||
spaces
|
|
||||||
moviepy
|
moviepy
|
||||||
|
|||||||
@@ -8,9 +8,11 @@ import shutil
|
|||||||
import pickle
|
import pickle
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import subprocess
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from transformers import WhisperModel
|
from transformers import WhisperModel
|
||||||
|
import sys
|
||||||
|
|
||||||
from musetalk.utils.blending import get_image
|
from musetalk.utils.blending import get_image
|
||||||
from musetalk.utils.face_parsing import FaceParsing
|
from musetalk.utils.face_parsing import FaceParsing
|
||||||
@@ -18,16 +20,26 @@ from musetalk.utils.audio_processor import AudioProcessor
|
|||||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
||||||
|
|
||||||
|
def fast_check_ffmpeg():
|
||||||
|
try:
|
||||||
|
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main(args):
|
def main(args):
|
||||||
# Configure ffmpeg path
|
# Configure ffmpeg path
|
||||||
if args.ffmpeg_path not in os.getenv('PATH'):
|
if not fast_check_ffmpeg():
|
||||||
print("Adding ffmpeg to PATH")
|
print("Adding ffmpeg to PATH")
|
||||||
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
|
# Choose path separator based on operating system
|
||||||
|
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||||
|
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||||
|
if not fast_check_ffmpeg():
|
||||||
|
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||||
|
|
||||||
# Set computing device
|
# Set computing device
|
||||||
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Load model weights
|
# Load model weights
|
||||||
vae, unet, pe = load_all_model(
|
vae, unet, pe = load_all_model(
|
||||||
unet_model_path=args.unet_model_path,
|
unet_model_path=args.unet_model_path,
|
||||||
|
|||||||
@@ -12,11 +12,23 @@ from mmpose.structures import merge_data_samples
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def fast_check_ffmpeg():
|
||||||
|
try:
|
||||||
|
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
|
ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
|
||||||
if ffmpeg_path not in os.getenv('PATH'):
|
if not fast_check_ffmpeg():
|
||||||
print("add ffmpeg to path")
|
print("Adding ffmpeg to PATH")
|
||||||
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
|
# Choose path separator based on operating system
|
||||||
|
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||||
|
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||||
|
if not fast_check_ffmpeg():
|
||||||
|
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||||
|
|
||||||
class AnalyzeFace:
|
class AnalyzeFace:
|
||||||
def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):
|
def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):
|
||||||
|
|||||||
@@ -23,6 +23,15 @@ import shutil
|
|||||||
import threading
|
import threading
|
||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
|
def fast_check_ffmpeg():
|
||||||
|
try:
|
||||||
|
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
||||||
@@ -318,7 +327,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
||||||
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
||||||
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
||||||
parser.add_argument("--batch_size", type=int, default=25, help="Batch size for inference")
|
parser.add_argument("--batch_size", type=int, default=20, help="Batch size for inference")
|
||||||
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
||||||
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
||||||
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
||||||
@@ -332,6 +341,15 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Configure ffmpeg path
|
||||||
|
if not fast_check_ffmpeg():
|
||||||
|
print("Adding ffmpeg to PATH")
|
||||||
|
# Choose path separator based on operating system
|
||||||
|
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||||
|
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||||
|
if not fast_check_ffmpeg():
|
||||||
|
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||||
|
|
||||||
# Set computing device
|
# Set computing device
|
||||||
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|||||||
33
test_ffmpeg.py
Normal file
33
test_ffmpeg.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def test_ffmpeg(ffmpeg_path):
|
||||||
|
print(f"Testing ffmpeg path: {ffmpeg_path}")
|
||||||
|
|
||||||
|
# Choose path separator based on operating system
|
||||||
|
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||||
|
|
||||||
|
# Add ffmpeg path to environment variable
|
||||||
|
os.environ["PATH"] = f"{ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to run ffmpeg
|
||||||
|
result = subprocess.run(["ffmpeg", "-version"], capture_output=True, text=True)
|
||||||
|
print("FFmpeg test successful!")
|
||||||
|
print("FFmpeg version information:")
|
||||||
|
print(result.stdout)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print("FFmpeg test failed!")
|
||||||
|
print(f"Error message: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Default ffmpeg path, can be modified as needed
|
||||||
|
default_path = r"ffmpeg-master-latest-win64-gpl-shared\bin"
|
||||||
|
|
||||||
|
# Use command line argument if provided, otherwise use default path
|
||||||
|
ffmpeg_path = sys.argv[1] if len(sys.argv) > 1 else default_path
|
||||||
|
|
||||||
|
test_ffmpeg(ffmpeg_path)
|
||||||
Reference in New Issue
Block a user