diff --git a/files/joint_VAD_just_RU_jit_cut_q.pth.tar b/files/joint_VAD_just_RU_jit_cut_q.pth.tar
new file mode 100644
index 0000000..20e8e4a
Binary files /dev/null and b/files/joint_VAD_just_RU_jit_cut_q.pth.tar differ
diff --git a/silero-vad.ipynb b/silero-vad.ipynb
new file mode 100644
index 0000000..403df15
--- /dev/null
+++ b/silero-vad.ipynb
@@ -0,0 +1,322 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-12-11T13:30:32.615246Z",
+ "start_time": "2020-12-11T13:30:32.126553Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n",
+ " warnings.warn(\n",
+ "/opt/conda/lib/python3.8/site-packages/torchaudio/backend/utils.py:63: UserWarning: The interface of \"soundfile\" backend is planned to change in 0.8.0 to match that of \"sox_io\" backend and the current interface will be removed in 0.9.0. To use the new interface, do `torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False` before setting the backend to \"soundfile\". Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import numpy as np\n",
+ "import glob\n",
+ "import torch.nn.functional as F\n",
+ "import soundfile as sf\n",
+ "from IPython.display import Audio\n",
+ "torch.set_num_threads(1)\n",
+ "from utils import init_jit_model, STFTExtractor, get_speech_ts, read_audio, state_generator\n",
+ "extractor = STFTExtractor()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Full audio example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-12-11T13:32:01.978079Z",
+ "start_time": "2020-12-11T13:32:01.974912Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def collect_speeches(tss, wav):\n",
+ " speech_chunks = []\n",
+ " for i in tss:\n",
+ " speech_chunks.append(wav[i['start']: i['end']])\n",
+ " return np.concatenate(speech_chunks)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-12-11T13:31:55.255097Z",
+ "start_time": "2020-12-11T13:31:55.020705Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-12-11T13:32:10.391589Z",
+ "start_time": "2020-12-11T13:32:10.387109Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "wav = read_audio('files/test_audio_6.wav')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-12-11T13:32:11.670091Z",
+ "start_time": "2020-12-11T13:32:10.814378Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "speech_timestamps = get_speech_ts(wav, model, extractor, num_steps=4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-12-11T13:32:11.698816Z",
+ "start_time": "2020-12-11T13:32:11.671735Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sf.write('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000)\n",
+ "Audio('only_speech.wav')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Stream example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-12-11T13:31:34.137062Z",
+ "start_time": "2020-12-11T13:31:33.957092Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-12-11T13:31:36.332200Z",
+ "start_time": "2020-12-11T13:31:36.328087Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "10"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "audios_for_stream = glob.glob('files/test*.wav')\n",
+ "len(audios_for_stream)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2020-12-11T13:31:52.668041Z",
+ "start_time": "2020-12-11T13:31:37.357340Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Done initial Loading\n",
+ "[({106500: 'start'}, 'files/test_audio_1.wav')]\n",
+ "[({174000: 'start'}, 'files/test_audio_3.wav')]\n",
+ "[({261000: 'end'}, 'files/test_audio_1.wav')]\n",
+ "Loading next wav: files/test_audio_7.wav\n",
+ "[({134000: 'start'}, 'files/test_audio_7.wav')]\n",
+ "[({147500: 'end'}, 'files/test_audio_7.wav')]\n",
+ "[({442000: 'end'}, 'files/test_audio_3.wav')]\n",
+ "[({450500: 'start'}, 'files/test_audio_3.wav')]\n",
+ "[({209500: 'start'}, 'files/test_audio_7.wav')]\n",
+ "[({519500: 'end'}, 'files/test_audio_3.wav')]\n",
+ "[({533500: 'start'}, 'files/test_audio_3.wav')]\n",
+ "[({599904: 'end'}, 'files/test_audio_3.wav')]\n",
+ "Loading next wav: files/test_audio_6.wav\n",
+ "[({183500: 'start'}, 'files/test_audio_6.wav')]\n",
+ "[({503500: 'end'}, 'files/test_audio_7.wav')]\n",
+ "[({202500: 'end'}, 'files/test_audio_6.wav')]\n",
+ "[({537500: 'start'}, 'files/test_audio_7.wav')]\n",
+ "[({226500: 'start'}, 'files/test_audio_6.wav')]\n",
+ "[({283500: 'end'}, 'files/test_audio_6.wav')]\n",
+ "[({616500: 'end'}, 'files/test_audio_7.wav')]\n",
+ "[({337500: 'start'}, 'files/test_audio_6.wav')]\n",
+ "[({661500: 'start'}, 'files/test_audio_7.wav')]\n",
+ "[({785000: 'end'}, 'files/test_audio_7.wav')]\n",
+ "[({503000: 'end'}, 'files/test_audio_6.wav')]\n",
+ "[({507500: 'start'}, 'files/test_audio_6.wav')]\n",
+ "[({851500: 'start'}, 'files/test_audio_7.wav')]\n",
+ "[({919000: 'end'}, 'files/test_audio_7.wav')]\n",
+ "Loading next wav: files/test_audio_5.wav\n",
+ "[({627500: 'end'}, 'files/test_audio_6.wav')]\n",
+ "[({631500: 'start'}, 'files/test_audio_6.wav')]\n",
+ "[({151000: 'start'}, 'files/test_audio_5.wav')]\n",
+ "[({169000: 'end'}, 'files/test_audio_5.wav')]\n",
+ "[({211000: 'start'}, 'files/test_audio_5.wav')]\n",
+ "[({221500: 'end'}, 'files/test_audio_5.wav')]\n",
+ "Loading next wav: files/test_audio_2.wav\n",
+ "[({927488: 'end'}, 'files/test_audio_6.wav')]\n",
+ "Loading next wav: files/test_audio_8.wav\n",
+ "[({228000: 'start'}, 'files/test_audio_2.wav')]\n",
+ "[({179500: 'start'}, 'files/test_audio_8.wav')]\n",
+ "[({241500: 'end'}, 'files/test_audio_2.wav')]\n",
+ "[({279000: 'start'}, 'files/test_audio_2.wav')]\n",
+ "[({274500: 'end'}, 'files/test_audio_8.wav')]\n",
+ "[({300500: 'start'}, 'files/test_audio_8.wav')]\n",
+ "[({369500: 'end'}, 'files/test_audio_2.wav')]\n",
+ "[({378500: 'start'}, 'files/test_audio_2.wav')]\n",
+ "[({436500: 'end'}, 'files/test_audio_2.wav')]\n",
+ "[({423000: 'end'}, 'files/test_audio_8.wav')]\n",
+ "[({488500: 'start'}, 'files/test_audio_2.wav')]\n",
+ "[({458500: 'start'}, 'files/test_audio_8.wav')]\n",
+ "[({599904: 'end'}, 'files/test_audio_2.wav')]\n",
+ "Loading next wav: files/test_audio_4.wav\n",
+ "[({583500: 'end'}, 'files/test_audio_8.wav')]\n",
+ "[({599500: 'start'}, 'files/test_audio_8.wav')]\n",
+ "[({632500: 'end'}, 'files/test_audio_8.wav')]\n",
+ "[({660000: 'start'}, 'files/test_audio_8.wav')]\n",
+ "[({737000: 'end'}, 'files/test_audio_8.wav')]\n",
+ "[({761000: 'start'}, 'files/test_audio_8.wav')]\n",
+ "[({249500: 'start'}, 'files/test_audio_4.wav')]\n",
+ "[({257168: 'end'}, 'files/test_audio_4.wav')]\n",
+ "Loading next wav: files/test_audio_9.wav\n",
+ "[({843000: 'end'}, 'files/test_audio_8.wav')]\n",
+ "Loading next wav: files/test_audio_0.wav\n",
+ "[({133000: 'start'}, 'files/test_audio_9.wav')]\n",
+ "[({143500: 'end'}, 'files/test_audio_9.wav')]\n",
+ "[({272000: 'start'}, 'files/test_audio_9.wav')]\n",
+ "[({256500: 'start'}, 'files/test_audio_0.wav')]\n",
+ "[({336500: 'end'}, 'files/test_audio_9.wav'), ({281232: 'end'}, 'files/test_audio_0.wav')]\n",
+ "[({406500: 'start'}, 'files/test_audio_9.wav')]\n",
+ "[({460000: 'end'}, 'files/test_audio_9.wav')]\n",
+ "[({476000: 'start'}, 'files/test_audio_9.wav')]\n",
+ "[({494500: 'end'}, 'files/test_audio_9.wav')]\n",
+ "[({544500: 'start'}, 'files/test_audio_9.wav')]\n",
+ "[({564500: 'end'}, 'files/test_audio_9.wav')]\n",
+ "[({595000: 'start'}, 'files/test_audio_9.wav')]\n",
+ "[({682000: 'end'}, 'files/test_audio_9.wav')]\n",
+ "[({728500: 'start'}, 'files/test_audio_9.wav')]\n",
+ "[({786000: 'end'}, 'files/test_audio_9.wav')]\n",
+ "[({814000: 'start'}, 'files/test_audio_9.wav')]\n",
+ "[({826000: 'end'}, 'files/test_audio_9.wav')]\n"
+ ]
+ }
+ ],
+ "source": [
+ "for i in state_generator(model, audios_for_stream, extractor, audios_in_stream=2):\n",
+ " if i:\n",
+ " print(i)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.3"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": false
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/utils.py b/utils.py
index 23019df..8355a73 100644
--- a/utils.py
+++ b/utils.py
@@ -2,6 +2,11 @@ import torch
import tempfile
import torchaudio
from typing import List
+import torch.nn as nn
+import torch.nn.functional as F
+from collections import deque
+import numpy as np
+from itertools import repeat
torchaudio.set_audio_backend("soundfile") # switch backend
@@ -48,13 +53,210 @@ def prepare_model_input(batch: List[torch.Tensor],
return inputs
-def init_jit_model(model_url: str,
- device: torch.device = torch.device('cpu')):
+#def init_jit_model(model_url: str,
+# device: torch.device = torch.device('cpu')):
+# torch.set_grad_enabled(False)
+# with tempfile.NamedTemporaryFile('wb', suffix='.model') as f:
+# torch.hub.download_url_to_file(model_url,
+# f.name,
+# progress=True)
+# model = torch.jit.load(f.name, map_location=device)
+# model.eval()
+# return model
+
+
+def init_jit_model(model_path,
+ device):
torch.set_grad_enabled(False)
- with tempfile.NamedTemporaryFile('wb', suffix='.model') as f:
- torch.hub.download_url_to_file(model_url,
- f.name,
- progress=True)
- model = torch.jit.load(f.name, map_location=device)
- model.eval()
+ model = torch.jit.load(model_path, map_location=device)
+ model.eval()
return model
+
+
+def get_speech_ts(wav, model, extractor, trig_sum=0.25, neg_trig_sum=0.01, num_steps=8, batch_size=200):
+ assert 4000 % num_steps == 0
+ step = int(4000 / num_steps)
+ outs = []
+ to_concat = []
+ for i in range(0, len(wav), step):
+ chunk = wav[i: i+4000]
+ if len(chunk) < 4000:
+ chunk = F.pad(chunk, (0, 4000 - len(chunk)))
+ to_concat.append(chunk)
+ if len(to_concat) >= batch_size:
+ chunks = torch.Tensor(torch.vstack(to_concat))
+ with torch.no_grad():
+ out = model(extractor(chunks))[-2]
+ outs.append(out)
+ to_concat = []
+
+ if to_concat:
+ chunks = torch.Tensor(torch.vstack(to_concat))
+ with torch.no_grad():
+ out = model(extractor(chunks))[-2]
+ outs.append(out)
+
+ outs = torch.cat(outs, dim=0)
+
+ buffer = deque(maxlen=num_steps)
+ triggered = False
+ speeches = []
+ current_speech = {}
+ for i, predict in enumerate(outs[:, 1]):
+ buffer.append(predict)
+ if (np.mean(buffer) >= trig_sum) and not triggered:
+ triggered = True
+ current_speech['start'] = step * max(0, i-num_steps)
+ if (np.mean(buffer) < neg_trig_sum) and triggered:
+ current_speech['end'] = step * i
+ if (current_speech['end'] - current_speech['start']) > 10000:
+ speeches.append(current_speech)
+ current_speech = {}
+ triggered = False
+ if current_speech:
+ current_speech['end'] = len(wav)
+ speeches.append(current_speech)
+ return speeches
+
+
+class STFTExtractor(nn.Module):
+ def __init__(self, sr=16000, win_size=0.02, mode='mag'):
+ super(STFTExtractor, self).__init__()
+ self.sr = sr
+ self.n_fft = int(sr * (win_size + 1e-8))
+ self.win_length = self.n_fft
+ self.hop_length = self.win_length // 2
+ self.mode = 'mag' if mode == '' else mode
+
+ def forward(self, wav):
+ # center==True because other frame-level features are centered by default in torch/librosa and we can't change this.
+ stft_sample = torch.stft(wav,
+ n_fft=self.n_fft,
+ win_length=self.win_length,
+ hop_length=self.hop_length,
+ center=True)
+ mag, phase = torchaudio.functional.magphase(stft_sample)
+
+ # It seems it is not a "mag", it is "power" (exp == 1).
+ # Also there is "energy" (exp == 2).
+ if self.mode == 'mag':
+ return mag
+ if self.mode == 'phase':
+ return phase
+ elif self.mode == 'magphase':
+ return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1)
+ else:
+ raise NotImplementedError()
+
+
+class VADiterator:
+ def __init__(self, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8):
+ self.num_steps = num_steps
+ assert 4000 % num_steps == 0
+ self.step = int(4000 / num_steps)
+ self.prev = torch.zeros(4000)
+ self.last = False
+ self.triggered = False
+ self.buffer = deque(maxlen=8)
+ self.num_frames = 0
+ self.trig_sum = trig_sum
+ self.neg_trig_sum = neg_trig_sum
+ self.current_name = ''
+
+ def refresh(self):
+ self.prev = torch.zeros(4000)
+ self.last = False
+ self.triggered = False
+ self.buffer = deque(maxlen=8)
+ self.num_frames = 0
+
+ def prepare_batch(self, wav_chunk, name=None):
+ if (name is not None) and (name != self.current_name):
+ self.refresh()
+ self.current_name = name
+ assert len(wav_chunk) <= 4000
+ self.num_frames += len(wav_chunk)
+ if len(wav_chunk) < 4000:
+ wav_chunk = F.pad(wav_chunk, (0, 4000 - len(wav_chunk))) # assume that short chunk means end of the audio
+ self.last = True
+
+ stacked = torch.hstack([self.prev, wav_chunk])
+ self.prev = wav_chunk
+
+ overlap_chunks = [stacked[i:i+4000] for i in range(500, 4001, self.step)] # 500 step is good enough
+ return torch.vstack(overlap_chunks)
+
+ def state(self, model_out):
+ current_speech = {}
+ for i, predict in enumerate(model_out[:, 1]):
+ self.buffer.append(predict)
+ if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered:
+ self.triggered = True
+ current_speech[self.num_frames - (self.num_steps-i) * self.step] = 'start'
+ if (np.mean(self.buffer) < self.neg_trig_sum) and self.triggered:
+ current_speech[self.num_frames - (self.num_steps-i) * self.step] = 'end'
+ self.triggered = False
+ if self.triggered and self.last:
+ current_speech[self.num_frames] = 'end'
+ if self.last:
+ self.refresh()
+ return current_speech, self.current_name
+
+
+
+def state_generator(model, audios, extractor, onnx=False, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8, audios_in_stream=5):
+ VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)]
+ for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)):
+ for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
+ batch = torch.cat(for_batch)
+
+ with torch.no_grad():
+ if onnx:
+ ort_inputs = {'input': to_numpy(extractor(batch))}
+ ort_outs = model.run(None, ort_inputs)
+ vad_outs = np.split(ort_outs[-2], audios_in_stream)
+ else:
+ outs = model(extractor(batch))
+ vad_outs = np.split(outs[-2].numpy(), audios_in_stream)
+
+ states = []
+ for x, y in zip(VADiters, vad_outs):
+ cur_st = x.state(y)
+ if cur_st[0]:
+ states.append(cur_st)
+ yield states
+
+
+def stream_imitator(stereo, audios_in_stream):
+ stereo_iter = iter(stereo)
+ iterators = []
+ # initial wavs
+ for i in range(audios_in_stream):
+ next_wav = next(stereo_iter)
+ wav = read_audio(next_wav)
+ wav_chunks = iter([(wav[i:i+4000], next_wav) for i in range(0, len(wav), 4000)])
+ iterators.append(wav_chunks)
+ print('Done initial Loading')
+ good_iters = audios_in_stream
+ while True:
+ values = []
+ for i, it in enumerate(iterators):
+ try:
+ out, wav_name = next(it)
+ except StopIteration:
+ try:
+ next_wav = next(stereo_iter)
+ print('Loading next wav: ', next_wav)
+ wav = read_audio(next_wav)
+ iterators[i] = iter([(wav[i:i+4000], next_wav) for i in range(0, len(wav), 4000)])
+ out, wav_name = next(iterators[i])
+ except StopIteration:
+ good_iters -= 1
+ iterators[i] = repeat((torch.zeros(4000), 'junk'))
+ out, wav_name = next(iterators[i])
+ if good_iters == 0:
+ return
+ values.append((out, wav_name))
+ yield values
+
+