Merge pull request #573 from snakers4/adamnsandle

Adamnsandle
This commit is contained in:
Alexander Veysov
2024-11-13 12:32:55 +03:00
committed by GitHub
4 changed files with 29 additions and 7 deletions

View File

@@ -23,11 +23,14 @@ def versiontuple(v):
return tuple(version_list) return tuple(version_list)
def silero_vad(onnx=False, force_onnx_cpu=False): def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16):
"""Silero Voice Activity Detector """Silero Voice Activity Detector
Returns a model with a set of utils Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples Please see https://github.com/snakers4/silero-vad for usage examples
""" """
available_ops = [15, 16]
if onnx and opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}')
if not onnx: if not onnx:
installed_version = torch.__version__ installed_version = torch.__version__
@@ -37,7 +40,11 @@ def silero_vad(onnx=False, force_onnx_cpu=False):
model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data') model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data')
if onnx: if onnx:
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu) if opset_version == 16:
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
model = OnnxWrapper(os.path.join(model_dir, model_name), force_onnx_cpu)
else: else:
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit')) model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
utils = (get_speech_timestamps, utils = (get_speech_timestamps,

Binary file not shown.

View File

@@ -2,8 +2,19 @@ from .utils_vad import init_jit_model, OnnxWrapper
import torch import torch
torch.set_num_threads(1) torch.set_num_threads(1)
def load_silero_vad(onnx=False):
model_name = 'silero_vad.onnx' if onnx else 'silero_vad.jit' def load_silero_vad(onnx=False, opset_version=16):
available_ops = [15, 16]
if onnx and opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}')
if onnx:
if opset_version == 16:
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
else:
model_name = 'silero_vad.jit'
package_path = "silero_vad.data" package_path = "silero_vad.data"
try: try:

View File

@@ -23,6 +23,10 @@ class OnnxWrapper():
self.session = onnxruntime.InferenceSession(path, sess_options=opts) self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.reset_states() self.reset_states()
if '16k' in path:
warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000]
else:
self.sample_rates = [8000, 16000] self.sample_rates = [8000, 16000]
def _validate_input(self, x, sr: int): def _validate_input(self, x, sr: int):