From c762bb5b529d24102a4dee94380f853142c08bea Mon Sep 17 00:00:00 2001 From: adamnsandle Date: Thu, 15 Apr 2021 14:01:05 +0000 Subject: [PATCH] add adaptive examples --- README.md | 37 ++++++- silero-vad.ipynb | 260 ++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 281 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index e0c6a94..3f13f39 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,7 @@ model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', force_reload=True) (get_speech_ts, + get_speech_ts_adaptive _, read_audio, _, _, _) = utils @@ -122,9 +123,15 @@ files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files' wav = read_audio(f'{files_dir}/en.wav') # full audio # get speech timestamps from full audio file + +# classic way speech_timestamps = get_speech_ts(wav, model, num_steps=4) pprint(speech_timestamps) + +# adaptive way +speech_timestamps = get_speech_ts_adaptive(wav, model) +pprint(speech_timestamps) ``` #### Number Detector @@ -195,6 +202,7 @@ _, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', force_reload=True) (get_speech_ts, + get_speech_ts_adaptive _, read_audio, _, _, _) = utils @@ -208,14 +216,20 @@ def validate_onnx(model, inputs): ort_inputs = {'input': inputs.cpu().numpy()} outs = model.run(None, ort_inputs) outs = [torch.Tensor(x) for x in outs] - return outs + return outs[0] model = init_onnx_model(f'{files_dir}/model.onnx') wav = read_audio(f'{files_dir}/en.wav') # get speech timestamps from full audio file + +# classic way speech_timestamps = get_speech_ts(wav, model, num_steps=4, run_function=validate_onnx) pprint(speech_timestamps) + +# adaptive way +speech_timestamps = get_speech_ts(wav, model, run_function=validate_onnx) +pprint(speech_timestamps) ``` #### Number Detector @@ -347,6 +361,9 @@ Since our VAD (only VAD, other networks are more flexible) was trained on chunks ### VAD Parameter Fine Tuning +#### **Classic way** + +**This is straightforward classic method `get_speech_ts` where tresholds (`trig_sum` and `neg_trig_sum`) are specified by users** - Among others, we provide several [utils](https://github.com/snakers4/silero-vad/blob/8b28767292b424e3e505c55f15cd3c4b91e4804b/utils.py#L52-L59) to simplify working with VAD; - We provide sensible basic hyper-parameters that work for us, but your case can be different; - `trig_sum` - overlapping windows are used for each audio chunk, trig sum defines average probability among those windows for switching into triggered state (speech state); @@ -365,6 +382,24 @@ speech_timestamps = get_speech_ts(wav, model, visualize_probs=True) ``` +#### **Adaptive way** + +**Adaptive algorythm (`get_speech_ts_adaptive`) automatically selects tresholds (`trig_sum` and `neg_trig_sum`) based on median speech probabilities over whole audio, SOME ARGUMENTS VARY FROM CLASSIC WAY FUNCTION ARGUMENTS** +- `batch_size` - batch size to feed to silero VAD (default - `200`) +- `step` - step size in samples, (default - `500`) (`num_samples_per_window` / `num_steps` from classic method) +- `num_samples_per_window` - number of samples in each window, our models were trained using `4000` samples (250 ms) per window, so this is preferable value (lesser values reduce [quality](https://github.com/snakers4/silero-vad/issues/2#issuecomment-750840434)); +- `min_speech_samples` - minimum speech chunk duration in samples (default - `10000`) +- `min_silence_samples` - minimum silence duration in samples between to separate speech chunks (default - `4000`) +- `speech_pad_samples` - widen speech by this amount of samples each side (default - `2000`) + +``` +speech_timestamps = get_speech_ts_adaptive(wav, model, + num_samples_per_window=4000, + step=500, + visualize_probs=True) +``` + + The chart should looks something like this: ![image](https://user-images.githubusercontent.com/12515440/106242896-79142580-6219-11eb-9add-fa7195d6fd26.png) diff --git a/silero-vad.ipynb b/silero-vad.ipynb index 0cb0a9d..c8235cf 100755 --- a/silero-vad.ipynb +++ b/silero-vad.ipynb @@ -3,6 +3,7 @@ { "cell_type": "markdown", "metadata": { + "heading_collapsed": true, "id": "sVNOuHQQjsrp" }, "source": [ @@ -12,7 +13,8 @@ { "cell_type": "markdown", "metadata": { - "heading_collapsed": true + "heading_collapsed": true, + "hidden": true }, "source": [ "## VAD" @@ -57,6 +59,7 @@ " force_reload=True)\n", "\n", "(get_speech_ts,\n", + " get_speech_ts_adaptive,\n", " save_audio,\n", " read_audio,\n", " state_generator,\n", @@ -77,6 +80,15 @@ "### Full Audio" ] }, + { + "cell_type": "markdown", + "metadata": { + "hidden": true + }, + "source": [ + "**Classic way of getting speech chunks, you may need to select the tresholds yourself**" + ] + }, { "cell_type": "code", "execution_count": null, @@ -116,6 +128,43 @@ "Audio('only_speech.wav')" ] }, + { + "cell_type": "markdown", + "metadata": { + "hidden": true + }, + "source": [ + "**Experimental Adaptive method, algorythm selects tresholds itself (see readme for more information)**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hidden": true + }, + "outputs": [], + "source": [ + "wav = read_audio(f'{files_dir}/en.wav')\n", + "# get speech timestamps from full audio file\n", + "speech_timestamps = get_speech_ts_adaptive(wav, model, step=500, num_samples_per_window=4000)\n", + "pprint(speech_timestamps)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hidden": true + }, + "outputs": [], + "source": [ + "# merge all speech chunks to one audio\n", + "save_audio('only_speech.wav',\n", + " collect_chunks(speech_timestamps, wav), 16000) \n", + "Audio('only_speech.wav')" + ] + }, { "cell_type": "markdown", "metadata": { @@ -127,6 +176,19 @@ "### Single Audio Stream" ] }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-15T13:29:04.224833Z", + "start_time": "2021-04-15T13:29:04.220588Z" + }, + "hidden": true + }, + "source": [ + "**Classic way of getting speech chunks, you may need to select the tresholds yourself**" + ] + }, { "cell_type": "code", "execution_count": null, @@ -147,6 +209,30 @@ " print(batch)" ] }, + { + "cell_type": "markdown", + "metadata": { + "hidden": true + }, + "source": [ + "**Experimental Adaptive method, algorythm selects tresholds itself (see readme for more information)**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hidden": true + }, + "outputs": [], + "source": [ + "wav = f'{files_dir}/en.wav'\n", + "\n", + "for batch in single_audio_stream(model, wav, iterator_type='adaptive'):\n", + " if batch:\n", + " print(batch)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -196,7 +282,8 @@ { "cell_type": "markdown", "metadata": { - "heading_collapsed": true + "heading_collapsed": true, + "hidden": true }, "source": [ "## Number detector" @@ -315,7 +402,8 @@ { "cell_type": "markdown", "metadata": { - "heading_collapsed": true + "heading_collapsed": true, + "hidden": true }, "source": [ "## Language detector" @@ -387,6 +475,7 @@ { "cell_type": "markdown", "metadata": { + "heading_collapsed": true, "id": "57avIBd6jsrz" }, "source": [ @@ -396,7 +485,8 @@ { "cell_type": "markdown", "metadata": { - "heading_collapsed": true + "heading_collapsed": true, + "hidden": true }, "source": [ "## VAD" @@ -415,13 +505,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { + "ExecuteTime": { + "end_time": "2021-04-15T13:30:22.938755Z", + "start_time": "2021-04-15T13:30:20.970574Z" + }, "cellView": "form", "hidden": true, "id": "Q4QIfSpprnkI" }, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'torch' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m\u001b[0m", + "\u001b[0;31mNameError\u001b[0mTraceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisplay\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAudio\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m _, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'silero_vad'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m force_reload=True)\n", + "\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined" + ] + } + ], "source": [ "#@title Install and Import Dependencies\n", "\n", @@ -439,6 +545,7 @@ " force_reload=True)\n", "\n", "(get_speech_ts,\n", + " get_speech_ts_adaptive,\n", " save_audio,\n", " read_audio,\n", " state_generator,\n", @@ -470,17 +577,42 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": { "ExecuteTime": { - "end_time": "2020-12-15T13:09:06.643812Z", - "start_time": "2020-12-15T13:09:06.473386Z" + "end_time": "2021-04-15T13:34:22.554010Z", + "start_time": "2021-04-15T13:34:22.550308Z" + }, + "hidden": true + }, + "source": [ + "**Classic way of getting speech chunks, you may need to select the tresholds yourself**" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-15T13:30:14.475412Z", + "start_time": "2021-04-15T13:30:14.427933Z" }, "hidden": true, "id": "krnGoA6Kjsr0" }, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'init_onnx_model' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m\u001b[0m", + "\u001b[0;31mNameError\u001b[0mTraceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minit_onnx_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'{files_dir}/model.onnx'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mwav\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_audio\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'{files_dir}/en.wav'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;31m# get speech timestamps from full audio file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mspeech_timestamps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_speech_ts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwav\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_function\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvalidate_onnx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'init_onnx_model' is not defined" + ] + } + ], "source": [ "model = init_onnx_model(f'{files_dir}/model.onnx')\n", "wav = read_audio(f'{files_dir}/en.wav')\n", @@ -508,6 +640,60 @@ "Audio('only_speech.wav')" ] }, + { + "cell_type": "markdown", + "metadata": { + "hidden": true + }, + "source": [ + "**Experimental Adaptive method, algorythm selects tresholds itself (see readme for more information)**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hidden": true + }, + "outputs": [], + "source": [ + "model = init_onnx_model(f'{files_dir}/model.onnx')\n", + "wav = read_audio(f'{files_dir}/en.wav')\n", + "\n", + "# get speech timestamps from full audio file\n", + "speech_timestamps = get_speech_ts_adaptive(wav, model, run_function=validate_onnx) \n", + "pprint(speech_timestamps)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2021-04-15T13:34:41.375446Z", + "start_time": "2021-04-15T13:34:41.368055Z" + }, + "hidden": true + }, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'save_audio' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m\u001b[0m", + "\u001b[0;31mNameError\u001b[0mTraceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# merge all speech chunks to one audio\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msave_audio\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'only_speech.wav'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcollect_chunks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspeech_timestamps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwav\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m16000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mAudio\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'only_speech.wav'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'save_audio' is not defined" + ] + } + ], + "source": [ + "# merge all speech chunks to one audio\n", + "save_audio('only_speech.wav', collect_chunks(speech_timestamps, wav), 16000)\n", + "Audio('only_speech.wav')" + ] + }, { "cell_type": "markdown", "metadata": { @@ -519,6 +705,15 @@ "### Single Audio Stream" ] }, + { + "cell_type": "markdown", + "metadata": { + "hidden": true + }, + "source": [ + "**Classic way of getting speech chunks, you may need to select the tresholds yourself**" + ] + }, { "cell_type": "code", "execution_count": null, @@ -554,6 +749,40 @@ " pprint(batch)" ] }, + { + "cell_type": "markdown", + "metadata": { + "hidden": true + }, + "source": [ + "**Experimental Adaptive method, algorythm selects tresholds itself (see readme for more information)**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hidden": true + }, + "outputs": [], + "source": [ + "model = init_onnx_model(f'{files_dir}/model.onnx')\n", + "wav = f'{files_dir}/en.wav'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hidden": true + }, + "outputs": [], + "source": [ + "for batch in single_audio_stream(model, wav, iterator_type='adaptive', run_function=validate_onnx):\n", + " if batch:\n", + " pprint(batch)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -604,7 +833,8 @@ { "cell_type": "markdown", "metadata": { - "heading_collapsed": true + "heading_collapsed": true, + "hidden": true }, "source": [ "## Number detector" @@ -753,7 +983,8 @@ { "cell_type": "markdown", "metadata": { - "heading_collapsed": true + "heading_collapsed": true, + "hidden": true }, "source": [ "## Language detector" @@ -819,7 +1050,6 @@ { "cell_type": "markdown", "metadata": { - "heading_collapsed": true, "hidden": true, "id": "5JHErdB7jsr0" }, @@ -863,7 +1093,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.3" + "version": "3.8.8" }, "toc": { "base_numbering": 1,