import argparse import os from omegaconf import OmegaConf import numpy as np import cv2 import torch import glob import pickle import sys from tqdm import tqdm import copy import json from musetalk.utils.utils import get_file_type,get_video_fps,datagen from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending from musetalk.utils.utils import load_all_model import shutil import threading import queue import time # load model weights audio_processor,vae,unet,pe = load_all_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000): cap = cv2.VideoCapture(vid_path) count = 0 while True: if count > cut_frame: break ret, frame = cap.read() if ret: cv2.imwrite(f"{save_path}/{count:08d}.png", frame) count += 1 else: break def osmakedirs(path_list): for path in path_list: os.makedirs(path) if not os.path.exists(path) else None @torch.no_grad() class Avatar: def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation): self.avatar_id = avatar_id self.video_path = video_path self.bbox_shift = bbox_shift self.avatar_path = f"./results/avatars/{avatar_id}" self.full_imgs_path = f"{self.avatar_path}/full_imgs" self.coords_path = f"{self.avatar_path}/coords.pkl" self.latents_out_path= f"{self.avatar_path}/latents.pt" self.video_out_path = f"{self.avatar_path}/vid_output/" self.mask_out_path =f"{self.avatar_path}/mask" self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl" self.avatar_info_path = f"{self.avatar_path}/avator_info.json" self.avatar_info = { "avatar_id":avatar_id, "video_path":video_path, "bbox_shift":bbox_shift } self.preparation = preparation self.batch_size = batch_size self.idx = 0 self.init() def init(self): if self.preparation: if os.path.exists(self.avatar_path): response = input(f"{self.avatar_id} exists, Do you want to re-create it ? (y/n)") if response.lower() == "y": shutil.rmtree(self.avatar_path) print("*********************************") print(f" creating avator: {self.avatar_id}") print("*********************************") osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) self.prepare_material() else: self.input_latent_list_cycle = torch.load(self.latents_out_path) with open(self.coords_path, 'rb') as f: self.coord_list_cycle = pickle.load(f) input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) self.frame_list_cycle = read_imgs(input_img_list) with open(self.mask_coords_path, 'rb') as f: self.mask_coords_list_cycle = pickle.load(f) input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) self.mask_list_cycle = read_imgs(input_mask_list) else: print("*********************************") print(f" creating avator: {self.avatar_id}") print("*********************************") osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) self.prepare_material() else: with open(self.avatar_info_path, "r") as f: avatar_info = json.load(f) if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']: response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)") if response.lower() == "c": shutil.rmtree(self.avatar_path) print("*********************************") print(f" creating avator: {self.avatar_id}") print("*********************************") osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) self.prepare_material() else: sys.exit() else: self.input_latent_list_cycle = torch.load(self.latents_out_path) with open(self.coords_path, 'rb') as f: self.coord_list_cycle = pickle.load(f) input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) self.frame_list_cycle = read_imgs(input_img_list) with open(self.mask_coords_path, 'rb') as f: self.mask_coords_list_cycle = pickle.load(f) input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) self.mask_list_cycle = read_imgs(input_mask_list) def prepare_material(self): print("preparing data materials ... ...") with open(self.avatar_info_path, "w") as f: json.dump(self.avatar_info, f) if os.path.isfile(self.video_path): video2imgs(self.video_path, self.full_imgs_path, ext = 'png') else: print(f"copy files in {self.video_path}") files = os.listdir(self.video_path) files.sort() files = [file for file in files if file.split(".")[-1]=="png"] for filename in files: shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}") input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))) print("extracting landmarks...") coord_list, frame_list = get_landmark_and_bbox(input_img_list, self.bbox_shift) input_latent_list = [] idx = -1 # maker if the bbox is not sufficient coord_placeholder = (0.0,0.0,0.0,0.0) for bbox, frame in zip(coord_list, frame_list): idx = idx + 1 if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox crop_frame = frame[y1:y2, x1:x2] resized_crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) latents = vae.get_latents_for_unet(resized_crop_frame) input_latent_list.append(latents) self.frame_list_cycle = frame_list + frame_list[::-1] self.coord_list_cycle = coord_list + coord_list[::-1] self.input_latent_list_cycle = input_latent_list + input_latent_list[::-1] self.mask_coords_list_cycle = [] self.mask_list_cycle = [] for i,frame in enumerate(tqdm(self.frame_list_cycle)): cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png",frame) face_box = self.coord_list_cycle[i] mask,crop_box = get_image_prepare_material(frame,face_box) cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png",mask) self.mask_coords_list_cycle += [crop_box] self.mask_list_cycle.append(mask) with open(self.mask_coords_path, 'wb') as f: pickle.dump(self.mask_coords_list_cycle, f) with open(self.coords_path, 'wb') as f: pickle.dump(self.coord_list_cycle, f) torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path)) # def process_frames(self, res_frame_queue,video_len): print(video_len) while True: if self.idx>=video_len-1: break try: start = time.time() res_frame = res_frame_queue.get(block=True, timeout=1) except queue.Empty: continue bbox = self.coord_list_cycle[self.idx%(len(self.coord_list_cycle))] ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx%(len(self.frame_list_cycle))]) x1, y1, x2, y2 = bbox try: res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) except: continue mask = self.mask_list_cycle[self.idx%(len(self.mask_list_cycle))] mask_crop_box = self.mask_coords_list_cycle[self.idx%(len(self.mask_coords_list_cycle))] #combine_frame = get_image(ori_frame,res_frame,bbox) combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) fps = 1/(time.time()-start+1e-6) print(f"Generating the {self.idx}-th frame with FPS: {fps:.2f}") cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame) self.idx = self.idx + 1 def inference(self, audio_path, out_vid_name, fps): os.makedirs(self.avatar_path+'/tmp',exist_ok =True) ############################################## extract audio feature ############################################## whisper_feature = audio_processor.audio2feat(audio_path) whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) ############################################## inference batch by batch ############################################## video_num = len(whisper_chunks) print("start inference") res_frame_queue = queue.Queue() self.idx = 0 # # Create a sub-thread and start it process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue,video_num)) process_thread.start() start_time = time.time() gen = datagen(whisper_chunks,self.input_latent_list_cycle, self.batch_size) print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms") start_time = time.time() res_frame_list = [] for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))): start_time = time.time() tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch] audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384 audio_feature_batch = pe(audio_feature_batch) pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample recon = vae.decode_latents(pred_latents) for res_frame in recon: res_frame_queue.put(res_frame) # Close the queue and sub-thread after all tasks are completed process_thread.join() if out_vid_name is not None: # optional cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4" print(cmd_img2video) os.system(cmd_img2video) output_vid = os.path.join(self.video_out_path, out_vid_name+".mp4") # on cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i {self.avatar_path}/temp.mp4 {output_vid}" print(cmd_combine_audio) os.system(cmd_combine_audio) os.remove(f"{self.avatar_path}/temp.mp4") shutil.rmtree(f"{self.avatar_path}/tmp") print(f"result is save to {output_vid}") if __name__ == "__main__": ''' This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time. ''' parser = argparse.ArgumentParser() parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml") parser.add_argument("--fps", type=int, default=25) parser.add_argument("--batch_size", type=int, default=4) args = parser.parse_args() inference_config = OmegaConf.load(args.inference_config) print(inference_config) for avatar_id in inference_config: data_preparation = inference_config[avatar_id]["preparation"] video_path = inference_config[avatar_id]["video_path"] bbox_shift = inference_config[avatar_id]["bbox_shift"] avatar = Avatar( avatar_id = avatar_id, video_path = video_path, bbox_shift = bbox_shift, batch_size = args.batch_size, preparation= data_preparation) audio_clips = inference_config[avatar_id]["audio_clips"] for audio_num, audio_path in audio_clips.items(): print("Inferring using:",audio_path) avatar.inference(audio_path, audio_num, args.fps)