mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 09:59:18 +08:00
modified dataloader.py and inference.py for training and inference
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
export VAE_MODEL="./sd-vae-ft-mse/"
|
||||
export DATASET="..."
|
||||
export UNET_CONFIG="./musetalk.json"
|
||||
export VAE_MODEL="../models/sd-vae-ft-mse/"
|
||||
export DATASET="../data"
|
||||
export UNET_CONFIG="../models/musetalk/musetalk.json"
|
||||
|
||||
accelerate launch --multi_gpu train.py \
|
||||
accelerate launch train.py \
|
||||
--mixed_precision="fp16" \
|
||||
--unet_config_file=$UNET_CONFIG \
|
||||
--pretrained_model_name_or_path=$VAE_MODEL \
|
||||
@@ -10,13 +10,13 @@ accelerate launch --multi_gpu train.py \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--max_train_steps=200000 \
|
||||
--max_train_steps=50000 \
|
||||
--learning_rate=5e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="cosine" \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="..." \
|
||||
--val_out_dir='...' \
|
||||
--output_dir="output" \
|
||||
--val_out_dir='val' \
|
||||
--testing_speed \
|
||||
--checkpointing_steps=1000 \
|
||||
--validation_steps=1000 \
|
||||
|
||||
Reference in New Issue
Block a user