fixed mltiple video data preperation

This commit is contained in:
Shounak Banerjee
2024-06-17 18:39:15 +00:00
parent af82f3b00f
commit b968548131
3 changed files with 15 additions and 19 deletions

View File

@@ -71,11 +71,6 @@ class Dataset(object):
self.whisper_feature_W = 33
self.whisper_feature_H = 1280
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*2+2+1= 50
# if(self.split=="train"):
# self.all_videos=["../data/images/train"]
# if(self.split=="val"):
# self.all_videos=["../data/images/test"]
with open(json_path, 'r') as file:
self.all_videos = json.load(file)

View File

@@ -6,30 +6,29 @@ The test yaml file should contain the validation video paths and corresponding a
Run:
```
python -m scripts.data --inference_config path_to_train.yaml --folder_name train
python -m scripts.data --inference_config path_to_test.yaml --folder_name test
./data_new.sh train output train_video1.mp4 train_video2.mp4
./data_new.sh test output test_video1.mp4 test_video2.mp4
```
This creates folders which contain the image frames and npy files.
This creates folders which contain the image frames and npy files. This also creates train.json and val.json which can be used during the training.
## Data organization
```
./data/
├── images
│ └──train
│ └──RD_Radio10_000
│ └── 0.png
│ └── 1.png
│ └── xxx.png
│ └──test
│ └──RD_Radio11_000
│ └── 0.png
│ └── 1.png
│ └── xxx.png
├── audios
│ └──train
│ └──RD_Radio10_000
│ └── 0.npy
│ └── 1.npy
│ └── xxx.npy
│ └──test
│ └──RD_Radio11_000
│ └── 0.npy
│ └── 1.npy
│ └── xxx.npy
@@ -38,7 +37,9 @@ This creates folders which contain the image frames and npy files.
## Training
Simply run after preparing the preprocessed data
```
sh train.sh
cd train_codes
sh train.sh #--train_json="../train.json" \(Generated in Data preprocessing step.)
#--val_json="../val.json" \
```
## Inference with trained checkpoit
Simply run after training the model, the model checkpoints are saved at train_codes/output usually

View File

@@ -7,13 +7,12 @@ accelerate launch train.py \
--unet_config_file=$UNET_CONFIG \
--pretrained_model_name_or_path=$VAE_MODEL \
--data_root=$DATASET \
--train_batch_size=8 \
--gradient_accumulation_steps=4 \
--train_batch_size=256 \
--gradient_accumulation_steps=16 \
--gradient_checkpointing \
--max_train_steps=100000 \
--learning_rate=5e-05 \
--max_grad_norm=1 \
--lr_scheduler="cosine" \
--lr_warmup_steps=0 \
--output_dir="output" \
--val_out_dir='val' \
@@ -25,5 +24,6 @@ accelerate launch train.py \
--use_audio_length_left=2 \
--use_audio_length_right=2 \
--whisper_model_type="tiny" \
--train_json="/root/MuseTalk/train.json" \
--val_json="/root/MuseTalk/val.json" \
--train_json="../train.json" \
--val_json="../val.json" \
--lr_scheduler="cosine" \