mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 09:29:20 +08:00
<enhance>: support using float16 in inference to speed up
This commit is contained in:
@@ -16,12 +16,18 @@ from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
|
||||
# load model weights
|
||||
audio_processor,vae,unet,pe = load_all_model()
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
global pe
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
for task_id in inference_config:
|
||||
@@ -96,10 +102,11 @@ def main(args):
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
res_frame_list = []
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
|
||||
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
|
||||
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
|
||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||
dtype=unet.model.dtype) # torch, B, 5*N,384
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
@@ -145,7 +152,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--use_saved_coord",
|
||||
action="store_true",
|
||||
help='use saved coordinate to save time')
|
||||
|
||||
parser.add_argument("--use_float16",
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -22,10 +22,12 @@ import queue
|
||||
import time
|
||||
|
||||
# load model weights
|
||||
audio_processor,vae,unet,pe = load_all_model()
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
|
||||
cap = cv2.VideoCapture(vid_path)
|
||||
@@ -99,6 +101,10 @@ class Avatar:
|
||||
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
|
||||
self.prepare_material()
|
||||
else:
|
||||
if not os.path.exists(self.avatar_path):
|
||||
print(f"{self.avatar_id} does not exist, you should set preparation to True")
|
||||
sys.exit()
|
||||
|
||||
with open(self.avatar_info_path, "r") as f:
|
||||
avatar_info = json.load(f)
|
||||
|
||||
@@ -182,7 +188,10 @@ class Avatar:
|
||||
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
|
||||
#
|
||||
|
||||
def process_frames(self, res_frame_queue,video_len):
|
||||
def process_frames(self,
|
||||
res_frame_queue,
|
||||
video_len,
|
||||
skip_save_images):
|
||||
print(video_len)
|
||||
while True:
|
||||
if self.idx>=video_len-1:
|
||||
@@ -205,44 +214,62 @@ class Avatar:
|
||||
#combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
|
||||
|
||||
fps = 1/(time.time()-start+1e-6)
|
||||
print(f"Displaying the {self.idx}-th frame with FPS: {fps:.2f}")
|
||||
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
|
||||
if skip_save_images is False:
|
||||
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
|
||||
self.idx = self.idx + 1
|
||||
|
||||
def inference(self, audio_path, out_vid_name, fps):
|
||||
def inference(self,
|
||||
audio_path,
|
||||
out_vid_name,
|
||||
fps,
|
||||
skip_save_images):
|
||||
os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
|
||||
print("start inference")
|
||||
############################################## extract audio feature ##############################################
|
||||
start_time = time.time()
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
|
||||
############################################## inference batch by batch ##############################################
|
||||
video_num = len(whisper_chunks)
|
||||
print("start inference")
|
||||
res_frame_queue = queue.Queue()
|
||||
self.idx = 0
|
||||
# # Create a sub-thread and start it
|
||||
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue,video_num))
|
||||
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
|
||||
process_thread.start()
|
||||
start_time = time.time()
|
||||
gen = datagen(whisper_chunks,self.input_latent_list_cycle, self.batch_size)
|
||||
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
|
||||
|
||||
gen = datagen(whisper_chunks,
|
||||
self.input_latent_list_cycle,
|
||||
self.batch_size)
|
||||
start_time = time.time()
|
||||
res_frame_list = []
|
||||
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
|
||||
start_time = time.time()
|
||||
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
|
||||
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
|
||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||
dtype=unet.model.dtype)
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch,
|
||||
timesteps,
|
||||
encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_queue.put(res_frame)
|
||||
# Close the queue and sub-thread after all tasks are completed
|
||||
process_thread.join()
|
||||
|
||||
if out_vid_name is not None:
|
||||
if args.skip_save_images is True:
|
||||
print('Total process time of {} frames without saving images = {}s'.format(
|
||||
video_num,
|
||||
time.time()-start_time))
|
||||
else:
|
||||
print('Total process time of {} frames including saving images = {}s'.format(
|
||||
video_num,
|
||||
time.time()-start_time))
|
||||
|
||||
if out_vid_name is not None and args.skip_save_images is False:
|
||||
# optional
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
|
||||
print(cmd_img2video)
|
||||
@@ -256,20 +283,31 @@ class Avatar:
|
||||
os.remove(f"{self.avatar_path}/temp.mp4")
|
||||
shutil.rmtree(f"{self.avatar_path}/tmp")
|
||||
print(f"result is save to {output_vid}")
|
||||
print("\n")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
'''
|
||||
This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
|
||||
'''
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml")
|
||||
parser.add_argument("--fps", type=int, default=25)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--inference_config",
|
||||
type=str,
|
||||
default="configs/inference/realtime.yaml",
|
||||
)
|
||||
parser.add_argument("--fps",
|
||||
type=int,
|
||||
default=25,
|
||||
)
|
||||
parser.add_argument("--batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument("--skip_save_images",
|
||||
action="store_true",
|
||||
help="Whether skip saving images for better generation speed calculation",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -291,5 +329,7 @@ if __name__ == "__main__":
|
||||
audio_clips = inference_config[avatar_id]["audio_clips"]
|
||||
for audio_num, audio_path in audio_clips.items():
|
||||
print("Inferring using:",audio_path)
|
||||
avatar.inference(audio_path, audio_num, args.fps)
|
||||
|
||||
avatar.inference(audio_path,
|
||||
audio_num,
|
||||
args.fps,
|
||||
args.skip_save_images)
|
||||
|
||||
Reference in New Issue
Block a user