mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
add new parameter: force_onnx_cpu
This commit is contained in:
14
utils_vad.py
14
utils_vad.py
@@ -9,11 +9,14 @@ languages = ['ru', 'en', 'de', 'es']
|
||||
|
||||
class OnnxWrapper():
|
||||
|
||||
def __init__(self, path):
|
||||
def __init__(self, path, force_onnx_cpu):
|
||||
import numpy as np
|
||||
global np
|
||||
import onnxruntime
|
||||
self.session = onnxruntime.InferenceSession(path)
|
||||
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])
|
||||
else:
|
||||
self.session = onnxruntime.InferenceSession(path)
|
||||
self.session.intra_op_num_threads = 1
|
||||
self.session.inter_op_num_threads = 1
|
||||
|
||||
@@ -53,12 +56,15 @@ class OnnxWrapper():
|
||||
|
||||
|
||||
class Validator():
|
||||
def __init__(self, url):
|
||||
def __init__(self, url, force_onnx_cpu):
|
||||
self.onnx = True if url.endswith('.onnx') else False
|
||||
torch.hub.download_url_to_file(url, 'inf.model')
|
||||
if self.onnx:
|
||||
import onnxruntime
|
||||
self.model = onnxruntime.InferenceSession('inf.model')
|
||||
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||
self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider'])
|
||||
else:
|
||||
self.model = onnxruntime.InferenceSession('inf.model')
|
||||
else:
|
||||
self.model = init_jit_model(model_path='inf.model')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user