fx dtype bug

This commit is contained in:
adamnsandle
2024-07-01 09:00:59 +00:00
parent 89e66a3474
commit 902cfc9248
2 changed files with 1 additions and 2 deletions

View File

@@ -37,7 +37,6 @@ def silero_vad(onnx=False, force_onnx_cpu=False):
raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)')
model_dir = os.path.join(os.path.dirname(__file__), 'files')
print(model_dir, os.path.dirname(__file__))
if onnx:
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
else:

View File

@@ -72,7 +72,7 @@ class OnnxWrapper():
x = torch.cat([self._context, x], dim=1)
if sr in [8000, 16000]:
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr)}
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
ort_outs = self.session.run(None, ort_inputs)
out, state = ort_outs
self._state = torch.from_numpy(state)