22 Commits
v3.0 ... v3.1

Author SHA1 Message Date
adamnsandle
f9876dd5f9 v3.1 path fix 2024-07-01 09:26:52 +00:00
Dimitrii Voronin
8ebaf139c6 Merge pull request #138 from snakers4/adamnsandle
Update README.md
2021-12-17 18:14:03 +03:00
Dimitrii Voronin
0a90316625 Update README.md 2021-12-17 17:13:33 +02:00
Dimitrii Voronin
35d8969322 Merge pull request #137 from snakers4/adamnsandle
Adamnsandle
2021-12-17 17:50:13 +03:00
adamnsandle
7c3eb8bfb5 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-17 14:48:58 +00:00
adamnsandle
74f759c8f8 add onnx vad 2021-12-17 14:48:32 +00:00
Dimitrii Voronin
5816eb08c4 Merge pull request #135 from snakers4/adamnsandle
Adamnsandle
2021-12-10 14:28:59 +03:00
adamnsandle
0feae6cbbe Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-10 11:28:25 +00:00
adamnsandle
fc0a70f42e imporved model 2021-12-10 11:28:07 +00:00
Dimitrii Voronin
13fd927b84 Merge pull request #134 from snakers4/adamnsandle
Update README.md
2021-12-10 13:57:17 +03:00
Dimitrii Voronin
124d6564a0 Update README.md 2021-12-10 12:56:59 +02:00
Dimitrii Voronin
56fa93a1c9 Merge pull request #133 from snakers4/adamnsandle
Adamnsandle
2021-12-10 13:08:54 +03:00
adamnsandle
1a93276208 fx example 2021-12-10 10:07:38 +00:00
Dimitrii Voronin
9fbd0c4c2d Merge pull request #132 from snakers4/adamnsandle
delete big files from repo
2021-12-10 12:53:52 +03:00
adamnsandle
7b05a183a3 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-10 09:53:37 +00:00
adamnsandle
f67e68efc3 delete big files from repo 2021-12-10 09:52:22 +00:00
Alexander Veysov
51b1365bb0 Merge pull request #131 from snakers4/adamnsandle
add collab record example
2021-12-10 12:20:26 +03:00
adamnsandle
79fdb55f1c add collab record example 2021-12-10 09:18:15 +00:00
Alexander Veysov
b17da75dac Merge pull request #129 from snakers4/adamnsandle
Adamnsandle
2021-12-07 15:40:07 +03:00
adamnsandle
184e384697 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-07 12:32:16 +00:00
adamnsandle
adf5d6d020 fx example 2021-12-07 12:32:04 +00:00
Alexander Veysov
41ee0f6b9f Update README.md 2021-12-07 15:26:13 +03:00
18 changed files with 476 additions and 10271 deletions

View File

@@ -15,7 +15,7 @@ This repository also includes Number Detector and Language classifier [models](h
<br/> <br/>
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/36505480/145007002-8473f909-5985-4942-bbcf-9ac86d156c2f.png" /> <img src="https://user-images.githubusercontent.com/36505480/145563071-681b57e3-06b5-4cd0-bdee-e2ade3d50a60.png" />
</p> </p>
<details> <details>
@@ -47,7 +47,7 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
- **Flexible sampling rate** - **Flexible sampling rate**
Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate). Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** (JIT) and **16000 Hz** (ONNX) [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate).
- **Flexible chunk size** - **Flexible chunk size**
@@ -72,6 +72,7 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
- [Performance Metrics](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics) - [Performance Metrics](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics)
- Number Detector and Language classifier [models](https://github.com/snakers4/silero-vad/wiki/Other-Models) - Number Detector and Language classifier [models](https://github.com/snakers4/silero-vad/wiki/Other-Models)
- [Versions and Available Models](https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models) - [Versions and Available Models](https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)
- [Further reading](https://github.com/snakers4/silero-models#further-reading)
<br/> <br/>
<h2 align="center">Get In Touch</h2> <h2 align="center">Get In Touch</h2>
@@ -94,4 +95,4 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) and [tiers
commit = {insert_some_commit_here}, commit = {insert_some_commit_here},
email = {hello@silero.ai} email = {hello@silero.ai}
} }
``` ```

View File

@@ -0,0 +1,241 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "bccAucKjnPHm"
},
"source": [
"### Dependencies and inputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cSih95WFmwgi"
},
"outputs": [],
"source": [
"!pip -q install pydub\n",
"from google.colab import output\n",
"from base64 import b64decode, b64encode\n",
"from io import BytesIO\n",
"import numpy as np\n",
"from pydub import AudioSegment\n",
"from IPython.display import HTML, display\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import moviepy.editor as mpe\n",
"from matplotlib.animation import FuncAnimation, FFMpegWriter\n",
"import matplotlib\n",
"matplotlib.use('Agg')\n",
"\n",
"torch.set_num_threads(1)\n",
"\n",
"model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=True)\n",
"\n",
"def int2float(sound):\n",
" abs_max = np.abs(sound).max()\n",
" sound = sound.astype('float32')\n",
" if abs_max > 0:\n",
" sound *= 1/abs_max\n",
" sound = sound.squeeze()\n",
" return sound\n",
"\n",
"AUDIO_HTML = \"\"\"\n",
"<script>\n",
"var my_div = document.createElement(\"DIV\");\n",
"var my_p = document.createElement(\"P\");\n",
"var my_btn = document.createElement(\"BUTTON\");\n",
"var t = document.createTextNode(\"Press to start recording\");\n",
"\n",
"my_btn.appendChild(t);\n",
"//my_p.appendChild(my_btn);\n",
"my_div.appendChild(my_btn);\n",
"document.body.appendChild(my_div);\n",
"\n",
"var base64data = 0;\n",
"var reader;\n",
"var recorder, gumStream;\n",
"var recordButton = my_btn;\n",
"\n",
"var handleSuccess = function(stream) {\n",
" gumStream = stream;\n",
" var options = {\n",
" //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n",
" mimeType : 'audio/webm;codecs=opus'\n",
" //mimeType : 'audio/webm;codecs=pcm'\n",
" }; \n",
" //recorder = new MediaRecorder(stream, options);\n",
" recorder = new MediaRecorder(stream);\n",
" recorder.ondataavailable = function(e) { \n",
" var url = URL.createObjectURL(e.data);\n",
" // var preview = document.createElement('audio');\n",
" // preview.controls = true;\n",
" // preview.src = url;\n",
" // document.body.appendChild(preview);\n",
"\n",
" reader = new FileReader();\n",
" reader.readAsDataURL(e.data); \n",
" reader.onloadend = function() {\n",
" base64data = reader.result;\n",
" //console.log(\"Inside FileReader:\" + base64data);\n",
" }\n",
" };\n",
" recorder.start();\n",
" };\n",
"\n",
"recordButton.innerText = \"Recording... press to stop\";\n",
"\n",
"navigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);\n",
"\n",
"\n",
"function toggleRecording() {\n",
" if (recorder && recorder.state == \"recording\") {\n",
" recorder.stop();\n",
" gumStream.getAudioTracks()[0].stop();\n",
" recordButton.innerText = \"Saving recording...\"\n",
" }\n",
"}\n",
"\n",
"// https://stackoverflow.com/a/951057\n",
"function sleep(ms) {\n",
" return new Promise(resolve => setTimeout(resolve, ms));\n",
"}\n",
"\n",
"var data = new Promise(resolve=>{\n",
"//recordButton.addEventListener(\"click\", toggleRecording);\n",
"recordButton.onclick = ()=>{\n",
"toggleRecording()\n",
"\n",
"sleep(2000).then(() => {\n",
" // wait 2000ms for the data to be available...\n",
" // ideally this should use something like await...\n",
" //console.log(\"Inside data:\" + base64data)\n",
" resolve(base64data.toString())\n",
"\n",
"});\n",
"\n",
"}\n",
"});\n",
" \n",
"</script>\n",
"\"\"\"\n",
"\n",
"def record(sec=10):\n",
" display(HTML(AUDIO_HTML))\n",
" s = output.eval_js(\"data\")\n",
" b = b64decode(s.split(',')[1])\n",
" audio = AudioSegment.from_file(BytesIO(b))\n",
" audio.export('test.mp3', format='mp3')\n",
" audio = audio.set_channels(1)\n",
" audio = audio.set_frame_rate(16000)\n",
" audio_float = int2float(np.array(audio.get_array_of_samples()))\n",
" audio_tens = torch.tensor(audio_float )\n",
" return audio_tens\n",
"\n",
"def make_animation(probs, audio_duration, interval=40):\n",
" fig = plt.figure(figsize=(16, 9))\n",
" ax = plt.axes(xlim=(0, audio_duration), ylim=(0, 1.02))\n",
" line, = ax.plot([], [], lw=2)\n",
" x = [i / 16000 * 512 for i in range(len(probs))]\n",
" plt.xlabel('Time, seconds', fontsize=16)\n",
" plt.ylabel('Speech Probability', fontsize=16)\n",
"\n",
" def init():\n",
" plt.fill_between(x, probs, color='#064273')\n",
" line.set_data([], [])\n",
" line.set_color('#990000')\n",
" return line,\n",
"\n",
" def animate(i):\n",
" x = i * interval / 1000 - 0.04\n",
" y = np.linspace(0, 1.02, 2)\n",
" \n",
" line.set_data(x, y)\n",
" line.set_color('#990000')\n",
" return line,\n",
"\n",
" anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=audio_duration / (interval / 1000))\n",
"\n",
" f = r\"animation.mp4\" \n",
" writervideo = FFMpegWriter(fps=1000/interval) \n",
" anim.save(f, writer=writervideo)\n",
" plt.close('all')\n",
"\n",
"def combine_audio(vidname, audname, outname, fps=25): \n",
" my_clip = mpe.VideoFileClip(vidname, verbose=False)\n",
" audio_background = mpe.AudioFileClip(audname)\n",
" final_clip = my_clip.set_audio(audio_background)\n",
" final_clip.write_videofile(outname,fps=fps,verbose=False)\n",
"\n",
"def record_make_animation():\n",
" tensor = record()\n",
"\n",
" print('Calculating probabilities...')\n",
" speech_probs = []\n",
" window_size_samples = 512\n",
" for i in range(0, len(tensor), window_size_samples):\n",
" if len(tensor[i: i+ window_size_samples]) < window_size_samples:\n",
" break\n",
" speech_prob = model(tensor[i: i+ window_size_samples], 16000).item()\n",
" speech_probs.append(speech_prob)\n",
" model.reset_states()\n",
" print('Making animation...')\n",
" make_animation(speech_probs, len(tensor) / 16000)\n",
"\n",
" print('Merging your voice with animation...')\n",
" combine_audio('animation.mp4', 'test.mp3', 'merged.mp4')\n",
" print('Done!')\n",
" mp4 = open('merged.mp4','rb').read()\n",
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
" display(HTML(\"\"\"\n",
" <video width=800 controls>\n",
" <source src=\"%s\" type=\"video/mp4\">\n",
" </video>\n",
" \"\"\" % data_url))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IFVs3GvTnpB1"
},
"source": [
"## Record example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5EBjrTwiqAaQ"
},
"outputs": [],
"source": [
"record_make_animation()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"bccAucKjnPHm"
],
"name": "Untitled2.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
files/silero_vad.onnx Normal file

Binary file not shown.

View File

@@ -1,5 +1,6 @@
dependencies = ['torch', 'torchaudio'] dependencies = ['torch', 'torchaudio']
import torch import torch
import os
import json import json
from utils_vad import (init_jit_model, from utils_vad import (init_jit_model,
get_speech_timestamps, get_speech_timestamps,
@@ -10,16 +11,21 @@ from utils_vad import (init_jit_model,
read_audio, read_audio,
VADIterator, VADIterator,
collect_chunks, collect_chunks,
drop_chunks) drop_chunks,
Validator,
OnnxWrapper)
def silero_vad(**kwargs): def silero_vad(onnx=False):
"""Silero Voice Activity Detector """Silero Voice Activity Detector
Returns a model with a set of utils Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples Please see https://github.com/snakers4/silero-vad for usage examples
""" """
hub_dir = torch.hub.get_dir() model_dir = os.path.join(os.path.dirname(__file__), 'files')
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/silero_vad.jit') if onnx:
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'))
else:
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
utils = (get_speech_timestamps, utils = (get_speech_timestamps,
save_audio, save_audio,
read_audio, read_audio,
@@ -29,13 +35,16 @@ def silero_vad(**kwargs):
return model, utils return model, utils
def silero_number_detector(**kwargs): def silero_number_detector(onnx=False):
"""Silero Number Detector """Silero Number Detector
Returns a model with a set of utils Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples Please see https://github.com/snakers4/silero-vad for usage examples
""" """
hub_dir = torch.hub.get_dir() if onnx:
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/number_detector.jit') url = 'https://models.silero.ai/vad_models/number_detector.onnx'
else:
url = 'https://models.silero.ai/vad_models/number_detector.jit'
model = Validator(url)
utils = (get_number_ts, utils = (get_number_ts,
save_audio, save_audio,
read_audio, read_audio,
@@ -45,32 +54,39 @@ def silero_number_detector(**kwargs):
return model, utils return model, utils
def silero_lang_detector(**kwargs): def silero_lang_detector(onnx=False):
"""Silero Language Classifier """Silero Language Classifier
Returns a model with a set of utils Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples Please see https://github.com/snakers4/silero-vad for usage examples
""" """
hub_dir = torch.hub.get_dir() if onnx:
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/number_detector.jit') url = 'https://models.silero.ai/vad_models/number_detector.onnx'
else:
url = 'https://models.silero.ai/vad_models/number_detector.jit'
model = Validator(url)
utils = (get_language, utils = (get_language,
read_audio) read_audio)
return model, utils return model, utils
def silero_lang_detector_95(**kwargs): def silero_lang_detector_95(onnx=False):
"""Silero Language Classifier (95 languages) """Silero Language Classifier (95 languages)
Returns a model with a set of utils Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples Please see https://github.com/snakers4/silero-vad for usage examples
""" """
hub_dir = torch.hub.get_dir() if onnx:
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/lang_classifier_95.jit') url = 'https://models.silero.ai/vad_models/lang_classifier_95.onnx'
else:
with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_dict_95.json', 'r') as f: url = 'https://models.silero.ai/vad_models/lang_classifier_95.jit'
model = Validator(url)
model_dir = os.path.join(os.path.dirname(__file__), 'files')
with open(os.path.join(model_dir, 'lang_dict_95.json'), 'r') as f:
lang_dict = json.load(f) lang_dict = json.load(f)
with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_group_dict_95.json', 'r') as f: with open(os.path.join(model_dir, 'lang_group_dict_95.json'), 'r') as f:
lang_group_dict = json.load(f) lang_group_dict = json.load(f)
utils = (get_language_and_group, read_audio) utils = (get_language_and_group, read_audio)

View File

@@ -1,21 +1,12 @@
{ {
"cells": [ "cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "sVNOuHQQjsrp"
},
"source": [
"# PyTorch Examples"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "FpMplOCA2Fwp" "id": "FpMplOCA2Fwp"
}, },
"source": [ "source": [
"## VAD" "#VAD"
] ]
}, },
{ {
@@ -25,7 +16,7 @@
"id": "62A6F_072Fwq" "id": "62A6F_072Fwq"
}, },
"source": [ "source": [
"### Install Dependencies" "## Install Dependencies"
] ]
}, },
{ {
@@ -42,26 +33,39 @@
"# this assumes that you have a relevant version of PyTorch installed\n", "# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio\n", "!pip install -q torchaudio\n",
"\n", "\n",
"SAMPLE_RATE = 16000\n", "SAMPLING_RATE = 16000\n",
"\n", "\n",
"import glob\n",
"import torch\n", "import torch\n",
"torch.set_num_threads(1)\n", "torch.set_num_threads(1)\n",
"\n", "\n",
"from IPython.display import Audio\n", "from IPython.display import Audio\n",
"from pprint import pprint\n", "from pprint import pprint\n",
"\n", "# download example\n",
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pSifus5IilRp"
},
"outputs": [],
"source": [
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
" \n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n", " model='silero_vad',\n",
" force_reload=True)\n", " force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n", "\n",
"(get_speech_timestamps,\n", "(get_speech_timestamps,\n",
" save_audio,\n", " save_audio,\n",
" read_audio,\n", " read_audio,\n",
" VADIterator,\n", " VADIterator,\n",
" collect_chunks) = utils\n", " collect_chunks) = utils"
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
] ]
}, },
{ {
@@ -70,7 +74,7 @@
"id": "fXbbaUO3jsrw" "id": "fXbbaUO3jsrw"
}, },
"source": [ "source": [
"### Full Audio" "## Full Audio"
] ]
}, },
{ {
@@ -90,9 +94,9 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n", "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"# get speech timestamps from full audio file\n", "# get speech timestamps from full audio file\n",
"speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLE_RATE)\n", "speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLING_RATE)\n",
"pprint(speech_timestamps)" "pprint(speech_timestamps)"
] ]
}, },
@@ -106,7 +110,7 @@
"source": [ "source": [
"# merge all speech chunks to one audio\n", "# merge all speech chunks to one audio\n",
"save_audio('only_speech.wav',\n", "save_audio('only_speech.wav',\n",
" collect_chunks(speech_timestamps, wav), sampling_rate=16000) \n", " collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE) \n",
"Audio('only_speech.wav')" "Audio('only_speech.wav')"
] ]
}, },
@@ -116,7 +120,7 @@
"id": "iDKQbVr8jsry" "id": "iDKQbVr8jsry"
}, },
"source": [ "source": [
"### Stream imitation example" "## Stream imitation example"
] ]
}, },
{ {
@@ -130,7 +134,7 @@
"## using VADIterator class\n", "## using VADIterator class\n",
"\n", "\n",
"vad_iterator = VADIterator(model)\n", "vad_iterator = VADIterator(model)\n",
"wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n", "wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"\n", "\n",
"window_size_samples = 1536 # number of samples in a single audio chunk\n", "window_size_samples = 1536 # number of samples in a single audio chunk\n",
"for i in range(0, len(wav), window_size_samples):\n", "for i in range(0, len(wav), window_size_samples):\n",
@@ -150,14 +154,15 @@
"source": [ "source": [
"## just probabilities\n", "## just probabilities\n",
"\n", "\n",
"wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n", "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"speech_probs = []\n", "speech_probs = []\n",
"window_size_samples = 1536\n", "window_size_samples = 1536\n",
"for i in range(0, len(wav), window_size_samples):\n", "for i in range(0, len(wav), window_size_samples):\n",
" speech_prob = model(wav[i: i+ window_size_samples], SAMPLE_RATE).item()\n", " speech_prob = model(wav[i: i+ window_size_samples], SAMPLING_RATE).item()\n",
" speech_probs.append(speech_prob)\n", " speech_probs.append(speech_prob)\n",
"vad_iterator.reset_states() # reset model states after each audio\n",
"\n", "\n",
"pprint(speech_probs[:100])" "print(speech_probs[:10]) # first 10 chunks predicts"
] ]
}, },
{ {
@@ -167,7 +172,7 @@
"id": "36jY0niD2Fww" "id": "36jY0niD2Fww"
}, },
"source": [ "source": [
"## Number detector" "# Number detector"
] ]
}, },
{ {
@@ -178,7 +183,7 @@
"id": "scd1DlS42Fwx" "id": "scd1DlS42Fwx"
}, },
"source": [ "source": [
"### Install Dependencies" "## Install Dependencies"
] ]
}, },
{ {
@@ -193,26 +198,41 @@
"#@title Install and Import Dependencies\n", "#@title Install and Import Dependencies\n",
"\n", "\n",
"# this assumes that you have a relevant version of PyTorch installed\n", "# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile\n", "!pip install -q torchaudio\n",
"\n",
"SAMPLING_RATE = 16000\n",
"\n", "\n",
"import glob\n",
"import torch\n", "import torch\n",
"torch.set_num_threads(1)\n", "torch.set_num_threads(1)\n",
"\n", "\n",
"from IPython.display import Audio\n", "from IPython.display import Audio\n",
"from pprint import pprint\n", "from pprint import pprint\n",
"\n", "# download example\n",
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en_num.wav', 'en_number_example.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dPwCFHmFycUF"
},
"outputs": [],
"source": [
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
" \n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_number_detector',\n", " model='silero_number_detector',\n",
" force_reload=True)\n", " force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n", "\n",
"(get_number_ts,\n", "(get_number_ts,\n",
" save_audio,\n", " save_audio,\n",
" read_audio,\n", " read_audio,\n",
" collect_chunks,\n", " collect_chunks,\n",
" drop_chunks) = utils\n", " drop_chunks) = utils\n"
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
] ]
}, },
{ {
@@ -223,7 +243,7 @@
"id": "qhPa30ij2Fwy" "id": "qhPa30ij2Fwy"
}, },
"source": [ "source": [
"### Full audio" "## Full audio"
] ]
}, },
{ {
@@ -235,7 +255,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"wav = read_audio(f'{files_dir}/en_num.wav')\n", "wav = read_audio('en_number_example.wav', sampling_rate=SAMPLING_RATE)\n",
"# get number timestamps from full audio file\n", "# get number timestamps from full audio file\n",
"number_timestamps = get_number_ts(wav, model)\n", "number_timestamps = get_number_ts(wav, model)\n",
"pprint(number_timestamps)" "pprint(number_timestamps)"
@@ -250,11 +270,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_rate = 16000\n",
"# convert ms in timestamps to samples\n", "# convert ms in timestamps to samples\n",
"for timestamp in number_timestamps:\n", "for timestamp in number_timestamps:\n",
" timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n", " timestamp['start'] = int(timestamp['start'] * SAMPLING_RATE / 1000)\n",
" timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)" " timestamp['end'] = int(timestamp['end'] * SAMPLING_RATE / 1000)"
] ]
}, },
{ {
@@ -268,7 +287,7 @@
"source": [ "source": [
"# merge all number chunks to one audio\n", "# merge all number chunks to one audio\n",
"save_audio('only_numbers.wav',\n", "save_audio('only_numbers.wav',\n",
" collect_chunks(number_timestamps, wav), sample_rate) \n", " collect_chunks(number_timestamps, wav), SAMPLING_RATE) \n",
"Audio('only_numbers.wav')" "Audio('only_numbers.wav')"
] ]
}, },
@@ -283,7 +302,7 @@
"source": [ "source": [
"# drop all number chunks from audio\n", "# drop all number chunks from audio\n",
"save_audio('no_numbers.wav',\n", "save_audio('no_numbers.wav',\n",
" drop_chunks(number_timestamps, wav), sample_rate) \n", " drop_chunks(number_timestamps, wav), SAMPLING_RATE) \n",
"Audio('no_numbers.wav')" "Audio('no_numbers.wav')"
] ]
}, },
@@ -294,7 +313,7 @@
"id": "PnKtJKbq2Fwz" "id": "PnKtJKbq2Fwz"
}, },
"source": [ "source": [
"## Language detector" "# Language detector"
] ]
}, },
{ {
@@ -305,7 +324,7 @@
"id": "F5cAmMbP2Fwz" "id": "F5cAmMbP2Fwz"
}, },
"source": [ "source": [
"### Install Dependencies" "## Install Dependencies"
] ]
}, },
{ {
@@ -320,23 +339,37 @@
"#@title Install and Import Dependencies\n", "#@title Install and Import Dependencies\n",
"\n", "\n",
"# this assumes that you have a relevant version of PyTorch installed\n", "# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile\n", "!pip install -q torchaudio\n",
"\n",
"SAMPLING_RATE = 16000\n",
"\n", "\n",
"import glob\n",
"import torch\n", "import torch\n",
"torch.set_num_threads(1)\n", "torch.set_num_threads(1)\n",
"\n", "\n",
"from IPython.display import Audio\n", "from IPython.display import Audio\n",
"from pprint import pprint\n", "from pprint import pprint\n",
"\n", "# download example\n",
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JfRKDZiRztFe"
},
"outputs": [],
"source": [
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
" \n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_lang_detector',\n", " model='silero_lang_detector',\n",
" force_reload=True)\n", " force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n", "\n",
"(get_language,\n", "get_language, read_audio = utils"
" read_audio) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
] ]
}, },
{ {
@@ -347,7 +380,7 @@
"id": "iC696eMX2Fwz" "id": "iC696eMX2Fwz"
}, },
"source": [ "source": [
"### Full audio" "## Full audio"
] ]
}, },
{ {
@@ -359,266 +392,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"wav = read_audio(f'{files_dir}/en.wav')\n", "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"lang = get_language(wav, model)\n", "lang = get_language(wav, model)\n",
"print(lang)" "print(lang)"
] ]
},
{
"cell_type": "markdown",
"metadata": {
"id": "57avIBd6jsrz"
},
"source": [
"# ONNX Example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hEhnfORV2Fw0"
},
"source": [
"## VAD"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Cy7y-NAyALSe"
},
"source": [
"**TO BE DONE**"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "7QMvUvpg2Fw4"
},
"source": [
"## Number detector"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "tBPDkpHr2Fw4"
},
"source": [
"### Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"hidden": true,
"id": "PdjGd56R2Fw5"
},
"outputs": [],
"source": [
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile onnxruntime\n",
"\n",
"import glob\n",
"import torch\n",
"import onnxruntime\n",
"from pprint import pprint\n",
"\n",
"from IPython.display import Audio\n",
"\n",
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_number_detector',\n",
" force_reload=True)\n",
"\n",
"(get_number_ts,\n",
" save_audio,\n",
" read_audio,\n",
" collect_chunks,\n",
" drop_chunks) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
"\n",
"def init_onnx_model(model_path: str):\n",
" return onnxruntime.InferenceSession(model_path)\n",
"\n",
"def validate_onnx(model, inputs):\n",
" with torch.no_grad():\n",
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
" outs = model.run(None, ort_inputs)\n",
" outs = [torch.Tensor(x) for x in outs]\n",
" return outs"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "I9QWSFZh2Fw5"
},
"source": [
"### Full Audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "_r6QZiwu2Fw5"
},
"outputs": [],
"source": [
"model = init_onnx_model(f'{files_dir}/number_detector.onnx')\n",
"wav = read_audio(f'{files_dir}/en_num.wav')\n",
"\n",
"# get number timestamps from full audio file\n",
"number_timestamps = get_number_ts(wav, model, run_function=validate_onnx)\n",
"pprint(number_timestamps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "FN4aDwLV2Fw5"
},
"outputs": [],
"source": [
"sample_rate = 16000\n",
"# convert ms in timestamps to samples\n",
"for timestamp in number_timestamps:\n",
" timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n",
" timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "JnvS6WTK2Fw5"
},
"outputs": [],
"source": [
"# merge all number chunks to one audio\n",
"save_audio('only_numbers.wav',\n",
" collect_chunks(number_timestamps, wav), 16000) \n",
"Audio('only_numbers.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "yUxOcOFG2Fw6"
},
"outputs": [],
"source": [
"# drop all number chunks from audio\n",
"save_audio('no_numbers.wav',\n",
" drop_chunks(number_timestamps, wav), 16000) \n",
"Audio('no_numbers.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "SR8Bgcd52Fw6"
},
"source": [
"## Language detector"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "PBnXPtKo2Fw6"
},
"source": [
"### Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"hidden": true,
"id": "iNkDWJ3H2Fw6"
},
"outputs": [],
"source": [
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile onnxruntime\n",
"\n",
"import glob\n",
"import torch\n",
"import onnxruntime\n",
"from pprint import pprint\n",
"\n",
"from IPython.display import Audio\n",
"\n",
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_lang_detector',\n",
" force_reload=True)\n",
"\n",
"(get_language,\n",
" read_audio) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
"\n",
"def init_onnx_model(model_path: str):\n",
" return onnxruntime.InferenceSession(model_path)\n",
"\n",
"def validate_onnx(model, inputs):\n",
" with torch.no_grad():\n",
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
" outs = model.run(None, ort_inputs)\n",
" outs = [torch.Tensor(x) for x in outs]\n",
" return outs"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true,
"id": "G8N8oP4q2Fw6"
},
"source": [
"### Full Audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "WHXnh9IV2Fw6"
},
"outputs": [],
"source": [
"model = init_onnx_model(f'{files_dir}/number_detector.onnx')\n",
"wav = read_audio(f'{files_dir}/en.wav')\n",
"\n",
"lang = get_language(wav, model, run_function=validate_onnx)\n",
"print(lang)"
]
} }
], ],
"metadata": { "metadata": {

View File

@@ -7,14 +7,66 @@ import warnings
languages = ['ru', 'en', 'de', 'es'] languages = ['ru', 'en', 'de', 'es']
def validate(model, class OnnxWrapper():
inputs: torch.Tensor,
**kwargs): def __init__(self, path):
with torch.no_grad(): import numpy as np
outs = model(inputs, **kwargs) global np
if len(outs.shape) == 1: import onnxruntime
return outs[1:] self.session = onnxruntime.InferenceSession(path)
return outs[:, 1] # 0 for noise, 1 for speech self.session.intra_op_num_threads = 1
self.session.inter_op_num_threads = 1
self.reset_states()
def reset_states(self):
self._h = np.zeros((2, 1, 64)).astype('float32')
self._c = np.zeros((2, 1, 64)).astype('float32')
def __call__(self, x, sr: int):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() > 2:
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
if x.shape[0] > 1:
raise ValueError("Onnx model does not support batching")
if sr not in [16000]:
raise ValueError(f"Supported sample rates: {[16000]}")
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
ort_inputs = {'input': x.numpy(), 'h0': self._h, 'c0': self._c}
ort_outs = self.session.run(None, ort_inputs)
out, self._h, self._c = ort_outs
out = torch.tensor(out).squeeze(2)[:, 1] # make output type match JIT analog
return out
class Validator():
def __init__(self, url):
self.onnx = True if url.endswith('.onnx') else False
torch.hub.download_url_to_file(url, 'inf.model')
if self.onnx:
import onnxruntime
self.model = onnxruntime.InferenceSession('inf.model')
else:
self.model = init_jit_model(model_path='inf.model')
def __call__(self, inputs: torch.Tensor):
with torch.no_grad():
if self.onnx:
ort_inputs = {'input': inputs.cpu().numpy()}
outs = self.model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
else:
outs = self.model(inputs)
return outs
def read_audio(path: str, def read_audio(path: str,
@@ -206,10 +258,9 @@ def get_number_ts(wav: torch.Tensor,
model, model,
model_stride=8, model_stride=8,
hop_length=160, hop_length=160,
sample_rate=16000, sample_rate=16000):
run_function=validate):
wav = torch.unsqueeze(wav, dim=0) wav = torch.unsqueeze(wav, dim=0)
perframe_logits = run_function(model, wav)[0] perframe_logits = model(wav)[0]
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided) perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
extended_preds = [] extended_preds = []
for i in perframe_preds: for i in perframe_preds:
@@ -236,10 +287,9 @@ def get_number_ts(wav: torch.Tensor,
def get_language(wav: torch.Tensor, def get_language(wav: torch.Tensor,
model, model):
run_function=validate):
wav = torch.unsqueeze(wav, dim=0) wav = torch.unsqueeze(wav, dim=0)
lang_logits = run_function(model, wav)[2] lang_logits = model(wav)[2]
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1 lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
assert lang_pred < len(languages) assert lang_pred < len(languages)
return languages[lang_pred] return languages[lang_pred]
@@ -249,10 +299,9 @@ def get_language_and_group(wav: torch.Tensor,
model, model,
lang_dict: dict, lang_dict: dict,
lang_group_dict: dict, lang_group_dict: dict,
top_n=1, top_n=1):
run_function=validate):
wav = torch.unsqueeze(wav, dim=0) wav = torch.unsqueeze(wav, dim=0)
lang_logits, lang_group_logits = run_function(model, wav) lang_logits, lang_group_logits = model(wav)
softm = torch.softmax(lang_logits, dim=1).squeeze() softm = torch.softmax(lang_logits, dim=1).squeeze()
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze() softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
@@ -323,6 +372,13 @@ class VADIterator:
return_seconds: bool (default - False) return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples) whether return timestamps in seconds (default - samples)
""" """
if not torch.is_tensor(x):
try:
x = torch.Tensor(x)
except:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
window_size_samples = len(x[0]) if x.dim() == 2 else len(x) window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
self.current_sample += window_size_samples self.current_sample += window_size_samples