mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 17:39:20 +08:00
temporary commit to save changes
This commit is contained in:
77
data_new.sh
Executable file
77
data_new.sh
Executable file
@@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Function to extract video and audio sections
|
||||
extract_sections() {
|
||||
input_video=$1
|
||||
base_name=$(basename "$input_video" .mp4)
|
||||
output_dir=$2
|
||||
split=$3
|
||||
duration=$(ffmpeg -i "$input_video" 2>&1 | grep Duration | awk '{print $2}' | tr -d ,)
|
||||
IFS=: read -r hours minutes seconds <<< "$duration"
|
||||
total_seconds=$((10#${hours}*3600 + 10#${minutes}*60 + 10#${seconds%.*}))
|
||||
chunk_size=180 # 3 minutes in seconds
|
||||
index=0
|
||||
|
||||
mkdir -p "$output_dir"
|
||||
|
||||
while [ $((index * chunk_size)) -lt $total_seconds ]; do
|
||||
start_time=$((index * chunk_size))
|
||||
section_video="${output_dir}/${base_name}_part${index}.mp4"
|
||||
section_audio="${output_dir}/${base_name}_part${index}.mp3"
|
||||
|
||||
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -c copy "$section_video"
|
||||
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -q:a 0 -map a "$section_audio"
|
||||
|
||||
# Create and update the config.yaml file
|
||||
echo "task_0:" > config.yaml
|
||||
echo " video_path: \"$section_video\"" >> config.yaml
|
||||
echo " audio_path: \"$section_audio\"" >> config.yaml
|
||||
|
||||
# Run the Python script with the current config.yaml
|
||||
python -m scripts.data --inference_config config.yaml --folder_name "$base_name"
|
||||
|
||||
index=$((index + 1))
|
||||
done
|
||||
|
||||
# Clean up save folder
|
||||
rm -rf $output_dir
|
||||
}
|
||||
|
||||
# Main script
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <train/test> <output_directory> <input_videos...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
split=$1
|
||||
output_dir=$2
|
||||
shift 2
|
||||
input_videos=("$@")
|
||||
|
||||
# Initialize JSON array
|
||||
json_array="["
|
||||
|
||||
for input_video in "${input_videos[@]}"; do
|
||||
base_name=$(basename "$input_video" .mp4)
|
||||
|
||||
# Extract sections and run the Python script for each section
|
||||
extract_sections "$input_video" "$output_dir" "$split"
|
||||
|
||||
# Add entry to JSON array
|
||||
json_array+="\"../data/images/$base_name\","
|
||||
done
|
||||
|
||||
# Remove trailing comma and close JSON array
|
||||
json_array="${json_array%,}]"
|
||||
|
||||
# Write JSON array to the correct file
|
||||
if [ "$split" == "train" ]; then
|
||||
echo "$json_array" > train.json
|
||||
elif [ "$split" == "test" ]; then
|
||||
echo "$json_array" > test.json
|
||||
else
|
||||
echo "Invalid split: $split. Must be 'train' or 'test'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Processing complete."
|
||||
@@ -25,6 +25,32 @@ audio_processor, vae, unet, pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
def get_largest_integer_filename(folder_path):
|
||||
# Check if the folder exists
|
||||
if not os.path.isdir(folder_path):
|
||||
return -1
|
||||
|
||||
# Get the list of files in the folder
|
||||
files = os.listdir(folder_path)
|
||||
|
||||
# Check if the folder is empty
|
||||
if not files:
|
||||
return -1
|
||||
|
||||
# Extract the integer part of filenames and find the largest
|
||||
largest_integer = -1
|
||||
for file in files:
|
||||
try:
|
||||
# Get the integer part of the filename
|
||||
file_int = int(os.path.splitext(file)[0])
|
||||
if file_int > largest_integer:
|
||||
largest_integer = file_int
|
||||
except ValueError:
|
||||
# Skip files that don't have an integer filename
|
||||
continue
|
||||
|
||||
return largest_integer
|
||||
|
||||
def datagen(whisper_chunks,
|
||||
crop_images,
|
||||
batch_size=8,
|
||||
@@ -58,10 +84,10 @@ def main(args):
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
total_audio_index=-1
|
||||
total_image_index=-1
|
||||
temp_audio_index=-1
|
||||
temp_image_index=-1
|
||||
total_audio_index=get_largest_integer_filename(f"data/audios/{args.folder_name}")
|
||||
total_image_index=get_largest_integer_filename(f"data/images/{args.folder_name}")
|
||||
temp_audio_index=total_audio_index
|
||||
temp_image_index=total_image_index
|
||||
for task_id in inference_config:
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
|
||||
@@ -48,15 +48,15 @@ def get_image_list(data_root, split):
|
||||
class Dataset(object):
|
||||
def __init__(self,
|
||||
data_root,
|
||||
split,
|
||||
json_path,
|
||||
use_audio_length_left=1,
|
||||
use_audio_length_right=1,
|
||||
whisper_model_type = "tiny"
|
||||
):
|
||||
self.all_videos, self.all_imgNum = get_image_list(data_root, split)
|
||||
# self.all_videos, self.all_imgNum = get_image_list(data_root, split)
|
||||
self.audio_feature = [use_audio_length_left,use_audio_length_right]
|
||||
self.all_img_names = []
|
||||
self.split = split
|
||||
# self.split = split
|
||||
self.img_names_path = '../data'
|
||||
self.whisper_model_type = whisper_model_type
|
||||
self.use_audio_length_left = use_audio_length_left
|
||||
@@ -72,10 +72,13 @@ class Dataset(object):
|
||||
self.whisper_feature_H = 1280
|
||||
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50
|
||||
|
||||
if(self.split=="train"):
|
||||
self.all_videos=["../data/images/train"]
|
||||
if(self.split=="val"):
|
||||
self.all_videos=["../data/images/test"]
|
||||
# if(self.split=="train"):
|
||||
# self.all_videos=["../data/images/train"]
|
||||
# if(self.split=="val"):
|
||||
# self.all_videos=["../data/images/test"]
|
||||
with open(json_path, 'r') as file:
|
||||
self.all_videos = json.load(file)
|
||||
|
||||
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
|
||||
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
|
||||
if not os.path.exists(json_path_names):
|
||||
|
||||
@@ -140,6 +140,8 @@ def parse_args():
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders")
|
||||
parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
@@ -350,7 +352,7 @@ def main():
|
||||
|
||||
print("loading train_dataset ...")
|
||||
train_dataset = Dataset(args.data_root,
|
||||
'train',
|
||||
args.train_json,
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
@@ -360,7 +362,7 @@ def main():
|
||||
num_workers=8)
|
||||
print("loading val_dataset ...")
|
||||
val_dataset = Dataset(args.data_root,
|
||||
'val',
|
||||
args.val_json,
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
|
||||
@@ -18,10 +18,12 @@ accelerate launch train.py \
|
||||
--output_dir="output" \
|
||||
--val_out_dir='val' \
|
||||
--testing_speed \
|
||||
--checkpointing_steps=1000 \
|
||||
--validation_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--validation_steps=2000 \
|
||||
--reconstruction \
|
||||
--resume_from_checkpoint="latest" \
|
||||
--use_audio_length_left=2 \
|
||||
--use_audio_length_right=2 \
|
||||
--whisper_model_type="tiny" \
|
||||
--train_json="/root/MuseTalk/train.json" \
|
||||
--val_json="/root/MuseTalk/val.json" \
|
||||
|
||||
Reference in New Issue
Block a user