From b9685481315441636a3604483836d16ff439fc61 Mon Sep 17 00:00:00 2001 From: Shounak Banerjee Date: Mon, 17 Jun 2024 18:39:15 +0000 Subject: [PATCH] fixed mltiple video data preperation --- train_codes/DataLoader.py | 5 ----- train_codes/README.md | 19 ++++++++++--------- train_codes/train.sh | 10 +++++----- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/train_codes/DataLoader.py b/train_codes/DataLoader.py index 431ee39..10ee8eb 100644 --- a/train_codes/DataLoader.py +++ b/train_codes/DataLoader.py @@ -71,11 +71,6 @@ class Dataset(object): self.whisper_feature_W = 33 self.whisper_feature_H = 1280 self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50 - - # if(self.split=="train"): - # self.all_videos=["../data/images/train"] - # if(self.split=="val"): - # self.all_videos=["../data/images/test"] with open(json_path, 'r') as file: self.all_videos = json.load(file) diff --git a/train_codes/README.md b/train_codes/README.md index f303b87..102a345 100644 --- a/train_codes/README.md +++ b/train_codes/README.md @@ -6,30 +6,29 @@ The test yaml file should contain the validation video paths and corresponding a Run: ``` -python -m scripts.data --inference_config path_to_train.yaml --folder_name train -python -m scripts.data --inference_config path_to_test.yaml --folder_name test +./data_new.sh train output train_video1.mp4 train_video2.mp4 +./data_new.sh test output test_video1.mp4 test_video2.mp4 ``` -This creates folders which contain the image frames and npy files. - +This creates folders which contain the image frames and npy files. This also creates train.json and val.json which can be used during the training. ## Data organization ``` ./data/ ├── images -│ └──train +│ └──RD_Radio10_000 │ └── 0.png │ └── 1.png │ └── xxx.png -│ └──test +│ └──RD_Radio11_000 │ └── 0.png │ └── 1.png │ └── xxx.png ├── audios -│ └──train +│ └──RD_Radio10_000 │ └── 0.npy │ └── 1.npy │ └── xxx.npy -│ └──test +│ └──RD_Radio11_000 │ └── 0.npy │ └── 1.npy │ └── xxx.npy @@ -38,7 +37,9 @@ This creates folders which contain the image frames and npy files. ## Training Simply run after preparing the preprocessed data ``` -sh train.sh +cd train_codes +sh train.sh #--train_json="../train.json" \(Generated in Data preprocessing step.) + #--val_json="../val.json" \ ``` ## Inference with trained checkpoit Simply run after training the model, the model checkpoints are saved at train_codes/output usually diff --git a/train_codes/train.sh b/train_codes/train.sh index 600632b..f15ddf9 100644 --- a/train_codes/train.sh +++ b/train_codes/train.sh @@ -7,13 +7,12 @@ accelerate launch train.py \ --unet_config_file=$UNET_CONFIG \ --pretrained_model_name_or_path=$VAE_MODEL \ --data_root=$DATASET \ ---train_batch_size=8 \ ---gradient_accumulation_steps=4 \ +--train_batch_size=256 \ +--gradient_accumulation_steps=16 \ --gradient_checkpointing \ --max_train_steps=100000 \ --learning_rate=5e-05 \ --max_grad_norm=1 \ ---lr_scheduler="cosine" \ --lr_warmup_steps=0 \ --output_dir="output" \ --val_out_dir='val' \ @@ -25,5 +24,6 @@ accelerate launch train.py \ --use_audio_length_left=2 \ --use_audio_length_right=2 \ --whisper_model_type="tiny" \ ---train_json="/root/MuseTalk/train.json" \ ---val_json="/root/MuseTalk/val.json" \ +--train_json="../train.json" \ +--val_json="../val.json" \ +--lr_scheduler="cosine" \