Update lite_avatar.py

This commit is contained in:
sudowind
2025-04-11 09:36:53 +08:00
committed by GitHub
parent d2cb9dac2b
commit 48d2674006

View File

@@ -41,18 +41,21 @@ class liteAvatar(object):
num_threads=1, num_threads=1,
use_bg_as_idle=False, use_bg_as_idle=False,
fps=30, fps=30,
generate_offline=False): generate_offline=False,
use_gpu=False):
logger.info('liteAvatar init start...') logger.info('liteAvatar init start...')
self.data_dir = data_dir self.data_dir = data_dir
self.fps = fps self.fps = fps
self.use_bg_as_idle = use_bg_as_idle self.use_bg_as_idle = use_bg_as_idle
self.use_gpu = use_gpu
self.device = "cuda" if use_gpu else "cpu"
s = time.time() s = time.time()
from audio2mouth_cpu import Audio2Mouth from audio2mouth_cpu import Audio2Mouth
self.audio2mouth = Audio2Mouth() self.audio2mouth = Audio2Mouth(use_gpu)
logger.info(f'audio2mouth init over in {time.time() - s}s') logger.info(f'audio2mouth init over in {time.time() - s}s')
self.p_list = [str(ii) for ii in range(32)] self.p_list = [str(ii) for ii in range(32)]
@@ -82,8 +85,8 @@ class liteAvatar(object):
def load_dynamic_model(self, data_dir): def load_dynamic_model(self, data_dir):
logger.info("start to load dynamic data") logger.info("start to load dynamic data")
start_time = time.time() start_time = time.time()
self.encoder = torch.jit.load(f'{data_dir}/net_encode.pt') self.encoder = torch.jit.load(f'{data_dir}/net_encode.pt').to(self.device)
self.generator = torch.jit.load(f'{data_dir}/net_decode.pt') self.generator = torch.jit.load(f'{data_dir}/net_decode.pt').to(self.device)
self.load_data_sync(data_dir=data_dir, bg_frame_cnt=150) self.load_data_sync(data_dir=data_dir, bg_frame_cnt=150)
self.load_data(data_dir=data_dir, bg_frame_cnt=150) self.load_data(data_dir=data_dir, bg_frame_cnt=150)
@@ -137,7 +140,7 @@ class liteAvatar(object):
image = cv2.cvtColor(cv2.imread(img_file_path)[:,:,0:3],cv2.COLOR_BGR2RGB) image = cv2.cvtColor(cv2.imread(img_file_path)[:,:,0:3],cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_LINEAR) image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_LINEAR)
ref_img = self.image_transforms(np.uint8(image)) ref_img = self.image_transforms(np.uint8(image))
encoder_input = ref_img.unsqueeze(0).float() encoder_input = ref_img.unsqueeze(0).float().to(self.device)
x = self.encoder(encoder_input) x = self.encoder(encoder_input)
self.ref_img_list.append(x) self.ref_img_list.append(x)
@@ -179,8 +182,8 @@ class liteAvatar(object):
param_val.append(val) param_val.append(val)
param_val = np.asarray(param_val) param_val = np.asarray(param_val)
source_img = self.generator(self.ref_img_list[bg_frame_id], torch.from_numpy(param_val).unsqueeze(0).float()) source_img = self.generator(self.ref_img_list[bg_frame_id], torch.from_numpy(param_val).unsqueeze(0).float().to(self.device))
source_img = source_img.detach() source_img = source_img.detach().to("cpu")
return source_img return source_img