Files
MuseTalk/scripts/data.py
2024-05-16 18:24:44 +00:00

244 lines
10 KiB
Python

import cv2
import os
# import dlib
import argparse
import os
from omegaconf import OmegaConf
import numpy as np
import cv2
import torch
import glob
import pickle
from tqdm import tqdm
import copy
import uuid
from musetalk.utils.utils import get_file_type,get_video_fps
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model
import shutil
# load model weights
audio_processor, vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
def datagen(whisper_chunks,
crop_images,
batch_size=8,
delay_frame=0):
whisper_batch, crop_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(crop_images)
crop_image = crop_images[idx]
whisper_batch.append(w)
crop_batch.append(crop_image)
if len(crop_batch) >= batch_size:
whisper_batch = np.stack(whisper_batch)
# latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, crop_batch
whisper_batch, crop_batch = [], []
# the last batch may smaller than batch size
if len(crop_batch) > 0:
whisper_batch = np.stack(whisper_batch)
# latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, crop_batch
@torch.no_grad()
def main(args):
global pe
if args.use_float16 is True:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
for task_id in inference_config:
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
folder_name = args.folder_name
if not os.path.exists(f"data/images/{folder_name}/"):
os.makedirs(f"data/images/{folder_name}")
if not os.path.exists(f"data/audios/{folder_name}/"):
os.makedirs(f"data/audios/{folder_name}")
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path,exist_ok =True)
if args.output_vid_name is None:
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
else:
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
############################################## extract frames from source video ##############################################
if get_file_type(video_path)=="video":
save_dir_full = os.path.join(args.result_dir, input_basename)
os.makedirs(save_dir_full,exist_ok = True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path)=="image":
input_img_list = [video_path, ]
fps = args.fps
elif os.path.isdir(video_path): # input img folder
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
print("LEN..........")
print(len(input_img_list))
############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path)
print(len(whisper_feature))
print("Whisper feature length........")
print(whisper_feature[0].shape)
# print(whisper_feature)
for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements
# Combine two consecutive chunks
# pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]])
concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0)
# Save the pair to a .npy file
print("Pair shape",concatenated_chunks.shape)
np.save(f'data/audios/{folder_name}/{__//2}.npy', concatenated_chunks)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
print(len(whisper_chunks))
# whisper_i=0
# for chunk in whisper_chunks:
# # print("CHUNMK SHAPE...........")
# # print(chunk.shape)
# np.save(f'data/audios/{folder_name}/{str(whisper_i)}.npy', chunk)
# whisper_i+=1
############################################## preprocess input image ##############################################
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates")
with open(crop_coord_save_path,'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("extracting landmarks...time consuming")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
print(len(frame_list))
i = 0
input_latent_list = []
crop_i=0
crop_data=[]
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
x1=max(0,x1)
y1=max(0,y1)
x2=max(0,x2)
y2=max(0,y2)
if ((y2-y1)<=0) or ((x2-x1)<=0):
continue
crop_frame = frame[y1:y2, x1:x2]
print("crop sizes",bbox)
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
latents = vae.get_latents_for_unet(crop_frame)
crop_data.append(crop_frame)
input_latent_list.append(latents)
crop_i+=1
# to smooth the first and the last frame
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
crop_data = crop_data + crop_data[::-1]
############################################## inference batch by batch ##############################################
print("start inference")
print(len(input_latent_list_cycle),len(whisper_chunks))
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,crop_data,batch_size)
for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
print("BATCH LEN..............")
print(len(whisper_batch),len(crop_batch))
crop_index=0
for image,audio in zip(crop_batch,whisper_batch):
cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index)}.png",image)
crop_index+=1
# np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio)
print(folder_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
parser.add_argument("--bbox_shift", type=int, default=0)
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--folder_name", default=f'{uuid.uuid4()}', help="path to output")
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--output_vid_name", type=str, default=None)
parser.add_argument("--use_saved_coord",
action="store_true",
help='use saved coordinate to save time')
parser.add_argument("--use_float16",
action="store_true",
help="Whether use float16 to speed up inference",
)
args = parser.parse_args()
main(args)
def process_audio(audio_path):
whisper_feature = audio_processor.audio2feat(audio_path)
np.save('audio/your_filename.npy', whisper_feature)
def mask_face(image):
# Load dlib's face detector and the landmark predictor
detector = dlib.get_frontal_face_detector()
predictor_path = "/content/shape_predictor_68_face_landmarks.dat" # Set path to your downloaded predictor file
predictor = dlib.shape_predictor(predictor_path)
# Load your input image
# image_path = "/content/ori_frame_00000077.png" # Replace with the path to your input image
# image = cv2.imread(image_path)
if image is None:
raise ValueError("Image not found or unable to load.")
# Convert to grayscale for detection
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Detect faces in the image
faces = detector(gray)
# Process each detected face
for face in faces:
# Predict landmarks
landmarks = predictor(gray, face)
# The indices of nose landmarks are 27 to 35
nose_tip = landmarks.part(33).y
# Blacken the region below the nose tip
blacken_area = image[nose_tip:, :]
blacken_area[:] = (0, 0, 0)
# Save the final image or display it
# cv2.imwrite("output_image.jpg", image)
return image