From 48d267400667e4c32f115dad82773f9ce49e5b73 Mon Sep 17 00:00:00 2001 From: sudowind Date: Fri, 11 Apr 2025 09:36:53 +0800 Subject: [PATCH] Update lite_avatar.py --- lite_avatar.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/lite_avatar.py b/lite_avatar.py index 647df3a..ead894e 100644 --- a/lite_avatar.py +++ b/lite_avatar.py @@ -41,18 +41,21 @@ class liteAvatar(object): num_threads=1, use_bg_as_idle=False, fps=30, - generate_offline=False): + generate_offline=False, + use_gpu=False): logger.info('liteAvatar init start...') self.data_dir = data_dir self.fps = fps self.use_bg_as_idle = use_bg_as_idle + self.use_gpu = use_gpu + self.device = "cuda" if use_gpu else "cpu" s = time.time() from audio2mouth_cpu import Audio2Mouth - self.audio2mouth = Audio2Mouth() + self.audio2mouth = Audio2Mouth(use_gpu) logger.info(f'audio2mouth init over in {time.time() - s}s') self.p_list = [str(ii) for ii in range(32)] @@ -82,8 +85,8 @@ class liteAvatar(object): def load_dynamic_model(self, data_dir): logger.info("start to load dynamic data") start_time = time.time() - self.encoder = torch.jit.load(f'{data_dir}/net_encode.pt') - self.generator = torch.jit.load(f'{data_dir}/net_decode.pt') + self.encoder = torch.jit.load(f'{data_dir}/net_encode.pt').to(self.device) + self.generator = torch.jit.load(f'{data_dir}/net_decode.pt').to(self.device) self.load_data_sync(data_dir=data_dir, bg_frame_cnt=150) self.load_data(data_dir=data_dir, bg_frame_cnt=150) @@ -137,7 +140,7 @@ class liteAvatar(object): image = cv2.cvtColor(cv2.imread(img_file_path)[:,:,0:3],cv2.COLOR_BGR2RGB) image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_LINEAR) ref_img = self.image_transforms(np.uint8(image)) - encoder_input = ref_img.unsqueeze(0).float() + encoder_input = ref_img.unsqueeze(0).float().to(self.device) x = self.encoder(encoder_input) self.ref_img_list.append(x) @@ -179,8 +182,8 @@ class liteAvatar(object): param_val.append(val) param_val = np.asarray(param_val) - source_img = self.generator(self.ref_img_list[bg_frame_id], torch.from_numpy(param_val).unsqueeze(0).float()) - source_img = source_img.detach() + source_img = self.generator(self.ref_img_list[bg_frame_id], torch.from_numpy(param_val).unsqueeze(0).float().to(self.device)) + source_img = source_img.detach().to("cpu") return source_img @@ -362,4 +365,4 @@ if __name__ == '__main__': lite_avatar = liteAvatar(data_dir=args.data_dir, num_threads=1, generate_offline=True) lite_avatar.handle(audio_file, tmp_frame_dir) - \ No newline at end of file +