mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
add just 16k model
This commit is contained in:
@@ -28,7 +28,7 @@ def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16):
|
||||
Returns a model with a set of utils
|
||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
||||
"""
|
||||
available_ops = [13, 14, 15, 16]
|
||||
available_ops = [15, 16]
|
||||
if onnx and opset_version not in available_ops:
|
||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||
|
||||
@@ -43,7 +43,7 @@ def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16):
|
||||
if opset_version == 16:
|
||||
model_name = 'silero_vad.onnx'
|
||||
else:
|
||||
model_name = f'silero_vad_op{opset_version}.onnx'
|
||||
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||
model = OnnxWrapper(os.path.join(model_dir, model_name), force_onnx_cpu)
|
||||
else:
|
||||
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -4,7 +4,7 @@ torch.set_num_threads(1)
|
||||
|
||||
|
||||
def load_silero_vad(onnx=False, opset_version=16):
|
||||
available_ops = [13, 14, 15, 16]
|
||||
available_ops = [15, 16]
|
||||
if onnx and opset_version not in available_ops:
|
||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||
|
||||
@@ -12,7 +12,7 @@ def load_silero_vad(onnx=False, opset_version=16):
|
||||
if opset_version == 16:
|
||||
model_name = 'silero_vad.onnx'
|
||||
else:
|
||||
model_name = f'silero_vad_op{opset_version}.onnx'
|
||||
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||
else:
|
||||
model_name = 'silero_vad.jit'
|
||||
package_path = "silero_vad.data"
|
||||
|
||||
@@ -23,7 +23,11 @@ class OnnxWrapper():
|
||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||
|
||||
self.reset_states()
|
||||
self.sample_rates = [8000, 16000]
|
||||
if '16k' in path:
|
||||
warnings.warn('This model support only 16000 sampling rate!')
|
||||
self.sample_rates = [16000]
|
||||
else:
|
||||
self.sample_rates = [8000, 16000]
|
||||
|
||||
def _validate_input(self, x, sr: int):
|
||||
if x.dim() == 1:
|
||||
|
||||
Reference in New Issue
Block a user