modified dataloader.py and inference.py for training and inference

This commit is contained in:
Shounak Banerjee
2024-06-03 11:09:12 +00:00
parent 7254ca6306
commit b4a592d7f3
6 changed files with 106 additions and 58 deletions

View File

@@ -15,14 +15,27 @@ from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model
import shutil
from accelerate import Accelerator
# load model weights
audio_processor, vae, unet, pe = load_all_model()
accelerator = Accelerator(
mixed_precision="fp16",
)
unet = accelerator.prepare(
unet,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
@torch.no_grad()
def main(args):
global pe
if not (args.unet_checkpoint == None):
print("unet ckpt loaded")
accelerator.load_state(args.unet_checkpoint)
if args.use_float16 is True:
pe = pe.half()
vae.vae = vae.vae.half()
@@ -63,8 +76,6 @@ def main(args):
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
#print(input_img_list)
############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
@@ -79,24 +90,27 @@ def main(args):
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
i = 0
input_latent_list = []
crop_i=0
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
crop_i+=1
# to smooth the first and the last frame
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
############################################## inference batch by batch ##############################################
print("start inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
@@ -107,7 +121,6 @@ def main(args):
dtype=unet.model.dtype) # torch, B, 5*N,384
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
@@ -122,22 +135,29 @@ def main(args):
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
# print(bbox)
continue
combine_frame = get_image(ori_frame,res_frame,bbox)
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
print(cmd_img2video)
os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
print(cmd_combine_audio)
os.system(cmd_combine_audio)
os.remove("temp.mp4")
shutil.rmtree(result_img_save_path)
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
os.system(cmd_img2video)
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
# print(cmd_combine_audio)
# os.system(cmd_combine_audio)
# shutil.rmtree(result_img_save_path)
print(f"result is save to {output_vid_name}")
if __name__ == "__main__":
@@ -156,6 +176,7 @@ if __name__ == "__main__":
action="store_true",
help="Whether use float16 to speed up inference",
)
parser.add_argument("--unet_checkpoint", type=str, default=None)
args = parser.parse_args()
main(args)
main(args)