add onnx vad

This commit is contained in:
adamnsandle
2021-12-17 14:48:32 +00:00
parent 0feae6cbbe
commit 74f759c8f8
10 changed files with 185 additions and 377 deletions

View File

@@ -5,25 +5,68 @@ import torch.nn.functional as F
import warnings
languages = ['ru', 'en', 'de', 'es']
onnx_url_dict = {
'lang_classifier_95': 'https://models.silero.ai/vad_models/lang_classifier_95.onnx',
'number_detector':'https://models.silero.ai/vad_models/number_detector.onnx'
}
def donwload_onnx_model(model_name):
class OnnxWrapper():
if model_name not in ['lang_classifier_95', 'number_detector']:
raise ValueError
def __init__(self, path):
import numpy as np
global np
import onnxruntime
self.session = onnxruntime.InferenceSession(path)
self.session.intra_op_num_threads = 1
self.session.inter_op_num_threads = 1
torch.hub.download_url_to_file(onnx_url_dict[model_name], f'{model_name}.onnx')
self.reset_states()
def reset_states(self):
self._h = np.zeros((2, 1, 64)).astype('float32')
self._c = np.zeros((2, 1, 64)).astype('float32')
def __call__(self, x, sr: int):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() > 2:
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
if x.shape[0] > 1:
raise ValueError("Onnx model does not support batching")
if sr not in [16000]:
raise ValueError(f"Supported sample rates: {[16000]}")
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
ort_inputs = {'input': x.numpy(), 'h0': self._h, 'c0': self._c}
ort_outs = self.session.run(None, ort_inputs)
out, self._h, self._c = ort_outs
out = torch.tensor(out).squeeze(2)[:, 1] # make output type match JIT analog
return out
def validate(model,
inputs: torch.Tensor):
with torch.no_grad():
outs = model(inputs)
return outs
class Validator():
def __init__(self, url):
self.onnx = True if url.endswith('.onnx') else False
torch.hub.download_url_to_file(url, 'inf.model')
if self.onnx:
import onnxruntime
self.model = onnxruntime.InferenceSession('inf.model')
else:
self.model = init_jit_model(model_path='inf.model')
def __call__(self, inputs: torch.Tensor):
with torch.no_grad():
if self.onnx:
ort_inputs = {'input': inputs.cpu().numpy()}
outs = self.model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
else:
outs = self.model(inputs)
return outs
def read_audio(path: str,
@@ -215,10 +258,9 @@ def get_number_ts(wav: torch.Tensor,
model,
model_stride=8,
hop_length=160,
sample_rate=16000,
run_function=validate):
sample_rate=16000):
wav = torch.unsqueeze(wav, dim=0)
perframe_logits = run_function(model, wav)[0]
perframe_logits = model(wav)[0]
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
extended_preds = []
for i in perframe_preds:
@@ -245,10 +287,9 @@ def get_number_ts(wav: torch.Tensor,
def get_language(wav: torch.Tensor,
model,
run_function=validate):
model):
wav = torch.unsqueeze(wav, dim=0)
lang_logits = run_function(model, wav)[2]
lang_logits = model(wav)[2]
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
assert lang_pred < len(languages)
return languages[lang_pred]
@@ -258,10 +299,9 @@ def get_language_and_group(wav: torch.Tensor,
model,
lang_dict: dict,
lang_group_dict: dict,
top_n=1,
run_function=validate):
top_n=1):
wav = torch.unsqueeze(wav, dim=0)
lang_logits, lang_group_logits = run_function(model, wav)
lang_logits, lang_group_logits = model(wav)
softm = torch.softmax(lang_logits, dim=1).squeeze()
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
@@ -332,6 +372,13 @@ class VADIterator:
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
"""
if not torch.is_tensor(x):
try:
x = torch.Tensor(x)
except:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
self.current_sample += window_size_samples