clean code and sepaarate finetuned_inference.py

This commit is contained in:
Shounak Banerjee
2024-06-07 18:39:24 +00:00
parent b4a592d7f3
commit d74c4c098b
5 changed files with 206 additions and 58 deletions

View File

@@ -18,6 +18,7 @@ from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_p
from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model
import shutil
import gc
# load model weights
audio_processor, vae, unet, pe = load_all_model()
@@ -57,7 +58,10 @@ def main(args):
unet.model = unet.model.half()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
total_audio_index=-1
total_image_index=-1
temp_audio_index=-1
temp_image_index=-1
for task_id in inference_config:
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
@@ -95,32 +99,20 @@ def main(args):
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
print("LEN..........")
print(len(input_img_list))
############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path)
print(len(whisper_feature))
print("Whisper feature length........")
print(whisper_feature[0].shape)
# print(whisper_feature)
for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements
# Combine two consecutive chunks
# pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]])
concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0)
# Save the pair to a .npy file
print("Pair shape",concatenated_chunks.shape)
np.save(f'data/audios/{folder_name}/{__//2}.npy', concatenated_chunks)
np.save(f'data/audios/{folder_name}/{total_audio_index+(__//2)+1}.npy', concatenated_chunks)
temp_audio_index=(__//2)+total_audio_index+1
total_audio_index=temp_audio_index
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
print(len(whisper_chunks))
# whisper_i=0
# for chunk in whisper_chunks:
# # print("CHUNMK SHAPE...........")
# # print(chunk.shape)
# np.save(f'data/audios/{folder_name}/{str(whisper_i)}.npy', chunk)
# whisper_i+=1
############################################## preprocess input image ##############################################
gc.collect()
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates")
with open(crop_coord_save_path,'rb') as f:
@@ -131,8 +123,7 @@ def main(args):
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
print(len(frame_list))
i = 0
input_latent_list = []
@@ -151,9 +142,7 @@ def main(args):
if ((y2-y1)<=0) or ((x2-x1)<=0):
continue
crop_frame = frame[y1:y2, x1:x2]
print("crop sizes",bbox)
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
latents = vae.get_latents_for_unet(crop_frame)
crop_data.append(crop_frame)
input_latent_list.append(latents)
@@ -165,20 +154,18 @@ def main(args):
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
crop_data = crop_data + crop_data[::-1]
############################################## inference batch by batch ##############################################
print("start inference")
print(len(input_latent_list_cycle),len(whisper_chunks))
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,crop_data,batch_size)
for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
print("BATCH LEN..............")
print(len(whisper_batch),len(crop_batch))
crop_index=0
for image,audio in zip(crop_batch,whisper_batch):
cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index)}.png",image)
cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index+total_image_index+1)}.png",image)
crop_index+=1
temp_image_index=i+crop_index+total_image_index+1
# np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio)
print(folder_name)
total_image_index=temp_image_index
gc.collect()