diff --git a/hubconf.py b/hubconf.py index e38b9dc..d10073f 100644 --- a/hubconf.py +++ b/hubconf.py @@ -23,11 +23,14 @@ def versiontuple(v): return tuple(version_list) -def silero_vad(onnx=False, force_onnx_cpu=False): +def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16): """Silero Voice Activity Detector Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples """ + available_ops = [13, 14, 15, 16] + if onnx and opset_version not in available_ops: + raise Exception(f'Available ONNX opset_version: {available_ops}') if not onnx: installed_version = torch.__version__ @@ -37,7 +40,11 @@ def silero_vad(onnx=False, force_onnx_cpu=False): model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data') if onnx: - model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu) + if opset_version == 16: + model_name = 'silero_vad.onnx' + else: + model_name = f'silero_vad_op{opset_version}.onnx' + model = OnnxWrapper(os.path.join(model_dir, model_name), force_onnx_cpu) else: model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit')) utils = (get_speech_timestamps, diff --git a/src/silero_vad/data/silero_vad_op13.onnx b/src/silero_vad/data/silero_vad_op13.onnx new file mode 100644 index 0000000..3e33527 Binary files /dev/null and b/src/silero_vad/data/silero_vad_op13.onnx differ diff --git a/src/silero_vad/data/silero_vad_op14.onnx b/src/silero_vad/data/silero_vad_op14.onnx new file mode 100644 index 0000000..b3e3a90 Binary files /dev/null and b/src/silero_vad/data/silero_vad_op14.onnx differ diff --git a/src/silero_vad/data/silero_vad_op15.onnx b/src/silero_vad/data/silero_vad_op15.onnx new file mode 100644 index 0000000..b3e3a90 Binary files /dev/null and b/src/silero_vad/data/silero_vad_op15.onnx differ diff --git a/src/silero_vad/model.py b/src/silero_vad/model.py index 165e9c6..067c7f7 100644 --- a/src/silero_vad/model.py +++ b/src/silero_vad/model.py @@ -2,10 +2,21 @@ from .utils_vad import init_jit_model, OnnxWrapper import torch torch.set_num_threads(1) -def load_silero_vad(onnx=False): - model_name = 'silero_vad.onnx' if onnx else 'silero_vad.jit' + +def load_silero_vad(onnx=False, opset_version=16): + available_ops = [13, 14, 15, 16] + if onnx and opset_version not in available_ops: + raise Exception(f'Available ONNX opset_version: {available_ops}') + + if onnx: + if opset_version == 16: + model_name = 'silero_vad.onnx' + else: + model_name = f'silero_vad_op{opset_version}.onnx' + else: + model_name = 'silero_vad.jit' package_path = "silero_vad.data" - + try: import importlib_resources as impresources model_file_path = str(impresources.files(package_path).joinpath(model_name)) @@ -21,5 +32,5 @@ def load_silero_vad(onnx=False): model = OnnxWrapper(model_file_path, force_onnx_cpu=True) else: model = init_jit_model(model_file_path) - + return model