diff --git a/scripts/data.py b/scripts/data.py index 22887b8..f6cdb81 100644 --- a/scripts/data.py +++ b/scripts/data.py @@ -18,6 +18,7 @@ from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_p from musetalk.utils.blending import get_image from musetalk.utils.utils import load_all_model import shutil +import gc # load model weights audio_processor, vae, unet, pe = load_all_model() @@ -57,7 +58,10 @@ def main(args): unet.model = unet.model.half() inference_config = OmegaConf.load(args.inference_config) - print(inference_config) + total_audio_index=-1 + total_image_index=-1 + temp_audio_index=-1 + temp_image_index=-1 for task_id in inference_config: video_path = inference_config[task_id]["video_path"] audio_path = inference_config[task_id]["audio_path"] @@ -95,32 +99,20 @@ def main(args): fps = args.fps else: raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") - print("LEN..........") - - print(len(input_img_list)) ############################################## extract audio feature ############################################## whisper_feature = audio_processor.audio2feat(audio_path) - print(len(whisper_feature)) - print("Whisper feature length........") - print(whisper_feature[0].shape) - # print(whisper_feature) for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements # Combine two consecutive chunks # pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]]) concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0) # Save the pair to a .npy file - print("Pair shape",concatenated_chunks.shape) - np.save(f'data/audios/{folder_name}/{__//2}.npy', concatenated_chunks) + np.save(f'data/audios/{folder_name}/{total_audio_index+(__//2)+1}.npy', concatenated_chunks) + temp_audio_index=(__//2)+total_audio_index+1 + total_audio_index=temp_audio_index whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) - print(len(whisper_chunks)) - # whisper_i=0 - # for chunk in whisper_chunks: - # # print("CHUNMK SHAPE...........") - # # print(chunk.shape) - # np.save(f'data/audios/{folder_name}/{str(whisper_i)}.npy', chunk) - # whisper_i+=1 ############################################## preprocess input image ############################################## + gc.collect() if os.path.exists(crop_coord_save_path) and args.use_saved_coord: print("using extracted coordinates") with open(crop_coord_save_path,'rb') as f: @@ -131,8 +123,7 @@ def main(args): coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) with open(crop_coord_save_path, 'wb') as f: pickle.dump(coord_list, f) - - print(len(frame_list)) + i = 0 input_latent_list = [] @@ -151,9 +142,7 @@ def main(args): if ((y2-y1)<=0) or ((x2-x1)<=0): continue crop_frame = frame[y1:y2, x1:x2] - print("crop sizes",bbox) crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) - cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame) latents = vae.get_latents_for_unet(crop_frame) crop_data.append(crop_frame) input_latent_list.append(latents) @@ -165,20 +154,18 @@ def main(args): input_latent_list_cycle = input_latent_list + input_latent_list[::-1] crop_data = crop_data + crop_data[::-1] ############################################## inference batch by batch ############################################## - print("start inference") - print(len(input_latent_list_cycle),len(whisper_chunks)) video_num = len(whisper_chunks) batch_size = args.batch_size gen = datagen(whisper_chunks,crop_data,batch_size) for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): - print("BATCH LEN..............") - print(len(whisper_batch),len(crop_batch)) crop_index=0 for image,audio in zip(crop_batch,whisper_batch): - cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index)}.png",image) + cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index+total_image_index+1)}.png",image) crop_index+=1 + temp_image_index=i+crop_index+total_image_index+1 # np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio) - print(folder_name) + total_image_index=temp_image_index + gc.collect() diff --git a/scripts/finetuned_inference.py b/scripts/finetuned_inference.py new file mode 100644 index 0000000..dddce31 --- /dev/null +++ b/scripts/finetuned_inference.py @@ -0,0 +1,182 @@ +import argparse +import os +from omegaconf import OmegaConf +import numpy as np +import cv2 +import torch +import glob +import pickle +from tqdm import tqdm +import copy + +from musetalk.utils.utils import get_file_type,get_video_fps,datagen +from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder +from musetalk.utils.blending import get_image +from musetalk.utils.utils import load_all_model +import shutil + +from accelerate import Accelerator + +# load model weights +audio_processor, vae, unet, pe = load_all_model() +accelerator = Accelerator( + mixed_precision="fp16", + ) +unet = accelerator.prepare( + unet, + + ) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +timesteps = torch.tensor([0], device=device) + +@torch.no_grad() +def main(args): + global pe + if not (args.unet_checkpoint == None): + print("unet ckpt loaded") + accelerator.load_state(args.unet_checkpoint) + + if args.use_float16 is True: + pe = pe.half() + vae.vae = vae.vae.half() + unet.model = unet.model.half() + + inference_config = OmegaConf.load(args.inference_config) + print(inference_config) + for task_id in inference_config: + video_path = inference_config[task_id]["video_path"] + audio_path = inference_config[task_id]["audio_path"] + bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) + + input_basename = os.path.basename(video_path).split('.')[0] + audio_basename = os.path.basename(audio_path).split('.')[0] + output_basename = f"{input_basename}_{audio_basename}" + result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs + crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input + os.makedirs(result_img_save_path,exist_ok =True) + + if args.output_vid_name is None: + output_vid_name = os.path.join(args.result_dir, output_basename+".mp4") + else: + output_vid_name = os.path.join(args.result_dir, args.output_vid_name) + ############################################## extract frames from source video ############################################## + if get_file_type(video_path)=="video": + save_dir_full = os.path.join(args.result_dir, input_basename) + os.makedirs(save_dir_full,exist_ok = True) + cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" + os.system(cmd) + input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) + fps = get_video_fps(video_path) + elif get_file_type(video_path)=="image": + input_img_list = [video_path, ] + fps = args.fps + elif os.path.isdir(video_path): # input img folder + input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + fps = args.fps + else: + raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") + ############################################## extract audio feature ############################################## + whisper_feature = audio_processor.audio2feat(audio_path) + whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) + ############################################## preprocess input image ############################################## + if os.path.exists(crop_coord_save_path) and args.use_saved_coord: + print("using extracted coordinates") + with open(crop_coord_save_path,'rb') as f: + coord_list = pickle.load(f) + frame_list = read_imgs(input_img_list) + else: + print("extracting landmarks...time consuming") + coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) + with open(crop_coord_save_path, 'wb') as f: + pickle.dump(coord_list, f) + + + i = 0 + input_latent_list = [] + crop_i=0 + for bbox, frame in zip(coord_list, frame_list): + if bbox == coord_placeholder: + continue + x1, y1, x2, y2 = bbox + crop_frame = frame[y1:y2, x1:x2] + crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) + cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame) + latents = vae.get_latents_for_unet(crop_frame) + input_latent_list.append(latents) + crop_i+=1 + + # to smooth the first and the last frame + frame_list_cycle = frame_list + frame_list[::-1] + coord_list_cycle = coord_list + coord_list[::-1] + input_latent_list_cycle = input_latent_list + input_latent_list[::-1] + ############################################## inference batch by batch ############################################## + video_num = len(whisper_chunks) + batch_size = args.batch_size + gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) + res_frame_list = [] + for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): + audio_feature_batch = torch.from_numpy(whisper_batch) + audio_feature_batch = audio_feature_batch.to(device=unet.device, + dtype=unet.model.dtype) # torch, B, 5*N,384 + audio_feature_batch = pe(audio_feature_batch) + latent_batch = latent_batch.to(dtype=unet.model.dtype) + pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample + recon = vae.decode_latents(pred_latents) + for res_frame in recon: + res_frame_list.append(res_frame) + + ############################################## pad to full image ############################################## + print("pad talking image to original video") + for i, res_frame in enumerate(tqdm(res_frame_list)): + bbox = coord_list_cycle[i%(len(coord_list_cycle))] + ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) + x1, y1, x2, y2 = bbox + try: + res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) + except: + continue + + combine_frame = get_image(ori_frame,res_frame,bbox) + cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame) + cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame) + cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) + + cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" + os.system(cmd_img2video) + + cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" + os.system(cmd_combine_audio) + + os.remove("temp.mp4") + + cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" + os.system(cmd_img2video) + + # cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" + # print(cmd_combine_audio) + # os.system(cmd_combine_audio) + + # shutil.rmtree(result_img_save_path) + print(f"result is save to {output_vid_name}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml") + parser.add_argument("--bbox_shift", type=int, default=0) + parser.add_argument("--result_dir", default='./results', help="path to output") + + parser.add_argument("--fps", type=int, default=25) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--output_vid_name", type=str, default=None) + parser.add_argument("--use_saved_coord", + action="store_true", + help='use saved coordinate to save time') + parser.add_argument("--use_float16", + action="store_true", + help="Whether use float16 to speed up inference", + ) + parser.add_argument("--unet_checkpoint", type=str, default=None) + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/scripts/inference.py b/scripts/inference.py index dddce31..b708774 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -15,27 +15,14 @@ from musetalk.utils.blending import get_image from musetalk.utils.utils import load_all_model import shutil -from accelerate import Accelerator - # load model weights audio_processor, vae, unet, pe = load_all_model() -accelerator = Accelerator( - mixed_precision="fp16", - ) -unet = accelerator.prepare( - unet, - - ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) @torch.no_grad() def main(args): global pe - if not (args.unet_checkpoint == None): - print("unet ckpt loaded") - accelerator.load_state(args.unet_checkpoint) - if args.use_float16 is True: pe = pe.half() vae.vae = vae.vae.half() @@ -76,6 +63,8 @@ def main(args): fps = args.fps else: raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") + + #print(input_img_list) ############################################## extract audio feature ############################################## whisper_feature = audio_processor.audio2feat(audio_path) whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) @@ -90,27 +79,24 @@ def main(args): coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) with open(crop_coord_save_path, 'wb') as f: pickle.dump(coord_list, f) - i = 0 input_latent_list = [] - crop_i=0 for bbox, frame in zip(coord_list, frame_list): if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox crop_frame = frame[y1:y2, x1:x2] crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) - cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame) latents = vae.get_latents_for_unet(crop_frame) input_latent_list.append(latents) - crop_i+=1 # to smooth the first and the last frame frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] ############################################## inference batch by batch ############################################## + print("start inference") video_num = len(whisper_chunks) batch_size = args.batch_size gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) @@ -121,6 +107,7 @@ def main(args): dtype=unet.model.dtype) # torch, B, 5*N,384 audio_feature_batch = pe(audio_feature_batch) latent_batch = latent_batch.to(dtype=unet.model.dtype) + pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample recon = vae.decode_latents(pred_latents) for res_frame in recon: @@ -135,29 +122,22 @@ def main(args): try: res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) except: +# print(bbox) continue combine_frame = get_image(ori_frame,res_frame,bbox) - cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame) - cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame) cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" + print(cmd_img2video) os.system(cmd_img2video) cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" + print(cmd_combine_audio) os.system(cmd_combine_audio) os.remove("temp.mp4") - - cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" - os.system(cmd_img2video) - - # cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" - # print(cmd_combine_audio) - # os.system(cmd_combine_audio) - - # shutil.rmtree(result_img_save_path) + shutil.rmtree(result_img_save_path) print(f"result is save to {output_vid_name}") if __name__ == "__main__": @@ -176,7 +156,6 @@ if __name__ == "__main__": action="store_true", help="Whether use float16 to speed up inference", ) - parser.add_argument("--unet_checkpoint", type=str, default=None) args = parser.parse_args() main(args) \ No newline at end of file diff --git a/train_codes/README.md b/train_codes/README.md index ba0f370..f303b87 100644 --- a/train_codes/README.md +++ b/train_codes/README.md @@ -43,7 +43,7 @@ sh train.sh ## Inference with trained checkpoit Simply run after training the model, the model checkpoints are saved at train_codes/output usually ``` -python -m scripts.inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder +python -m scripts.finetuned_inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder ``` ## TODO diff --git a/train_codes/train.sh b/train_codes/train.sh index 8c60b48..2e29d5c 100644 --- a/train_codes/train.sh +++ b/train_codes/train.sh @@ -10,7 +10,7 @@ accelerate launch train.py \ --train_batch_size=8 \ --gradient_accumulation_steps=4 \ --gradient_checkpointing \ ---max_train_steps=50000 \ +--max_train_steps=100000 \ --learning_rate=5e-05 \ --max_grad_norm=1 \ --lr_scheduler="cosine" \