mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-04 17:39:19 +08:00
Update lite_avatar.py
This commit is contained in:
@@ -41,18 +41,21 @@ class liteAvatar(object):
|
||||
num_threads=1,
|
||||
use_bg_as_idle=False,
|
||||
fps=30,
|
||||
generate_offline=False):
|
||||
generate_offline=False,
|
||||
use_gpu=False):
|
||||
|
||||
logger.info('liteAvatar init start...')
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.fps = fps
|
||||
self.use_bg_as_idle = use_bg_as_idle
|
||||
self.use_gpu = use_gpu
|
||||
self.device = "cuda" if use_gpu else "cpu"
|
||||
|
||||
s = time.time()
|
||||
from audio2mouth_cpu import Audio2Mouth
|
||||
|
||||
self.audio2mouth = Audio2Mouth()
|
||||
self.audio2mouth = Audio2Mouth(use_gpu)
|
||||
logger.info(f'audio2mouth init over in {time.time() - s}s')
|
||||
|
||||
self.p_list = [str(ii) for ii in range(32)]
|
||||
@@ -82,8 +85,8 @@ class liteAvatar(object):
|
||||
def load_dynamic_model(self, data_dir):
|
||||
logger.info("start to load dynamic data")
|
||||
start_time = time.time()
|
||||
self.encoder = torch.jit.load(f'{data_dir}/net_encode.pt')
|
||||
self.generator = torch.jit.load(f'{data_dir}/net_decode.pt')
|
||||
self.encoder = torch.jit.load(f'{data_dir}/net_encode.pt').to(self.device)
|
||||
self.generator = torch.jit.load(f'{data_dir}/net_decode.pt').to(self.device)
|
||||
|
||||
self.load_data_sync(data_dir=data_dir, bg_frame_cnt=150)
|
||||
self.load_data(data_dir=data_dir, bg_frame_cnt=150)
|
||||
@@ -137,7 +140,7 @@ class liteAvatar(object):
|
||||
image = cv2.cvtColor(cv2.imread(img_file_path)[:,:,0:3],cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_LINEAR)
|
||||
ref_img = self.image_transforms(np.uint8(image))
|
||||
encoder_input = ref_img.unsqueeze(0).float()
|
||||
encoder_input = ref_img.unsqueeze(0).float().to(self.device)
|
||||
x = self.encoder(encoder_input)
|
||||
self.ref_img_list.append(x)
|
||||
|
||||
@@ -179,8 +182,8 @@ class liteAvatar(object):
|
||||
param_val.append(val)
|
||||
param_val = np.asarray(param_val)
|
||||
|
||||
source_img = self.generator(self.ref_img_list[bg_frame_id], torch.from_numpy(param_val).unsqueeze(0).float())
|
||||
source_img = source_img.detach()
|
||||
source_img = self.generator(self.ref_img_list[bg_frame_id], torch.from_numpy(param_val).unsqueeze(0).float().to(self.device))
|
||||
source_img = source_img.detach().to("cpu")
|
||||
|
||||
return source_img
|
||||
|
||||
@@ -362,4 +365,4 @@ if __name__ == '__main__':
|
||||
lite_avatar = liteAvatar(data_dir=args.data_dir, num_threads=1, generate_offline=True)
|
||||
|
||||
lite_avatar.handle(audio_file, tmp_frame_dir)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user