diff --git a/hubconf.py b/hubconf.py index e38b9dc..1e15b44 100644 --- a/hubconf.py +++ b/hubconf.py @@ -23,11 +23,14 @@ def versiontuple(v): return tuple(version_list) -def silero_vad(onnx=False, force_onnx_cpu=False): +def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16): """Silero Voice Activity Detector Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples """ + available_ops = [15, 16] + if onnx and opset_version not in available_ops: + raise Exception(f'Available ONNX opset_version: {available_ops}') if not onnx: installed_version = torch.__version__ @@ -37,7 +40,11 @@ def silero_vad(onnx=False, force_onnx_cpu=False): model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data') if onnx: - model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu) + if opset_version == 16: + model_name = 'silero_vad.onnx' + else: + model_name = f'silero_vad_16k_op{opset_version}.onnx' + model = OnnxWrapper(os.path.join(model_dir, model_name), force_onnx_cpu) else: model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit')) utils = (get_speech_timestamps, diff --git a/src/silero_vad/data/silero_vad_16k_op15.onnx b/src/silero_vad/data/silero_vad_16k_op15.onnx new file mode 100644 index 0000000..0607ae8 Binary files /dev/null and b/src/silero_vad/data/silero_vad_16k_op15.onnx differ diff --git a/src/silero_vad/model.py b/src/silero_vad/model.py index 165e9c6..40792ef 100644 --- a/src/silero_vad/model.py +++ b/src/silero_vad/model.py @@ -2,10 +2,21 @@ from .utils_vad import init_jit_model, OnnxWrapper import torch torch.set_num_threads(1) -def load_silero_vad(onnx=False): - model_name = 'silero_vad.onnx' if onnx else 'silero_vad.jit' + +def load_silero_vad(onnx=False, opset_version=16): + available_ops = [15, 16] + if onnx and opset_version not in available_ops: + raise Exception(f'Available ONNX opset_version: {available_ops}') + + if onnx: + if opset_version == 16: + model_name = 'silero_vad.onnx' + else: + model_name = f'silero_vad_16k_op{opset_version}.onnx' + else: + model_name = 'silero_vad.jit' package_path = "silero_vad.data" - + try: import importlib_resources as impresources model_file_path = str(impresources.files(package_path).joinpath(model_name)) @@ -21,5 +32,5 @@ def load_silero_vad(onnx=False): model = OnnxWrapper(model_file_path, force_onnx_cpu=True) else: model = init_jit_model(model_file_path) - + return model diff --git a/src/silero_vad/utils_vad.py b/src/silero_vad/utils_vad.py index d95487d..9867c0b 100644 --- a/src/silero_vad/utils_vad.py +++ b/src/silero_vad/utils_vad.py @@ -23,7 +23,11 @@ class OnnxWrapper(): self.session = onnxruntime.InferenceSession(path, sess_options=opts) self.reset_states() - self.sample_rates = [8000, 16000] + if '16k' in path: + warnings.warn('This model support only 16000 sampling rate!') + self.sample_rates = [16000] + else: + self.sample_rates = [8000, 16000] def _validate_input(self, x, sr: int): if x.dim() == 1: