mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
modified dataloader.py and inference.py for training and inference
This commit is contained in:
@@ -1,32 +1,35 @@
|
||||
# Draft training codes
|
||||
# Data preprocessing
|
||||
|
||||
We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized.
|
||||
Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
|
||||
The train yaml file should contain the training video paths and corresponding audio paths
|
||||
The test yaml file should contain the validation video paths and corresponding audio paths
|
||||
|
||||
## Setup
|
||||
Run:
|
||||
```
|
||||
python -m scripts.data --inference_config path_to_train.yaml --folder_name train
|
||||
python -m scripts.data --inference_config path_to_test.yaml --folder_name test
|
||||
```
|
||||
This creates folders which contain the image frames and npy files.
|
||||
|
||||
We trained our model on an NVIDIA A100 with `batch size=8, gradient_accumulation_steps=4` for 20w+ steps. Using multiple GPUs should accelerate the training.
|
||||
|
||||
## Data preprocessing
|
||||
You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69).
|
||||
|
||||
Finally, the data should be organized as follows:
|
||||
## Data organization
|
||||
```
|
||||
./data/
|
||||
├── images
|
||||
│ └──RD_Radio10_000
|
||||
│ └──train
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
│ └──RD_Radio11_000
|
||||
│ └──test
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
├── audios
|
||||
│ └──RD_Radio10_000
|
||||
│ └──train
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
│ └──RD_Radio11_000
|
||||
│ └──test
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
@@ -37,7 +40,12 @@ Simply run after preparing the preprocessed data
|
||||
```
|
||||
sh train.sh
|
||||
```
|
||||
## Inference with trained checkpoit
|
||||
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
|
||||
```
|
||||
python -m scripts.inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
|
||||
```
|
||||
|
||||
## TODO
|
||||
- [ ] release data preprocessing codes
|
||||
- [x] release data preprocessing codes
|
||||
- [ ] release some novel designs in training (after technical report)
|
||||
Reference in New Issue
Block a user