modified dataloader.py and inference.py for training and inference

This commit is contained in:
Shounak Banerjee
2024-06-03 11:09:12 +00:00
parent 7254ca6306
commit b4a592d7f3
6 changed files with 106 additions and 58 deletions

View File

@@ -1,32 +1,35 @@
# Draft training codes
# Data preprocessing
We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized.
Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
The train yaml file should contain the training video paths and corresponding audio paths
The test yaml file should contain the validation video paths and corresponding audio paths
## Setup
Run:
```
python -m scripts.data --inference_config path_to_train.yaml --folder_name train
python -m scripts.data --inference_config path_to_test.yaml --folder_name test
```
This creates folders which contain the image frames and npy files.
We trained our model on an NVIDIA A100 with `batch size=8, gradient_accumulation_steps=4` for 20w+ steps. Using multiple GPUs should accelerate the training.
## Data preprocessing
You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69).
Finally, the data should be organized as follows:
## Data organization
```
./data/
├── images
│ └──RD_Radio10_000
│ └──train
│ └── 0.png
│ └── 1.png
│ └── xxx.png
│ └──RD_Radio11_000
│ └──test
│ └── 0.png
│ └── 1.png
│ └── xxx.png
├── audios
│ └──RD_Radio10_000
│ └──train
│ └── 0.npy
│ └── 1.npy
│ └── xxx.npy
│ └──RD_Radio11_000
│ └──test
│ └── 0.npy
│ └── 1.npy
│ └── xxx.npy
@@ -37,7 +40,12 @@ Simply run after preparing the preprocessed data
```
sh train.sh
```
## Inference with trained checkpoit
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
```
python -m scripts.inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
```
## TODO
- [ ] release data preprocessing codes
- [x] release data preprocessing codes
- [ ] release some novel designs in training (after technical report)