add new parameter: force_onnx_cpu

This commit is contained in:
ChiehKai Yang
2022-10-12 01:56:43 +08:00
parent 7c671a75c2
commit 17209e6c4f
2 changed files with 18 additions and 12 deletions

View File

@@ -9,11 +9,14 @@ languages = ['ru', 'en', 'de', 'es']
class OnnxWrapper():
def __init__(self, path):
def __init__(self, path, force_onnx_cpu):
import numpy as np
global np
import onnxruntime
self.session = onnxruntime.InferenceSession(path)
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])
else:
self.session = onnxruntime.InferenceSession(path)
self.session.intra_op_num_threads = 1
self.session.inter_op_num_threads = 1
@@ -53,12 +56,15 @@ class OnnxWrapper():
class Validator():
def __init__(self, url):
def __init__(self, url, force_onnx_cpu):
self.onnx = True if url.endswith('.onnx') else False
torch.hub.download_url_to_file(url, 'inf.model')
if self.onnx:
import onnxruntime
self.model = onnxruntime.InferenceSession('inf.model')
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider'])
else:
self.model = onnxruntime.InferenceSession('inf.model')
else:
self.model = init_jit_model(model_path='inf.model')