mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
fixed mltiple video data preperation
This commit is contained in:
@@ -71,11 +71,6 @@ class Dataset(object):
|
||||
self.whisper_feature_W = 33
|
||||
self.whisper_feature_H = 1280
|
||||
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50
|
||||
|
||||
# if(self.split=="train"):
|
||||
# self.all_videos=["../data/images/train"]
|
||||
# if(self.split=="val"):
|
||||
# self.all_videos=["../data/images/test"]
|
||||
with open(json_path, 'r') as file:
|
||||
self.all_videos = json.load(file)
|
||||
|
||||
|
||||
@@ -6,30 +6,29 @@ The test yaml file should contain the validation video paths and corresponding a
|
||||
|
||||
Run:
|
||||
```
|
||||
python -m scripts.data --inference_config path_to_train.yaml --folder_name train
|
||||
python -m scripts.data --inference_config path_to_test.yaml --folder_name test
|
||||
./data_new.sh train output train_video1.mp4 train_video2.mp4
|
||||
./data_new.sh test output test_video1.mp4 test_video2.mp4
|
||||
```
|
||||
This creates folders which contain the image frames and npy files.
|
||||
|
||||
This creates folders which contain the image frames and npy files. This also creates train.json and val.json which can be used during the training.
|
||||
|
||||
## Data organization
|
||||
```
|
||||
./data/
|
||||
├── images
|
||||
│ └──train
|
||||
│ └──RD_Radio10_000
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
│ └──test
|
||||
│ └──RD_Radio11_000
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
├── audios
|
||||
│ └──train
|
||||
│ └──RD_Radio10_000
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
│ └──test
|
||||
│ └──RD_Radio11_000
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
@@ -38,7 +37,9 @@ This creates folders which contain the image frames and npy files.
|
||||
## Training
|
||||
Simply run after preparing the preprocessed data
|
||||
```
|
||||
sh train.sh
|
||||
cd train_codes
|
||||
sh train.sh #--train_json="../train.json" \(Generated in Data preprocessing step.)
|
||||
#--val_json="../val.json" \
|
||||
```
|
||||
## Inference with trained checkpoit
|
||||
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
|
||||
|
||||
@@ -7,13 +7,12 @@ accelerate launch train.py \
|
||||
--unet_config_file=$UNET_CONFIG \
|
||||
--pretrained_model_name_or_path=$VAE_MODEL \
|
||||
--data_root=$DATASET \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--train_batch_size=256 \
|
||||
--gradient_accumulation_steps=16 \
|
||||
--gradient_checkpointing \
|
||||
--max_train_steps=100000 \
|
||||
--learning_rate=5e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="cosine" \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="output" \
|
||||
--val_out_dir='val' \
|
||||
@@ -25,5 +24,6 @@ accelerate launch train.py \
|
||||
--use_audio_length_left=2 \
|
||||
--use_audio_length_right=2 \
|
||||
--whisper_model_type="tiny" \
|
||||
--train_json="/root/MuseTalk/train.json" \
|
||||
--val_json="/root/MuseTalk/val.json" \
|
||||
--train_json="../train.json" \
|
||||
--val_json="../val.json" \
|
||||
--lr_scheduler="cosine" \
|
||||
|
||||
Reference in New Issue
Block a user