mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
clean code and sepaarate finetuned_inference.py
This commit is contained in:
@@ -18,6 +18,7 @@ from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_p
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
import gc
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
@@ -57,7 +58,10 @@ def main(args):
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
total_audio_index=-1
|
||||
total_image_index=-1
|
||||
temp_audio_index=-1
|
||||
temp_image_index=-1
|
||||
for task_id in inference_config:
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
@@ -95,32 +99,20 @@ def main(args):
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
print("LEN..........")
|
||||
|
||||
print(len(input_img_list))
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
print(len(whisper_feature))
|
||||
print("Whisper feature length........")
|
||||
print(whisper_feature[0].shape)
|
||||
# print(whisper_feature)
|
||||
for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements
|
||||
# Combine two consecutive chunks
|
||||
# pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]])
|
||||
concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0)
|
||||
# Save the pair to a .npy file
|
||||
print("Pair shape",concatenated_chunks.shape)
|
||||
np.save(f'data/audios/{folder_name}/{__//2}.npy', concatenated_chunks)
|
||||
np.save(f'data/audios/{folder_name}/{total_audio_index+(__//2)+1}.npy', concatenated_chunks)
|
||||
temp_audio_index=(__//2)+total_audio_index+1
|
||||
total_audio_index=temp_audio_index
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
print(len(whisper_chunks))
|
||||
# whisper_i=0
|
||||
# for chunk in whisper_chunks:
|
||||
# # print("CHUNMK SHAPE...........")
|
||||
# # print(chunk.shape)
|
||||
# np.save(f'data/audios/{folder_name}/{str(whisper_i)}.npy', chunk)
|
||||
# whisper_i+=1
|
||||
|
||||
############################################## preprocess input image ##############################################
|
||||
gc.collect()
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
@@ -131,8 +123,7 @@ def main(args):
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
print(len(frame_list))
|
||||
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
@@ -151,9 +142,7 @@ def main(args):
|
||||
if ((y2-y1)<=0) or ((x2-x1)<=0):
|
||||
continue
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
print("crop sizes",bbox)
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
crop_data.append(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
@@ -165,20 +154,18 @@ def main(args):
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
crop_data = crop_data + crop_data[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
print("start inference")
|
||||
print(len(input_latent_list_cycle),len(whisper_chunks))
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,crop_data,batch_size)
|
||||
for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
print("BATCH LEN..............")
|
||||
print(len(whisper_batch),len(crop_batch))
|
||||
crop_index=0
|
||||
for image,audio in zip(crop_batch,whisper_batch):
|
||||
cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index)}.png",image)
|
||||
cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index+total_image_index+1)}.png",image)
|
||||
crop_index+=1
|
||||
temp_image_index=i+crop_index+total_image_index+1
|
||||
# np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio)
|
||||
print(folder_name)
|
||||
total_image_index=temp_image_index
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user