From 48d267400667e4c32f115dad82773f9ce49e5b73 Mon Sep 17 00:00:00 2001 From: sudowind Date: Fri, 11 Apr 2025 09:36:53 +0800 Subject: [PATCH 1/3] Update lite_avatar.py --- lite_avatar.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/lite_avatar.py b/lite_avatar.py index 647df3a..ead894e 100644 --- a/lite_avatar.py +++ b/lite_avatar.py @@ -41,18 +41,21 @@ class liteAvatar(object): num_threads=1, use_bg_as_idle=False, fps=30, - generate_offline=False): + generate_offline=False, + use_gpu=False): logger.info('liteAvatar init start...') self.data_dir = data_dir self.fps = fps self.use_bg_as_idle = use_bg_as_idle + self.use_gpu = use_gpu + self.device = "cuda" if use_gpu else "cpu" s = time.time() from audio2mouth_cpu import Audio2Mouth - self.audio2mouth = Audio2Mouth() + self.audio2mouth = Audio2Mouth(use_gpu) logger.info(f'audio2mouth init over in {time.time() - s}s') self.p_list = [str(ii) for ii in range(32)] @@ -82,8 +85,8 @@ class liteAvatar(object): def load_dynamic_model(self, data_dir): logger.info("start to load dynamic data") start_time = time.time() - self.encoder = torch.jit.load(f'{data_dir}/net_encode.pt') - self.generator = torch.jit.load(f'{data_dir}/net_decode.pt') + self.encoder = torch.jit.load(f'{data_dir}/net_encode.pt').to(self.device) + self.generator = torch.jit.load(f'{data_dir}/net_decode.pt').to(self.device) self.load_data_sync(data_dir=data_dir, bg_frame_cnt=150) self.load_data(data_dir=data_dir, bg_frame_cnt=150) @@ -137,7 +140,7 @@ class liteAvatar(object): image = cv2.cvtColor(cv2.imread(img_file_path)[:,:,0:3],cv2.COLOR_BGR2RGB) image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_LINEAR) ref_img = self.image_transforms(np.uint8(image)) - encoder_input = ref_img.unsqueeze(0).float() + encoder_input = ref_img.unsqueeze(0).float().to(self.device) x = self.encoder(encoder_input) self.ref_img_list.append(x) @@ -179,8 +182,8 @@ class liteAvatar(object): param_val.append(val) param_val = np.asarray(param_val) - source_img = self.generator(self.ref_img_list[bg_frame_id], torch.from_numpy(param_val).unsqueeze(0).float()) - source_img = source_img.detach() + source_img = self.generator(self.ref_img_list[bg_frame_id], torch.from_numpy(param_val).unsqueeze(0).float().to(self.device)) + source_img = source_img.detach().to("cpu") return source_img @@ -362,4 +365,4 @@ if __name__ == '__main__': lite_avatar = liteAvatar(data_dir=args.data_dir, num_threads=1, generate_offline=True) lite_avatar.handle(audio_file, tmp_frame_dir) - \ No newline at end of file + From 7ed7dc549dcadb3ca8e527e017b9cb35a03b63f5 Mon Sep 17 00:00:00 2001 From: sudowind Date: Fri, 11 Apr 2025 09:38:26 +0800 Subject: [PATCH 2/3] Update audio2mouth_cpu.py --- audio2mouth_cpu.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/audio2mouth_cpu.py b/audio2mouth_cpu.py index c5ce12d..bec0e47 100644 --- a/audio2mouth_cpu.py +++ b/audio2mouth_cpu.py @@ -8,12 +8,13 @@ from extract_paraformer_feature import extract_para_feature from scipy import signal class Audio2Mouth(object): - def __init__(self): + def __init__(self. use_gpu): self.p_list = [str(ii) for ii in range(32)] model_path = './weights/model_1.onnx' - self.audio2mouth_model=onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider']) + provider = "CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider" + self.audio2mouth_model=onnxruntime.InferenceSession(model_path, providers=[provider]) self.w = np.array([1.0]).astype(np.float32) self.sp = np.array([2]).astype(np.int64) From b2cd07cfd9c091167c003a6377ea846fab2b8b88 Mon Sep 17 00:00:00 2001 From: sudowind Date: Fri, 11 Apr 2025 09:41:53 +0800 Subject: [PATCH 3/3] Update audio2mouth_cpu.py --- audio2mouth_cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/audio2mouth_cpu.py b/audio2mouth_cpu.py index bec0e47..3f459be 100644 --- a/audio2mouth_cpu.py +++ b/audio2mouth_cpu.py @@ -8,7 +8,7 @@ from extract_paraformer_feature import extract_para_feature from scipy import signal class Audio2Mouth(object): - def __init__(self. use_gpu): + def __init__(self, use_gpu): self.p_list = [str(ii) for ii in range(32)]