diff --git a/scripts/data.py b/scripts/data.py new file mode 100644 index 0000000..22887b8 --- /dev/null +++ b/scripts/data.py @@ -0,0 +1,243 @@ +import cv2 +import os +# import dlib +import argparse +import os +from omegaconf import OmegaConf +import numpy as np +import cv2 +import torch +import glob +import pickle +from tqdm import tqdm +import copy +import uuid + +from musetalk.utils.utils import get_file_type,get_video_fps +from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder +from musetalk.utils.blending import get_image +from musetalk.utils.utils import load_all_model +import shutil + +# load model weights +audio_processor, vae, unet, pe = load_all_model() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +timesteps = torch.tensor([0], device=device) + +def datagen(whisper_chunks, + crop_images, + batch_size=8, + delay_frame=0): + whisper_batch, crop_batch = [], [] + for i, w in enumerate(whisper_chunks): + idx = (i+delay_frame)%len(crop_images) + crop_image = crop_images[idx] + whisper_batch.append(w) + crop_batch.append(crop_image) + + if len(crop_batch) >= batch_size: + whisper_batch = np.stack(whisper_batch) + # latent_batch = torch.cat(latent_batch, dim=0) + yield whisper_batch, crop_batch + whisper_batch, crop_batch = [], [] + + # the last batch may smaller than batch size + if len(crop_batch) > 0: + whisper_batch = np.stack(whisper_batch) + # latent_batch = torch.cat(latent_batch, dim=0) + + yield whisper_batch, crop_batch + +@torch.no_grad() +def main(args): + global pe + if args.use_float16 is True: + pe = pe.half() + vae.vae = vae.vae.half() + unet.model = unet.model.half() + + inference_config = OmegaConf.load(args.inference_config) + print(inference_config) + for task_id in inference_config: + video_path = inference_config[task_id]["video_path"] + audio_path = inference_config[task_id]["audio_path"] + bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) + folder_name = args.folder_name + if not os.path.exists(f"data/images/{folder_name}/"): + os.makedirs(f"data/images/{folder_name}") + if not os.path.exists(f"data/audios/{folder_name}/"): + os.makedirs(f"data/audios/{folder_name}") + input_basename = os.path.basename(video_path).split('.')[0] + audio_basename = os.path.basename(audio_path).split('.')[0] + output_basename = f"{input_basename}_{audio_basename}" + result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs + crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input + os.makedirs(result_img_save_path,exist_ok =True) + + if args.output_vid_name is None: + output_vid_name = os.path.join(args.result_dir, output_basename+".mp4") + else: + output_vid_name = os.path.join(args.result_dir, args.output_vid_name) + ############################################## extract frames from source video ############################################## + if get_file_type(video_path)=="video": + save_dir_full = os.path.join(args.result_dir, input_basename) + os.makedirs(save_dir_full,exist_ok = True) + cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" + os.system(cmd) + input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) + fps = get_video_fps(video_path) + elif get_file_type(video_path)=="image": + input_img_list = [video_path, ] + fps = args.fps + elif os.path.isdir(video_path): # input img folder + input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + fps = args.fps + else: + raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") + print("LEN..........") + + print(len(input_img_list)) + ############################################## extract audio feature ############################################## + whisper_feature = audio_processor.audio2feat(audio_path) + print(len(whisper_feature)) + print("Whisper feature length........") + print(whisper_feature[0].shape) + # print(whisper_feature) + for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements + # Combine two consecutive chunks + # pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]]) + concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0) + # Save the pair to a .npy file + print("Pair shape",concatenated_chunks.shape) + np.save(f'data/audios/{folder_name}/{__//2}.npy', concatenated_chunks) + whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) + print(len(whisper_chunks)) + # whisper_i=0 + # for chunk in whisper_chunks: + # # print("CHUNMK SHAPE...........") + # # print(chunk.shape) + # np.save(f'data/audios/{folder_name}/{str(whisper_i)}.npy', chunk) + # whisper_i+=1 + + ############################################## preprocess input image ############################################## + if os.path.exists(crop_coord_save_path) and args.use_saved_coord: + print("using extracted coordinates") + with open(crop_coord_save_path,'rb') as f: + coord_list = pickle.load(f) + frame_list = read_imgs(input_img_list) + else: + print("extracting landmarks...time consuming") + coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) + with open(crop_coord_save_path, 'wb') as f: + pickle.dump(coord_list, f) + + print(len(frame_list)) + + i = 0 + input_latent_list = [] + crop_i=0 + crop_data=[] + for bbox, frame in zip(coord_list, frame_list): + if bbox == coord_placeholder: + continue + x1, y1, x2, y2 = bbox + + x1=max(0,x1) + y1=max(0,y1) + x2=max(0,x2) + y2=max(0,y2) + + if ((y2-y1)<=0) or ((x2-x1)<=0): + continue + crop_frame = frame[y1:y2, x1:x2] + print("crop sizes",bbox) + crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) + cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame) + latents = vae.get_latents_for_unet(crop_frame) + crop_data.append(crop_frame) + input_latent_list.append(latents) + crop_i+=1 + + # to smooth the first and the last frame + frame_list_cycle = frame_list + frame_list[::-1] + coord_list_cycle = coord_list + coord_list[::-1] + input_latent_list_cycle = input_latent_list + input_latent_list[::-1] + crop_data = crop_data + crop_data[::-1] + ############################################## inference batch by batch ############################################## + print("start inference") + print(len(input_latent_list_cycle),len(whisper_chunks)) + video_num = len(whisper_chunks) + batch_size = args.batch_size + gen = datagen(whisper_chunks,crop_data,batch_size) + for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): + print("BATCH LEN..............") + print(len(whisper_batch),len(crop_batch)) + crop_index=0 + for image,audio in zip(crop_batch,whisper_batch): + cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index)}.png",image) + crop_index+=1 + # np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio) + print(folder_name) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml") + parser.add_argument("--bbox_shift", type=int, default=0) + parser.add_argument("--result_dir", default='./results', help="path to output") + parser.add_argument("--folder_name", default=f'{uuid.uuid4()}', help="path to output") + + parser.add_argument("--fps", type=int, default=25) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--output_vid_name", type=str, default=None) + parser.add_argument("--use_saved_coord", + action="store_true", + help='use saved coordinate to save time') + parser.add_argument("--use_float16", + action="store_true", + help="Whether use float16 to speed up inference", + ) + + args = parser.parse_args() + main(args) + + +def process_audio(audio_path): + whisper_feature = audio_processor.audio2feat(audio_path) + np.save('audio/your_filename.npy', whisper_feature) + +def mask_face(image): + # Load dlib's face detector and the landmark predictor + detector = dlib.get_frontal_face_detector() + predictor_path = "/content/shape_predictor_68_face_landmarks.dat" # Set path to your downloaded predictor file + predictor = dlib.shape_predictor(predictor_path) + + # Load your input image + # image_path = "/content/ori_frame_00000077.png" # Replace with the path to your input image + # image = cv2.imread(image_path) + if image is None: + raise ValueError("Image not found or unable to load.") + + # Convert to grayscale for detection + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + # Detect faces in the image + faces = detector(gray) + + # Process each detected face + for face in faces: + # Predict landmarks + landmarks = predictor(gray, face) + + # The indices of nose landmarks are 27 to 35 + nose_tip = landmarks.part(33).y + + # Blacken the region below the nose tip + blacken_area = image[nose_tip:, :] + blacken_area[:] = (0, 0, 0) + + # Save the final image or display it + # cv2.imwrite("output_image.jpg", image) + return image