diff --git a/files/joint_VAD_just_RU.onnx b/files/joint_VAD_just_RU.onnx new file mode 100644 index 0000000..e4596f1 Binary files /dev/null and b/files/joint_VAD_just_RU.onnx differ diff --git a/files/joint_VAD_just_RU_jit_cut_q.pth.tar b/files/joint_VAD_just_RU_jit_cut_q.pth.tar index 20e8e4a..67be4fb 100644 Binary files a/files/joint_VAD_just_RU_jit_cut_q.pth.tar and b/files/joint_VAD_just_RU_jit_cut_q.pth.tar differ diff --git a/silero-vad.ipynb b/silero-vad.ipynb index aefc8b5..291067e 100644 --- a/silero-vad.ipynb +++ b/silero-vad.ipynb @@ -2,35 +2,22 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-11T15:10:52.128138Z", - "start_time": "2020-12-11T15:10:51.548322Z" + "end_time": "2020-12-14T13:43:24.487521Z", + "start_time": "2020-12-14T13:43:23.780570Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/conda/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n", - " warnings.warn(\n", - "/opt/conda/lib/python3.8/site-packages/torchaudio/backend/utils.py:63: UserWarning: The interface of \"soundfile\" backend is planned to change in 0.8.0 to match that of \"sox_io\" backend and the current interface will be removed in 0.9.0. To use the new interface, do `torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False` before setting the backend to \"soundfile\". Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "import glob\n", "import torch\n", - "import numpy as np # use only torch?\n", - "import soundfile as sf\n", - "# import torch.nn.functional as F\n", "from IPython.display import Audio\n", "torch.set_num_threads(1)\n", - "from utils import init_jit_model, STFTExtractor, get_speech_ts, read_audio, state_generator, single_audio_stream\n", - "extractor = STFTExtractor()" + "from utils import (init_jit_model, get_speech_ts, \n", + " save_audio, read_audio, \n", + " state_generator, single_audio_stream, init_onnx_model)" ] }, { @@ -42,11 +29,11 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-11T14:25:05.274301Z", - "start_time": "2020-12-11T14:25:05.271313Z" + "end_time": "2020-12-14T13:43:24.492506Z", + "start_time": "2020-12-14T13:43:24.489440Z" } }, "outputs": [], @@ -55,132 +42,82 @@ " speech_chunks = []\n", " for i in tss:\n", " speech_chunks.append(wav[i['start']: i['end']])\n", - " return np.concatenate(speech_chunks)" + " return torch.cat(speech_chunks)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-11T14:25:06.395183Z", - "start_time": "2020-12-11T14:25:06.082595Z" + "end_time": "2020-12-14T13:43:24.760714Z", + "start_time": "2020-12-14T13:43:24.493992Z" } }, "outputs": [], "source": [ - "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file" + "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n", + "model = init_onnx_model('files/joint_VAD_just_RU.onnx')" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-11T14:25:25.523423Z", - "start_time": "2020-12-11T14:25:25.493581Z" - } - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Audio('files/test_audio_8.wav')" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "ExecuteTime": { - "end_time": "2020-12-11T14:25:43.023784Z", - "start_time": "2020-12-11T14:25:43.017360Z" + "end_time": "2020-12-14T13:43:24.793384Z", + "start_time": "2020-12-14T13:43:24.762311Z" } }, "outputs": [], "source": [ - "wav = read_audio('files/test_audio_8.wav')" + "Audio('files/test_audio_2.wav')\n", + "wav = read_audio('files/test_audio_2.wav')" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-11T14:25:45.083872Z", - "start_time": "2020-12-11T14:25:43.371366Z" + "end_time": "2020-12-14T13:43:25.320324Z", + "start_time": "2020-12-14T13:43:24.808594Z" } }, "outputs": [], "source": [ - "speech_timestamps = get_speech_ts(wav, model, extractor, num_steps=4) # kill extractor" + "speech_timestamps = get_speech_ts(wav, model, num_steps=4) # kill extractor" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-11T14:25:45.130371Z", - "start_time": "2020-12-11T14:25:45.091010Z" + "end_time": "2020-12-14T13:43:25.324901Z", + "start_time": "2020-12-14T13:43:25.321759Z" } }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "sf.write('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000)\n", + "speech_timestamps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-14T13:43:25.344065Z", + "start_time": "2020-12-14T13:43:25.326162Z" + } + }, + "outputs": [], + "source": [ + "save_audio('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000)\n", "Audio('only_speech.wav')" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, @@ -190,76 +127,36 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-11T15:10:55.789272Z", - "start_time": "2020-12-11T15:10:55.543652Z" + "end_time": "2020-12-14T13:43:25.778585Z", + "start_time": "2020-12-14T13:43:25.496583Z" } }, "outputs": [], "source": [ - "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu')\n", + "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n", + "#model = init_onnx_model('files/joint_VAD_just_RU.onnx')\n", "audio = 'files/test_audio_6.wav'" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-11T15:10:59.503301Z", - "start_time": "2020-12-11T15:10:55.790671Z" + "end_time": "2020-12-14T13:43:29.402604Z", + "start_time": "2020-12-14T13:43:25.780037Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/conda/lib/python3.8/site-packages/torch/functional.py:515: UserWarning: stft will require the return_complex parameter be explicitly specified in a future PyTorch release. Use return_complex=False to preserve the current behavior or return_complex=True to return a complex output. (Triggered internally at /opt/conda/conda-bld/pytorch_1603729096996/work/aten/src/ATen/native/SpectralOps.cpp:653.)\n", - " return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore\n", - "/opt/conda/lib/python3.8/site-packages/torch/functional.py:515: UserWarning: The function torch.rfft is deprecated and will be removed in a future PyTorch release. Use the new torch.fft module functions, instead, by importing torch.fft and calling torch.fft.fft or torch.fft.rfft. (Triggered internally at /opt/conda/conda-bld/pytorch_1603729096996/work/aten/src/ATen/native/SpectralOps.cpp:590.)\n", - " return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[{183500: 'start'}]\n", - "[{202500: 'end'}]\n", - "[{226500: 'start'}]\n", - "[{283500: 'end'}]\n", - "[{337500: 'start'}]\n", - "[{503000: 'end'}]\n", - "[{507500: 'start'}]\n", - "[{627500: 'end'}]\n", - "[{631500: 'start'}]\n", - "[{927488: 'end'}]\n" - ] - } - ], + "outputs": [], "source": [ - "for i in single_audio_stream(model, audio, extractor):\n", + "for i in single_audio_stream(model, audio):\n", " if i:\n", " print(i)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, @@ -269,39 +166,29 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-11T14:28:09.649303Z", - "start_time": "2020-12-11T14:28:09.373634Z" + "end_time": "2020-12-14T13:43:29.674262Z", + "start_time": "2020-12-14T13:43:29.403972Z" } }, "outputs": [], "source": [ - "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu')" + "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n", + "model = init_onnx_model('files/joint_VAD_just_RU.onnx')" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-11T14:28:12.273951Z", - "start_time": "2020-12-11T14:28:12.269729Z" + "end_time": "2020-12-14T13:43:29.678449Z", + "start_time": "2020-12-14T13:43:29.675519Z" } }, - "outputs": [ - { - "data": { - "text/plain": [ - "10" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "audios_for_stream = glob.glob('files/test*.wav')\n", "len(audios_for_stream)" @@ -309,107 +196,34 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-11T14:28:32.459872Z", - "start_time": "2020-12-11T14:28:14.502871Z" + "end_time": "2020-12-14T13:43:40.236387Z", + "start_time": "2020-12-14T13:43:29.679274Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Done initial Loading\n", - "[({106500: 'start'}, 'files/test_audio_1.wav')]\n", - "[({174000: 'start'}, 'files/test_audio_3.wav')]\n", - "[({261000: 'end'}, 'files/test_audio_1.wav')]\n", - "Loading next wav: files/test_audio_7.wav\n", - "[({134000: 'start'}, 'files/test_audio_7.wav')]\n", - "[({147500: 'end'}, 'files/test_audio_7.wav')]\n", - "[({442000: 'end'}, 'files/test_audio_3.wav')]\n", - "[({450500: 'start'}, 'files/test_audio_3.wav')]\n", - "[({209500: 'start'}, 'files/test_audio_7.wav')]\n", - "[({519500: 'end'}, 'files/test_audio_3.wav')]\n", - "[({533500: 'start'}, 'files/test_audio_3.wav')]\n", - "[({599904: 'end'}, 'files/test_audio_3.wav')]\n", - "Loading next wav: files/test_audio_6.wav\n", - "[({183500: 'start'}, 'files/test_audio_6.wav')]\n", - "[({503500: 'end'}, 'files/test_audio_7.wav')]\n", - "[({202500: 'end'}, 'files/test_audio_6.wav')]\n", - "[({537500: 'start'}, 'files/test_audio_7.wav')]\n", - "[({226500: 'start'}, 'files/test_audio_6.wav')]\n", - "[({283500: 'end'}, 'files/test_audio_6.wav')]\n", - "[({616500: 'end'}, 'files/test_audio_7.wav')]\n", - "[({337500: 'start'}, 'files/test_audio_6.wav')]\n", - "[({661500: 'start'}, 'files/test_audio_7.wav')]\n", - "[({785000: 'end'}, 'files/test_audio_7.wav')]\n", - "[({503000: 'end'}, 'files/test_audio_6.wav')]\n", - "[({507500: 'start'}, 'files/test_audio_6.wav')]\n", - "[({851500: 'start'}, 'files/test_audio_7.wav')]\n", - "[({919000: 'end'}, 'files/test_audio_7.wav')]\n", - "Loading next wav: files/test_audio_5.wav\n", - "[({627500: 'end'}, 'files/test_audio_6.wav')]\n", - "[({631500: 'start'}, 'files/test_audio_6.wav')]\n", - "[({151000: 'start'}, 'files/test_audio_5.wav')]\n", - "[({169000: 'end'}, 'files/test_audio_5.wav')]\n", - "[({211000: 'start'}, 'files/test_audio_5.wav')]\n", - "[({221500: 'end'}, 'files/test_audio_5.wav')]\n", - "Loading next wav: files/test_audio_2.wav\n", - "[({927488: 'end'}, 'files/test_audio_6.wav')]\n", - "Loading next wav: files/test_audio_8.wav\n", - "[({228000: 'start'}, 'files/test_audio_2.wav')]\n", - "[({179500: 'start'}, 'files/test_audio_8.wav')]\n", - "[({241500: 'end'}, 'files/test_audio_2.wav')]\n", - "[({279000: 'start'}, 'files/test_audio_2.wav')]\n", - "[({274500: 'end'}, 'files/test_audio_8.wav')]\n", - "[({300500: 'start'}, 'files/test_audio_8.wav')]\n", - "[({369500: 'end'}, 'files/test_audio_2.wav')]\n", - "[({378500: 'start'}, 'files/test_audio_2.wav')]\n", - "[({436500: 'end'}, 'files/test_audio_2.wav')]\n", - "[({423000: 'end'}, 'files/test_audio_8.wav')]\n", - "[({488500: 'start'}, 'files/test_audio_2.wav')]\n", - "[({458500: 'start'}, 'files/test_audio_8.wav')]\n", - "[({599904: 'end'}, 'files/test_audio_2.wav')]\n", - "Loading next wav: files/test_audio_4.wav\n", - "[({583500: 'end'}, 'files/test_audio_8.wav')]\n", - "[({599500: 'start'}, 'files/test_audio_8.wav')]\n", - "[({632500: 'end'}, 'files/test_audio_8.wav')]\n", - "[({660000: 'start'}, 'files/test_audio_8.wav')]\n", - "[({737000: 'end'}, 'files/test_audio_8.wav')]\n", - "[({761000: 'start'}, 'files/test_audio_8.wav')]\n", - "[({249500: 'start'}, 'files/test_audio_4.wav')]\n", - "[({257168: 'end'}, 'files/test_audio_4.wav')]\n", - "Loading next wav: files/test_audio_9.wav\n", - "[({843000: 'end'}, 'files/test_audio_8.wav')]\n", - "Loading next wav: files/test_audio_0.wav\n", - "[({133000: 'start'}, 'files/test_audio_9.wav')]\n", - "[({143500: 'end'}, 'files/test_audio_9.wav')]\n", - "[({272000: 'start'}, 'files/test_audio_9.wav')]\n", - "[({256500: 'start'}, 'files/test_audio_0.wav')]\n", - "[({336500: 'end'}, 'files/test_audio_9.wav'), ({281232: 'end'}, 'files/test_audio_0.wav')]\n", - "[({406500: 'start'}, 'files/test_audio_9.wav')]\n", - "[({460000: 'end'}, 'files/test_audio_9.wav')]\n", - "[({476000: 'start'}, 'files/test_audio_9.wav')]\n", - "[({494500: 'end'}, 'files/test_audio_9.wav')]\n", - "[({544500: 'start'}, 'files/test_audio_9.wav')]\n", - "[({564500: 'end'}, 'files/test_audio_9.wav')]\n", - "[({595000: 'start'}, 'files/test_audio_9.wav')]\n", - "[({682000: 'end'}, 'files/test_audio_9.wav')]\n", - "[({728500: 'start'}, 'files/test_audio_9.wav')]\n", - "[({786000: 'end'}, 'files/test_audio_9.wav')]\n", - "[({814000: 'start'}, 'files/test_audio_9.wav')]\n", - "[({826000: 'end'}, 'files/test_audio_9.wav')]\n" - ] - } - ], + "outputs": [], "source": [ - "for i in state_generator(model, audios_for_stream, extractor, audios_in_stream=2):\n", + "for i in state_generator(model, audios_for_stream, audios_in_stream=2):\n", " if i:\n", " print(i)" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-14T13:46:49.812052Z", + "start_time": "2020-12-14T13:46:49.586637Z" + } + }, + "outputs": [], + "source": [ + "!cp ../silero-models-research/model_saves/joint_VAD_just_RU_jit_cut_q.pth.tar files/" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/utils.py b/utils.py index a4e05ed..0cda630 100644 --- a/utils.py +++ b/utils.py @@ -7,23 +7,10 @@ import torch.nn.functional as F from collections import deque import numpy as np from itertools import repeat +import onnxruntime torchaudio.set_audio_backend("soundfile") # switch backend - -def read_batch(audio_paths: List[str]): - return [read_audio(audio_path) - for audio_path - in audio_paths] - - -def split_into_batches(lst: List[str], - batch_size: int = 10): - return [lst[i:i + batch_size] - for i in - range(0, len(lst), batch_size)] - - def read_audio(path: str, target_sr: int = 16000): @@ -42,15 +29,10 @@ def read_audio(path: str, assert sr == target_sr return wav.squeeze(0) - -def prepare_model_input(batch: List[torch.Tensor], - device=torch.device('cpu')): - max_seqlength = max(max([len(_) for _ in batch]), 12800) - inputs = torch.zeros(len(batch), max_seqlength) - for i, wav in enumerate(batch): - inputs[i, :len(wav)].copy_(wav) - inputs = inputs.to(device) - return inputs +def save_audio(path: str, + tensor: torch.Tensor, + sr: int): + torchaudio.save(path, tensor, sr) #def init_jit_model(model_url: str, @@ -72,8 +54,11 @@ def init_jit_model(model_path, model.eval() return model +def init_onnx_model(model_path): + return onnxruntime.InferenceSession(model_path) -def get_speech_ts(wav, model, extractor, + +def get_speech_ts(wav, model, trig_sum=0.25, neg_trig_sum=0.01, num_steps=8, batch_size=200): @@ -90,15 +75,13 @@ def get_speech_ts(wav, model, extractor, to_concat.append(chunk) if len(to_concat) >= batch_size: chunks = torch.Tensor(torch.vstack(to_concat)) - with torch.no_grad(): - out = model(extractor(chunks))[-2] + out = validate(model, chunks)[-2] outs.append(out) to_concat = [] if to_concat: chunks = torch.Tensor(torch.vstack(to_concat)) - with torch.no_grad(): - out = model(extractor(chunks))[-2] + out = validate(model, chunks)[-2] outs.append(out) outs = torch.cat(outs, dim=0) @@ -125,37 +108,6 @@ def get_speech_ts(wav, model, extractor, speeches.append(current_speech) return speeches - -class STFTExtractor(nn.Module): - def __init__(self, sr=16000, win_size=0.02, mode='mag'): - super(STFTExtractor, self).__init__() - self.sr = sr - self.n_fft = int(sr * (win_size + 1e-8)) - self.win_length = self.n_fft - self.hop_length = self.win_length // 2 - self.mode = 'mag' if mode == '' else mode - - def forward(self, wav): - # center==True because other frame-level features are centered by default in torch/librosa and we can't change this. - stft_sample = torch.stft(wav, - n_fft=self.n_fft, - win_length=self.win_length, - hop_length=self.hop_length, - center=True) - mag, phase = torchaudio.functional.magphase(stft_sample) - - # It seems it is not a "mag", it is "power" (exp == 1). - # Also there is "energy" (exp == 2). - if self.mode == 'mag': - return mag - if self.mode == 'phase': - return phase - elif self.mode == 'magphase': - return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1) - else: - raise NotImplementedError() - - class VADiterator: def __init__(self, trig_sum=0.26, neg_trig_sum=0.01, @@ -214,7 +166,7 @@ class VADiterator: return current_speech, self.current_name -def state_generator(model, audios, extractor, +def state_generator(model, audios, onnx=False, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8, audios_in_stream=5): @@ -223,14 +175,8 @@ def state_generator(model, audios, extractor, for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)] batch = torch.cat(for_batch) - with torch.no_grad(): - if onnx: - ort_inputs = {'input': to_numpy(extractor(batch))} - ort_outs = model.run(None, ort_inputs) - vad_outs = np.split(ort_outs[-2], audios_in_stream) - else: - outs = model(extractor(batch)) - vad_outs = np.split(outs[-2].numpy(), audios_in_stream) + outs = validate(model, batch) + vad_outs = np.split(outs[-2].numpy(), audios_in_stream) states = [] for x, y in zip(VADiters, vad_outs): @@ -273,7 +219,7 @@ def stream_imitator(audios, audios_in_stream): values.append((out, wav_name)) yield values -def single_audio_stream(model, audio, extractor, onnx=False, trig_sum=0.26, +def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8): num_samples = 4000 VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps) @@ -281,18 +227,25 @@ def single_audio_stream(model, audio, extractor, onnx=False, trig_sum=0.26, wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)]) for chunk in wav_chunks: batch = VADiter.prepare_batch(chunk) - - with torch.no_grad(): - if onnx: - ort_inputs = {'input': to_numpy(extractor(batch))} - ort_outs = model.run(None, ort_inputs) - vad_outs = ort_outs[-2] - else: - outs = model(extractor(batch)) - vad_outs = outs[-2] + + outs = validate(model, batch) + vad_outs = outs[-2] states = [] state = VADiter.state(vad_outs) if state[0]: states.append(state[0]) yield states + +def validate(model, inputs): + onnx = False + if type(model) == onnxruntime.capi.session.InferenceSession: + onnx = True + with torch.no_grad(): + if onnx: + ort_inputs = {'input': inputs.cpu().numpy()} + outs = model.run(None, ort_inputs) + outs = [torch.Tensor(x) for x in outs] + else: + outs = model(inputs) + return outs