feat: data preprocessing and training (#294)

* docs: update readme

* docs: update readme

* feat: training codes

* feat: data preprocess

* docs: release training
This commit is contained in:
Zhizhou Zhong
2025-04-04 22:10:03 +08:00
committed by GitHub
parent e636166b85
commit 1ab53a626b
23 changed files with 3854 additions and 6 deletions

View File

@@ -130,8 +130,9 @@ https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
- [x] codes for real-time inference.
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
- [x] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
- [ ] training and dataloader code (Expected completion on 04/04/2025).
- [x] realtime inference code for 1.5 version.
- [x] training and data preprocessing codes.
- [ ] **always** welcome to submit issues and PRs to improve this repository! 😊
# Getting Started
@@ -187,6 +188,7 @@ huggingface-cli download TMElyralab/MuseTalk --local-dir models/
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
Finally, these weights should be organized in `models` as follows:
@@ -198,6 +200,8 @@ Finally, these weights should be organized in `models` as follows:
├── musetalkV15
│ └── musetalk.json
│ └── unet.pth
├── syncnet
│ └── latentsync_syncnet.pt
├── dwpose
│ └── dw-ll_ucoco_384.pth
├── face-parse-bisent
@@ -265,6 +269,73 @@ For faster generation without saving images, you can use:
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
```
## Training
### Data Preparation
To train MuseTalk, you need to prepare your dataset following these steps:
1. **Place your source videos**
For example, if you're using the HDTF dataset, place all your video files in `./dataset/HDTF/source`.
2. **Run the preprocessing script**
```bash
python -m scripts.preprocess --config ./configs/training/preprocess.yaml
```
This script will:
- Extract frames from videos
- Detect and align faces
- Generate audio features
- Create the necessary data structure for training
### Training Process
After data preprocessing, you can start the training process:
1. **First Stage**
```bash
sh train.sh stage1
```
2. **Second Stage**
```bash
sh train.sh stage2
```
### Configuration Adjustment
Before starting the training, you should adjust the configuration files according to your hardware and requirements:
1. **GPU Configuration** (`configs/training/gpu.yaml`):
- `gpu_ids`: Specify the GPU IDs you want to use (e.g., "0,1,2,3")
- `num_processes`: Set this to match the number of GPUs you're using
2. **Stage 1 Configuration** (`configs/training/stage1.yaml`):
- `data.train_bs`: Adjust batch size based on your GPU memory (default: 32)
- `data.n_sample_frames`: Number of sampled frames per video (default: 1)
3. **Stage 2 Configuration** (`configs/training/stage2.yaml`):
- `random_init_unet`: Must be set to `False` to use the model from stage 1
- `data.train_bs`: Smaller batch size due to high GPU memory cost (default: 2)
- `data.n_sample_frames`: Higher value for temporal consistency (default: 16)
- `solver.gradient_accumulation_steps`: Increase to simulate larger batch sizes (default: 8)
### GPU Memory Requirements
Based on our testing on a machine with 8 NVIDIA H20 GPUs:
#### Stage 1 Memory Usage
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|:----------:|:----------------------:|:--------------:|:--------------:|
| 8 | 1 | ~32GB | |
| 16 | 1 | ~45GB | |
| 32 | 1 | ~74GB | ✓ |
#### Stage 2 Memory Usage
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|:----------:|:----------------------:|:--------------:|:--------------:|
| 1 | 8 | ~54GB | |
| 2 | 2 | ~80GB | |
| 2 | 8 | ~85GB | ✓ |
## TestCases For 1.0
<table class="center">
<tr style="font-weight: bolder;text-align:center;">
@@ -368,7 +439,7 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
# Acknowledgement
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch) and [LatentSync](https://huggingface.co/ByteDance/LatentSync/tree/main).
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.