diff --git a/files/joint_VAD_just_RU.onnx b/files/joint_VAD_just_RU.onnx
new file mode 100644
index 0000000..e4596f1
Binary files /dev/null and b/files/joint_VAD_just_RU.onnx differ
diff --git a/files/joint_VAD_just_RU_jit_cut_q.pth.tar b/files/joint_VAD_just_RU_jit_cut_q.pth.tar
index 20e8e4a..67be4fb 100644
Binary files a/files/joint_VAD_just_RU_jit_cut_q.pth.tar and b/files/joint_VAD_just_RU_jit_cut_q.pth.tar differ
diff --git a/silero-vad.ipynb b/silero-vad.ipynb
index aefc8b5..291067e 100644
--- a/silero-vad.ipynb
+++ b/silero-vad.ipynb
@@ -2,35 +2,22 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-12-11T15:10:52.128138Z",
- "start_time": "2020-12-11T15:10:51.548322Z"
+ "end_time": "2020-12-14T13:43:24.487521Z",
+ "start_time": "2020-12-14T13:43:23.780570Z"
}
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/opt/conda/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n",
- " warnings.warn(\n",
- "/opt/conda/lib/python3.8/site-packages/torchaudio/backend/utils.py:63: UserWarning: The interface of \"soundfile\" backend is planned to change in 0.8.0 to match that of \"sox_io\" backend and the current interface will be removed in 0.9.0. To use the new interface, do `torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False` before setting the backend to \"soundfile\". Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n",
- " warnings.warn(\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import glob\n",
"import torch\n",
- "import numpy as np # use only torch?\n",
- "import soundfile as sf\n",
- "# import torch.nn.functional as F\n",
"from IPython.display import Audio\n",
"torch.set_num_threads(1)\n",
- "from utils import init_jit_model, STFTExtractor, get_speech_ts, read_audio, state_generator, single_audio_stream\n",
- "extractor = STFTExtractor()"
+ "from utils import (init_jit_model, get_speech_ts, \n",
+ " save_audio, read_audio, \n",
+ " state_generator, single_audio_stream, init_onnx_model)"
]
},
{
@@ -42,11 +29,11 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-12-11T14:25:05.274301Z",
- "start_time": "2020-12-11T14:25:05.271313Z"
+ "end_time": "2020-12-14T13:43:24.492506Z",
+ "start_time": "2020-12-14T13:43:24.489440Z"
}
},
"outputs": [],
@@ -55,132 +42,82 @@
" speech_chunks = []\n",
" for i in tss:\n",
" speech_chunks.append(wav[i['start']: i['end']])\n",
- " return np.concatenate(speech_chunks)"
+ " return torch.cat(speech_chunks)"
]
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-12-11T14:25:06.395183Z",
- "start_time": "2020-12-11T14:25:06.082595Z"
+ "end_time": "2020-12-14T13:43:24.760714Z",
+ "start_time": "2020-12-14T13:43:24.493992Z"
}
},
"outputs": [],
"source": [
- "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file"
+ "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n",
+ "model = init_onnx_model('files/joint_VAD_just_RU.onnx')"
]
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-12-11T14:25:25.523423Z",
- "start_time": "2020-12-11T14:25:25.493581Z"
- }
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "Audio('files/test_audio_8.wav')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2020-12-11T14:25:43.023784Z",
- "start_time": "2020-12-11T14:25:43.017360Z"
+ "end_time": "2020-12-14T13:43:24.793384Z",
+ "start_time": "2020-12-14T13:43:24.762311Z"
}
},
"outputs": [],
"source": [
- "wav = read_audio('files/test_audio_8.wav')"
+ "Audio('files/test_audio_2.wav')\n",
+ "wav = read_audio('files/test_audio_2.wav')"
]
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-12-11T14:25:45.083872Z",
- "start_time": "2020-12-11T14:25:43.371366Z"
+ "end_time": "2020-12-14T13:43:25.320324Z",
+ "start_time": "2020-12-14T13:43:24.808594Z"
}
},
"outputs": [],
"source": [
- "speech_timestamps = get_speech_ts(wav, model, extractor, num_steps=4) # kill extractor"
+ "speech_timestamps = get_speech_ts(wav, model, num_steps=4) # kill extractor"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-12-11T14:25:45.130371Z",
- "start_time": "2020-12-11T14:25:45.091010Z"
+ "end_time": "2020-12-14T13:43:25.324901Z",
+ "start_time": "2020-12-14T13:43:25.321759Z"
}
},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 19,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "sf.write('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000)\n",
+ "speech_timestamps"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-12-14T13:43:25.344065Z",
+ "start_time": "2020-12-14T13:43:25.326162Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "save_audio('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000)\n",
"Audio('only_speech.wav')"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -190,76 +127,36 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-12-11T15:10:55.789272Z",
- "start_time": "2020-12-11T15:10:55.543652Z"
+ "end_time": "2020-12-14T13:43:25.778585Z",
+ "start_time": "2020-12-14T13:43:25.496583Z"
}
},
"outputs": [],
"source": [
- "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu')\n",
+ "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n",
+ "#model = init_onnx_model('files/joint_VAD_just_RU.onnx')\n",
"audio = 'files/test_audio_6.wav'"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-12-11T15:10:59.503301Z",
- "start_time": "2020-12-11T15:10:55.790671Z"
+ "end_time": "2020-12-14T13:43:29.402604Z",
+ "start_time": "2020-12-14T13:43:25.780037Z"
}
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/opt/conda/lib/python3.8/site-packages/torch/functional.py:515: UserWarning: stft will require the return_complex parameter be explicitly specified in a future PyTorch release. Use return_complex=False to preserve the current behavior or return_complex=True to return a complex output. (Triggered internally at /opt/conda/conda-bld/pytorch_1603729096996/work/aten/src/ATen/native/SpectralOps.cpp:653.)\n",
- " return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore\n",
- "/opt/conda/lib/python3.8/site-packages/torch/functional.py:515: UserWarning: The function torch.rfft is deprecated and will be removed in a future PyTorch release. Use the new torch.fft module functions, instead, by importing torch.fft and calling torch.fft.fft or torch.fft.rfft. (Triggered internally at /opt/conda/conda-bld/pytorch_1603729096996/work/aten/src/ATen/native/SpectralOps.cpp:590.)\n",
- " return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[{183500: 'start'}]\n",
- "[{202500: 'end'}]\n",
- "[{226500: 'start'}]\n",
- "[{283500: 'end'}]\n",
- "[{337500: 'start'}]\n",
- "[{503000: 'end'}]\n",
- "[{507500: 'start'}]\n",
- "[{627500: 'end'}]\n",
- "[{631500: 'start'}]\n",
- "[{927488: 'end'}]\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "for i in single_audio_stream(model, audio, extractor):\n",
+ "for i in single_audio_stream(model, audio):\n",
" if i:\n",
" print(i)"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -269,39 +166,29 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-12-11T14:28:09.649303Z",
- "start_time": "2020-12-11T14:28:09.373634Z"
+ "end_time": "2020-12-14T13:43:29.674262Z",
+ "start_time": "2020-12-14T13:43:29.403972Z"
}
},
"outputs": [],
"source": [
- "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu')"
+ "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n",
+ "model = init_onnx_model('files/joint_VAD_just_RU.onnx')"
]
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-12-11T14:28:12.273951Z",
- "start_time": "2020-12-11T14:28:12.269729Z"
+ "end_time": "2020-12-14T13:43:29.678449Z",
+ "start_time": "2020-12-14T13:43:29.675519Z"
}
},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "10"
- ]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"audios_for_stream = glob.glob('files/test*.wav')\n",
"len(audios_for_stream)"
@@ -309,107 +196,34 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2020-12-11T14:28:32.459872Z",
- "start_time": "2020-12-11T14:28:14.502871Z"
+ "end_time": "2020-12-14T13:43:40.236387Z",
+ "start_time": "2020-12-14T13:43:29.679274Z"
}
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Done initial Loading\n",
- "[({106500: 'start'}, 'files/test_audio_1.wav')]\n",
- "[({174000: 'start'}, 'files/test_audio_3.wav')]\n",
- "[({261000: 'end'}, 'files/test_audio_1.wav')]\n",
- "Loading next wav: files/test_audio_7.wav\n",
- "[({134000: 'start'}, 'files/test_audio_7.wav')]\n",
- "[({147500: 'end'}, 'files/test_audio_7.wav')]\n",
- "[({442000: 'end'}, 'files/test_audio_3.wav')]\n",
- "[({450500: 'start'}, 'files/test_audio_3.wav')]\n",
- "[({209500: 'start'}, 'files/test_audio_7.wav')]\n",
- "[({519500: 'end'}, 'files/test_audio_3.wav')]\n",
- "[({533500: 'start'}, 'files/test_audio_3.wav')]\n",
- "[({599904: 'end'}, 'files/test_audio_3.wav')]\n",
- "Loading next wav: files/test_audio_6.wav\n",
- "[({183500: 'start'}, 'files/test_audio_6.wav')]\n",
- "[({503500: 'end'}, 'files/test_audio_7.wav')]\n",
- "[({202500: 'end'}, 'files/test_audio_6.wav')]\n",
- "[({537500: 'start'}, 'files/test_audio_7.wav')]\n",
- "[({226500: 'start'}, 'files/test_audio_6.wav')]\n",
- "[({283500: 'end'}, 'files/test_audio_6.wav')]\n",
- "[({616500: 'end'}, 'files/test_audio_7.wav')]\n",
- "[({337500: 'start'}, 'files/test_audio_6.wav')]\n",
- "[({661500: 'start'}, 'files/test_audio_7.wav')]\n",
- "[({785000: 'end'}, 'files/test_audio_7.wav')]\n",
- "[({503000: 'end'}, 'files/test_audio_6.wav')]\n",
- "[({507500: 'start'}, 'files/test_audio_6.wav')]\n",
- "[({851500: 'start'}, 'files/test_audio_7.wav')]\n",
- "[({919000: 'end'}, 'files/test_audio_7.wav')]\n",
- "Loading next wav: files/test_audio_5.wav\n",
- "[({627500: 'end'}, 'files/test_audio_6.wav')]\n",
- "[({631500: 'start'}, 'files/test_audio_6.wav')]\n",
- "[({151000: 'start'}, 'files/test_audio_5.wav')]\n",
- "[({169000: 'end'}, 'files/test_audio_5.wav')]\n",
- "[({211000: 'start'}, 'files/test_audio_5.wav')]\n",
- "[({221500: 'end'}, 'files/test_audio_5.wav')]\n",
- "Loading next wav: files/test_audio_2.wav\n",
- "[({927488: 'end'}, 'files/test_audio_6.wav')]\n",
- "Loading next wav: files/test_audio_8.wav\n",
- "[({228000: 'start'}, 'files/test_audio_2.wav')]\n",
- "[({179500: 'start'}, 'files/test_audio_8.wav')]\n",
- "[({241500: 'end'}, 'files/test_audio_2.wav')]\n",
- "[({279000: 'start'}, 'files/test_audio_2.wav')]\n",
- "[({274500: 'end'}, 'files/test_audio_8.wav')]\n",
- "[({300500: 'start'}, 'files/test_audio_8.wav')]\n",
- "[({369500: 'end'}, 'files/test_audio_2.wav')]\n",
- "[({378500: 'start'}, 'files/test_audio_2.wav')]\n",
- "[({436500: 'end'}, 'files/test_audio_2.wav')]\n",
- "[({423000: 'end'}, 'files/test_audio_8.wav')]\n",
- "[({488500: 'start'}, 'files/test_audio_2.wav')]\n",
- "[({458500: 'start'}, 'files/test_audio_8.wav')]\n",
- "[({599904: 'end'}, 'files/test_audio_2.wav')]\n",
- "Loading next wav: files/test_audio_4.wav\n",
- "[({583500: 'end'}, 'files/test_audio_8.wav')]\n",
- "[({599500: 'start'}, 'files/test_audio_8.wav')]\n",
- "[({632500: 'end'}, 'files/test_audio_8.wav')]\n",
- "[({660000: 'start'}, 'files/test_audio_8.wav')]\n",
- "[({737000: 'end'}, 'files/test_audio_8.wav')]\n",
- "[({761000: 'start'}, 'files/test_audio_8.wav')]\n",
- "[({249500: 'start'}, 'files/test_audio_4.wav')]\n",
- "[({257168: 'end'}, 'files/test_audio_4.wav')]\n",
- "Loading next wav: files/test_audio_9.wav\n",
- "[({843000: 'end'}, 'files/test_audio_8.wav')]\n",
- "Loading next wav: files/test_audio_0.wav\n",
- "[({133000: 'start'}, 'files/test_audio_9.wav')]\n",
- "[({143500: 'end'}, 'files/test_audio_9.wav')]\n",
- "[({272000: 'start'}, 'files/test_audio_9.wav')]\n",
- "[({256500: 'start'}, 'files/test_audio_0.wav')]\n",
- "[({336500: 'end'}, 'files/test_audio_9.wav'), ({281232: 'end'}, 'files/test_audio_0.wav')]\n",
- "[({406500: 'start'}, 'files/test_audio_9.wav')]\n",
- "[({460000: 'end'}, 'files/test_audio_9.wav')]\n",
- "[({476000: 'start'}, 'files/test_audio_9.wav')]\n",
- "[({494500: 'end'}, 'files/test_audio_9.wav')]\n",
- "[({544500: 'start'}, 'files/test_audio_9.wav')]\n",
- "[({564500: 'end'}, 'files/test_audio_9.wav')]\n",
- "[({595000: 'start'}, 'files/test_audio_9.wav')]\n",
- "[({682000: 'end'}, 'files/test_audio_9.wav')]\n",
- "[({728500: 'start'}, 'files/test_audio_9.wav')]\n",
- "[({786000: 'end'}, 'files/test_audio_9.wav')]\n",
- "[({814000: 'start'}, 'files/test_audio_9.wav')]\n",
- "[({826000: 'end'}, 'files/test_audio_9.wav')]\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "for i in state_generator(model, audios_for_stream, extractor, audios_in_stream=2):\n",
+ "for i in state_generator(model, audios_for_stream, audios_in_stream=2):\n",
" if i:\n",
" print(i)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-12-14T13:46:49.812052Z",
+ "start_time": "2020-12-14T13:46:49.586637Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!cp ../silero-models-research/model_saves/joint_VAD_just_RU_jit_cut_q.pth.tar files/"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
diff --git a/utils.py b/utils.py
index a4e05ed..0cda630 100644
--- a/utils.py
+++ b/utils.py
@@ -7,23 +7,10 @@ import torch.nn.functional as F
from collections import deque
import numpy as np
from itertools import repeat
+import onnxruntime
torchaudio.set_audio_backend("soundfile") # switch backend
-
-def read_batch(audio_paths: List[str]):
- return [read_audio(audio_path)
- for audio_path
- in audio_paths]
-
-
-def split_into_batches(lst: List[str],
- batch_size: int = 10):
- return [lst[i:i + batch_size]
- for i in
- range(0, len(lst), batch_size)]
-
-
def read_audio(path: str,
target_sr: int = 16000):
@@ -42,15 +29,10 @@ def read_audio(path: str,
assert sr == target_sr
return wav.squeeze(0)
-
-def prepare_model_input(batch: List[torch.Tensor],
- device=torch.device('cpu')):
- max_seqlength = max(max([len(_) for _ in batch]), 12800)
- inputs = torch.zeros(len(batch), max_seqlength)
- for i, wav in enumerate(batch):
- inputs[i, :len(wav)].copy_(wav)
- inputs = inputs.to(device)
- return inputs
+def save_audio(path: str,
+ tensor: torch.Tensor,
+ sr: int):
+ torchaudio.save(path, tensor, sr)
#def init_jit_model(model_url: str,
@@ -72,8 +54,11 @@ def init_jit_model(model_path,
model.eval()
return model
+def init_onnx_model(model_path):
+ return onnxruntime.InferenceSession(model_path)
-def get_speech_ts(wav, model, extractor,
+
+def get_speech_ts(wav, model,
trig_sum=0.25, neg_trig_sum=0.01,
num_steps=8, batch_size=200):
@@ -90,15 +75,13 @@ def get_speech_ts(wav, model, extractor,
to_concat.append(chunk)
if len(to_concat) >= batch_size:
chunks = torch.Tensor(torch.vstack(to_concat))
- with torch.no_grad():
- out = model(extractor(chunks))[-2]
+ out = validate(model, chunks)[-2]
outs.append(out)
to_concat = []
if to_concat:
chunks = torch.Tensor(torch.vstack(to_concat))
- with torch.no_grad():
- out = model(extractor(chunks))[-2]
+ out = validate(model, chunks)[-2]
outs.append(out)
outs = torch.cat(outs, dim=0)
@@ -125,37 +108,6 @@ def get_speech_ts(wav, model, extractor,
speeches.append(current_speech)
return speeches
-
-class STFTExtractor(nn.Module):
- def __init__(self, sr=16000, win_size=0.02, mode='mag'):
- super(STFTExtractor, self).__init__()
- self.sr = sr
- self.n_fft = int(sr * (win_size + 1e-8))
- self.win_length = self.n_fft
- self.hop_length = self.win_length // 2
- self.mode = 'mag' if mode == '' else mode
-
- def forward(self, wav):
- # center==True because other frame-level features are centered by default in torch/librosa and we can't change this.
- stft_sample = torch.stft(wav,
- n_fft=self.n_fft,
- win_length=self.win_length,
- hop_length=self.hop_length,
- center=True)
- mag, phase = torchaudio.functional.magphase(stft_sample)
-
- # It seems it is not a "mag", it is "power" (exp == 1).
- # Also there is "energy" (exp == 2).
- if self.mode == 'mag':
- return mag
- if self.mode == 'phase':
- return phase
- elif self.mode == 'magphase':
- return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1)
- else:
- raise NotImplementedError()
-
-
class VADiterator:
def __init__(self,
trig_sum=0.26, neg_trig_sum=0.01,
@@ -214,7 +166,7 @@ class VADiterator:
return current_speech, self.current_name
-def state_generator(model, audios, extractor,
+def state_generator(model, audios,
onnx=False,
trig_sum=0.26, neg_trig_sum=0.01,
num_steps=8, audios_in_stream=5):
@@ -223,14 +175,8 @@ def state_generator(model, audios, extractor,
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
batch = torch.cat(for_batch)
- with torch.no_grad():
- if onnx:
- ort_inputs = {'input': to_numpy(extractor(batch))}
- ort_outs = model.run(None, ort_inputs)
- vad_outs = np.split(ort_outs[-2], audios_in_stream)
- else:
- outs = model(extractor(batch))
- vad_outs = np.split(outs[-2].numpy(), audios_in_stream)
+ outs = validate(model, batch)
+ vad_outs = np.split(outs[-2].numpy(), audios_in_stream)
states = []
for x, y in zip(VADiters, vad_outs):
@@ -273,7 +219,7 @@ def stream_imitator(audios, audios_in_stream):
values.append((out, wav_name))
yield values
-def single_audio_stream(model, audio, extractor, onnx=False, trig_sum=0.26,
+def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
neg_trig_sum=0.01, num_steps=8):
num_samples = 4000
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
@@ -281,18 +227,25 @@ def single_audio_stream(model, audio, extractor, onnx=False, trig_sum=0.26,
wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)])
for chunk in wav_chunks:
batch = VADiter.prepare_batch(chunk)
-
- with torch.no_grad():
- if onnx:
- ort_inputs = {'input': to_numpy(extractor(batch))}
- ort_outs = model.run(None, ort_inputs)
- vad_outs = ort_outs[-2]
- else:
- outs = model(extractor(batch))
- vad_outs = outs[-2]
+
+ outs = validate(model, batch)
+ vad_outs = outs[-2]
states = []
state = VADiter.state(vad_outs)
if state[0]:
states.append(state[0])
yield states
+
+def validate(model, inputs):
+ onnx = False
+ if type(model) == onnxruntime.capi.session.InferenceSession:
+ onnx = True
+ with torch.no_grad():
+ if onnx:
+ ort_inputs = {'input': inputs.cpu().numpy()}
+ outs = model.run(None, ort_inputs)
+ outs = [torch.Tensor(x) for x in outs]
+ else:
+ outs = model(inputs)
+ return outs