mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
add onnx vad
This commit is contained in:
91
utils_vad.py
91
utils_vad.py
@@ -5,25 +5,68 @@ import torch.nn.functional as F
|
||||
import warnings
|
||||
|
||||
languages = ['ru', 'en', 'de', 'es']
|
||||
onnx_url_dict = {
|
||||
'lang_classifier_95': 'https://models.silero.ai/vad_models/lang_classifier_95.onnx',
|
||||
'number_detector':'https://models.silero.ai/vad_models/number_detector.onnx'
|
||||
}
|
||||
|
||||
|
||||
def donwload_onnx_model(model_name):
|
||||
class OnnxWrapper():
|
||||
|
||||
if model_name not in ['lang_classifier_95', 'number_detector']:
|
||||
raise ValueError
|
||||
def __init__(self, path):
|
||||
import numpy as np
|
||||
global np
|
||||
import onnxruntime
|
||||
self.session = onnxruntime.InferenceSession(path)
|
||||
self.session.intra_op_num_threads = 1
|
||||
self.session.inter_op_num_threads = 1
|
||||
|
||||
torch.hub.download_url_to_file(onnx_url_dict[model_name], f'{model_name}.onnx')
|
||||
self.reset_states()
|
||||
|
||||
def reset_states(self):
|
||||
self._h = np.zeros((2, 1, 64)).astype('float32')
|
||||
self._c = np.zeros((2, 1, 64)).astype('float32')
|
||||
|
||||
def __call__(self, x, sr: int):
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(0)
|
||||
if x.dim() > 2:
|
||||
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
||||
|
||||
if x.shape[0] > 1:
|
||||
raise ValueError("Onnx model does not support batching")
|
||||
|
||||
if sr not in [16000]:
|
||||
raise ValueError(f"Supported sample rates: {[16000]}")
|
||||
|
||||
if sr / x.shape[1] > 31.25:
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
ort_inputs = {'input': x.numpy(), 'h0': self._h, 'c0': self._c}
|
||||
ort_outs = self.session.run(None, ort_inputs)
|
||||
out, self._h, self._c = ort_outs
|
||||
|
||||
out = torch.tensor(out).squeeze(2)[:, 1] # make output type match JIT analog
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def validate(model,
|
||||
inputs: torch.Tensor):
|
||||
with torch.no_grad():
|
||||
outs = model(inputs)
|
||||
return outs
|
||||
class Validator():
|
||||
def __init__(self, url):
|
||||
self.onnx = True if url.endswith('.onnx') else False
|
||||
torch.hub.download_url_to_file(url, 'inf.model')
|
||||
if self.onnx:
|
||||
import onnxruntime
|
||||
self.model = onnxruntime.InferenceSession('inf.model')
|
||||
else:
|
||||
self.model = init_jit_model(model_path='inf.model')
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
with torch.no_grad():
|
||||
if self.onnx:
|
||||
ort_inputs = {'input': inputs.cpu().numpy()}
|
||||
outs = self.model.run(None, ort_inputs)
|
||||
outs = [torch.Tensor(x) for x in outs]
|
||||
else:
|
||||
outs = self.model(inputs)
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
def read_audio(path: str,
|
||||
@@ -215,10 +258,9 @@ def get_number_ts(wav: torch.Tensor,
|
||||
model,
|
||||
model_stride=8,
|
||||
hop_length=160,
|
||||
sample_rate=16000,
|
||||
run_function=validate):
|
||||
sample_rate=16000):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
perframe_logits = run_function(model, wav)[0]
|
||||
perframe_logits = model(wav)[0]
|
||||
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
|
||||
extended_preds = []
|
||||
for i in perframe_preds:
|
||||
@@ -245,10 +287,9 @@ def get_number_ts(wav: torch.Tensor,
|
||||
|
||||
|
||||
def get_language(wav: torch.Tensor,
|
||||
model,
|
||||
run_function=validate):
|
||||
model):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
lang_logits = run_function(model, wav)[2]
|
||||
lang_logits = model(wav)[2]
|
||||
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
|
||||
assert lang_pred < len(languages)
|
||||
return languages[lang_pred]
|
||||
@@ -258,10 +299,9 @@ def get_language_and_group(wav: torch.Tensor,
|
||||
model,
|
||||
lang_dict: dict,
|
||||
lang_group_dict: dict,
|
||||
top_n=1,
|
||||
run_function=validate):
|
||||
top_n=1):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
lang_logits, lang_group_logits = run_function(model, wav)
|
||||
lang_logits, lang_group_logits = model(wav)
|
||||
|
||||
softm = torch.softmax(lang_logits, dim=1).squeeze()
|
||||
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
|
||||
@@ -332,6 +372,13 @@ class VADIterator:
|
||||
return_seconds: bool (default - False)
|
||||
whether return timestamps in seconds (default - samples)
|
||||
"""
|
||||
|
||||
if not torch.is_tensor(x):
|
||||
try:
|
||||
x = torch.Tensor(x)
|
||||
except:
|
||||
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
|
||||
|
||||
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
||||
self.current_sample += window_size_samples
|
||||
|
||||
|
||||
Reference in New Issue
Block a user