62 Commits
v3.0 ... v4.0

Author SHA1 Message Date
adamnsandle
915dd3d639 v4.0stable force_onnx_cpu fx 2024-07-08 10:16:52 +00:00
adamnsandle
ac128b3c55 v4.0stable fx 2024-07-01 09:53:25 +00:00
Dimitrii Voronin
82d199ff22 Merge pull request #256 from snakers4/adamnsandle
Adamnsandle
2022-10-28 13:57:10 +03:00
adamnsandle
5ba388d894 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-10-28 10:55:59 +00:00
adamnsandle
790844ba0f revert to exception 2022-10-28 10:55:46 +00:00
Dimitrii Voronin
51b5245410 Merge pull request #255 from snakers4/adamnsandle
Adamnsandle
2022-10-28 13:33:18 +03:00
adamnsandle
888970e77d Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-10-28 10:32:08 +00:00
adamnsandle
cb6d308335 fx 2022-10-28 10:31:55 +00:00
adamnsandle
1b212c6e95 change exception to warning 2022-10-28 10:26:07 +00:00
adamnsandle
452060ad65 fx 2022-10-28 10:13:00 +00:00
Dimitrii Voronin
c7eab751b5 Merge pull request #253 from snakers4/adamnsandle
add torch version check
2022-10-28 13:09:18 +03:00
adamnsandle
d1714a9ff7 add torch version check 2022-10-28 10:08:07 +00:00
Dimitrii Voronin
94c79d899d Merge pull request #251 from snakers4/adamnsandle
v4 hotfix
2022-10-27 20:26:31 +03:00
adamnsandle
1baf307b35 v4 hotfix 2022-10-27 17:25:31 +00:00
Dimitrii Voronin
e324285cdc Merge pull request #247 from snakers4/adamnsandle
Adamnsandle
2022-10-26 19:17:44 +03:00
adamnsandle
13dce2d067 Merge branch 'MASTER' into adamnsandle 2022-10-26 16:13:37 +00:00
adamnsandle
081e6b9886 VAD v4 2022-10-26 16:10:20 +00:00
Alexander Veysov
572134fdf1 Update README.md 2022-10-25 05:52:53 +03:00
Dimitrii Voronin
a799dea837 Merge pull request #244 from owlsometech-kenyang/feature/support-force-onnx-cpu
Suggesting a new kwarg: force_onnx_cpu
2022-10-14 11:53:46 +03:00
ChiehKai Yang
17209e6c4f add new parameter: force_onnx_cpu 2022-10-12 01:56:43 +08:00
adamnsandle
6661cc9691 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-06-02 10:41:54 +00:00
Dimitrii Voronin
7c671a75c2 Merge pull request #199 from snakers4/adamnsandle
fx end of chunk may exceed audio length
2022-06-02 13:40:42 +03:00
adamnsandle
622016e672 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-06-02 10:40:11 +00:00
adamnsandle
8eba346bc9 fx end of chunk may exceed audio length 2022-06-02 10:39:16 +00:00
Dimitrii Voronin
900c71a109 Merge pull request #198 from snakers4/adamnsandle
fx get_speech ts start of an audio chunk pad
2022-06-02 13:33:36 +03:00
adamnsandle
bf0127e016 fx get_speech ts start of an audio chunk pad 2022-06-02 10:32:32 +00:00
Dimitrii Voronin
ea7af70fe9 Merge pull request #182 from snakers4/adamnsandle
Adamnsandle
2022-04-05 14:36:00 +03:00
adamnsandle
8cdc8d36c9 fx 2022-04-05 11:35:23 +00:00
adamnsandle
6e9fd77500 fx stram imitation example bug 2022-04-05 11:33:34 +00:00
Alexander Veysov
6cc08b1077 Merge pull request #170 from gabrielziegler3/169-fix-min-speech-duration-bug
Fix #169
2022-02-10 12:18:23 +03:00
Gabriel Ziegler
0e8e080894 Remove unnecessary if statement 2022-02-09 19:22:04 -03:00
Gabriel Ziegler
af6931d1de Fix bug where min_speech_duration_ms is not checked in the last speech segment
Signed-off-by: Gabriel Ziegler <gabrielziegler3@gmail.com>
2022-02-09 19:18:48 -03:00
Alexander Veysov
76687cbe25 Update README.md 2021-12-21 14:43:36 +03:00
Dimitrii Voronin
b2329fa5f2 Merge pull request #144 from snakers4/adamnsandle
Update README.md
2021-12-21 14:25:56 +03:00
Dimitrii Voronin
005886e7eb Update README.md 2021-12-21 13:25:14 +02:00
Dimitrii Voronin
f6b1294cb2 Merge pull request #143 from snakers4/adamnsandle
Adamnsandle
2021-12-21 14:02:25 +03:00
adamnsandle
2392ea33f4 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-21 11:01:25 +00:00
adamnsandle
45d72863b6 add multiple of 16k sr support 2021-12-21 11:01:07 +00:00
Alexander Veysov
f40cc128a4 Update utils_vad.py 2021-12-21 08:24:48 +03:00
Alexander Veysov
0d61e4cee1 Update README.md 2021-12-17 22:03:49 +03:00
Alexander Veysov
011268e492 Polish the copy a bit 2021-12-17 22:00:36 +03:00
Dimitrii Voronin
8ebaf139c6 Merge pull request #138 from snakers4/adamnsandle
Update README.md
2021-12-17 18:14:03 +03:00
Dimitrii Voronin
0a90316625 Update README.md 2021-12-17 17:13:33 +02:00
Dimitrii Voronin
35d8969322 Merge pull request #137 from snakers4/adamnsandle
Adamnsandle
2021-12-17 17:50:13 +03:00
adamnsandle
7c3eb8bfb5 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-17 14:48:58 +00:00
adamnsandle
74f759c8f8 add onnx vad 2021-12-17 14:48:32 +00:00
Dimitrii Voronin
5816eb08c4 Merge pull request #135 from snakers4/adamnsandle
Adamnsandle
2021-12-10 14:28:59 +03:00
adamnsandle
0feae6cbbe Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-10 11:28:25 +00:00
adamnsandle
fc0a70f42e imporved model 2021-12-10 11:28:07 +00:00
Dimitrii Voronin
13fd927b84 Merge pull request #134 from snakers4/adamnsandle
Update README.md
2021-12-10 13:57:17 +03:00
Dimitrii Voronin
124d6564a0 Update README.md 2021-12-10 12:56:59 +02:00
Dimitrii Voronin
56fa93a1c9 Merge pull request #133 from snakers4/adamnsandle
Adamnsandle
2021-12-10 13:08:54 +03:00
adamnsandle
1a93276208 fx example 2021-12-10 10:07:38 +00:00
Dimitrii Voronin
9fbd0c4c2d Merge pull request #132 from snakers4/adamnsandle
delete big files from repo
2021-12-10 12:53:52 +03:00
adamnsandle
7b05a183a3 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-10 09:53:37 +00:00
adamnsandle
f67e68efc3 delete big files from repo 2021-12-10 09:52:22 +00:00
Alexander Veysov
51b1365bb0 Merge pull request #131 from snakers4/adamnsandle
add collab record example
2021-12-10 12:20:26 +03:00
adamnsandle
79fdb55f1c add collab record example 2021-12-10 09:18:15 +00:00
Alexander Veysov
b17da75dac Merge pull request #129 from snakers4/adamnsandle
Adamnsandle
2021-12-07 15:40:07 +03:00
adamnsandle
184e384697 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-07 12:32:16 +00:00
adamnsandle
adf5d6d020 fx example 2021-12-07 12:32:04 +00:00
Alexander Veysov
41ee0f6b9f Update README.md 2021-12-07 15:26:13 +03:00
18 changed files with 580 additions and 10279 deletions

View File

@@ -15,7 +15,7 @@ This repository also includes Number Detector and Language classifier [models](h
<br/> <br/>
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/36505480/145007002-8473f909-5985-4942-bbcf-9ac86d156c2f.png" /> <img src="https://user-images.githubusercontent.com/36505480/198026365-8da383e0-5398-4a12-b7f8-22c2c0059512.png" />
</p> </p>
<details> <details>
@@ -29,17 +29,17 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
<h2 align="center">Key Features</h2> <h2 align="center">Key Features</h2>
<br/> <br/>
- **High accuracy** - **Stellar accuracy**
Silero VAD has [excellent results](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#vs-other-available-solutions) on speech detection tasks. Silero VAD has [excellent results](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#vs-other-available-solutions) on speech detection tasks.
- **Fast** - **Fast**
One audio chunk (30+ ms) [takes](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics#silero-vad-performance-metrics) around **1ms** to be processed on a single CPU thread. Using batching or GPU can also improve performance considerably. One audio chunk (30+ ms) [takes](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics#silero-vad-performance-metrics) less than **1ms** to be processed on a single CPU thread. Using batching or GPU can also improve performance considerably. Under certain conditions ONNX may even run up to 4-5x faster.
- **Lightweight** - **Lightweight**
JIT model is less than one megabyte in size. JIT model is around one megabyte in size.
- **General** - **General**
@@ -51,7 +51,15 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
- **Flexible chunk size** - **Flexible chunk size**
Model was trained on audio chunks of different lengths. **30 ms**, **60 ms** and **100 ms** long chunks are supported directly, others may work as well. Model was trained on **30 ms**. Longer chunks are supported directly, others may work as well.
- **Highly Portable**
Silero VAD reaps benefits from the rich ecosystems built around **PyTorch** and **ONNX** running everywhere where these runtimes are available.
- **No Strings Attached**
Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
<br/> <br/>
<h2 align="center">Typical Use Cases</h2> <h2 align="center">Typical Use Cases</h2>
@@ -70,8 +78,10 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
- [Examples and Dependencies](https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies) - [Examples and Dependencies](https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies)
- [Quality Metrics](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics) - [Quality Metrics](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics)
- [Performance Metrics](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics) - [Performance Metrics](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics)
- Number Detector and Language classifier [models](https://github.com/snakers4/silero-vad/wiki/Other-Models) - [Number Detector and Language classifier models](https://github.com/snakers4/silero-vad/wiki/Other-Models)
- [Versions and Available Models](https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models) - [Versions and Available Models](https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)
- [Further reading](https://github.com/snakers4/silero-models#further-reading)
- [FAQ](https://github.com/snakers4/silero-vad/wiki/FAQ)
<br/> <br/>
<h2 align="center">Get In Touch</h2> <h2 align="center">Get In Touch</h2>
@@ -95,3 +105,9 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) and [tiers
email = {hello@silero.ai} email = {hello@silero.ai}
} }
``` ```
<br/>
<h2 align="center">VAD-based Community Apps</h2>
<br/>
- Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web

View File

@@ -0,0 +1,241 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "bccAucKjnPHm"
},
"source": [
"### Dependencies and inputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cSih95WFmwgi"
},
"outputs": [],
"source": [
"!pip -q install pydub\n",
"from google.colab import output\n",
"from base64 import b64decode, b64encode\n",
"from io import BytesIO\n",
"import numpy as np\n",
"from pydub import AudioSegment\n",
"from IPython.display import HTML, display\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import moviepy.editor as mpe\n",
"from matplotlib.animation import FuncAnimation, FFMpegWriter\n",
"import matplotlib\n",
"matplotlib.use('Agg')\n",
"\n",
"torch.set_num_threads(1)\n",
"\n",
"model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=True)\n",
"\n",
"def int2float(sound):\n",
" abs_max = np.abs(sound).max()\n",
" sound = sound.astype('float32')\n",
" if abs_max > 0:\n",
" sound *= 1/abs_max\n",
" sound = sound.squeeze()\n",
" return sound\n",
"\n",
"AUDIO_HTML = \"\"\"\n",
"<script>\n",
"var my_div = document.createElement(\"DIV\");\n",
"var my_p = document.createElement(\"P\");\n",
"var my_btn = document.createElement(\"BUTTON\");\n",
"var t = document.createTextNode(\"Press to start recording\");\n",
"\n",
"my_btn.appendChild(t);\n",
"//my_p.appendChild(my_btn);\n",
"my_div.appendChild(my_btn);\n",
"document.body.appendChild(my_div);\n",
"\n",
"var base64data = 0;\n",
"var reader;\n",
"var recorder, gumStream;\n",
"var recordButton = my_btn;\n",
"\n",
"var handleSuccess = function(stream) {\n",
" gumStream = stream;\n",
" var options = {\n",
" //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n",
" mimeType : 'audio/webm;codecs=opus'\n",
" //mimeType : 'audio/webm;codecs=pcm'\n",
" }; \n",
" //recorder = new MediaRecorder(stream, options);\n",
" recorder = new MediaRecorder(stream);\n",
" recorder.ondataavailable = function(e) { \n",
" var url = URL.createObjectURL(e.data);\n",
" // var preview = document.createElement('audio');\n",
" // preview.controls = true;\n",
" // preview.src = url;\n",
" // document.body.appendChild(preview);\n",
"\n",
" reader = new FileReader();\n",
" reader.readAsDataURL(e.data); \n",
" reader.onloadend = function() {\n",
" base64data = reader.result;\n",
" //console.log(\"Inside FileReader:\" + base64data);\n",
" }\n",
" };\n",
" recorder.start();\n",
" };\n",
"\n",
"recordButton.innerText = \"Recording... press to stop\";\n",
"\n",
"navigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);\n",
"\n",
"\n",
"function toggleRecording() {\n",
" if (recorder && recorder.state == \"recording\") {\n",
" recorder.stop();\n",
" gumStream.getAudioTracks()[0].stop();\n",
" recordButton.innerText = \"Saving recording...\"\n",
" }\n",
"}\n",
"\n",
"// https://stackoverflow.com/a/951057\n",
"function sleep(ms) {\n",
" return new Promise(resolve => setTimeout(resolve, ms));\n",
"}\n",
"\n",
"var data = new Promise(resolve=>{\n",
"//recordButton.addEventListener(\"click\", toggleRecording);\n",
"recordButton.onclick = ()=>{\n",
"toggleRecording()\n",
"\n",
"sleep(2000).then(() => {\n",
" // wait 2000ms for the data to be available...\n",
" // ideally this should use something like await...\n",
" //console.log(\"Inside data:\" + base64data)\n",
" resolve(base64data.toString())\n",
"\n",
"});\n",
"\n",
"}\n",
"});\n",
" \n",
"</script>\n",
"\"\"\"\n",
"\n",
"def record(sec=10):\n",
" display(HTML(AUDIO_HTML))\n",
" s = output.eval_js(\"data\")\n",
" b = b64decode(s.split(',')[1])\n",
" audio = AudioSegment.from_file(BytesIO(b))\n",
" audio.export('test.mp3', format='mp3')\n",
" audio = audio.set_channels(1)\n",
" audio = audio.set_frame_rate(16000)\n",
" audio_float = int2float(np.array(audio.get_array_of_samples()))\n",
" audio_tens = torch.tensor(audio_float )\n",
" return audio_tens\n",
"\n",
"def make_animation(probs, audio_duration, interval=40):\n",
" fig = plt.figure(figsize=(16, 9))\n",
" ax = plt.axes(xlim=(0, audio_duration), ylim=(0, 1.02))\n",
" line, = ax.plot([], [], lw=2)\n",
" x = [i / 16000 * 512 for i in range(len(probs))]\n",
" plt.xlabel('Time, seconds', fontsize=16)\n",
" plt.ylabel('Speech Probability', fontsize=16)\n",
"\n",
" def init():\n",
" plt.fill_between(x, probs, color='#064273')\n",
" line.set_data([], [])\n",
" line.set_color('#990000')\n",
" return line,\n",
"\n",
" def animate(i):\n",
" x = i * interval / 1000 - 0.04\n",
" y = np.linspace(0, 1.02, 2)\n",
" \n",
" line.set_data(x, y)\n",
" line.set_color('#990000')\n",
" return line,\n",
"\n",
" anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=audio_duration / (interval / 1000))\n",
"\n",
" f = r\"animation.mp4\" \n",
" writervideo = FFMpegWriter(fps=1000/interval) \n",
" anim.save(f, writer=writervideo)\n",
" plt.close('all')\n",
"\n",
"def combine_audio(vidname, audname, outname, fps=25): \n",
" my_clip = mpe.VideoFileClip(vidname, verbose=False)\n",
" audio_background = mpe.AudioFileClip(audname)\n",
" final_clip = my_clip.set_audio(audio_background)\n",
" final_clip.write_videofile(outname,fps=fps,verbose=False)\n",
"\n",
"def record_make_animation():\n",
" tensor = record()\n",
"\n",
" print('Calculating probabilities...')\n",
" speech_probs = []\n",
" window_size_samples = 512\n",
" for i in range(0, len(tensor), window_size_samples):\n",
" if len(tensor[i: i+ window_size_samples]) < window_size_samples:\n",
" break\n",
" speech_prob = model(tensor[i: i+ window_size_samples], 16000).item()\n",
" speech_probs.append(speech_prob)\n",
" model.reset_states()\n",
" print('Making animation...')\n",
" make_animation(speech_probs, len(tensor) / 16000)\n",
"\n",
" print('Merging your voice with animation...')\n",
" combine_audio('animation.mp4', 'test.mp3', 'merged.mp4')\n",
" print('Done!')\n",
" mp4 = open('merged.mp4','rb').read()\n",
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
" display(HTML(\"\"\"\n",
" <video width=800 controls>\n",
" <source src=\"%s\" type=\"video/mp4\">\n",
" </video>\n",
" \"\"\" % data_url))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IFVs3GvTnpB1"
},
"source": [
"## Record example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5EBjrTwiqAaQ"
},
"outputs": [],
"source": [
"record_make_animation()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"bccAucKjnPHm"
],
"name": "Untitled2.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
files/silero_vad.onnx Normal file

Binary file not shown.

View File

@@ -1,5 +1,6 @@
dependencies = ['torch', 'torchaudio'] dependencies = ['torch', 'torchaudio']
import torch import torch
import os
import json import json
from utils_vad import (init_jit_model, from utils_vad import (init_jit_model,
get_speech_timestamps, get_speech_timestamps,
@@ -10,16 +11,32 @@ from utils_vad import (init_jit_model,
read_audio, read_audio,
VADIterator, VADIterator,
collect_chunks, collect_chunks,
drop_chunks) drop_chunks,
Validator,
OnnxWrapper)
def silero_vad(**kwargs): def versiontuple(v):
return tuple(map(int, (v.split('+')[0].split("."))))
def silero_vad(onnx=False, force_onnx_cpu=False):
"""Silero Voice Activity Detector """Silero Voice Activity Detector
Returns a model with a set of utils Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples Please see https://github.com/snakers4/silero-vad for usage examples
""" """
hub_dir = torch.hub.get_dir()
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/silero_vad.jit') if not onnx:
installed_version = torch.__version__
supported_version = '1.12.0'
if versiontuple(installed_version) < versiontuple(supported_version):
raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)')
model_dir = os.path.join(os.path.dirname(__file__), 'files')
if onnx:
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
else:
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
utils = (get_speech_timestamps, utils = (get_speech_timestamps,
save_audio, save_audio,
read_audio, read_audio,
@@ -29,13 +46,16 @@ def silero_vad(**kwargs):
return model, utils return model, utils
def silero_number_detector(**kwargs): def silero_number_detector(onnx=False, force_onnx_cpu=False):
"""Silero Number Detector """Silero Number Detector
Returns a model with a set of utils Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples Please see https://github.com/snakers4/silero-vad for usage examples
""" """
hub_dir = torch.hub.get_dir() if onnx:
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/number_detector.jit') url = 'https://models.silero.ai/vad_models/number_detector.onnx'
else:
url = 'https://models.silero.ai/vad_models/number_detector.jit'
model = Validator(url, force_onnx_cpu)
utils = (get_number_ts, utils = (get_number_ts,
save_audio, save_audio,
read_audio, read_audio,
@@ -45,32 +65,39 @@ def silero_number_detector(**kwargs):
return model, utils return model, utils
def silero_lang_detector(**kwargs): def silero_lang_detector(onnx=False, force_onnx_cpu=False):
"""Silero Language Classifier """Silero Language Classifier
Returns a model with a set of utils Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples Please see https://github.com/snakers4/silero-vad for usage examples
""" """
hub_dir = torch.hub.get_dir() if onnx:
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/number_detector.jit') url = 'https://models.silero.ai/vad_models/number_detector.onnx'
else:
url = 'https://models.silero.ai/vad_models/number_detector.jit'
model = Validator(url, force_onnx_cpu)
utils = (get_language, utils = (get_language,
read_audio) read_audio)
return model, utils return model, utils
def silero_lang_detector_95(**kwargs): def silero_lang_detector_95(onnx=False, force_onnx_cpu=False):
"""Silero Language Classifier (95 languages) """Silero Language Classifier (95 languages)
Returns a model with a set of utils Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples Please see https://github.com/snakers4/silero-vad for usage examples
""" """
hub_dir = torch.hub.get_dir() if onnx:
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/lang_classifier_95.jit') url = 'https://models.silero.ai/vad_models/lang_classifier_95.onnx'
else:
url = 'https://models.silero.ai/vad_models/lang_classifier_95.jit'
model = Validator(url, force_onnx_cpu)
with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_dict_95.json', 'r') as f: model_dir = os.path.join(os.path.dirname(__file__), 'files')
with open(os.path.join(model_dir, 'lang_dict_95.json'), 'r') as f:
lang_dict = json.load(f) lang_dict = json.load(f)
with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_group_dict_95.json', 'r') as f: with open(os.path.join(model_dir, 'lang_group_dict_95.json'), 'r') as f:
lang_group_dict = json.load(f) lang_group_dict = json.load(f)
utils = (get_language_and_group, read_audio) utils = (get_language_and_group, read_audio)

View File

@@ -1,21 +1,12 @@
{ {
"cells": [ "cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "sVNOuHQQjsrp"
},
"source": [
"# PyTorch Examples"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "FpMplOCA2Fwp" "id": "FpMplOCA2Fwp"
}, },
"source": [ "source": [
"## VAD" "#VAD"
] ]
}, },
{ {
@@ -25,7 +16,7 @@
"id": "62A6F_072Fwq" "id": "62A6F_072Fwq"
}, },
"source": [ "source": [
"### Install Dependencies" "## Install Dependencies"
] ]
}, },
{ {
@@ -42,26 +33,39 @@
"# this assumes that you have a relevant version of PyTorch installed\n", "# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio\n", "!pip install -q torchaudio\n",
"\n", "\n",
"SAMPLE_RATE = 16000\n", "SAMPLING_RATE = 16000\n",
"\n", "\n",
"import glob\n",
"import torch\n", "import torch\n",
"torch.set_num_threads(1)\n", "torch.set_num_threads(1)\n",
"\n", "\n",
"from IPython.display import Audio\n", "from IPython.display import Audio\n",
"from pprint import pprint\n", "from pprint import pprint\n",
"\n", "# download example\n",
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pSifus5IilRp"
},
"outputs": [],
"source": [
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
" \n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n", " model='silero_vad',\n",
" force_reload=True)\n", " force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n", "\n",
"(get_speech_timestamps,\n", "(get_speech_timestamps,\n",
" save_audio,\n", " save_audio,\n",
" read_audio,\n", " read_audio,\n",
" VADIterator,\n", " VADIterator,\n",
" collect_chunks) = utils\n", " collect_chunks) = utils"
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
] ]
}, },
{ {
@@ -70,7 +74,7 @@
"id": "fXbbaUO3jsrw" "id": "fXbbaUO3jsrw"
}, },
"source": [ "source": [
"### Full Audio" "## Full Audio"
] ]
}, },
{ {
@@ -90,9 +94,9 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n", "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"# get speech timestamps from full audio file\n", "# get speech timestamps from full audio file\n",
"speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLE_RATE)\n", "speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLING_RATE)\n",
"pprint(speech_timestamps)" "pprint(speech_timestamps)"
] ]
}, },
@@ -106,7 +110,7 @@
"source": [ "source": [
"# merge all speech chunks to one audio\n", "# merge all speech chunks to one audio\n",
"save_audio('only_speech.wav',\n", "save_audio('only_speech.wav',\n",
" collect_chunks(speech_timestamps, wav), sampling_rate=16000) \n", " collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE) \n",
"Audio('only_speech.wav')" "Audio('only_speech.wav')"
] ]
}, },
@@ -116,7 +120,7 @@
"id": "iDKQbVr8jsry" "id": "iDKQbVr8jsry"
}, },
"source": [ "source": [
"### Stream imitation example" "## Stream imitation example"
] ]
}, },
{ {
@@ -130,11 +134,14 @@
"## using VADIterator class\n", "## using VADIterator class\n",
"\n", "\n",
"vad_iterator = VADIterator(model)\n", "vad_iterator = VADIterator(model)\n",
"wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n", "wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"\n", "\n",
"window_size_samples = 1536 # number of samples in a single audio chunk\n", "window_size_samples = 1536 # number of samples in a single audio chunk\n",
"for i in range(0, len(wav), window_size_samples):\n", "for i in range(0, len(wav), window_size_samples):\n",
" speech_dict = vad_iterator(wav[i: i+ window_size_samples], return_seconds=True)\n", " chunk = wav[i: i+ window_size_samples]\n",
" if len(chunk) < window_size_samples:\n",
" break\n",
" speech_dict = vad_iterator(chunk, return_seconds=True)\n",
" if speech_dict:\n", " if speech_dict:\n",
" print(speech_dict, end=' ')\n", " print(speech_dict, end=' ')\n",
"vad_iterator.reset_states() # reset model states after each audio" "vad_iterator.reset_states() # reset model states after each audio"
@@ -150,14 +157,18 @@
"source": [ "source": [
"## just probabilities\n", "## just probabilities\n",
"\n", "\n",
"wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n", "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"speech_probs = []\n", "speech_probs = []\n",
"window_size_samples = 1536\n", "window_size_samples = 1536\n",
"for i in range(0, len(wav), window_size_samples):\n", "for i in range(0, len(wav), window_size_samples):\n",
" speech_prob = model(wav[i: i+ window_size_samples], SAMPLE_RATE).item()\n", " chunk = wav[i: i+ window_size_samples]\n",
" if len(chunk) < window_size_samples:\n",
" break\n",
" speech_prob = model(chunk, SAMPLING_RATE).item()\n",
" speech_probs.append(speech_prob)\n", " speech_probs.append(speech_prob)\n",
"vad_iterator.reset_states() # reset model states after each audio\n",
"\n", "\n",
"pprint(speech_probs[:100])" "print(speech_probs[:10]) # first 10 chunks predicts"
] ]
}, },
{ {
@@ -167,7 +178,7 @@
"id": "36jY0niD2Fww" "id": "36jY0niD2Fww"
}, },
"source": [ "source": [
"## Number detector" "# Number detector"
] ]
}, },
{ {
@@ -178,7 +189,7 @@
"id": "scd1DlS42Fwx" "id": "scd1DlS42Fwx"
}, },
"source": [ "source": [
"### Install Dependencies" "## Install Dependencies"
] ]
}, },
{ {
@@ -193,26 +204,41 @@
"#@title Install and Import Dependencies\n", "#@title Install and Import Dependencies\n",
"\n", "\n",
"# this assumes that you have a relevant version of PyTorch installed\n", "# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile\n", "!pip install -q torchaudio\n",
"\n",
"SAMPLING_RATE = 16000\n",
"\n", "\n",
"import glob\n",
"import torch\n", "import torch\n",
"torch.set_num_threads(1)\n", "torch.set_num_threads(1)\n",
"\n", "\n",
"from IPython.display import Audio\n", "from IPython.display import Audio\n",
"from pprint import pprint\n", "from pprint import pprint\n",
"\n", "# download example\n",
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en_num.wav', 'en_number_example.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dPwCFHmFycUF"
},
"outputs": [],
"source": [
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
" \n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_number_detector',\n", " model='silero_number_detector',\n",
" force_reload=True)\n", " force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n", "\n",
"(get_number_ts,\n", "(get_number_ts,\n",
" save_audio,\n", " save_audio,\n",
" read_audio,\n", " read_audio,\n",
" collect_chunks,\n", " collect_chunks,\n",
" drop_chunks) = utils\n", " drop_chunks) = utils\n"
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
] ]
}, },
{ {
@@ -223,7 +249,7 @@
"id": "qhPa30ij2Fwy" "id": "qhPa30ij2Fwy"
}, },
"source": [ "source": [
"### Full audio" "## Full audio"
] ]
}, },
{ {
@@ -235,7 +261,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"wav = read_audio(f'{files_dir}/en_num.wav')\n", "wav = read_audio('en_number_example.wav', sampling_rate=SAMPLING_RATE)\n",
"# get number timestamps from full audio file\n", "# get number timestamps from full audio file\n",
"number_timestamps = get_number_ts(wav, model)\n", "number_timestamps = get_number_ts(wav, model)\n",
"pprint(number_timestamps)" "pprint(number_timestamps)"
@@ -250,11 +276,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_rate = 16000\n",
"# convert ms in timestamps to samples\n", "# convert ms in timestamps to samples\n",
"for timestamp in number_timestamps:\n", "for timestamp in number_timestamps:\n",
" timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n", " timestamp['start'] = int(timestamp['start'] * SAMPLING_RATE / 1000)\n",
" timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)" " timestamp['end'] = int(timestamp['end'] * SAMPLING_RATE / 1000)"
] ]
}, },
{ {
@@ -268,7 +293,7 @@
"source": [ "source": [
"# merge all number chunks to one audio\n", "# merge all number chunks to one audio\n",
"save_audio('only_numbers.wav',\n", "save_audio('only_numbers.wav',\n",
" collect_chunks(number_timestamps, wav), sample_rate) \n", " collect_chunks(number_timestamps, wav), SAMPLING_RATE) \n",
"Audio('only_numbers.wav')" "Audio('only_numbers.wav')"
] ]
}, },
@@ -283,7 +308,7 @@
"source": [ "source": [
"# drop all number chunks from audio\n", "# drop all number chunks from audio\n",
"save_audio('no_numbers.wav',\n", "save_audio('no_numbers.wav',\n",
" drop_chunks(number_timestamps, wav), sample_rate) \n", " drop_chunks(number_timestamps, wav), SAMPLING_RATE) \n",
"Audio('no_numbers.wav')" "Audio('no_numbers.wav')"
] ]
}, },
@@ -294,7 +319,7 @@
"id": "PnKtJKbq2Fwz" "id": "PnKtJKbq2Fwz"
}, },
"source": [ "source": [
"## Language detector" "# Language detector"
] ]
}, },
{ {
@@ -305,7 +330,7 @@
"id": "F5cAmMbP2Fwz" "id": "F5cAmMbP2Fwz"
}, },
"source": [ "source": [
"### Install Dependencies" "## Install Dependencies"
] ]
}, },
{ {
@@ -320,23 +345,37 @@
"#@title Install and Import Dependencies\n", "#@title Install and Import Dependencies\n",
"\n", "\n",
"# this assumes that you have a relevant version of PyTorch installed\n", "# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile\n", "!pip install -q torchaudio\n",
"\n",
"SAMPLING_RATE = 16000\n",
"\n", "\n",
"import glob\n",
"import torch\n", "import torch\n",
"torch.set_num_threads(1)\n", "torch.set_num_threads(1)\n",
"\n", "\n",
"from IPython.display import Audio\n", "from IPython.display import Audio\n",
"from pprint import pprint\n", "from pprint import pprint\n",
"\n", "# download example\n",
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JfRKDZiRztFe"
},
"outputs": [],
"source": [
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
" \n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_lang_detector',\n", " model='silero_lang_detector',\n",
" force_reload=True)\n", " force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n", "\n",
"(get_language,\n", "get_language, read_audio = utils"
" read_audio) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
] ]
}, },
{ {
@@ -347,7 +386,7 @@
"id": "iC696eMX2Fwz" "id": "iC696eMX2Fwz"
}, },
"source": [ "source": [
"### Full audio" "## Full audio"
] ]
}, },
{ {
@@ -359,266 +398,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"wav = read_audio(f'{files_dir}/en.wav')\n", "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"lang = get_language(wav, model)\n", "lang = get_language(wav, model)\n",
"print(lang)" "print(lang)"
] ]
},
{
"cell_type": "markdown",
"metadata": {
"id": "57avIBd6jsrz"
},
"source": [
"# ONNX Example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hEhnfORV2Fw0"
},
"source": [
"## VAD"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Cy7y-NAyALSe"
},
"source": [
"**TO BE DONE**"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "7QMvUvpg2Fw4"
},
"source": [
"## Number detector"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "tBPDkpHr2Fw4"
},
"source": [
"### Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"hidden": true,
"id": "PdjGd56R2Fw5"
},
"outputs": [],
"source": [
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile onnxruntime\n",
"\n",
"import glob\n",
"import torch\n",
"import onnxruntime\n",
"from pprint import pprint\n",
"\n",
"from IPython.display import Audio\n",
"\n",
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_number_detector',\n",
" force_reload=True)\n",
"\n",
"(get_number_ts,\n",
" save_audio,\n",
" read_audio,\n",
" collect_chunks,\n",
" drop_chunks) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
"\n",
"def init_onnx_model(model_path: str):\n",
" return onnxruntime.InferenceSession(model_path)\n",
"\n",
"def validate_onnx(model, inputs):\n",
" with torch.no_grad():\n",
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
" outs = model.run(None, ort_inputs)\n",
" outs = [torch.Tensor(x) for x in outs]\n",
" return outs"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "I9QWSFZh2Fw5"
},
"source": [
"### Full Audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "_r6QZiwu2Fw5"
},
"outputs": [],
"source": [
"model = init_onnx_model(f'{files_dir}/number_detector.onnx')\n",
"wav = read_audio(f'{files_dir}/en_num.wav')\n",
"\n",
"# get number timestamps from full audio file\n",
"number_timestamps = get_number_ts(wav, model, run_function=validate_onnx)\n",
"pprint(number_timestamps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "FN4aDwLV2Fw5"
},
"outputs": [],
"source": [
"sample_rate = 16000\n",
"# convert ms in timestamps to samples\n",
"for timestamp in number_timestamps:\n",
" timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n",
" timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "JnvS6WTK2Fw5"
},
"outputs": [],
"source": [
"# merge all number chunks to one audio\n",
"save_audio('only_numbers.wav',\n",
" collect_chunks(number_timestamps, wav), 16000) \n",
"Audio('only_numbers.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "yUxOcOFG2Fw6"
},
"outputs": [],
"source": [
"# drop all number chunks from audio\n",
"save_audio('no_numbers.wav',\n",
" drop_chunks(number_timestamps, wav), 16000) \n",
"Audio('no_numbers.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "SR8Bgcd52Fw6"
},
"source": [
"## Language detector"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "PBnXPtKo2Fw6"
},
"source": [
"### Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"hidden": true,
"id": "iNkDWJ3H2Fw6"
},
"outputs": [],
"source": [
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile onnxruntime\n",
"\n",
"import glob\n",
"import torch\n",
"import onnxruntime\n",
"from pprint import pprint\n",
"\n",
"from IPython.display import Audio\n",
"\n",
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_lang_detector',\n",
" force_reload=True)\n",
"\n",
"(get_language,\n",
" read_audio) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
"\n",
"def init_onnx_model(model_path: str):\n",
" return onnxruntime.InferenceSession(model_path)\n",
"\n",
"def validate_onnx(model, inputs):\n",
" with torch.no_grad():\n",
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
" outs = model.run(None, ort_inputs)\n",
" outs = [torch.Tensor(x) for x in outs]\n",
" return outs"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true,
"id": "G8N8oP4q2Fw6"
},
"source": [
"### Full Audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "WHXnh9IV2Fw6"
},
"outputs": [],
"source": [
"model = init_onnx_model(f'{files_dir}/number_detector.onnx')\n",
"wav = read_audio(f'{files_dir}/en.wav')\n",
"\n",
"lang = get_language(wav, model, run_function=validate_onnx)\n",
"print(lang)"
]
} }
], ],
"metadata": { "metadata": {

View File

@@ -7,14 +7,113 @@ import warnings
languages = ['ru', 'en', 'de', 'es'] languages = ['ru', 'en', 'de', 'es']
def validate(model, class OnnxWrapper():
inputs: torch.Tensor,
**kwargs): def __init__(self, path, force_onnx_cpu=False):
with torch.no_grad(): import numpy as np
outs = model(inputs, **kwargs) global np
if len(outs.shape) == 1: import onnxruntime
return outs[1:] if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
return outs[:, 1] # 0 for noise, 1 for speech self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])
else:
self.session = onnxruntime.InferenceSession(path)
self.session.intra_op_num_threads = 1
self.session.inter_op_num_threads = 1
self.reset_states()
self.sample_rates = [8000, 16000]
def _validate_input(self, x, sr: int):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() > 2:
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
if sr != 16000 and (sr % 16000 == 0):
step = sr // 16000
x = x[::step]
sr = 16000
if sr not in self.sample_rates:
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
return x, sr
def reset_states(self, batch_size=1):
self._h = np.zeros((2, batch_size, 64)).astype('float32')
self._c = np.zeros((2, batch_size, 64)).astype('float32')
self._last_sr = 0
self._last_batch_size = 0
def __call__(self, x, sr: int):
x, sr = self._validate_input(x, sr)
batch_size = x.shape[0]
if not self._last_batch_size:
self.reset_states(batch_size)
if (self._last_sr) and (self._last_sr != sr):
self.reset_states(batch_size)
if (self._last_batch_size) and (self._last_batch_size != batch_size):
self.reset_states(batch_size)
if sr in [8000, 16000]:
ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr)}
ort_outs = self.session.run(None, ort_inputs)
out, self._h, self._c = ort_outs
else:
raise ValueError()
self._last_sr = sr
self._last_batch_size = batch_size
out = torch.tensor(out)
return out
def audio_forward(self, x, sr: int, num_samples: int = 512):
outs = []
x, sr = self._validate_input(x, sr)
if x.shape[1] % num_samples:
pad_num = num_samples - (x.shape[1] % num_samples)
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
self.reset_states(x.shape[0])
for i in range(0, x.shape[1], num_samples):
wavs_batch = x[:, i:i+num_samples]
out_chunk = self.__call__(wavs_batch, sr)
outs.append(out_chunk)
stacked = torch.cat(outs, dim=1)
return stacked.cpu()
class Validator():
def __init__(self, url, force_onnx_cpu):
self.onnx = True if url.endswith('.onnx') else False
torch.hub.download_url_to_file(url, 'inf.model')
if self.onnx:
import onnxruntime
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider'])
else:
self.model = onnxruntime.InferenceSession('inf.model')
else:
self.model = init_jit_model(model_path='inf.model')
def __call__(self, inputs: torch.Tensor):
with torch.no_grad():
if self.onnx:
ort_inputs = {'input': inputs.cpu().numpy()}
outs = self.model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
else:
outs = self.model(inputs)
return outs
def read_audio(path: str, def read_audio(path: str,
@@ -65,7 +164,7 @@ def get_speech_timestamps(audio: torch.Tensor,
sampling_rate: int = 16000, sampling_rate: int = 16000,
min_speech_duration_ms: int = 250, min_speech_duration_ms: int = 250,
min_silence_duration_ms: int = 100, min_silence_duration_ms: int = 100,
window_size_samples: int = 1536, window_size_samples: int = 512,
speech_pad_ms: int = 30, speech_pad_ms: int = 30,
return_seconds: bool = False, return_seconds: bool = False,
visualize_probs: bool = False): visualize_probs: bool = False):
@@ -125,8 +224,16 @@ def get_speech_timestamps(audio: torch.Tensor,
if len(audio.shape) > 1: if len(audio.shape) > 1:
raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?") raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?")
if sampling_rate > 16000 and (sampling_rate % 16000 == 0):
step = sampling_rate // 16000
sampling_rate = 16000
audio = audio[::step]
warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!')
else:
step = 1
if sampling_rate == 8000 and window_size_samples > 768: if sampling_rate == 8000 and window_size_samples > 768:
warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 1536 for 8000 sample rate!') warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!')
if window_size_samples not in [256, 512, 768, 1024, 1536]: if window_size_samples not in [256, 512, 768, 1024, 1536]:
warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate') warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate')
@@ -174,7 +281,7 @@ def get_speech_timestamps(audio: torch.Tensor,
triggered = False triggered = False
continue continue
if current_speech: if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
current_speech['end'] = audio_length_samples current_speech['end'] = audio_length_samples
speeches.append(current_speech) speeches.append(current_speech)
@@ -187,7 +294,8 @@ def get_speech_timestamps(audio: torch.Tensor,
speech['end'] += int(silence_duration // 2) speech['end'] += int(silence_duration // 2)
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2)) speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2))
else: else:
speech['end'] += int(speech_pad_samples) speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples))
else: else:
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
@@ -195,6 +303,10 @@ def get_speech_timestamps(audio: torch.Tensor,
for speech_dict in speeches: for speech_dict in speeches:
speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1) speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1)
speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1) speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1)
elif step > 1:
for speech_dict in speeches:
speech_dict['start'] *= step
speech_dict['end'] *= step
if visualize_probs: if visualize_probs:
make_visualization(speech_probs, window_size_samples / sampling_rate) make_visualization(speech_probs, window_size_samples / sampling_rate)
@@ -206,10 +318,9 @@ def get_number_ts(wav: torch.Tensor,
model, model,
model_stride=8, model_stride=8,
hop_length=160, hop_length=160,
sample_rate=16000, sample_rate=16000):
run_function=validate):
wav = torch.unsqueeze(wav, dim=0) wav = torch.unsqueeze(wav, dim=0)
perframe_logits = run_function(model, wav)[0] perframe_logits = model(wav)[0]
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided) perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
extended_preds = [] extended_preds = []
for i in perframe_preds: for i in perframe_preds:
@@ -236,10 +347,9 @@ def get_number_ts(wav: torch.Tensor,
def get_language(wav: torch.Tensor, def get_language(wav: torch.Tensor,
model, model):
run_function=validate):
wav = torch.unsqueeze(wav, dim=0) wav = torch.unsqueeze(wav, dim=0)
lang_logits = run_function(model, wav)[2] lang_logits = model(wav)[2]
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1 lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
assert lang_pred < len(languages) assert lang_pred < len(languages)
return languages[lang_pred] return languages[lang_pred]
@@ -249,10 +359,9 @@ def get_language_and_group(wav: torch.Tensor,
model, model,
lang_dict: dict, lang_dict: dict,
lang_group_dict: dict, lang_group_dict: dict,
top_n=1, top_n=1):
run_function=validate):
wav = torch.unsqueeze(wav, dim=0) wav = torch.unsqueeze(wav, dim=0)
lang_logits, lang_group_logits = run_function(model, wav) lang_logits, lang_group_logits = model(wav)
softm = torch.softmax(lang_logits, dim=1).squeeze() softm = torch.softmax(lang_logits, dim=1).squeeze()
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze() softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
@@ -304,6 +413,10 @@ class VADIterator:
self.model = model self.model = model
self.threshold = threshold self.threshold = threshold
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
if sampling_rate not in [8000, 16000]:
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
self.reset_states() self.reset_states()
@@ -323,6 +436,13 @@ class VADIterator:
return_seconds: bool (default - False) return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples) whether return timestamps in seconds (default - samples)
""" """
if not torch.is_tensor(x):
try:
x = torch.Tensor(x)
except:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
window_size_samples = len(x[0]) if x.dim() == 2 else len(x) window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
self.current_sample += window_size_samples self.current_sample += window_size_samples