diff --git a/files/model.jit b/files/model.jit new file mode 100644 index 0000000..0bec210 Binary files /dev/null and b/files/model.jit differ diff --git a/files/joint_VAD_just_RU.onnx b/files/model.onnx similarity index 100% rename from files/joint_VAD_just_RU.onnx rename to files/model.onnx diff --git a/silero-vad.ipynb b/silero-vad.ipynb index 291067e..0bfec50 100644 --- a/silero-vad.ipynb +++ b/silero-vad.ipynb @@ -1,30 +1,39 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Jit example" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:24.487521Z", - "start_time": "2020-12-14T13:43:23.780570Z" + "end_time": "2020-12-15T11:54:25.940761Z", + "start_time": "2020-12-15T11:54:25.933842Z" } }, "outputs": [], "source": [ + "# imports\n", "import glob\n", "import torch\n", "from IPython.display import Audio\n", "torch.set_num_threads(1)\n", - "from utils import (init_jit_model, get_speech_ts, \n", + "\n", + "from utils import (init_jit_model, get_speech_ts,\n", " save_audio, read_audio, \n", - " state_generator, single_audio_stream, init_onnx_model)" + " state_generator, single_audio_stream)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Full audio example" + "## Full audio" ] }, { @@ -32,8 +41,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:24.492506Z", - "start_time": "2020-12-14T13:43:24.489440Z" + "end_time": "2020-12-15T11:54:27.939388Z", + "start_time": "2020-12-15T11:54:27.936636Z" } }, "outputs": [], @@ -50,14 +59,13 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:24.760714Z", - "start_time": "2020-12-14T13:43:24.493992Z" + "end_time": "2020-12-15T11:54:28.415177Z", + "start_time": "2020-12-15T11:54:28.231677Z" } }, "outputs": [], "source": [ - "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n", - "model = init_onnx_model('files/joint_VAD_just_RU.onnx')" + "model = init_jit_model('files/model.jit', 'cpu')" ] }, { @@ -65,14 +73,13 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:24.793384Z", - "start_time": "2020-12-14T13:43:24.762311Z" + "end_time": "2020-12-15T11:54:28.560822Z", + "start_time": "2020-12-15T11:54:28.549811Z" } }, "outputs": [], "source": [ - "Audio('files/test_audio_2.wav')\n", - "wav = read_audio('files/test_audio_2.wav')" + "wav = read_audio('files/en.wav')" ] }, { @@ -80,13 +87,13 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:25.320324Z", - "start_time": "2020-12-14T13:43:24.808594Z" + "end_time": "2020-12-15T11:54:30.088721Z", + "start_time": "2020-12-15T11:54:29.019358Z" } }, "outputs": [], "source": [ - "speech_timestamps = get_speech_ts(wav, model, num_steps=4) # kill extractor" + "speech_timestamps = get_speech_ts(wav, model, num_steps=4) # get speech timestamps from full audio file" ] }, { @@ -94,8 +101,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:25.324901Z", - "start_time": "2020-12-14T13:43:25.321759Z" + "end_time": "2020-12-15T11:54:30.198484Z", + "start_time": "2020-12-15T11:54:30.188311Z" } }, "outputs": [], @@ -108,13 +115,13 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:25.344065Z", - "start_time": "2020-12-14T13:43:25.326162Z" + "end_time": "2020-12-15T11:54:30.816893Z", + "start_time": "2020-12-15T11:54:30.782667Z" } }, "outputs": [], "source": [ - "save_audio('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000)\n", + "save_audio('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000) # merge all speech chunks to one audio\n", "Audio('only_speech.wav')" ] }, @@ -122,7 +129,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Single stream example" + "## Single audio stream" ] }, { @@ -130,15 +137,14 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:25.778585Z", - "start_time": "2020-12-14T13:43:25.496583Z" + "end_time": "2020-12-15T11:54:31.886189Z", + "start_time": "2020-12-15T11:54:31.572194Z" } }, "outputs": [], "source": [ - "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n", - "#model = init_onnx_model('files/joint_VAD_just_RU.onnx')\n", - "audio = 'files/test_audio_6.wav'" + "model = init_jit_model('files/model.jit', 'cpu')\n", + "wav = 'files/en.wav'" ] }, { @@ -146,13 +152,13 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:29.402604Z", - "start_time": "2020-12-14T13:43:25.780037Z" + "end_time": "2020-12-15T11:54:35.624279Z", + "start_time": "2020-12-15T11:54:32.049532Z" } }, "outputs": [], "source": [ - "for i in single_audio_stream(model, audio):\n", + "for i in single_audio_stream(model, wav):\n", " if i:\n", " print(i)" ] @@ -161,7 +167,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Multiple stream example" + "## Multiple audio stream" ] }, { @@ -169,14 +175,13 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:29.674262Z", - "start_time": "2020-12-14T13:43:29.403972Z" + "end_time": "2020-12-15T11:40:13.406225Z", + "start_time": "2020-12-15T11:40:13.206354Z" } }, "outputs": [], "source": [ - "model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n", - "model = init_onnx_model('files/joint_VAD_just_RU.onnx')" + "model = init_jit_model('files/model.jit', 'cpu')" ] }, { @@ -184,14 +189,14 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:29.678449Z", - "start_time": "2020-12-14T13:43:29.675519Z" + "end_time": "2020-12-15T11:41:08.470917Z", + "start_time": "2020-12-15T11:41:08.467369Z" } }, "outputs": [], "source": [ - "audios_for_stream = glob.glob('files/test*.wav')\n", - "len(audios_for_stream)" + "audios_for_stream = glob.glob('files/*.wav')\n", + "len(audios_for_stream) # total 4 audios" ] }, { @@ -199,29 +204,211 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:43:40.236387Z", - "start_time": "2020-12-14T13:43:29.679274Z" + "end_time": "2020-12-15T11:41:25.685356Z", + "start_time": "2020-12-15T11:41:16.222672Z" } }, "outputs": [], "source": [ - "for i in state_generator(model, audios_for_stream, audios_in_stream=2):\n", + "for i in state_generator(model, audios_for_stream, audios_in_stream=2): # 2 audio stream\n", " if i:\n", " print(i)" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Onnx example" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-12-14T13:46:49.812052Z", - "start_time": "2020-12-14T13:46:49.586637Z" + "end_time": "2020-12-15T11:55:45.597504Z", + "start_time": "2020-12-15T11:55:45.582356Z" } }, "outputs": [], "source": [ - "!cp ../silero-models-research/model_saves/joint_VAD_just_RU_jit_cut_q.pth.tar files/" + "# imports\n", + "import glob\n", + "import torch\n", + "from IPython.display import Audio\n", + "torch.set_num_threads(1)\n", + "import onnxruntime\n", + "\n", + "from utils import (get_speech_ts, save_audio, read_audio, \n", + " state_generator, single_audio_stream)\n", + "\n", + "def init_onnx_model(model_path: str):\n", + " return onnxruntime.InferenceSession(model_path)\n", + "\n", + "def validate_onnx(model, inputs):\n", + " with torch.no_grad():\n", + " ort_inputs = {'input': inputs.cpu().numpy()}\n", + " outs = model.run(None, ort_inputs)\n", + " outs = [torch.Tensor(x) for x in outs]\n", + " return outs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Full audio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-15T11:55:56.874376Z", + "start_time": "2020-12-15T11:55:56.782230Z" + } + }, + "outputs": [], + "source": [ + "model = init_onnx_model('files/model.onnx')\n", + "wav = read_audio('files/en.wav')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-15T11:56:12.159463Z", + "start_time": "2020-12-15T11:56:11.446991Z" + } + }, + "outputs": [], + "source": [ + "speech_timestamps = get_speech_ts(wav, model, num_steps=4, run_function=validate_onnx) # get speech timestamps from full audio file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-15T11:56:20.488863Z", + "start_time": "2020-12-15T11:56:20.485485Z" + } + }, + "outputs": [], + "source": [ + "speech_timestamps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-15T11:56:27.908128Z", + "start_time": "2020-12-15T11:56:27.870978Z" + } + }, + "outputs": [], + "source": [ + "save_audio('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000) # merge all speech chunks to one audio\n", + "Audio('only_speech.wav')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Single audio stream" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-15T11:58:09.012892Z", + "start_time": "2020-12-15T11:58:08.940907Z" + } + }, + "outputs": [], + "source": [ + "model = init_onnx_model('files/model.onnx')\n", + "wav = 'files/en.wav'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-15T11:58:11.562186Z", + "start_time": "2020-12-15T11:58:09.949825Z" + } + }, + "outputs": [], + "source": [ + "for i in single_audio_stream(model, wav, run_function=validate_onnx):\n", + " if i:\n", + " print(i)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multiple audio stream" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = init_onnx_model('files/model.onnx')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-15T11:59:09.381687Z", + "start_time": "2020-12-15T11:59:09.378552Z" + } + }, + "outputs": [], + "source": [ + "audios_for_stream = glob.glob('files/*.wav')\n", + "len(audios_for_stream) # total 4 audios" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-12-15T11:59:27.712905Z", + "start_time": "2020-12-15T11:59:21.608435Z" + } + }, + "outputs": [], + "source": [ + "for i in state_generator(model, audios_for_stream, audios_in_stream=2, run_function=validate_onnx): # 2 audio stream\n", + " if i:\n", + " print(i)" ] }, { diff --git a/utils.py b/utils.py index 5c9dd25..a5c00bb 100644 --- a/utils.py +++ b/utils.py @@ -1,15 +1,16 @@ import torch import torchaudio -import onnxruntime import numpy as np -from typing import List from itertools import repeat from collections import deque import torch.nn.functional as F - torchaudio.set_audio_backend("soundfile") # switch backend +def validate(model, inputs): + with torch.no_grad(): + outs = model(inputs) + return outs def read_audio(path: str, target_sr: int = 16000): @@ -43,14 +44,9 @@ def init_jit_model(model_path: str, model.eval() return model - -def init_onnx_model(model_path: str): - return onnxruntime.InferenceSession(model_path) - - def get_speech_ts(wav, model, - trig_sum=0.25, neg_trig_sum=0.01, - num_steps=8, batch_size=200): + trig_sum=0.25, neg_trig_sum=0.02, + num_steps=8, batch_size=200, run_function=validate): num_samples = 4000 assert num_samples % num_steps == 0 @@ -62,16 +58,16 @@ def get_speech_ts(wav, model, chunk = wav[i: i+num_samples] if len(chunk) < num_samples: chunk = F.pad(chunk, (0, num_samples - len(chunk))) - to_concat.append(chunk) + to_concat.append(chunk.unsqueeze(0)) if len(to_concat) >= batch_size: - chunks = torch.Tensor(torch.vstack(to_concat)) - out = validate(model, chunks)[-2] + chunks = torch.Tensor(torch.cat(to_concat, dim=0)) + out = run_function(model, chunks)[-2] outs.append(out) to_concat = [] if to_concat: - chunks = torch.Tensor(torch.vstack(to_concat)) - out = validate(model, chunks)[-2] + chunks = torch.Tensor(torch.cat(to_concat, dim=0)) + out = run_function(model, chunks)[-2] outs.append(out) outs = torch.cat(outs, dim=0) @@ -101,7 +97,7 @@ def get_speech_ts(wav, model, class VADiterator: def __init__(self, - trig_sum=0.26, neg_trig_sum=0.01, + trig_sum=0.26, neg_trig_sum=0.02, num_steps=8): self.num_samples = 4000 self.num_steps = num_steps @@ -133,11 +129,11 @@ class VADiterator: wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of audio self.last = True - stacked = torch.hstack([self.prev, wav_chunk]) + stacked = torch.cat([self.prev, wav_chunk]) self.prev = wav_chunk - overlap_chunks = [stacked[i:i+self.num_samples] for i in range(self.step, self.num_samples+1, self.step)] # 500 step is good enough - return torch.vstack(overlap_chunks) + overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0) for i in range(self.step, self.num_samples+1, self.step)] # 500 step is good enough + return torch.cat(overlap_chunks, dim=0) def state(self, model_out): current_speech = {} @@ -159,14 +155,14 @@ class VADiterator: def state_generator(model, audios, onnx=False, - trig_sum=0.26, neg_trig_sum=0.01, - num_steps=8, audios_in_stream=5): + trig_sum=0.26, neg_trig_sum=0.02, + num_steps=8, audios_in_stream=5, run_function=validate): VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)] for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)): for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)] batch = torch.cat(for_batch) - outs = validate(model, batch) + outs = run_function(model, batch) vad_outs = np.split(outs[-2].numpy(), audios_in_stream) states = [] @@ -212,7 +208,7 @@ def stream_imitator(audios, audios_in_stream): def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, - neg_trig_sum=0.01, num_steps=8): + neg_trig_sum=0.02, num_steps=8, run_function=validate): num_samples = 4000 VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps) wav = read_audio(audio) @@ -220,7 +216,7 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, for chunk in wav_chunks: batch = VADiter.prepare_batch(chunk) - outs = validate(model, batch) + outs = run_function(model, batch) vad_outs = outs[-2] # this is very misleading states = [] @@ -228,17 +224,3 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, if state[0]: states.append(state[0]) yield states - - -def validate(model, inputs): - onnx = False - if type(model) == onnxruntime.capi.session.InferenceSession: - onnx = True - with torch.no_grad(): - if onnx: - ort_inputs = {'input': inputs.cpu().numpy()} - outs = model.run(None, ort_inputs) - outs = [torch.Tensor(x) for x in outs] - else: - outs = model(inputs) - return outs