From 7254ca63069186c4a396a6d88f2a4cb65fcd7f8e Mon Sep 17 00:00:00 2001 From: shounak Date: Thu, 16 May 2024 18:24:44 +0000 Subject: [PATCH 1/5] initial data script --- scripts/data.py | 243 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 scripts/data.py diff --git a/scripts/data.py b/scripts/data.py new file mode 100644 index 0000000..22887b8 --- /dev/null +++ b/scripts/data.py @@ -0,0 +1,243 @@ +import cv2 +import os +# import dlib +import argparse +import os +from omegaconf import OmegaConf +import numpy as np +import cv2 +import torch +import glob +import pickle +from tqdm import tqdm +import copy +import uuid + +from musetalk.utils.utils import get_file_type,get_video_fps +from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder +from musetalk.utils.blending import get_image +from musetalk.utils.utils import load_all_model +import shutil + +# load model weights +audio_processor, vae, unet, pe = load_all_model() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +timesteps = torch.tensor([0], device=device) + +def datagen(whisper_chunks, + crop_images, + batch_size=8, + delay_frame=0): + whisper_batch, crop_batch = [], [] + for i, w in enumerate(whisper_chunks): + idx = (i+delay_frame)%len(crop_images) + crop_image = crop_images[idx] + whisper_batch.append(w) + crop_batch.append(crop_image) + + if len(crop_batch) >= batch_size: + whisper_batch = np.stack(whisper_batch) + # latent_batch = torch.cat(latent_batch, dim=0) + yield whisper_batch, crop_batch + whisper_batch, crop_batch = [], [] + + # the last batch may smaller than batch size + if len(crop_batch) > 0: + whisper_batch = np.stack(whisper_batch) + # latent_batch = torch.cat(latent_batch, dim=0) + + yield whisper_batch, crop_batch + +@torch.no_grad() +def main(args): + global pe + if args.use_float16 is True: + pe = pe.half() + vae.vae = vae.vae.half() + unet.model = unet.model.half() + + inference_config = OmegaConf.load(args.inference_config) + print(inference_config) + for task_id in inference_config: + video_path = inference_config[task_id]["video_path"] + audio_path = inference_config[task_id]["audio_path"] + bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) + folder_name = args.folder_name + if not os.path.exists(f"data/images/{folder_name}/"): + os.makedirs(f"data/images/{folder_name}") + if not os.path.exists(f"data/audios/{folder_name}/"): + os.makedirs(f"data/audios/{folder_name}") + input_basename = os.path.basename(video_path).split('.')[0] + audio_basename = os.path.basename(audio_path).split('.')[0] + output_basename = f"{input_basename}_{audio_basename}" + result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs + crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input + os.makedirs(result_img_save_path,exist_ok =True) + + if args.output_vid_name is None: + output_vid_name = os.path.join(args.result_dir, output_basename+".mp4") + else: + output_vid_name = os.path.join(args.result_dir, args.output_vid_name) + ############################################## extract frames from source video ############################################## + if get_file_type(video_path)=="video": + save_dir_full = os.path.join(args.result_dir, input_basename) + os.makedirs(save_dir_full,exist_ok = True) + cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" + os.system(cmd) + input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) + fps = get_video_fps(video_path) + elif get_file_type(video_path)=="image": + input_img_list = [video_path, ] + fps = args.fps + elif os.path.isdir(video_path): # input img folder + input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + fps = args.fps + else: + raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") + print("LEN..........") + + print(len(input_img_list)) + ############################################## extract audio feature ############################################## + whisper_feature = audio_processor.audio2feat(audio_path) + print(len(whisper_feature)) + print("Whisper feature length........") + print(whisper_feature[0].shape) + # print(whisper_feature) + for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements + # Combine two consecutive chunks + # pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]]) + concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0) + # Save the pair to a .npy file + print("Pair shape",concatenated_chunks.shape) + np.save(f'data/audios/{folder_name}/{__//2}.npy', concatenated_chunks) + whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) + print(len(whisper_chunks)) + # whisper_i=0 + # for chunk in whisper_chunks: + # # print("CHUNMK SHAPE...........") + # # print(chunk.shape) + # np.save(f'data/audios/{folder_name}/{str(whisper_i)}.npy', chunk) + # whisper_i+=1 + + ############################################## preprocess input image ############################################## + if os.path.exists(crop_coord_save_path) and args.use_saved_coord: + print("using extracted coordinates") + with open(crop_coord_save_path,'rb') as f: + coord_list = pickle.load(f) + frame_list = read_imgs(input_img_list) + else: + print("extracting landmarks...time consuming") + coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) + with open(crop_coord_save_path, 'wb') as f: + pickle.dump(coord_list, f) + + print(len(frame_list)) + + i = 0 + input_latent_list = [] + crop_i=0 + crop_data=[] + for bbox, frame in zip(coord_list, frame_list): + if bbox == coord_placeholder: + continue + x1, y1, x2, y2 = bbox + + x1=max(0,x1) + y1=max(0,y1) + x2=max(0,x2) + y2=max(0,y2) + + if ((y2-y1)<=0) or ((x2-x1)<=0): + continue + crop_frame = frame[y1:y2, x1:x2] + print("crop sizes",bbox) + crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) + cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame) + latents = vae.get_latents_for_unet(crop_frame) + crop_data.append(crop_frame) + input_latent_list.append(latents) + crop_i+=1 + + # to smooth the first and the last frame + frame_list_cycle = frame_list + frame_list[::-1] + coord_list_cycle = coord_list + coord_list[::-1] + input_latent_list_cycle = input_latent_list + input_latent_list[::-1] + crop_data = crop_data + crop_data[::-1] + ############################################## inference batch by batch ############################################## + print("start inference") + print(len(input_latent_list_cycle),len(whisper_chunks)) + video_num = len(whisper_chunks) + batch_size = args.batch_size + gen = datagen(whisper_chunks,crop_data,batch_size) + for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): + print("BATCH LEN..............") + print(len(whisper_batch),len(crop_batch)) + crop_index=0 + for image,audio in zip(crop_batch,whisper_batch): + cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index)}.png",image) + crop_index+=1 + # np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio) + print(folder_name) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml") + parser.add_argument("--bbox_shift", type=int, default=0) + parser.add_argument("--result_dir", default='./results', help="path to output") + parser.add_argument("--folder_name", default=f'{uuid.uuid4()}', help="path to output") + + parser.add_argument("--fps", type=int, default=25) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--output_vid_name", type=str, default=None) + parser.add_argument("--use_saved_coord", + action="store_true", + help='use saved coordinate to save time') + parser.add_argument("--use_float16", + action="store_true", + help="Whether use float16 to speed up inference", + ) + + args = parser.parse_args() + main(args) + + +def process_audio(audio_path): + whisper_feature = audio_processor.audio2feat(audio_path) + np.save('audio/your_filename.npy', whisper_feature) + +def mask_face(image): + # Load dlib's face detector and the landmark predictor + detector = dlib.get_frontal_face_detector() + predictor_path = "/content/shape_predictor_68_face_landmarks.dat" # Set path to your downloaded predictor file + predictor = dlib.shape_predictor(predictor_path) + + # Load your input image + # image_path = "/content/ori_frame_00000077.png" # Replace with the path to your input image + # image = cv2.imread(image_path) + if image is None: + raise ValueError("Image not found or unable to load.") + + # Convert to grayscale for detection + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + # Detect faces in the image + faces = detector(gray) + + # Process each detected face + for face in faces: + # Predict landmarks + landmarks = predictor(gray, face) + + # The indices of nose landmarks are 27 to 35 + nose_tip = landmarks.part(33).y + + # Blacken the region below the nose tip + blacken_area = image[nose_tip:, :] + blacken_area[:] = (0, 0, 0) + + # Save the final image or display it + # cv2.imwrite("output_image.jpg", image) + return image From b4a592d7f3b79d1db8791541a014614641358a75 Mon Sep 17 00:00:00 2001 From: Shounak Banerjee Date: Mon, 3 Jun 2024 11:09:12 +0000 Subject: [PATCH 2/5] modified dataloader.py and inference.py for training and inference --- scripts/inference.py | 39 ++++++++++++++++----- train_codes/DataLoader.py | 15 ++++---- train_codes/README.md | 34 +++++++++++------- train_codes/train.py | 60 ++++++++++++++++++++++---------- train_codes/train.sh | 14 ++++---- train_codes/utils/model_utils.py | 2 +- 6 files changed, 106 insertions(+), 58 deletions(-) diff --git a/scripts/inference.py b/scripts/inference.py index fe2a234..dddce31 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -15,14 +15,27 @@ from musetalk.utils.blending import get_image from musetalk.utils.utils import load_all_model import shutil +from accelerate import Accelerator + # load model weights audio_processor, vae, unet, pe = load_all_model() +accelerator = Accelerator( + mixed_precision="fp16", + ) +unet = accelerator.prepare( + unet, + + ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) @torch.no_grad() def main(args): global pe + if not (args.unet_checkpoint == None): + print("unet ckpt loaded") + accelerator.load_state(args.unet_checkpoint) + if args.use_float16 is True: pe = pe.half() vae.vae = vae.vae.half() @@ -63,8 +76,6 @@ def main(args): fps = args.fps else: raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") - - #print(input_img_list) ############################################## extract audio feature ############################################## whisper_feature = audio_processor.audio2feat(audio_path) whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) @@ -79,24 +90,27 @@ def main(args): coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) with open(crop_coord_save_path, 'wb') as f: pickle.dump(coord_list, f) + i = 0 input_latent_list = [] + crop_i=0 for bbox, frame in zip(coord_list, frame_list): if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox crop_frame = frame[y1:y2, x1:x2] crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) + cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame) latents = vae.get_latents_for_unet(crop_frame) input_latent_list.append(latents) + crop_i+=1 # to smooth the first and the last frame frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] ############################################## inference batch by batch ############################################## - print("start inference") video_num = len(whisper_chunks) batch_size = args.batch_size gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) @@ -107,7 +121,6 @@ def main(args): dtype=unet.model.dtype) # torch, B, 5*N,384 audio_feature_batch = pe(audio_feature_batch) latent_batch = latent_batch.to(dtype=unet.model.dtype) - pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample recon = vae.decode_latents(pred_latents) for res_frame in recon: @@ -122,22 +135,29 @@ def main(args): try: res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) except: -# print(bbox) continue combine_frame = get_image(ori_frame,res_frame,bbox) + cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame) + cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame) cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" - print(cmd_img2video) os.system(cmd_img2video) cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" - print(cmd_combine_audio) os.system(cmd_combine_audio) os.remove("temp.mp4") - shutil.rmtree(result_img_save_path) + + cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" + os.system(cmd_img2video) + + # cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" + # print(cmd_combine_audio) + # os.system(cmd_combine_audio) + + # shutil.rmtree(result_img_save_path) print(f"result is save to {output_vid_name}") if __name__ == "__main__": @@ -156,6 +176,7 @@ if __name__ == "__main__": action="store_true", help="Whether use float16 to speed up inference", ) + parser.add_argument("--unet_checkpoint", type=str, default=None) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/train_codes/DataLoader.py b/train_codes/DataLoader.py index 8dacbde..f0652e3 100644 --- a/train_codes/DataLoader.py +++ b/train_codes/DataLoader.py @@ -57,13 +57,13 @@ class Dataset(object): self.audio_feature = [use_audio_length_left,use_audio_length_right] self.all_img_names = [] self.split = split - self.img_names_path = '...' + self.img_names_path = '../data' self.whisper_model_type = whisper_model_type self.use_audio_length_left = use_audio_length_left self.use_audio_length_right = use_audio_length_right if self.whisper_model_type =="tiny": - self.whisper_path = '...' + self.whisper_path = '../data/audios' self.whisper_feature_W = 5 self.whisper_feature_H = 384 elif self.whisper_model_type =="largeV2": @@ -72,6 +72,10 @@ class Dataset(object): self.whisper_feature_H = 1280 self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50 + if(self.split=="train"): + self.all_videos=["../data/images/train"] + if(self.split=="val"): + self.all_videos=["../data/images/test"] for vidname in tqdm(self.all_videos, desc="Preparing dataset"): json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json" if not os.path.exists(json_path_names): @@ -79,7 +83,6 @@ class Dataset(object): img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0])) with open(json_path_names, "w") as f: json.dump(img_names,f) - print(f"save to {json_path_names}") else: with open(json_path_names, "r") as f: img_names = json.load(f) @@ -147,7 +150,6 @@ class Dataset(object): vidname = self.all_videos[idx].split('/')[-1] video_imgs = self.all_img_names[idx] if len(video_imgs) == 0: -# print("video_imgs = 0:",vidname) continue img_name = random.choice(video_imgs) img_idx = int(basename(img_name).split(".")[0]) @@ -205,7 +207,6 @@ class Dataset(object): for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1): # 判定是否越界 audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy") - if not os.path.exists(audio_feat_path): is_index_out_of_range = True break @@ -226,8 +227,6 @@ class Dataset(object): print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}") continue audio_feature = torch.squeeze(torch.FloatTensor(audio_feature)) - - return ref_image, image, masked_image, mask, audio_feature @@ -243,10 +242,8 @@ if __name__ == "__main__": val_data_loader = data_utils.DataLoader( val_data, batch_size=4, shuffle=True, num_workers=1) - print("val_dataset:",val_data_loader.__len__()) for i, data in enumerate(val_data_loader): ref_image, image, masked_image, mask, audio_feature = data - print("ref_image: ", ref_image.shape) \ No newline at end of file diff --git a/train_codes/README.md b/train_codes/README.md index db9848c..ba0f370 100644 --- a/train_codes/README.md +++ b/train_codes/README.md @@ -1,32 +1,35 @@ -# Draft training codes +# Data preprocessing -We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized. +Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml) +The train yaml file should contain the training video paths and corresponding audio paths +The test yaml file should contain the validation video paths and corresponding audio paths -## Setup +Run: +``` +python -m scripts.data --inference_config path_to_train.yaml --folder_name train +python -m scripts.data --inference_config path_to_test.yaml --folder_name test +``` +This creates folders which contain the image frames and npy files. -We trained our model on an NVIDIA A100 with `batch size=8, gradient_accumulation_steps=4` for 20w+ steps. Using multiple GPUs should accelerate the training. -## Data preprocessing - You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69). - -Finally, the data should be organized as follows: +## Data organization ``` ./data/ ├── images -│ └──RD_Radio10_000 +│ └──train │ └── 0.png │ └── 1.png │ └── xxx.png -│ └──RD_Radio11_000 +│ └──test │ └── 0.png │ └── 1.png │ └── xxx.png ├── audios -│ └──RD_Radio10_000 +│ └──train │ └── 0.npy │ └── 1.npy │ └── xxx.npy -│ └──RD_Radio11_000 +│ └──test │ └── 0.npy │ └── 1.npy │ └── xxx.npy @@ -37,7 +40,12 @@ Simply run after preparing the preprocessed data ``` sh train.sh ``` +## Inference with trained checkpoit +Simply run after training the model, the model checkpoints are saved at train_codes/output usually +``` +python -m scripts.inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder +``` ## TODO -- [ ] release data preprocessing codes +- [x] release data preprocessing codes - [ ] release some novel designs in training (after technical report) \ No newline at end of file diff --git a/train_codes/train.py b/train_codes/train.py index cb50c1c..37a9447 100755 --- a/train_codes/train.py +++ b/train_codes/train.py @@ -27,10 +27,13 @@ from diffusers import ( from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version +import sys +sys.path.append("./") + from DataLoader import Dataset from utils.utils import preprocess_img_tensor from torch.utils import data as data_utils -from model_utils import validation,PositionalEncoding +from utils.model_utils import validation,PositionalEncoding import time import pandas as pd from PIL import Image @@ -234,13 +237,17 @@ def parse_args(): env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank + + return args - +def print_model_dtypes(model, model_name): + for name, param in model.named_parameters(): + if(param.dtype!=torch.float32): + print(f"{name}: {param.dtype}") def main(): args = parse_args() - print(args) args.output_dir = f"output/{args.output_dir}" args.val_out_dir = f"val/{args.val_out_dir}" os.makedirs(args.output_dir, exist_ok=True) @@ -332,7 +339,7 @@ def main(): optimizer_class = torch.optim.AdamW params_to_optimize = ( - itertools.chain(unet.parameters()) + itertools.chain(unet.parameters())) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -348,7 +355,6 @@ def main(): use_audio_length_right=args.use_audio_length_right, whisper_model_type=args.whisper_model_type ) - print("train_dataset:",train_dataset.__len__()) train_data_loader = data_utils.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=8) @@ -359,7 +365,6 @@ def main(): use_audio_length_right=args.use_audio_length_right, whisper_model_type=args.whisper_model_type ) - print("val_dataset:",val_dataset.__len__()) val_data_loader = data_utils.DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=8) @@ -388,6 +393,7 @@ def main(): vae_fp32.requires_grad_(False) weight_dtype = torch.float32 + # weight_dtype = torch.float16 vae_fp32.to(accelerator.device, dtype=weight_dtype) vae_fp32.encoder = None if accelerator.mixed_precision == "fp16": @@ -412,6 +418,8 @@ def main(): # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print(f" Num batches each epoch = {len(train_data_loader)}") + logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_data_loader)}") @@ -433,6 +441,9 @@ def main(): dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None + # path="../models/pytorch_model.bin" + #TODO change path + # path=None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." @@ -458,10 +469,11 @@ def main(): # caluate the elapsed time elapsed_time = [] start = time.time() + + for epoch in range(first_epoch, args.num_train_epochs): unet.train() -# for step, batch in enumerate(train_dataloader): for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader): # Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: @@ -470,24 +482,23 @@ def main(): continue dataloader_time = time.time() - start start = time.time() - masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device) - """ - print("=============epoch:{0}=step:{1}=====".format(epoch,step)) - print("ref_image: ",ref_image.shape) - print("masks: ", masks.shape) - print("masked_image: ", masked_image.shape) - print("audio feature: ", audio_feature.shape) - print("image: ", image.shape) - """ + # """ + # print("=============epoch:{0}=step:{1}=====".format(epoch,step)) + # print("ref_image: ",ref_image.shape) + # print("masks: ", masks.shape) + # print("masked_image: ", masked_image.shape) + # print("audio feature: ", audio_feature.shape) + # print("image: ", image.shape) + # """ ref_image = preprocess_img_tensor(ref_image).to(vae.device) image = preprocess_img_tensor(image).to(vae.device) masked_image = preprocess_img_tensor(masked_image).to(vae.device) img_process_time = time.time() - start start = time.time() - with accelerator.accumulate(unet): + vae = vae.half() # Convert images to latent space latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image latents = latents * vae.config.scaling_factor @@ -592,12 +603,23 @@ def main(): f"Running validation... epoch={epoch}, global_step={global_step}" ) print("===========start validation==========") + # Use the helper function to check the data types for each model + vae_new = vae.float() + print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE") + print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32") + print_model_dtypes(accelerator.unwrap_model(unet), "UNET") + + print(f"weight_dtype: {weight_dtype}") + print(f"epoch type: {type(epoch)}") + print(f"global_step type: {type(global_step)}") validation( - vae=accelerator.unwrap_model(vae), + # vae=accelerator.unwrap_model(vae), + vae=accelerator.unwrap_model(vae_new), vae_fp32=accelerator.unwrap_model(vae_fp32), unet=accelerator.unwrap_model(unet), unet_config=unet_config, - weight_dtype=weight_dtype, + # weight_dtype=weight_dtype, + weight_dtype=torch.float32, epoch=epoch, global_step=global_step, val_data_loader=val_data_loader, diff --git a/train_codes/train.sh b/train_codes/train.sh index 908a676..8c60b48 100644 --- a/train_codes/train.sh +++ b/train_codes/train.sh @@ -1,8 +1,8 @@ -export VAE_MODEL="./sd-vae-ft-mse/" -export DATASET="..." -export UNET_CONFIG="./musetalk.json" +export VAE_MODEL="../models/sd-vae-ft-mse/" +export DATASET="../data" +export UNET_CONFIG="../models/musetalk/musetalk.json" -accelerate launch --multi_gpu train.py \ +accelerate launch train.py \ --mixed_precision="fp16" \ --unet_config_file=$UNET_CONFIG \ --pretrained_model_name_or_path=$VAE_MODEL \ @@ -10,13 +10,13 @@ accelerate launch --multi_gpu train.py \ --train_batch_size=8 \ --gradient_accumulation_steps=4 \ --gradient_checkpointing \ ---max_train_steps=200000 \ +--max_train_steps=50000 \ --learning_rate=5e-05 \ --max_grad_norm=1 \ --lr_scheduler="cosine" \ --lr_warmup_steps=0 \ ---output_dir="..." \ ---val_out_dir='...' \ +--output_dir="output" \ +--val_out_dir='val' \ --testing_speed \ --checkpointing_steps=1000 \ --validation_steps=1000 \ diff --git a/train_codes/utils/model_utils.py b/train_codes/utils/model_utils.py index e0fd2e2..0534521 100644 --- a/train_codes/utils/model_utils.py +++ b/train_codes/utils/model_utils.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import time import math -from utils import decode_latents, preprocess_img_tensor +from utils.utils import decode_latents, preprocess_img_tensor import os from PIL import Image from typing import Any, Dict, List, Optional, Tuple, Union From d74c4c098b3af834d795e0c8494214f92df5363a Mon Sep 17 00:00:00 2001 From: Shounak Banerjee Date: Fri, 7 Jun 2024 18:39:24 +0000 Subject: [PATCH 3/5] clean code and sepaarate finetuned_inference.py --- scripts/data.py | 41 +++----- scripts/finetuned_inference.py | 182 +++++++++++++++++++++++++++++++++ scripts/inference.py | 37 ++----- train_codes/README.md | 2 +- train_codes/train.sh | 2 +- 5 files changed, 206 insertions(+), 58 deletions(-) create mode 100644 scripts/finetuned_inference.py diff --git a/scripts/data.py b/scripts/data.py index 22887b8..f6cdb81 100644 --- a/scripts/data.py +++ b/scripts/data.py @@ -18,6 +18,7 @@ from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_p from musetalk.utils.blending import get_image from musetalk.utils.utils import load_all_model import shutil +import gc # load model weights audio_processor, vae, unet, pe = load_all_model() @@ -57,7 +58,10 @@ def main(args): unet.model = unet.model.half() inference_config = OmegaConf.load(args.inference_config) - print(inference_config) + total_audio_index=-1 + total_image_index=-1 + temp_audio_index=-1 + temp_image_index=-1 for task_id in inference_config: video_path = inference_config[task_id]["video_path"] audio_path = inference_config[task_id]["audio_path"] @@ -95,32 +99,20 @@ def main(args): fps = args.fps else: raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") - print("LEN..........") - - print(len(input_img_list)) ############################################## extract audio feature ############################################## whisper_feature = audio_processor.audio2feat(audio_path) - print(len(whisper_feature)) - print("Whisper feature length........") - print(whisper_feature[0].shape) - # print(whisper_feature) for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements # Combine two consecutive chunks # pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]]) concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0) # Save the pair to a .npy file - print("Pair shape",concatenated_chunks.shape) - np.save(f'data/audios/{folder_name}/{__//2}.npy', concatenated_chunks) + np.save(f'data/audios/{folder_name}/{total_audio_index+(__//2)+1}.npy', concatenated_chunks) + temp_audio_index=(__//2)+total_audio_index+1 + total_audio_index=temp_audio_index whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) - print(len(whisper_chunks)) - # whisper_i=0 - # for chunk in whisper_chunks: - # # print("CHUNMK SHAPE...........") - # # print(chunk.shape) - # np.save(f'data/audios/{folder_name}/{str(whisper_i)}.npy', chunk) - # whisper_i+=1 ############################################## preprocess input image ############################################## + gc.collect() if os.path.exists(crop_coord_save_path) and args.use_saved_coord: print("using extracted coordinates") with open(crop_coord_save_path,'rb') as f: @@ -131,8 +123,7 @@ def main(args): coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) with open(crop_coord_save_path, 'wb') as f: pickle.dump(coord_list, f) - - print(len(frame_list)) + i = 0 input_latent_list = [] @@ -151,9 +142,7 @@ def main(args): if ((y2-y1)<=0) or ((x2-x1)<=0): continue crop_frame = frame[y1:y2, x1:x2] - print("crop sizes",bbox) crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) - cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame) latents = vae.get_latents_for_unet(crop_frame) crop_data.append(crop_frame) input_latent_list.append(latents) @@ -165,20 +154,18 @@ def main(args): input_latent_list_cycle = input_latent_list + input_latent_list[::-1] crop_data = crop_data + crop_data[::-1] ############################################## inference batch by batch ############################################## - print("start inference") - print(len(input_latent_list_cycle),len(whisper_chunks)) video_num = len(whisper_chunks) batch_size = args.batch_size gen = datagen(whisper_chunks,crop_data,batch_size) for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): - print("BATCH LEN..............") - print(len(whisper_batch),len(crop_batch)) crop_index=0 for image,audio in zip(crop_batch,whisper_batch): - cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index)}.png",image) + cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index+total_image_index+1)}.png",image) crop_index+=1 + temp_image_index=i+crop_index+total_image_index+1 # np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio) - print(folder_name) + total_image_index=temp_image_index + gc.collect() diff --git a/scripts/finetuned_inference.py b/scripts/finetuned_inference.py new file mode 100644 index 0000000..dddce31 --- /dev/null +++ b/scripts/finetuned_inference.py @@ -0,0 +1,182 @@ +import argparse +import os +from omegaconf import OmegaConf +import numpy as np +import cv2 +import torch +import glob +import pickle +from tqdm import tqdm +import copy + +from musetalk.utils.utils import get_file_type,get_video_fps,datagen +from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder +from musetalk.utils.blending import get_image +from musetalk.utils.utils import load_all_model +import shutil + +from accelerate import Accelerator + +# load model weights +audio_processor, vae, unet, pe = load_all_model() +accelerator = Accelerator( + mixed_precision="fp16", + ) +unet = accelerator.prepare( + unet, + + ) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +timesteps = torch.tensor([0], device=device) + +@torch.no_grad() +def main(args): + global pe + if not (args.unet_checkpoint == None): + print("unet ckpt loaded") + accelerator.load_state(args.unet_checkpoint) + + if args.use_float16 is True: + pe = pe.half() + vae.vae = vae.vae.half() + unet.model = unet.model.half() + + inference_config = OmegaConf.load(args.inference_config) + print(inference_config) + for task_id in inference_config: + video_path = inference_config[task_id]["video_path"] + audio_path = inference_config[task_id]["audio_path"] + bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) + + input_basename = os.path.basename(video_path).split('.')[0] + audio_basename = os.path.basename(audio_path).split('.')[0] + output_basename = f"{input_basename}_{audio_basename}" + result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs + crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input + os.makedirs(result_img_save_path,exist_ok =True) + + if args.output_vid_name is None: + output_vid_name = os.path.join(args.result_dir, output_basename+".mp4") + else: + output_vid_name = os.path.join(args.result_dir, args.output_vid_name) + ############################################## extract frames from source video ############################################## + if get_file_type(video_path)=="video": + save_dir_full = os.path.join(args.result_dir, input_basename) + os.makedirs(save_dir_full,exist_ok = True) + cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" + os.system(cmd) + input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) + fps = get_video_fps(video_path) + elif get_file_type(video_path)=="image": + input_img_list = [video_path, ] + fps = args.fps + elif os.path.isdir(video_path): # input img folder + input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + fps = args.fps + else: + raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") + ############################################## extract audio feature ############################################## + whisper_feature = audio_processor.audio2feat(audio_path) + whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) + ############################################## preprocess input image ############################################## + if os.path.exists(crop_coord_save_path) and args.use_saved_coord: + print("using extracted coordinates") + with open(crop_coord_save_path,'rb') as f: + coord_list = pickle.load(f) + frame_list = read_imgs(input_img_list) + else: + print("extracting landmarks...time consuming") + coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) + with open(crop_coord_save_path, 'wb') as f: + pickle.dump(coord_list, f) + + + i = 0 + input_latent_list = [] + crop_i=0 + for bbox, frame in zip(coord_list, frame_list): + if bbox == coord_placeholder: + continue + x1, y1, x2, y2 = bbox + crop_frame = frame[y1:y2, x1:x2] + crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) + cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame) + latents = vae.get_latents_for_unet(crop_frame) + input_latent_list.append(latents) + crop_i+=1 + + # to smooth the first and the last frame + frame_list_cycle = frame_list + frame_list[::-1] + coord_list_cycle = coord_list + coord_list[::-1] + input_latent_list_cycle = input_latent_list + input_latent_list[::-1] + ############################################## inference batch by batch ############################################## + video_num = len(whisper_chunks) + batch_size = args.batch_size + gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) + res_frame_list = [] + for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): + audio_feature_batch = torch.from_numpy(whisper_batch) + audio_feature_batch = audio_feature_batch.to(device=unet.device, + dtype=unet.model.dtype) # torch, B, 5*N,384 + audio_feature_batch = pe(audio_feature_batch) + latent_batch = latent_batch.to(dtype=unet.model.dtype) + pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample + recon = vae.decode_latents(pred_latents) + for res_frame in recon: + res_frame_list.append(res_frame) + + ############################################## pad to full image ############################################## + print("pad talking image to original video") + for i, res_frame in enumerate(tqdm(res_frame_list)): + bbox = coord_list_cycle[i%(len(coord_list_cycle))] + ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) + x1, y1, x2, y2 = bbox + try: + res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) + except: + continue + + combine_frame = get_image(ori_frame,res_frame,bbox) + cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame) + cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame) + cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) + + cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" + os.system(cmd_img2video) + + cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" + os.system(cmd_combine_audio) + + os.remove("temp.mp4") + + cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" + os.system(cmd_img2video) + + # cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" + # print(cmd_combine_audio) + # os.system(cmd_combine_audio) + + # shutil.rmtree(result_img_save_path) + print(f"result is save to {output_vid_name}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml") + parser.add_argument("--bbox_shift", type=int, default=0) + parser.add_argument("--result_dir", default='./results', help="path to output") + + parser.add_argument("--fps", type=int, default=25) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--output_vid_name", type=str, default=None) + parser.add_argument("--use_saved_coord", + action="store_true", + help='use saved coordinate to save time') + parser.add_argument("--use_float16", + action="store_true", + help="Whether use float16 to speed up inference", + ) + parser.add_argument("--unet_checkpoint", type=str, default=None) + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/scripts/inference.py b/scripts/inference.py index dddce31..b708774 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -15,27 +15,14 @@ from musetalk.utils.blending import get_image from musetalk.utils.utils import load_all_model import shutil -from accelerate import Accelerator - # load model weights audio_processor, vae, unet, pe = load_all_model() -accelerator = Accelerator( - mixed_precision="fp16", - ) -unet = accelerator.prepare( - unet, - - ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) @torch.no_grad() def main(args): global pe - if not (args.unet_checkpoint == None): - print("unet ckpt loaded") - accelerator.load_state(args.unet_checkpoint) - if args.use_float16 is True: pe = pe.half() vae.vae = vae.vae.half() @@ -76,6 +63,8 @@ def main(args): fps = args.fps else: raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") + + #print(input_img_list) ############################################## extract audio feature ############################################## whisper_feature = audio_processor.audio2feat(audio_path) whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) @@ -90,27 +79,24 @@ def main(args): coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) with open(crop_coord_save_path, 'wb') as f: pickle.dump(coord_list, f) - i = 0 input_latent_list = [] - crop_i=0 for bbox, frame in zip(coord_list, frame_list): if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox crop_frame = frame[y1:y2, x1:x2] crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) - cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame) latents = vae.get_latents_for_unet(crop_frame) input_latent_list.append(latents) - crop_i+=1 # to smooth the first and the last frame frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] ############################################## inference batch by batch ############################################## + print("start inference") video_num = len(whisper_chunks) batch_size = args.batch_size gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) @@ -121,6 +107,7 @@ def main(args): dtype=unet.model.dtype) # torch, B, 5*N,384 audio_feature_batch = pe(audio_feature_batch) latent_batch = latent_batch.to(dtype=unet.model.dtype) + pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample recon = vae.decode_latents(pred_latents) for res_frame in recon: @@ -135,29 +122,22 @@ def main(args): try: res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) except: +# print(bbox) continue combine_frame = get_image(ori_frame,res_frame,bbox) - cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame) - cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame) cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" + print(cmd_img2video) os.system(cmd_img2video) cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" + print(cmd_combine_audio) os.system(cmd_combine_audio) os.remove("temp.mp4") - - cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" - os.system(cmd_img2video) - - # cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" - # print(cmd_combine_audio) - # os.system(cmd_combine_audio) - - # shutil.rmtree(result_img_save_path) + shutil.rmtree(result_img_save_path) print(f"result is save to {output_vid_name}") if __name__ == "__main__": @@ -176,7 +156,6 @@ if __name__ == "__main__": action="store_true", help="Whether use float16 to speed up inference", ) - parser.add_argument("--unet_checkpoint", type=str, default=None) args = parser.parse_args() main(args) \ No newline at end of file diff --git a/train_codes/README.md b/train_codes/README.md index ba0f370..f303b87 100644 --- a/train_codes/README.md +++ b/train_codes/README.md @@ -43,7 +43,7 @@ sh train.sh ## Inference with trained checkpoit Simply run after training the model, the model checkpoints are saved at train_codes/output usually ``` -python -m scripts.inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder +python -m scripts.finetuned_inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder ``` ## TODO diff --git a/train_codes/train.sh b/train_codes/train.sh index 8c60b48..2e29d5c 100644 --- a/train_codes/train.sh +++ b/train_codes/train.sh @@ -10,7 +10,7 @@ accelerate launch train.py \ --train_batch_size=8 \ --gradient_accumulation_steps=4 \ --gradient_checkpointing \ ---max_train_steps=50000 \ +--max_train_steps=100000 \ --learning_rate=5e-05 \ --max_grad_norm=1 \ --lr_scheduler="cosine" \ From af82f3b00f9a56b66ddeda5684f8f47eaad66f5c Mon Sep 17 00:00:00 2001 From: Shounak Banerjee Date: Thu, 13 Jun 2024 14:14:52 +0000 Subject: [PATCH 4/5] temporary commit to save changes --- data_new.sh | 77 +++++++++++++++++++++++++++++++++++++++ scripts/data.py | 34 +++++++++++++++-- train_codes/DataLoader.py | 17 +++++---- train_codes/train.py | 6 ++- train_codes/train.sh | 6 ++- 5 files changed, 125 insertions(+), 15 deletions(-) create mode 100755 data_new.sh diff --git a/data_new.sh b/data_new.sh new file mode 100755 index 0000000..61844b2 --- /dev/null +++ b/data_new.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +# Function to extract video and audio sections +extract_sections() { + input_video=$1 + base_name=$(basename "$input_video" .mp4) + output_dir=$2 + split=$3 + duration=$(ffmpeg -i "$input_video" 2>&1 | grep Duration | awk '{print $2}' | tr -d ,) + IFS=: read -r hours minutes seconds <<< "$duration" + total_seconds=$((10#${hours}*3600 + 10#${minutes}*60 + 10#${seconds%.*})) + chunk_size=180 # 3 minutes in seconds + index=0 + + mkdir -p "$output_dir" + + while [ $((index * chunk_size)) -lt $total_seconds ]; do + start_time=$((index * chunk_size)) + section_video="${output_dir}/${base_name}_part${index}.mp4" + section_audio="${output_dir}/${base_name}_part${index}.mp3" + + ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -c copy "$section_video" + ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -q:a 0 -map a "$section_audio" + + # Create and update the config.yaml file + echo "task_0:" > config.yaml + echo " video_path: \"$section_video\"" >> config.yaml + echo " audio_path: \"$section_audio\"" >> config.yaml + + # Run the Python script with the current config.yaml + python -m scripts.data --inference_config config.yaml --folder_name "$base_name" + + index=$((index + 1)) + done + + # Clean up save folder + rm -rf $output_dir +} + +# Main script +if [ $# -lt 3 ]; then + echo "Usage: $0 " + exit 1 +fi + +split=$1 +output_dir=$2 +shift 2 +input_videos=("$@") + +# Initialize JSON array +json_array="[" + +for input_video in "${input_videos[@]}"; do + base_name=$(basename "$input_video" .mp4) + + # Extract sections and run the Python script for each section + extract_sections "$input_video" "$output_dir" "$split" + + # Add entry to JSON array + json_array+="\"../data/images/$base_name\"," +done + +# Remove trailing comma and close JSON array +json_array="${json_array%,}]" + +# Write JSON array to the correct file +if [ "$split" == "train" ]; then + echo "$json_array" > train.json +elif [ "$split" == "test" ]; then + echo "$json_array" > test.json +else + echo "Invalid split: $split. Must be 'train' or 'test'." + exit 1 +fi + +echo "Processing complete." diff --git a/scripts/data.py b/scripts/data.py index f6cdb81..e5369d7 100644 --- a/scripts/data.py +++ b/scripts/data.py @@ -25,6 +25,32 @@ audio_processor, vae, unet, pe = load_all_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) +def get_largest_integer_filename(folder_path): + # Check if the folder exists + if not os.path.isdir(folder_path): + return -1 + + # Get the list of files in the folder + files = os.listdir(folder_path) + + # Check if the folder is empty + if not files: + return -1 + + # Extract the integer part of filenames and find the largest + largest_integer = -1 + for file in files: + try: + # Get the integer part of the filename + file_int = int(os.path.splitext(file)[0]) + if file_int > largest_integer: + largest_integer = file_int + except ValueError: + # Skip files that don't have an integer filename + continue + + return largest_integer + def datagen(whisper_chunks, crop_images, batch_size=8, @@ -58,10 +84,10 @@ def main(args): unet.model = unet.model.half() inference_config = OmegaConf.load(args.inference_config) - total_audio_index=-1 - total_image_index=-1 - temp_audio_index=-1 - temp_image_index=-1 + total_audio_index=get_largest_integer_filename(f"data/audios/{args.folder_name}") + total_image_index=get_largest_integer_filename(f"data/images/{args.folder_name}") + temp_audio_index=total_audio_index + temp_image_index=total_image_index for task_id in inference_config: video_path = inference_config[task_id]["video_path"] audio_path = inference_config[task_id]["audio_path"] diff --git a/train_codes/DataLoader.py b/train_codes/DataLoader.py index f0652e3..431ee39 100644 --- a/train_codes/DataLoader.py +++ b/train_codes/DataLoader.py @@ -48,15 +48,15 @@ def get_image_list(data_root, split): class Dataset(object): def __init__(self, data_root, - split, + json_path, use_audio_length_left=1, use_audio_length_right=1, whisper_model_type = "tiny" ): - self.all_videos, self.all_imgNum = get_image_list(data_root, split) + # self.all_videos, self.all_imgNum = get_image_list(data_root, split) self.audio_feature = [use_audio_length_left,use_audio_length_right] self.all_img_names = [] - self.split = split + # self.split = split self.img_names_path = '../data' self.whisper_model_type = whisper_model_type self.use_audio_length_left = use_audio_length_left @@ -72,10 +72,13 @@ class Dataset(object): self.whisper_feature_H = 1280 self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50 - if(self.split=="train"): - self.all_videos=["../data/images/train"] - if(self.split=="val"): - self.all_videos=["../data/images/test"] + # if(self.split=="train"): + # self.all_videos=["../data/images/train"] + # if(self.split=="val"): + # self.all_videos=["../data/images/test"] + with open(json_path, 'r') as file: + self.all_videos = json.load(file) + for vidname in tqdm(self.all_videos, desc="Preparing dataset"): json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json" if not os.path.exists(json_path_names): diff --git a/train_codes/train.py b/train_codes/train.py index 37a9447..fae0cb2 100755 --- a/train_codes/train.py +++ b/train_codes/train.py @@ -140,6 +140,8 @@ def parse_args(): parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders") + parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders") parser.add_argument( "--hub_model_id", type=str, @@ -350,7 +352,7 @@ def main(): print("loading train_dataset ...") train_dataset = Dataset(args.data_root, - 'train', + args.train_json, use_audio_length_left=args.use_audio_length_left, use_audio_length_right=args.use_audio_length_right, whisper_model_type=args.whisper_model_type @@ -360,7 +362,7 @@ def main(): num_workers=8) print("loading val_dataset ...") val_dataset = Dataset(args.data_root, - 'val', + args.val_json, use_audio_length_left=args.use_audio_length_left, use_audio_length_right=args.use_audio_length_right, whisper_model_type=args.whisper_model_type diff --git a/train_codes/train.sh b/train_codes/train.sh index 2e29d5c..600632b 100644 --- a/train_codes/train.sh +++ b/train_codes/train.sh @@ -18,10 +18,12 @@ accelerate launch train.py \ --output_dir="output" \ --val_out_dir='val' \ --testing_speed \ ---checkpointing_steps=1000 \ ---validation_steps=1000 \ +--checkpointing_steps=2000 \ +--validation_steps=2000 \ --reconstruction \ --resume_from_checkpoint="latest" \ --use_audio_length_left=2 \ --use_audio_length_right=2 \ --whisper_model_type="tiny" \ +--train_json="/root/MuseTalk/train.json" \ +--val_json="/root/MuseTalk/val.json" \ From b9685481315441636a3604483836d16ff439fc61 Mon Sep 17 00:00:00 2001 From: Shounak Banerjee Date: Mon, 17 Jun 2024 18:39:15 +0000 Subject: [PATCH 5/5] fixed mltiple video data preperation --- train_codes/DataLoader.py | 5 ----- train_codes/README.md | 19 ++++++++++--------- train_codes/train.sh | 10 +++++----- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/train_codes/DataLoader.py b/train_codes/DataLoader.py index 431ee39..10ee8eb 100644 --- a/train_codes/DataLoader.py +++ b/train_codes/DataLoader.py @@ -71,11 +71,6 @@ class Dataset(object): self.whisper_feature_W = 33 self.whisper_feature_H = 1280 self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50 - - # if(self.split=="train"): - # self.all_videos=["../data/images/train"] - # if(self.split=="val"): - # self.all_videos=["../data/images/test"] with open(json_path, 'r') as file: self.all_videos = json.load(file) diff --git a/train_codes/README.md b/train_codes/README.md index f303b87..102a345 100644 --- a/train_codes/README.md +++ b/train_codes/README.md @@ -6,30 +6,29 @@ The test yaml file should contain the validation video paths and corresponding a Run: ``` -python -m scripts.data --inference_config path_to_train.yaml --folder_name train -python -m scripts.data --inference_config path_to_test.yaml --folder_name test +./data_new.sh train output train_video1.mp4 train_video2.mp4 +./data_new.sh test output test_video1.mp4 test_video2.mp4 ``` -This creates folders which contain the image frames and npy files. - +This creates folders which contain the image frames and npy files. This also creates train.json and val.json which can be used during the training. ## Data organization ``` ./data/ ├── images -│ └──train +│ └──RD_Radio10_000 │ └── 0.png │ └── 1.png │ └── xxx.png -│ └──test +│ └──RD_Radio11_000 │ └── 0.png │ └── 1.png │ └── xxx.png ├── audios -│ └──train +│ └──RD_Radio10_000 │ └── 0.npy │ └── 1.npy │ └── xxx.npy -│ └──test +│ └──RD_Radio11_000 │ └── 0.npy │ └── 1.npy │ └── xxx.npy @@ -38,7 +37,9 @@ This creates folders which contain the image frames and npy files. ## Training Simply run after preparing the preprocessed data ``` -sh train.sh +cd train_codes +sh train.sh #--train_json="../train.json" \(Generated in Data preprocessing step.) + #--val_json="../val.json" \ ``` ## Inference with trained checkpoit Simply run after training the model, the model checkpoints are saved at train_codes/output usually diff --git a/train_codes/train.sh b/train_codes/train.sh index 600632b..f15ddf9 100644 --- a/train_codes/train.sh +++ b/train_codes/train.sh @@ -7,13 +7,12 @@ accelerate launch train.py \ --unet_config_file=$UNET_CONFIG \ --pretrained_model_name_or_path=$VAE_MODEL \ --data_root=$DATASET \ ---train_batch_size=8 \ ---gradient_accumulation_steps=4 \ +--train_batch_size=256 \ +--gradient_accumulation_steps=16 \ --gradient_checkpointing \ --max_train_steps=100000 \ --learning_rate=5e-05 \ --max_grad_norm=1 \ ---lr_scheduler="cosine" \ --lr_warmup_steps=0 \ --output_dir="output" \ --val_out_dir='val' \ @@ -25,5 +24,6 @@ accelerate launch train.py \ --use_audio_length_left=2 \ --use_audio_length_right=2 \ --whisper_model_type="tiny" \ ---train_json="/root/MuseTalk/train.json" \ ---val_json="/root/MuseTalk/val.json" \ +--train_json="../train.json" \ +--val_json="../val.json" \ +--lr_scheduler="cosine" \