From af82f3b00f9a56b66ddeda5684f8f47eaad66f5c Mon Sep 17 00:00:00 2001 From: Shounak Banerjee Date: Thu, 13 Jun 2024 14:14:52 +0000 Subject: [PATCH] temporary commit to save changes --- data_new.sh | 77 +++++++++++++++++++++++++++++++++++++++ scripts/data.py | 34 +++++++++++++++-- train_codes/DataLoader.py | 17 +++++---- train_codes/train.py | 6 ++- train_codes/train.sh | 6 ++- 5 files changed, 125 insertions(+), 15 deletions(-) create mode 100755 data_new.sh diff --git a/data_new.sh b/data_new.sh new file mode 100755 index 0000000..61844b2 --- /dev/null +++ b/data_new.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +# Function to extract video and audio sections +extract_sections() { + input_video=$1 + base_name=$(basename "$input_video" .mp4) + output_dir=$2 + split=$3 + duration=$(ffmpeg -i "$input_video" 2>&1 | grep Duration | awk '{print $2}' | tr -d ,) + IFS=: read -r hours minutes seconds <<< "$duration" + total_seconds=$((10#${hours}*3600 + 10#${minutes}*60 + 10#${seconds%.*})) + chunk_size=180 # 3 minutes in seconds + index=0 + + mkdir -p "$output_dir" + + while [ $((index * chunk_size)) -lt $total_seconds ]; do + start_time=$((index * chunk_size)) + section_video="${output_dir}/${base_name}_part${index}.mp4" + section_audio="${output_dir}/${base_name}_part${index}.mp3" + + ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -c copy "$section_video" + ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -q:a 0 -map a "$section_audio" + + # Create and update the config.yaml file + echo "task_0:" > config.yaml + echo " video_path: \"$section_video\"" >> config.yaml + echo " audio_path: \"$section_audio\"" >> config.yaml + + # Run the Python script with the current config.yaml + python -m scripts.data --inference_config config.yaml --folder_name "$base_name" + + index=$((index + 1)) + done + + # Clean up save folder + rm -rf $output_dir +} + +# Main script +if [ $# -lt 3 ]; then + echo "Usage: $0 " + exit 1 +fi + +split=$1 +output_dir=$2 +shift 2 +input_videos=("$@") + +# Initialize JSON array +json_array="[" + +for input_video in "${input_videos[@]}"; do + base_name=$(basename "$input_video" .mp4) + + # Extract sections and run the Python script for each section + extract_sections "$input_video" "$output_dir" "$split" + + # Add entry to JSON array + json_array+="\"../data/images/$base_name\"," +done + +# Remove trailing comma and close JSON array +json_array="${json_array%,}]" + +# Write JSON array to the correct file +if [ "$split" == "train" ]; then + echo "$json_array" > train.json +elif [ "$split" == "test" ]; then + echo "$json_array" > test.json +else + echo "Invalid split: $split. Must be 'train' or 'test'." + exit 1 +fi + +echo "Processing complete." diff --git a/scripts/data.py b/scripts/data.py index f6cdb81..e5369d7 100644 --- a/scripts/data.py +++ b/scripts/data.py @@ -25,6 +25,32 @@ audio_processor, vae, unet, pe = load_all_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) +def get_largest_integer_filename(folder_path): + # Check if the folder exists + if not os.path.isdir(folder_path): + return -1 + + # Get the list of files in the folder + files = os.listdir(folder_path) + + # Check if the folder is empty + if not files: + return -1 + + # Extract the integer part of filenames and find the largest + largest_integer = -1 + for file in files: + try: + # Get the integer part of the filename + file_int = int(os.path.splitext(file)[0]) + if file_int > largest_integer: + largest_integer = file_int + except ValueError: + # Skip files that don't have an integer filename + continue + + return largest_integer + def datagen(whisper_chunks, crop_images, batch_size=8, @@ -58,10 +84,10 @@ def main(args): unet.model = unet.model.half() inference_config = OmegaConf.load(args.inference_config) - total_audio_index=-1 - total_image_index=-1 - temp_audio_index=-1 - temp_image_index=-1 + total_audio_index=get_largest_integer_filename(f"data/audios/{args.folder_name}") + total_image_index=get_largest_integer_filename(f"data/images/{args.folder_name}") + temp_audio_index=total_audio_index + temp_image_index=total_image_index for task_id in inference_config: video_path = inference_config[task_id]["video_path"] audio_path = inference_config[task_id]["audio_path"] diff --git a/train_codes/DataLoader.py b/train_codes/DataLoader.py index f0652e3..431ee39 100644 --- a/train_codes/DataLoader.py +++ b/train_codes/DataLoader.py @@ -48,15 +48,15 @@ def get_image_list(data_root, split): class Dataset(object): def __init__(self, data_root, - split, + json_path, use_audio_length_left=1, use_audio_length_right=1, whisper_model_type = "tiny" ): - self.all_videos, self.all_imgNum = get_image_list(data_root, split) + # self.all_videos, self.all_imgNum = get_image_list(data_root, split) self.audio_feature = [use_audio_length_left,use_audio_length_right] self.all_img_names = [] - self.split = split + # self.split = split self.img_names_path = '../data' self.whisper_model_type = whisper_model_type self.use_audio_length_left = use_audio_length_left @@ -72,10 +72,13 @@ class Dataset(object): self.whisper_feature_H = 1280 self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50 - if(self.split=="train"): - self.all_videos=["../data/images/train"] - if(self.split=="val"): - self.all_videos=["../data/images/test"] + # if(self.split=="train"): + # self.all_videos=["../data/images/train"] + # if(self.split=="val"): + # self.all_videos=["../data/images/test"] + with open(json_path, 'r') as file: + self.all_videos = json.load(file) + for vidname in tqdm(self.all_videos, desc="Preparing dataset"): json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json" if not os.path.exists(json_path_names): diff --git a/train_codes/train.py b/train_codes/train.py index 37a9447..fae0cb2 100755 --- a/train_codes/train.py +++ b/train_codes/train.py @@ -140,6 +140,8 @@ def parse_args(): parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders") + parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders") parser.add_argument( "--hub_model_id", type=str, @@ -350,7 +352,7 @@ def main(): print("loading train_dataset ...") train_dataset = Dataset(args.data_root, - 'train', + args.train_json, use_audio_length_left=args.use_audio_length_left, use_audio_length_right=args.use_audio_length_right, whisper_model_type=args.whisper_model_type @@ -360,7 +362,7 @@ def main(): num_workers=8) print("loading val_dataset ...") val_dataset = Dataset(args.data_root, - 'val', + args.val_json, use_audio_length_left=args.use_audio_length_left, use_audio_length_right=args.use_audio_length_right, whisper_model_type=args.whisper_model_type diff --git a/train_codes/train.sh b/train_codes/train.sh index 2e29d5c..600632b 100644 --- a/train_codes/train.sh +++ b/train_codes/train.sh @@ -18,10 +18,12 @@ accelerate launch train.py \ --output_dir="output" \ --val_out_dir='val' \ --testing_speed \ ---checkpointing_steps=1000 \ ---validation_steps=1000 \ +--checkpointing_steps=2000 \ +--validation_steps=2000 \ --reconstruction \ --resume_from_checkpoint="latest" \ --use_audio_length_left=2 \ --use_audio_length_right=2 \ --whisper_model_type="tiny" \ +--train_json="/root/MuseTalk/train.json" \ +--val_json="/root/MuseTalk/val.json" \