diff --git a/files/joint_VAD_just_RU_jit_cut_q.pth.tar b/files/joint_VAD_just_RU_jit_cut_q.pth.tar new file mode 100644 index 0000000..20e8e4a Binary files /dev/null and b/files/joint_VAD_just_RU_jit_cut_q.pth.tar differ diff --git a/silero-vad.ipynb b/silero-vad.ipynb new file mode 100644 index 0000000..403df15 --- /dev/null +++ b/silero-vad.ipynb @@ -0,0 +1,322 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-11T13:30:32.615246Z", + "start_time": "2020-12-11T13:30:32.126553Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n", + " warnings.warn(\n", + "/opt/conda/lib/python3.8/site-packages/torchaudio/backend/utils.py:63: UserWarning: The interface of \"soundfile\" backend is planned to change in 0.8.0 to match that of \"sox_io\" backend and the current interface will be removed in 0.9.0. To use the new interface, do `torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False` before setting the backend to \"soundfile\". Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "import torch\n", + "import numpy as np\n", + "import glob\n", + "import torch.nn.functional as F\n", + "import soundfile as sf\n", + "from IPython.display import Audio\n", + "torch.set_num_threads(1)\n", + "from utils import init_jit_model, STFTExtractor, get_speech_ts, read_audio, state_generator\n", + "extractor = STFTExtractor()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Full audio example" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-11T13:32:01.978079Z", + "start_time": "2020-12-11T13:32:01.974912Z" + } + }, + "outputs": [], + "source": [ + "def collect_speeches(tss, wav):\n", + " speech_chunks = []\n", + " for i in tss:\n", + " speech_chunks.append(wav[i['start']: i['end']])\n", + " return np.concatenate(speech_chunks)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-11T13:31:55.255097Z", + "start_time": "2020-12-11T13:31:55.020705Z" + } + }, + "outputs": [], + "source": [ + "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-11T13:32:10.391589Z", + "start_time": "2020-12-11T13:32:10.387109Z" + } + }, + "outputs": [], + "source": [ + "wav = read_audio('files/test_audio_6.wav')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-11T13:32:11.670091Z", + "start_time": "2020-12-11T13:32:10.814378Z" + } + }, + "outputs": [], + "source": [ + "speech_timestamps = get_speech_ts(wav, model, extractor, num_steps=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-11T13:32:11.698816Z", + "start_time": "2020-12-11T13:32:11.671735Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sf.write('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000)\n", + "Audio('only_speech.wav')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Stream example" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-11T13:31:34.137062Z", + "start_time": "2020-12-11T13:31:33.957092Z" + } + }, + "outputs": [], + "source": [ + "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-11T13:31:36.332200Z", + "start_time": "2020-12-11T13:31:36.328087Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "audios_for_stream = glob.glob('files/test*.wav')\n", + "len(audios_for_stream)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-11T13:31:52.668041Z", + "start_time": "2020-12-11T13:31:37.357340Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done initial Loading\n", + "[({106500: 'start'}, 'files/test_audio_1.wav')]\n", + "[({174000: 'start'}, 'files/test_audio_3.wav')]\n", + "[({261000: 'end'}, 'files/test_audio_1.wav')]\n", + "Loading next wav: files/test_audio_7.wav\n", + "[({134000: 'start'}, 'files/test_audio_7.wav')]\n", + "[({147500: 'end'}, 'files/test_audio_7.wav')]\n", + "[({442000: 'end'}, 'files/test_audio_3.wav')]\n", + "[({450500: 'start'}, 'files/test_audio_3.wav')]\n", + "[({209500: 'start'}, 'files/test_audio_7.wav')]\n", + "[({519500: 'end'}, 'files/test_audio_3.wav')]\n", + "[({533500: 'start'}, 'files/test_audio_3.wav')]\n", + "[({599904: 'end'}, 'files/test_audio_3.wav')]\n", + "Loading next wav: files/test_audio_6.wav\n", + "[({183500: 'start'}, 'files/test_audio_6.wav')]\n", + "[({503500: 'end'}, 'files/test_audio_7.wav')]\n", + "[({202500: 'end'}, 'files/test_audio_6.wav')]\n", + "[({537500: 'start'}, 'files/test_audio_7.wav')]\n", + "[({226500: 'start'}, 'files/test_audio_6.wav')]\n", + "[({283500: 'end'}, 'files/test_audio_6.wav')]\n", + "[({616500: 'end'}, 'files/test_audio_7.wav')]\n", + "[({337500: 'start'}, 'files/test_audio_6.wav')]\n", + "[({661500: 'start'}, 'files/test_audio_7.wav')]\n", + "[({785000: 'end'}, 'files/test_audio_7.wav')]\n", + "[({503000: 'end'}, 'files/test_audio_6.wav')]\n", + "[({507500: 'start'}, 'files/test_audio_6.wav')]\n", + "[({851500: 'start'}, 'files/test_audio_7.wav')]\n", + "[({919000: 'end'}, 'files/test_audio_7.wav')]\n", + "Loading next wav: files/test_audio_5.wav\n", + "[({627500: 'end'}, 'files/test_audio_6.wav')]\n", + "[({631500: 'start'}, 'files/test_audio_6.wav')]\n", + "[({151000: 'start'}, 'files/test_audio_5.wav')]\n", + "[({169000: 'end'}, 'files/test_audio_5.wav')]\n", + "[({211000: 'start'}, 'files/test_audio_5.wav')]\n", + "[({221500: 'end'}, 'files/test_audio_5.wav')]\n", + "Loading next wav: files/test_audio_2.wav\n", + "[({927488: 'end'}, 'files/test_audio_6.wav')]\n", + "Loading next wav: files/test_audio_8.wav\n", + "[({228000: 'start'}, 'files/test_audio_2.wav')]\n", + "[({179500: 'start'}, 'files/test_audio_8.wav')]\n", + "[({241500: 'end'}, 'files/test_audio_2.wav')]\n", + "[({279000: 'start'}, 'files/test_audio_2.wav')]\n", + "[({274500: 'end'}, 'files/test_audio_8.wav')]\n", + "[({300500: 'start'}, 'files/test_audio_8.wav')]\n", + "[({369500: 'end'}, 'files/test_audio_2.wav')]\n", + "[({378500: 'start'}, 'files/test_audio_2.wav')]\n", + "[({436500: 'end'}, 'files/test_audio_2.wav')]\n", + "[({423000: 'end'}, 'files/test_audio_8.wav')]\n", + "[({488500: 'start'}, 'files/test_audio_2.wav')]\n", + "[({458500: 'start'}, 'files/test_audio_8.wav')]\n", + "[({599904: 'end'}, 'files/test_audio_2.wav')]\n", + "Loading next wav: files/test_audio_4.wav\n", + "[({583500: 'end'}, 'files/test_audio_8.wav')]\n", + "[({599500: 'start'}, 'files/test_audio_8.wav')]\n", + "[({632500: 'end'}, 'files/test_audio_8.wav')]\n", + "[({660000: 'start'}, 'files/test_audio_8.wav')]\n", + "[({737000: 'end'}, 'files/test_audio_8.wav')]\n", + "[({761000: 'start'}, 'files/test_audio_8.wav')]\n", + "[({249500: 'start'}, 'files/test_audio_4.wav')]\n", + "[({257168: 'end'}, 'files/test_audio_4.wav')]\n", + "Loading next wav: files/test_audio_9.wav\n", + "[({843000: 'end'}, 'files/test_audio_8.wav')]\n", + "Loading next wav: files/test_audio_0.wav\n", + "[({133000: 'start'}, 'files/test_audio_9.wav')]\n", + "[({143500: 'end'}, 'files/test_audio_9.wav')]\n", + "[({272000: 'start'}, 'files/test_audio_9.wav')]\n", + "[({256500: 'start'}, 'files/test_audio_0.wav')]\n", + "[({336500: 'end'}, 'files/test_audio_9.wav'), ({281232: 'end'}, 'files/test_audio_0.wav')]\n", + "[({406500: 'start'}, 'files/test_audio_9.wav')]\n", + "[({460000: 'end'}, 'files/test_audio_9.wav')]\n", + "[({476000: 'start'}, 'files/test_audio_9.wav')]\n", + "[({494500: 'end'}, 'files/test_audio_9.wav')]\n", + "[({544500: 'start'}, 'files/test_audio_9.wav')]\n", + "[({564500: 'end'}, 'files/test_audio_9.wav')]\n", + "[({595000: 'start'}, 'files/test_audio_9.wav')]\n", + "[({682000: 'end'}, 'files/test_audio_9.wav')]\n", + "[({728500: 'start'}, 'files/test_audio_9.wav')]\n", + "[({786000: 'end'}, 'files/test_audio_9.wav')]\n", + "[({814000: 'start'}, 'files/test_audio_9.wav')]\n", + "[({826000: 'end'}, 'files/test_audio_9.wav')]\n" + ] + } + ], + "source": [ + "for i in state_generator(model, audios_for_stream, extractor, audios_in_stream=2):\n", + " if i:\n", + " print(i)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/utils.py b/utils.py index 23019df..8355a73 100644 --- a/utils.py +++ b/utils.py @@ -2,6 +2,11 @@ import torch import tempfile import torchaudio from typing import List +import torch.nn as nn +import torch.nn.functional as F +from collections import deque +import numpy as np +from itertools import repeat torchaudio.set_audio_backend("soundfile") # switch backend @@ -48,13 +53,210 @@ def prepare_model_input(batch: List[torch.Tensor], return inputs -def init_jit_model(model_url: str, - device: torch.device = torch.device('cpu')): +#def init_jit_model(model_url: str, +# device: torch.device = torch.device('cpu')): +# torch.set_grad_enabled(False) +# with tempfile.NamedTemporaryFile('wb', suffix='.model') as f: +# torch.hub.download_url_to_file(model_url, +# f.name, +# progress=True) +# model = torch.jit.load(f.name, map_location=device) +# model.eval() +# return model + + +def init_jit_model(model_path, + device): torch.set_grad_enabled(False) - with tempfile.NamedTemporaryFile('wb', suffix='.model') as f: - torch.hub.download_url_to_file(model_url, - f.name, - progress=True) - model = torch.jit.load(f.name, map_location=device) - model.eval() + model = torch.jit.load(model_path, map_location=device) + model.eval() return model + + +def get_speech_ts(wav, model, extractor, trig_sum=0.25, neg_trig_sum=0.01, num_steps=8, batch_size=200): + assert 4000 % num_steps == 0 + step = int(4000 / num_steps) + outs = [] + to_concat = [] + for i in range(0, len(wav), step): + chunk = wav[i: i+4000] + if len(chunk) < 4000: + chunk = F.pad(chunk, (0, 4000 - len(chunk))) + to_concat.append(chunk) + if len(to_concat) >= batch_size: + chunks = torch.Tensor(torch.vstack(to_concat)) + with torch.no_grad(): + out = model(extractor(chunks))[-2] + outs.append(out) + to_concat = [] + + if to_concat: + chunks = torch.Tensor(torch.vstack(to_concat)) + with torch.no_grad(): + out = model(extractor(chunks))[-2] + outs.append(out) + + outs = torch.cat(outs, dim=0) + + buffer = deque(maxlen=num_steps) + triggered = False + speeches = [] + current_speech = {} + for i, predict in enumerate(outs[:, 1]): + buffer.append(predict) + if (np.mean(buffer) >= trig_sum) and not triggered: + triggered = True + current_speech['start'] = step * max(0, i-num_steps) + if (np.mean(buffer) < neg_trig_sum) and triggered: + current_speech['end'] = step * i + if (current_speech['end'] - current_speech['start']) > 10000: + speeches.append(current_speech) + current_speech = {} + triggered = False + if current_speech: + current_speech['end'] = len(wav) + speeches.append(current_speech) + return speeches + + +class STFTExtractor(nn.Module): + def __init__(self, sr=16000, win_size=0.02, mode='mag'): + super(STFTExtractor, self).__init__() + self.sr = sr + self.n_fft = int(sr * (win_size + 1e-8)) + self.win_length = self.n_fft + self.hop_length = self.win_length // 2 + self.mode = 'mag' if mode == '' else mode + + def forward(self, wav): + # center==True because other frame-level features are centered by default in torch/librosa and we can't change this. + stft_sample = torch.stft(wav, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + center=True) + mag, phase = torchaudio.functional.magphase(stft_sample) + + # It seems it is not a "mag", it is "power" (exp == 1). + # Also there is "energy" (exp == 2). + if self.mode == 'mag': + return mag + if self.mode == 'phase': + return phase + elif self.mode == 'magphase': + return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1) + else: + raise NotImplementedError() + + +class VADiterator: + def __init__(self, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8): + self.num_steps = num_steps + assert 4000 % num_steps == 0 + self.step = int(4000 / num_steps) + self.prev = torch.zeros(4000) + self.last = False + self.triggered = False + self.buffer = deque(maxlen=8) + self.num_frames = 0 + self.trig_sum = trig_sum + self.neg_trig_sum = neg_trig_sum + self.current_name = '' + + def refresh(self): + self.prev = torch.zeros(4000) + self.last = False + self.triggered = False + self.buffer = deque(maxlen=8) + self.num_frames = 0 + + def prepare_batch(self, wav_chunk, name=None): + if (name is not None) and (name != self.current_name): + self.refresh() + self.current_name = name + assert len(wav_chunk) <= 4000 + self.num_frames += len(wav_chunk) + if len(wav_chunk) < 4000: + wav_chunk = F.pad(wav_chunk, (0, 4000 - len(wav_chunk))) # assume that short chunk means end of the audio + self.last = True + + stacked = torch.hstack([self.prev, wav_chunk]) + self.prev = wav_chunk + + overlap_chunks = [stacked[i:i+4000] for i in range(500, 4001, self.step)] # 500 step is good enough + return torch.vstack(overlap_chunks) + + def state(self, model_out): + current_speech = {} + for i, predict in enumerate(model_out[:, 1]): + self.buffer.append(predict) + if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered: + self.triggered = True + current_speech[self.num_frames - (self.num_steps-i) * self.step] = 'start' + if (np.mean(self.buffer) < self.neg_trig_sum) and self.triggered: + current_speech[self.num_frames - (self.num_steps-i) * self.step] = 'end' + self.triggered = False + if self.triggered and self.last: + current_speech[self.num_frames] = 'end' + if self.last: + self.refresh() + return current_speech, self.current_name + + + +def state_generator(model, audios, extractor, onnx=False, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8, audios_in_stream=5): + VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)] + for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)): + for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)] + batch = torch.cat(for_batch) + + with torch.no_grad(): + if onnx: + ort_inputs = {'input': to_numpy(extractor(batch))} + ort_outs = model.run(None, ort_inputs) + vad_outs = np.split(ort_outs[-2], audios_in_stream) + else: + outs = model(extractor(batch)) + vad_outs = np.split(outs[-2].numpy(), audios_in_stream) + + states = [] + for x, y in zip(VADiters, vad_outs): + cur_st = x.state(y) + if cur_st[0]: + states.append(cur_st) + yield states + + +def stream_imitator(stereo, audios_in_stream): + stereo_iter = iter(stereo) + iterators = [] + # initial wavs + for i in range(audios_in_stream): + next_wav = next(stereo_iter) + wav = read_audio(next_wav) + wav_chunks = iter([(wav[i:i+4000], next_wav) for i in range(0, len(wav), 4000)]) + iterators.append(wav_chunks) + print('Done initial Loading') + good_iters = audios_in_stream + while True: + values = [] + for i, it in enumerate(iterators): + try: + out, wav_name = next(it) + except StopIteration: + try: + next_wav = next(stereo_iter) + print('Loading next wav: ', next_wav) + wav = read_audio(next_wav) + iterators[i] = iter([(wav[i:i+4000], next_wav) for i in range(0, len(wav), 4000)]) + out, wav_name = next(iterators[i]) + except StopIteration: + good_iters -= 1 + iterators[i] = repeat((torch.zeros(4000), 'junk')) + out, wav_name = next(iterators[i]) + if good_iters == 0: + return + values.append((out, wav_name)) + yield values + +