From 17209e6c4f7086e20f2863958fbec20f12592ec1 Mon Sep 17 00:00:00 2001 From: ChiehKai Yang Date: Wed, 12 Oct 2022 01:56:43 +0800 Subject: [PATCH] add new parameter: force_onnx_cpu --- hubconf.py | 16 ++++++++-------- utils_vad.py | 14 ++++++++++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/hubconf.py b/hubconf.py index a2b3754..e97d669 100644 --- a/hubconf.py +++ b/hubconf.py @@ -15,14 +15,14 @@ from utils_vad import (init_jit_model, OnnxWrapper) -def silero_vad(onnx=False): +def silero_vad(onnx=False, force_onnx_cpu=False): """Silero Voice Activity Detector Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples """ hub_dir = torch.hub.get_dir() if onnx: - model = OnnxWrapper(f'{hub_dir}/snakers4_silero-vad_master/files/silero_vad.onnx') + model = OnnxWrapper(f'{hub_dir}/snakers4_silero-vad_master/files/silero_vad.onnx', force_onnx_cpu) else: model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/silero_vad.jit') utils = (get_speech_timestamps, @@ -34,7 +34,7 @@ def silero_vad(onnx=False): return model, utils -def silero_number_detector(onnx=False): +def silero_number_detector(onnx=False, force_onnx_cpu=False): """Silero Number Detector Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples @@ -43,7 +43,7 @@ def silero_number_detector(onnx=False): url = 'https://models.silero.ai/vad_models/number_detector.onnx' else: url = 'https://models.silero.ai/vad_models/number_detector.jit' - model = Validator(url) + model = Validator(url, force_onnx_cpu) utils = (get_number_ts, save_audio, read_audio, @@ -53,7 +53,7 @@ def silero_number_detector(onnx=False): return model, utils -def silero_lang_detector(onnx=False): +def silero_lang_detector(onnx=False, force_onnx_cpu=False): """Silero Language Classifier Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples @@ -62,14 +62,14 @@ def silero_lang_detector(onnx=False): url = 'https://models.silero.ai/vad_models/number_detector.onnx' else: url = 'https://models.silero.ai/vad_models/number_detector.jit' - model = Validator(url) + model = Validator(url, force_onnx_cpu) utils = (get_language, read_audio) return model, utils -def silero_lang_detector_95(onnx=False): +def silero_lang_detector_95(onnx=False, force_onnx_cpu=False): """Silero Language Classifier (95 languages) Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples @@ -80,7 +80,7 @@ def silero_lang_detector_95(onnx=False): url = 'https://models.silero.ai/vad_models/lang_classifier_95.onnx' else: url = 'https://models.silero.ai/vad_models/lang_classifier_95.jit' - model = Validator(url) + model = Validator(url, force_onnx_cpu) with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_dict_95.json', 'r') as f: lang_dict = json.load(f) diff --git a/utils_vad.py b/utils_vad.py index 80db3c9..e4fcfbf 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -9,11 +9,14 @@ languages = ['ru', 'en', 'de', 'es'] class OnnxWrapper(): - def __init__(self, path): + def __init__(self, path, force_onnx_cpu): import numpy as np global np import onnxruntime - self.session = onnxruntime.InferenceSession(path) + if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): + self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider']) + else: + self.session = onnxruntime.InferenceSession(path) self.session.intra_op_num_threads = 1 self.session.inter_op_num_threads = 1 @@ -53,12 +56,15 @@ class OnnxWrapper(): class Validator(): - def __init__(self, url): + def __init__(self, url, force_onnx_cpu): self.onnx = True if url.endswith('.onnx') else False torch.hub.download_url_to_file(url, 'inf.model') if self.onnx: import onnxruntime - self.model = onnxruntime.InferenceSession('inf.model') + if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): + self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider']) + else: + self.model = onnxruntime.InferenceSession('inf.model') else: self.model = init_jit_model(model_path='inf.model')