mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 17:39:20 +08:00
feat: data preprocessing and training (#294)
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
This commit is contained in:
34
train.sh
Normal file
34
train.sh
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MuseTalk Training Script
|
||||
# This script combines both training stages for the MuseTalk model
|
||||
# Usage: sh train.sh [stage1|stage2]
|
||||
# Example: sh train.sh stage1 # To run stage 1 training
|
||||
# Example: sh train.sh stage2 # To run stage 2 training
|
||||
|
||||
# Check if stage argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Error: Please specify the training stage"
|
||||
echo "Usage: ./train.sh [stage1|stage2]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
STAGE=$1
|
||||
|
||||
# Validate stage argument
|
||||
if [ "$STAGE" != "stage1" ] && [ "$STAGE" != "stage2" ]; then
|
||||
echo "Error: Invalid stage. Must be either 'stage1' or 'stage2'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Launch distributed training using accelerate
|
||||
# --config_file: Path to the GPU configuration file
|
||||
# --main_process_port: Port number for the main process, used for distributed training communication
|
||||
# train.py: Training script
|
||||
# --config: Path to the training configuration file
|
||||
echo "Starting $STAGE training..."
|
||||
accelerate launch --config_file ./configs/training/gpu.yaml \
|
||||
--main_process_port 29502 \
|
||||
train.py --config ./configs/training/$STAGE.yaml
|
||||
|
||||
echo "Training completed for $STAGE"
|
||||
Reference in New Issue
Block a user