mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 09:29:22 +08:00
add just 16k model
This commit is contained in:
@@ -28,7 +28,7 @@ def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16):
|
|||||||
Returns a model with a set of utils
|
Returns a model with a set of utils
|
||||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
Please see https://github.com/snakers4/silero-vad for usage examples
|
||||||
"""
|
"""
|
||||||
available_ops = [13, 14, 15, 16]
|
available_ops = [15, 16]
|
||||||
if onnx and opset_version not in available_ops:
|
if onnx and opset_version not in available_ops:
|
||||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16):
|
|||||||
if opset_version == 16:
|
if opset_version == 16:
|
||||||
model_name = 'silero_vad.onnx'
|
model_name = 'silero_vad.onnx'
|
||||||
else:
|
else:
|
||||||
model_name = f'silero_vad_op{opset_version}.onnx'
|
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||||
model = OnnxWrapper(os.path.join(model_dir, model_name), force_onnx_cpu)
|
model = OnnxWrapper(os.path.join(model_dir, model_name), force_onnx_cpu)
|
||||||
else:
|
else:
|
||||||
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
|
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -4,7 +4,7 @@ torch.set_num_threads(1)
|
|||||||
|
|
||||||
|
|
||||||
def load_silero_vad(onnx=False, opset_version=16):
|
def load_silero_vad(onnx=False, opset_version=16):
|
||||||
available_ops = [13, 14, 15, 16]
|
available_ops = [15, 16]
|
||||||
if onnx and opset_version not in available_ops:
|
if onnx and opset_version not in available_ops:
|
||||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ def load_silero_vad(onnx=False, opset_version=16):
|
|||||||
if opset_version == 16:
|
if opset_version == 16:
|
||||||
model_name = 'silero_vad.onnx'
|
model_name = 'silero_vad.onnx'
|
||||||
else:
|
else:
|
||||||
model_name = f'silero_vad_op{opset_version}.onnx'
|
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||||
else:
|
else:
|
||||||
model_name = 'silero_vad.jit'
|
model_name = 'silero_vad.jit'
|
||||||
package_path = "silero_vad.data"
|
package_path = "silero_vad.data"
|
||||||
|
|||||||
@@ -23,7 +23,11 @@ class OnnxWrapper():
|
|||||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||||
|
|
||||||
self.reset_states()
|
self.reset_states()
|
||||||
self.sample_rates = [8000, 16000]
|
if '16k' in path:
|
||||||
|
warnings.warn('This model support only 16000 sampling rate!')
|
||||||
|
self.sample_rates = [16000]
|
||||||
|
else:
|
||||||
|
self.sample_rates = [8000, 16000]
|
||||||
|
|
||||||
def _validate_input(self, x, sr: int):
|
def _validate_input(self, x, sr: int):
|
||||||
if x.dim() == 1:
|
if x.dim() == 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user