From 865a68c60e6275026eb20c7ce94fdad56fb2172a Mon Sep 17 00:00:00 2001 From: czk32611 Date: Sat, 27 Apr 2024 14:26:50 +0800 Subject: [PATCH] : support using float16 in inference to speed up --- README.md | 19 +++---- musetalk/models/unet.py | 6 +- musetalk/utils/utils.py | 11 ++-- musetalk/whisper/audio2feature.py | 6 +- scripts/inference.py | 20 +++++-- scripts/realtime_inference.py | 92 ++++++++++++++++++++++--------- 6 files changed, 103 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index 7e4073c..e824b50 100644 --- a/README.md +++ b/README.md @@ -267,10 +267,8 @@ As a complete solution to virtual human generation, you are suggested to first a Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time. -Note that in this script, the generation time is also limited by I/O (e.g. saving images). - ``` -python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml +python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --batch_size 4 ``` configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`. @@ -280,17 +278,14 @@ configs/inference/realtime.yaml is the path to the real-time inference configura Inferring using: data/audio/yongen.wav ``` 1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100. - ``` - 2%|██▍ | 3/141 [00:00<00:32, 4.30it/s] # inference process - Displaying the 6-th frame with FPS: 48.58 # display process - Displaying the 7-th frame with FPS: 48.74 - Displaying the 8-th frame with FPS: 49.17 - 3%|███▎ | 4/141 [00:00<00:32, 4.21it/s] - ``` 1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar. -If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process. - +##### Note for Real-time inference +1. If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process. +1. In the previous script, the generation time is also limited by I/O (e.g. saving images). If you just want to test the generation speed without saving the images, you can run +``` +python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images +``` # Acknowledgement 1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch). diff --git a/musetalk/models/unet.py b/musetalk/models/unet.py index 8968657..2bcc2b0 100755 --- a/musetalk/models/unet.py +++ b/musetalk/models/unet.py @@ -37,11 +37,11 @@ class UNet(): self.model = UNet2DConditionModel(**unet_config) self.pe = PositionalEncoding(d_model=384) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device) - self.model.load_state_dict(self.weights) + weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device) + self.model.load_state_dict(weights) if use_float16: self.model = self.model.half() self.model.to(self.device) if __name__ == "__main__": - unet = UNet() \ No newline at end of file + unet = UNet() diff --git a/musetalk/utils/utils.py b/musetalk/utils/utils.py index 6b2f02f..caac0fb 100644 --- a/musetalk/utils/utils.py +++ b/musetalk/utils/utils.py @@ -39,7 +39,10 @@ def get_video_fps(video_path): video.release() return fps -def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0): +def datagen(whisper_chunks, + vae_encode_latents, + batch_size=8, + delay_frame=0): whisper_batch, latent_batch = [], [] for i, w in enumerate(whisper_chunks): idx = (i+delay_frame)%len(vae_encode_latents) @@ -48,14 +51,14 @@ def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0): latent_batch.append(latent) if len(latent_batch) >= batch_size: - whisper_batch = np.asarray(whisper_batch) + whisper_batch = np.stack(whisper_batch) latent_batch = torch.cat(latent_batch, dim=0) yield whisper_batch, latent_batch whisper_batch, latent_batch = [], [] # the last batch may smaller than batch size if len(latent_batch) > 0: - whisper_batch = np.asarray(whisper_batch) + whisper_batch = np.stack(whisper_batch) latent_batch = torch.cat(latent_batch, dim=0) - yield whisper_batch, latent_batch \ No newline at end of file + yield whisper_batch, latent_batch diff --git a/musetalk/whisper/audio2feature.py b/musetalk/whisper/audio2feature.py index 908db3d..2cfd3a9 100644 --- a/musetalk/whisper/audio2feature.py +++ b/musetalk/whisper/audio2feature.py @@ -13,7 +13,11 @@ class Audio2Feature(): self.whisper_model_type = whisper_model_type self.model = load_model(model_path) # - def get_sliced_feature(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25): + def get_sliced_feature(self, + feature_array, + vid_idx, + audio_feat_length=[2,2], + fps=25): """ Get sliced features based on a given index :param feature_array: diff --git a/scripts/inference.py b/scripts/inference.py index ad9de8d..fe2a234 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -16,12 +16,18 @@ from musetalk.utils.utils import load_all_model import shutil # load model weights -audio_processor,vae,unet,pe = load_all_model() +audio_processor, vae, unet, pe = load_all_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) @torch.no_grad() def main(args): + global pe + if args.use_float16 is True: + pe = pe.half() + vae.vae = vae.vae.half() + unet.model = unet.model.half() + inference_config = OmegaConf.load(args.inference_config) print(inference_config) for task_id in inference_config: @@ -96,10 +102,11 @@ def main(args): gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) res_frame_list = [] for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): - - tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch] - audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384 + audio_feature_batch = torch.from_numpy(whisper_batch) + audio_feature_batch = audio_feature_batch.to(device=unet.device, + dtype=unet.model.dtype) # torch, B, 5*N,384 audio_feature_batch = pe(audio_feature_batch) + latent_batch = latent_batch.to(dtype=unet.model.dtype) pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample recon = vae.decode_latents(pred_latents) @@ -145,7 +152,10 @@ if __name__ == "__main__": parser.add_argument("--use_saved_coord", action="store_true", help='use saved coordinate to save time') - + parser.add_argument("--use_float16", + action="store_true", + help="Whether use float16 to speed up inference", + ) args = parser.parse_args() main(args) diff --git a/scripts/realtime_inference.py b/scripts/realtime_inference.py index e9d1338..18bb856 100644 --- a/scripts/realtime_inference.py +++ b/scripts/realtime_inference.py @@ -22,10 +22,12 @@ import queue import time # load model weights -audio_processor,vae,unet,pe = load_all_model() +audio_processor, vae, unet, pe = load_all_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) - +pe = pe.half() +vae.vae = vae.vae.half() +unet.model = unet.model.half() def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000): cap = cv2.VideoCapture(vid_path) @@ -99,6 +101,10 @@ class Avatar: osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) self.prepare_material() else: + if not os.path.exists(self.avatar_path): + print(f"{self.avatar_id} does not exist, you should set preparation to True") + sys.exit() + with open(self.avatar_info_path, "r") as f: avatar_info = json.load(f) @@ -182,7 +188,10 @@ class Avatar: torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path)) # - def process_frames(self, res_frame_queue,video_len): + def process_frames(self, + res_frame_queue, + video_len, + skip_save_images): print(video_len) while True: if self.idx>=video_len-1: @@ -205,44 +214,62 @@ class Avatar: #combine_frame = get_image(ori_frame,res_frame,bbox) combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) - fps = 1/(time.time()-start+1e-6) - print(f"Displaying the {self.idx}-th frame with FPS: {fps:.2f}") - cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame) + if skip_save_images is False: + cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame) self.idx = self.idx + 1 - def inference(self, audio_path, out_vid_name, fps): + def inference(self, + audio_path, + out_vid_name, + fps, + skip_save_images): os.makedirs(self.avatar_path+'/tmp',exist_ok =True) + print("start inference") ############################################## extract audio feature ############################################## + start_time = time.time() whisper_feature = audio_processor.audio2feat(audio_path) whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) + print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms") ############################################## inference batch by batch ############################################## video_num = len(whisper_chunks) - print("start inference") res_frame_queue = queue.Queue() self.idx = 0 # # Create a sub-thread and start it - process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue,video_num)) + process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images)) process_thread.start() - start_time = time.time() - gen = datagen(whisper_chunks,self.input_latent_list_cycle, self.batch_size) - print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms") + + gen = datagen(whisper_chunks, + self.input_latent_list_cycle, + self.batch_size) start_time = time.time() res_frame_list = [] for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))): - start_time = time.time() - tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch] - audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384 + audio_feature_batch = torch.from_numpy(whisper_batch) + audio_feature_batch = audio_feature_batch.to(device=unet.device, + dtype=unet.model.dtype) audio_feature_batch = pe(audio_feature_batch) - - pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample + latent_batch = latent_batch.to(dtype=unet.model.dtype) + + pred_latents = unet.model(latent_batch, + timesteps, + encoder_hidden_states=audio_feature_batch).sample recon = vae.decode_latents(pred_latents) for res_frame in recon: res_frame_queue.put(res_frame) # Close the queue and sub-thread after all tasks are completed process_thread.join() - if out_vid_name is not None: + if args.skip_save_images is True: + print('Total process time of {} frames without saving images = {}s'.format( + video_num, + time.time()-start_time)) + else: + print('Total process time of {} frames including saving images = {}s'.format( + video_num, + time.time()-start_time)) + + if out_vid_name is not None and args.skip_save_images is False: # optional cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4" print(cmd_img2video) @@ -256,20 +283,31 @@ class Avatar: os.remove(f"{self.avatar_path}/temp.mp4") shutil.rmtree(f"{self.avatar_path}/tmp") print(f"result is save to {output_vid}") + print("\n") - - - if __name__ == "__main__": ''' This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time. ''' parser = argparse.ArgumentParser() - parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml") - parser.add_argument("--fps", type=int, default=25) - parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--inference_config", + type=str, + default="configs/inference/realtime.yaml", + ) + parser.add_argument("--fps", + type=int, + default=25, + ) + parser.add_argument("--batch_size", + type=int, + default=4, + ) + parser.add_argument("--skip_save_images", + action="store_true", + help="Whether skip saving images for better generation speed calculation", + ) args = parser.parse_args() @@ -291,5 +329,7 @@ if __name__ == "__main__": audio_clips = inference_config[avatar_id]["audio_clips"] for audio_num, audio_path in audio_clips.items(): print("Inferring using:",audio_path) - avatar.inference(audio_path, audio_num, args.fps) - + avatar.inference(audio_path, + audio_num, + args.fps, + args.skip_save_images)