mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-05 09:59:21 +08:00
run with cpu
This commit is contained in:
@@ -65,7 +65,7 @@ class InferBase:
|
|||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
self.logger.info(f"Num params: {n_parameters}")
|
self.logger.info(f"Num params: {n_parameters}")
|
||||||
model = create_ddp_model(
|
model = create_ddp_model(
|
||||||
model.cuda(),
|
model,
|
||||||
broadcast_buffers=False,
|
broadcast_buffers=False,
|
||||||
find_unused_parameters=self.cfg.find_unused_parameters,
|
find_unused_parameters=self.cfg.find_unused_parameters,
|
||||||
)
|
)
|
||||||
@@ -117,9 +117,9 @@ class Audio2ExpressionInfer(InferBase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
|
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
|
||||||
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[None,...]
|
self.cfg.model.backbone.num_identity_classes)[None,...]
|
||||||
speech_array, ssr = librosa.load(self.cfg.audio_input, sr=16000)
|
speech_array, ssr = librosa.load(self.cfg.audio_input, sr=16000)
|
||||||
input_dict['input_audio_array'] = torch.FloatTensor(speech_array).cuda(non_blocking=True)[None,...]
|
input_dict['input_audio_array'] = torch.FloatTensor(speech_array)[None,...]
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
output_dict = self.model(input_dict)
|
output_dict = self.model(input_dict)
|
||||||
@@ -198,9 +198,9 @@ class Audio2ExpressionInfer(InferBase):
|
|||||||
try:
|
try:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
|
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
|
||||||
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[
|
self.cfg.model.backbone.num_identity_classes)[
|
||||||
None, ...]
|
None, ...]
|
||||||
input_dict['input_audio_array'] = torch.FloatTensor(input_audio).cuda(non_blocking=True)[None, ...]
|
input_dict['input_audio_array'] = torch.FloatTensor(input_audio)[None, ...]
|
||||||
output_dict = self.model(input_dict)
|
output_dict = self.model(input_dict)
|
||||||
out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()[start_frame:, :]
|
out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()[start_frame:, :]
|
||||||
except:
|
except:
|
||||||
|
|||||||
Reference in New Issue
Block a user