mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
clean code and sepaarate finetuned_inference.py
This commit is contained in:
@@ -18,6 +18,7 @@ from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_p
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
import gc
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
@@ -57,7 +58,10 @@ def main(args):
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
total_audio_index=-1
|
||||
total_image_index=-1
|
||||
temp_audio_index=-1
|
||||
temp_image_index=-1
|
||||
for task_id in inference_config:
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
@@ -95,32 +99,20 @@ def main(args):
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
print("LEN..........")
|
||||
|
||||
print(len(input_img_list))
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
print(len(whisper_feature))
|
||||
print("Whisper feature length........")
|
||||
print(whisper_feature[0].shape)
|
||||
# print(whisper_feature)
|
||||
for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements
|
||||
# Combine two consecutive chunks
|
||||
# pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]])
|
||||
concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0)
|
||||
# Save the pair to a .npy file
|
||||
print("Pair shape",concatenated_chunks.shape)
|
||||
np.save(f'data/audios/{folder_name}/{__//2}.npy', concatenated_chunks)
|
||||
np.save(f'data/audios/{folder_name}/{total_audio_index+(__//2)+1}.npy', concatenated_chunks)
|
||||
temp_audio_index=(__//2)+total_audio_index+1
|
||||
total_audio_index=temp_audio_index
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
print(len(whisper_chunks))
|
||||
# whisper_i=0
|
||||
# for chunk in whisper_chunks:
|
||||
# # print("CHUNMK SHAPE...........")
|
||||
# # print(chunk.shape)
|
||||
# np.save(f'data/audios/{folder_name}/{str(whisper_i)}.npy', chunk)
|
||||
# whisper_i+=1
|
||||
|
||||
############################################## preprocess input image ##############################################
|
||||
gc.collect()
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
@@ -131,8 +123,7 @@ def main(args):
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
print(len(frame_list))
|
||||
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
@@ -151,9 +142,7 @@ def main(args):
|
||||
if ((y2-y1)<=0) or ((x2-x1)<=0):
|
||||
continue
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
print("crop sizes",bbox)
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
crop_data.append(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
@@ -165,20 +154,18 @@ def main(args):
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
crop_data = crop_data + crop_data[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
print("start inference")
|
||||
print(len(input_latent_list_cycle),len(whisper_chunks))
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,crop_data,batch_size)
|
||||
for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
print("BATCH LEN..............")
|
||||
print(len(whisper_batch),len(crop_batch))
|
||||
crop_index=0
|
||||
for image,audio in zip(crop_batch,whisper_batch):
|
||||
cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index)}.png",image)
|
||||
cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index+total_image_index+1)}.png",image)
|
||||
crop_index+=1
|
||||
temp_image_index=i+crop_index+total_image_index+1
|
||||
# np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio)
|
||||
print(folder_name)
|
||||
total_image_index=temp_image_index
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
|
||||
182
scripts/finetuned_inference.py
Normal file
182
scripts/finetuned_inference.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import glob
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="fp16",
|
||||
)
|
||||
unet = accelerator.prepare(
|
||||
unet,
|
||||
|
||||
)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
global pe
|
||||
if not (args.unet_checkpoint == None):
|
||||
print("unet ckpt loaded")
|
||||
accelerator.load_state(args.unet_checkpoint)
|
||||
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
for task_id in inference_config:
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||
os.makedirs(result_img_save_path,exist_ok =True)
|
||||
|
||||
if args.output_vid_name is None:
|
||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path)=="video":
|
||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||
os.makedirs(save_dir_full,exist_ok = True)
|
||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
os.system(cmd)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
elif get_file_type(video_path)=="image":
|
||||
input_img_list = [video_path, ]
|
||||
fps = args.fps
|
||||
elif os.path.isdir(video_path): # input img folder
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
############################################## preprocess input image ##############################################
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("extracting landmarks...time consuming")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
crop_i=0
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
crop_i+=1
|
||||
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
res_frame_list = []
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||
dtype=unet.model.dtype) # torch, B, 5*N,384
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
############################################## pad to full image ##############################################
|
||||
print("pad talking image to original video")
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
except:
|
||||
continue
|
||||
|
||||
combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
os.system(cmd_img2video)
|
||||
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
os.system(cmd_combine_audio)
|
||||
|
||||
os.remove("temp.mp4")
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
os.system(cmd_img2video)
|
||||
|
||||
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
# print(cmd_combine_audio)
|
||||
# os.system(cmd_combine_audio)
|
||||
|
||||
# shutil.rmtree(result_img_save_path)
|
||||
print(f"result is save to {output_vid_name}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0)
|
||||
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||
|
||||
parser.add_argument("--fps", type=int, default=25)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--output_vid_name", type=str, default=None)
|
||||
parser.add_argument("--use_saved_coord",
|
||||
action="store_true",
|
||||
help='use saved coordinate to save time')
|
||||
parser.add_argument("--use_float16",
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
parser.add_argument("--unet_checkpoint", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -15,27 +15,14 @@ from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="fp16",
|
||||
)
|
||||
unet = accelerator.prepare(
|
||||
unet,
|
||||
|
||||
)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
global pe
|
||||
if not (args.unet_checkpoint == None):
|
||||
print("unet ckpt loaded")
|
||||
accelerator.load_state(args.unet_checkpoint)
|
||||
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
@@ -76,6 +63,8 @@ def main(args):
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
|
||||
#print(input_img_list)
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
@@ -90,27 +79,24 @@ def main(args):
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
crop_i=0
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
crop_i+=1
|
||||
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
print("start inference")
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
@@ -121,6 +107,7 @@ def main(args):
|
||||
dtype=unet.model.dtype) # torch, B, 5*N,384
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
@@ -135,29 +122,22 @@ def main(args):
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
except:
|
||||
# print(bbox)
|
||||
continue
|
||||
|
||||
combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
print(cmd_img2video)
|
||||
os.system(cmd_img2video)
|
||||
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
print(cmd_combine_audio)
|
||||
os.system(cmd_combine_audio)
|
||||
|
||||
os.remove("temp.mp4")
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
os.system(cmd_img2video)
|
||||
|
||||
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
# print(cmd_combine_audio)
|
||||
# os.system(cmd_combine_audio)
|
||||
|
||||
# shutil.rmtree(result_img_save_path)
|
||||
shutil.rmtree(result_img_save_path)
|
||||
print(f"result is save to {output_vid_name}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -176,7 +156,6 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
parser.add_argument("--unet_checkpoint", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user