mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-04 17:39:24 +08:00
run with cpu
This commit is contained in:
@@ -65,7 +65,7 @@ class InferBase:
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
self.logger.info(f"Num params: {n_parameters}")
|
||||
model = create_ddp_model(
|
||||
model.cuda(),
|
||||
model,
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=self.cfg.find_unused_parameters,
|
||||
)
|
||||
@@ -117,9 +117,9 @@ class Audio2ExpressionInfer(InferBase):
|
||||
with torch.no_grad():
|
||||
input_dict = {}
|
||||
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
|
||||
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[None,...]
|
||||
self.cfg.model.backbone.num_identity_classes)[None,...]
|
||||
speech_array, ssr = librosa.load(self.cfg.audio_input, sr=16000)
|
||||
input_dict['input_audio_array'] = torch.FloatTensor(speech_array).cuda(non_blocking=True)[None,...]
|
||||
input_dict['input_audio_array'] = torch.FloatTensor(speech_array)[None,...]
|
||||
|
||||
end = time.time()
|
||||
output_dict = self.model(input_dict)
|
||||
@@ -198,9 +198,9 @@ class Audio2ExpressionInfer(InferBase):
|
||||
try:
|
||||
input_dict = {}
|
||||
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
|
||||
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[
|
||||
self.cfg.model.backbone.num_identity_classes)[
|
||||
None, ...]
|
||||
input_dict['input_audio_array'] = torch.FloatTensor(input_audio).cuda(non_blocking=True)[None, ...]
|
||||
input_dict['input_audio_array'] = torch.FloatTensor(input_audio)[None, ...]
|
||||
output_dict = self.model(input_dict)
|
||||
out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()[start_frame:, :]
|
||||
except:
|
||||
|
||||
Reference in New Issue
Block a user