<enhance>: support using float16 in inference to speed up

This commit is contained in:
czk32611
2024-04-27 14:26:50 +08:00
parent 2c52de01b4
commit 865a68c60e
6 changed files with 103 additions and 51 deletions

View File

@@ -267,10 +267,8 @@ As a complete solution to virtual human generation, you are suggested to first a
Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
Note that in this script, the generation time is also limited by I/O (e.g. saving images).
```
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --batch_size 4
```
configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
@@ -280,17 +278,14 @@ configs/inference/realtime.yaml is the path to the real-time inference configura
Inferring using: data/audio/yongen.wav
```
1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100.
```
2%|██▍ | 3/141 [00:00<00:32, 4.30it/s] # inference process
Displaying the 6-th frame with FPS: 48.58 # display process
Displaying the 7-th frame with FPS: 48.74
Displaying the 8-th frame with FPS: 49.17
3%|███▎ | 4/141 [00:00<00:32, 4.21it/s]
```
1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
##### Note for Real-time inference
1. If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
1. In the previous script, the generation time is also limited by I/O (e.g. saving images). If you just want to test the generation speed without saving the images, you can run
```
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
```
# Acknowledgement
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).

View File

@@ -37,11 +37,11 @@ class UNet():
self.model = UNet2DConditionModel(**unet_config)
self.pe = PositionalEncoding(d_model=384)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(self.weights)
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(weights)
if use_float16:
self.model = self.model.half()
self.model.to(self.device)
if __name__ == "__main__":
unet = UNet()
unet = UNet()

View File

@@ -39,7 +39,10 @@ def get_video_fps(video_path):
video.release()
return fps
def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
def datagen(whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0):
whisper_batch, latent_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(vae_encode_latents)
@@ -48,14 +51,14 @@ def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
latent_batch.append(latent)
if len(latent_batch) >= batch_size:
whisper_batch = np.asarray(whisper_batch)
whisper_batch = np.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
whisper_batch, latent_batch = [], []
# the last batch may smaller than batch size
if len(latent_batch) > 0:
whisper_batch = np.asarray(whisper_batch)
whisper_batch = np.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
yield whisper_batch, latent_batch

View File

@@ -13,7 +13,11 @@ class Audio2Feature():
self.whisper_model_type = whisper_model_type
self.model = load_model(model_path) #
def get_sliced_feature(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
def get_sliced_feature(self,
feature_array,
vid_idx,
audio_feat_length=[2,2],
fps=25):
"""
Get sliced features based on a given index
:param feature_array:

View File

@@ -16,12 +16,18 @@ from musetalk.utils.utils import load_all_model
import shutil
# load model weights
audio_processor,vae,unet,pe = load_all_model()
audio_processor, vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
@torch.no_grad()
def main(args):
global pe
if args.use_float16 is True:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
for task_id in inference_config:
@@ -96,10 +102,11 @@ def main(args):
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype) # torch, B, 5*N,384
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
@@ -145,7 +152,10 @@ if __name__ == "__main__":
parser.add_argument("--use_saved_coord",
action="store_true",
help='use saved coordinate to save time')
parser.add_argument("--use_float16",
action="store_true",
help="Whether use float16 to speed up inference",
)
args = parser.parse_args()
main(args)

View File

@@ -22,10 +22,12 @@ import queue
import time
# load model weights
audio_processor,vae,unet,pe = load_all_model()
audio_processor, vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
cap = cv2.VideoCapture(vid_path)
@@ -99,6 +101,10 @@ class Avatar:
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
self.prepare_material()
else:
if not os.path.exists(self.avatar_path):
print(f"{self.avatar_id} does not exist, you should set preparation to True")
sys.exit()
with open(self.avatar_info_path, "r") as f:
avatar_info = json.load(f)
@@ -182,7 +188,10 @@ class Avatar:
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
#
def process_frames(self, res_frame_queue,video_len):
def process_frames(self,
res_frame_queue,
video_len,
skip_save_images):
print(video_len)
while True:
if self.idx>=video_len-1:
@@ -205,44 +214,62 @@ class Avatar:
#combine_frame = get_image(ori_frame,res_frame,bbox)
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
fps = 1/(time.time()-start+1e-6)
print(f"Displaying the {self.idx}-th frame with FPS: {fps:.2f}")
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
if skip_save_images is False:
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
self.idx = self.idx + 1
def inference(self, audio_path, out_vid_name, fps):
def inference(self,
audio_path,
out_vid_name,
fps,
skip_save_images):
os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
print("start inference")
############################################## extract audio feature ##############################################
start_time = time.time()
whisper_feature = audio_processor.audio2feat(audio_path)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
############################################## inference batch by batch ##############################################
video_num = len(whisper_chunks)
print("start inference")
res_frame_queue = queue.Queue()
self.idx = 0
# # Create a sub-thread and start it
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue,video_num))
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
process_thread.start()
start_time = time.time()
gen = datagen(whisper_chunks,self.input_latent_list_cycle, self.batch_size)
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
gen = datagen(whisper_chunks,
self.input_latent_list_cycle,
self.batch_size)
start_time = time.time()
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
start_time = time.time()
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_queue.put(res_frame)
# Close the queue and sub-thread after all tasks are completed
process_thread.join()
if out_vid_name is not None:
if args.skip_save_images is True:
print('Total process time of {} frames without saving images = {}s'.format(
video_num,
time.time()-start_time))
else:
print('Total process time of {} frames including saving images = {}s'.format(
video_num,
time.time()-start_time))
if out_vid_name is not None and args.skip_save_images is False:
# optional
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
print(cmd_img2video)
@@ -256,20 +283,31 @@ class Avatar:
os.remove(f"{self.avatar_path}/temp.mp4")
shutil.rmtree(f"{self.avatar_path}/tmp")
print(f"result is save to {output_vid}")
print("\n")
if __name__ == "__main__":
'''
This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
'''
parser = argparse.ArgumentParser()
parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml")
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--inference_config",
type=str,
default="configs/inference/realtime.yaml",
)
parser.add_argument("--fps",
type=int,
default=25,
)
parser.add_argument("--batch_size",
type=int,
default=4,
)
parser.add_argument("--skip_save_images",
action="store_true",
help="Whether skip saving images for better generation speed calculation",
)
args = parser.parse_args()
@@ -291,5 +329,7 @@ if __name__ == "__main__":
audio_clips = inference_config[avatar_id]["audio_clips"]
for audio_num, audio_path in audio_clips.items():
print("Inferring using:",audio_path)
avatar.inference(audio_path, audio_num, args.fps)
avatar.inference(audio_path,
audio_num,
args.fps,
args.skip_save_images)