228 Commits

Author SHA1 Message Date
Dimitrii Voronin
64b863d2ff Update README.md 2024-09-24 14:48:35 +03:00
Dimitrii Voronin
8a3600665b Merge pull request #540 from snakers4/adamnsandle-patch-2
Update README.md
2024-09-24 13:45:31 +03:00
Dimitrii Voronin
9c2c90aa1c Update README.md 2024-09-24 13:45:16 +03:00
Dimitrii Voronin
1d48167271 Merge pull request #539 from gengyuchao/update/python_pyaudio_example
Fixed the pyaudio example can not run issue.
2024-09-11 12:27:15 +03:00
GengYuchao
d0139d94d9 Fixed the pyaudio example can not run issue.
Update the related packages.
2024-09-11 00:45:49 +08:00
Dimitrii Voronin
46f94b7d60 Merge pull request #529 from snakers4/adamnsandle
Adamnsandle
2024-08-22 17:31:42 +03:00
adamnsandle
3de3ee3abe Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2024-08-22 14:30:27 +00:00
adamnsandle
e680ea6633 add half onnx model 2024-08-22 14:30:13 +00:00
Dimitrii Voronin
199de226e5 Merge pull request #528 from snakers4/adamnsandle
add neg_threshold parameter explicitly
2024-08-22 16:39:33 +03:00
adamnsandle
4109b107c1 add neg_threshold parameter explicitly 2024-08-20 08:53:15 +00:00
Alexander Veysov
36854a90db Merge pull request #526 from snakers4/adamnsandle
код для тюнинга
2024-08-19 20:01:21 +03:00
adamnsandle
827e86e685 добавлен поиск порогов 2024-08-19 16:53:28 +00:00
Dimitrii Voronin
e706ec6fee Update README.md 2024-08-19 18:31:11 +03:00
adamnsandle
88df0ce1dd код для тюнинга 2024-08-19 14:36:45 +00:00
Dimitrii Voronin
d18b91e037 Merge pull request #521 from snakers4/adamnsandle
downgrade onnxruntime dependency
2024-08-09 14:23:16 +03:00
adamnsandle
1e3f343767 downgrade onnxruntime dependency 2024-08-09 11:15:22 +00:00
Alexander Veysov
6a8ee81ee0 Merge pull request #507 from nganju98/master
add csharp example
2024-07-21 09:03:38 +03:00
nick.ganju
cb25c0c047 add csharp example 2024-07-20 22:59:18 -04:00
Alexander Veysov
7af8628a27 Merge pull request #506 from yuguanqin/master
Add java example for wav file & support V5 model
2024-07-18 07:34:40 +03:00
yuguanqin
3682cb189c java example for whole wav file & compatible with V5 model 2024-07-18 10:34:02 +08:00
Dimitrii Voronin
57c0b51f9b Merge pull request #505 from snakers4/adamnsandle
VadIterator first chunk bag fx
2024-07-15 13:42:36 +03:00
adamnsandle
dd0b143803 VadIterator first chunk bag fx 2024-07-15 10:37:46 +00:00
Alexander Veysov
181cdf92b6 Merge pull request #497 from rumbleFTW/fix/rust-example-v5
fix: rust example for v5 checkpoint
2024-07-11 17:48:58 +03:00
rumbleFTW
a7bd2dd38f fix: rust example 2024-07-11 20:06:54 +05:30
Alexander Veysov
df7de797a5 Merge pull request #496 from streamer45/update-golang-example
Fix Golang example
2024-07-10 21:31:15 +03:00
streamer45
87ed11b508 Fix Golang example 2024-07-10 20:26:41 +02:00
Alexander Veysov
84768cefdf Merge pull request #493 from snakers4/adamnsandle
Adamnsandle
2024-07-09 16:16:40 +03:00
adamnsandle
6de3660f25 fx version 2024-07-09 10:27:00 +00:00
adamnsandle
d9a6941852 add pip examples to collab 2024-07-09 10:20:50 +00:00
adamnsandle
dfdc9a484e Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2024-07-09 09:51:42 +00:00
adamnsandle
f2e3a23d96 fx version 2024-07-09 09:45:10 +00:00
Dimitrii Voronin
2b97f61160 Merge pull request #492 from snakers4/adamnsandle-patch-1
Create python-publish.yml
2024-07-09 12:42:23 +03:00
Dimitrii Voronin
e8850d2b9b Create python-publish.yml 2024-07-09 12:41:49 +03:00
adamnsandle
657dac8736 add pyproject.toml 2024-07-09 09:31:18 +00:00
Dimitrii Voronin
412a478e29 Update README.md 2024-07-09 12:25:06 +03:00
adamnsandle
9adf6d2192 add abs import path 2024-07-09 09:06:05 +00:00
adamnsandle
8a2a73c14f fx package import 2024-07-09 09:02:33 +00:00
adamnsandle
3e0305559d fx hubconf 2024-07-09 08:32:18 +00:00
adamnsandle
f0d880d79c make package structure 2024-07-09 08:26:17 +00:00
Dimitrii Voronin
3888946c0c Merge pull request #489 from streamer45/update-golang-example
Update Golang example to support model v5
2024-07-08 09:03:12 +03:00
streamer45
24f51645d0 Update to support model v5 2024-07-08 07:43:42 +02:00
Dimitrii Voronin
fdbb0a3a81 Merge pull request #482 from filtercodes/v5_cpp_support
cpp example
2024-07-01 19:17:44 +03:00
Stefan Miletic
60ae7abfb7 v5 model cpp example 2024-07-01 15:32:40 +01:00
Stefan Miletic
0b3d43d432 cpp example v5 model 2024-07-01 15:04:48 +01:00
Dimitrii Voronin
a395853982 Merge pull request #475 from eltociear/patch-1
Update microphone_and_webRTC_integration.py
2024-07-01 12:09:08 +03:00
Dimitrii Voronin
78958b6fb6 Merge pull request #481 from snakers4/adamnsandle
Adamnsandle
2024-07-01 12:02:50 +03:00
adamnsandle
902cfc9248 fx dtype bug 2024-07-01 09:00:59 +00:00
adamnsandle
89e66a3474 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2024-07-01 08:54:27 +00:00
Alexander Veysov
a3bdebed16 Update README.md 2024-07-01 10:21:20 +03:00
Ikko Eltociear Ashimine
4bdcf31d17 Update microphone_and_webRTC_integration.py
nubmer -> number
2024-06-30 02:10:59 +09:00
adamnsandle
136cdcdf5b tst 2024-06-28 14:13:18 +00:00
Alexander Veysov
5cd2ba54db Update hubconf.py 2024-06-27 22:58:52 +03:00
Alexander Veysov
06b9e17c1e Update hubconf.py 2024-06-27 22:56:37 +03:00
Alexander Veysov
aace7e25b1 Update README.md 2024-06-27 22:25:19 +03:00
Alexander Veysov
74c3f7f3fb Remove old unused utils 2024-06-27 22:08:15 +03:00
Alexander Veysov
d77d0fd42c Merge pull request #468 from snakers4/adamnsandle
Adamnsandle
2024-06-27 21:03:38 +03:00
Dimitrii Voronin
49b421a9cd Update README.md 2024-06-27 19:12:27 +03:00
adamnsandle
fd1f1a62b7 v5 initial push 2024-06-27 15:41:20 +00:00
adamnsandle
8145ed9a91 Merge branch 'master' of github.com:snakers4/silero-vad 2024-06-25 08:13:21 +00:00
Dimitrii Voronin
4392725328 Merge pull request #467 from gau-nernst/fix_grad
Replace `torch.set_grad_enabled(False)` with `torch.no_grad()`
2024-06-21 16:31:27 +03:00
Thien Tran
8b0566682b use torch.no_grad() 2024-06-21 21:10:48 +08:00
Alexander Veysov
82342b8a4c Merge pull request #457 from akmitrich/rust-example
Rust example
2024-05-30 13:25:18 +03:00
Alexander Kalashnikov
4b8ce743a8 Rust example 2024-05-30 12:08:47 +03:00
Alexander Veysov
b4b6f2ab3e Update README.md 2024-04-02 07:18:09 +03:00
Alexander Veysov
5b02d84a4a Update README.md 2024-03-26 22:28:38 +03:00
Alexander Veysov
60465a7e61 Update README.md 2024-03-26 22:21:43 +03:00
Dimitrii Voronin
fcef5b3955 Merge pull request #435 from snakers4/adamnsandle-patch-1
Update README.md
2024-03-26 21:37:22 +03:00
Dimitrii Voronin
156436762f Update README.md 2024-03-26 21:37:06 +03:00
Dimitrii Voronin
a27b176e45 Update README.md 2024-03-26 21:24:06 +03:00
Alexander Veysov
8125483ef7 Merge pull request #433 from snakers4/snakers4-patch-3
Update README.md
2024-03-26 21:23:27 +03:00
Alexander Veysov
6969dcc2dc Update README.md 2024-03-26 21:23:16 +03:00
Alexander Veysov
6c816a05f0 Merge pull request #432 from snakers4/snakers4-patch-2
Update README.md
2024-03-26 21:20:03 +03:00
Alexander Veysov
9dc344df7f Update README.md 2024-03-26 21:19:40 +03:00
Dimitrii Voronin
41c5172dd9 Update README.md 2024-03-26 21:17:35 +03:00
Alexander Veysov
894ea259f9 Merge pull request #431 from snakers4/adamnsandle
add open datasets' annotation
2024-03-26 21:13:58 +03:00
adamnsandle
f56f56ffaa add open datasets' annotation 2024-03-26 18:06:11 +00:00
Alexander Veysov
6c8d844710 Merge pull request #424 from yairl/master
Support both sox and sox_io backends for in-place audio resampling.
2024-02-16 14:08:40 +03:00
Yair Lifshitz
d8cc947c73 Support both sox and sox_io backends for in-place audio resampling. 2024-02-16 05:58:09 -05:00
Alexander Veysov
797a88a386 Update utils_vad.py 2024-02-16 10:55:47 +03:00
Alexander Veysov
48b7c742dd Merge pull request #423 from snakers4/snakers4-patch-1
Update utils_vad.py
2024-02-16 10:53:21 +03:00
Alexander Veysov
c5ec6bae3d Update utils_vad.py 2024-02-16 10:53:12 +03:00
Alexander Veysov
af152c18f6 Merge pull request #421 from yairl/master
Perform in-place resampling during read_audio.
2024-02-16 09:49:28 +03:00
Yair Lifshitz
d391f4c302 Use SoX when possible for loading a file with in-place resampling, ffmpeg otherwise. 2024-02-15 12:25:22 -05:00
Yair Lifshitz
bf18ea6b56 Perform in-place resampling during read_audio. 2024-02-14 17:11:57 -05:00
Alexander Veysov
a65732a393 Merge pull request #417 from abinthomasonline/offset-bug-fix
fix window size offset bug
2024-01-22 09:15:44 +03:00
Abin Thomas
aae1e4f40d fix window size offset bug 2024-01-22 11:08:24 +05:30
Alexander Veysov
94504ece54 Merge pull request #407 from bygreencn/master
Fix a bug at c sample code and some bugs at wav.h.
2023-12-17 20:26:40 +03:00
bygreencn
0b7da6e74b Fix wav functions:
1. fix data_size is not correct and be 0.
2. detect data format of IEEE-float.
3. add PCMS8bit, PCMS16bit and PCMS32 convert to float 32bit at class WavReader.
2023-12-17 23:00:12 +08:00
bygreencn
efb5effc8f Fix a bug in c code sample when only one timestamp and start from 0. 2023-12-17 23:00:12 +08:00
Alexander Veysov
03dc3fae5c Merge pull request #406 from bygreencn/master
Make the c code sample have the same function as the python code
2023-12-15 16:46:58 +03:00
bygreencn
4a6d1701a4 Make the c code sample have the same function as the python code 2023-12-15 21:32:22 +08:00
Alexander Veysov
5e7ee10ee0 Merge pull request #392 from xiaoqiang306/fix_example_cpp
fix int16_t bytes normalized to float
2023-11-13 15:20:08 +03:00
jiqiang.fu
03fb810fab fix int16_t bytes normalized to float 2023-11-13 17:14:34 +08:00
Alexander Veysov
e30a7e32a9 Merge pull request #391 from streamer45/golang-example
Implement Golang example
2023-11-09 08:33:20 +03:00
streamer45
bbbc657dad Golang example 2023-11-08 17:48:36 -06:00
Alexander Veysov
cb92cdd1e3 Merge pull request #386 from VvvvvGH/master
add java onnx inference example
2023-10-18 09:29:04 +03:00
VvvvvGH
3780baf49f add java onnx example 2023-10-18 13:57:18 +08:00
Alexander Veysov
563106ef8c Merge pull request #350 from archive-r/archive-r-fix-typo
Fix typo
2023-06-16 09:45:49 +03:00
kh
f795bc479b Fix typo
https://github.com/snakers4/silero-vad/discussions/319#discussion-5081706
2023-06-16 11:47:59 +09:00
Alexander Veysov
7e9680bc83 Merge pull request #342 from AlexRainHao/master
fix #341 issue of cpp example coding mistake
2023-05-18 15:17:36 +03:00
AlexRainHao
3b4c02dfe3 fix #341 issue of cpp example coding mistake 2023-05-18 20:12:03 +08:00
Alexander Veysov
bc5a0a2dbf Merge pull request #340 from chenqianhe/master
fix speech and silence state transition
2023-05-12 12:03:50 +03:00
Qianhe Chen
b03fcb2ebe fix speech and silence state transition 2023-05-12 16:59:51 +08:00
Dimitrii Voronin
026bc3d292 Merge pull request #329 from mhThomsen/master
(Bug fix) Slices in correct dimension (audio dim), so batch size is not reduced
2023-04-28 15:00:12 +03:00
Dimitrii Voronin
e755baa3c2 Merge branch 'master' into master 2023-04-28 15:00:02 +03:00
Dimitrii Voronin
b88084c7ed Merge pull request #332 from snakers4/adamnsandle
fx https://github.com/snakers4/silero-vad/pull/329 bug
2023-04-28 14:54:44 +03:00
adamnsandle
a9d2b591de fx https://github.com/snakers4/silero-vad/pull/329 bug 2023-04-28 11:48:01 +00:00
Dimitrii Voronin
c3c67cdcb8 Merge pull request #330 from snakers4/adamnsandle
del deprecated models examples
2023-04-27 15:05:45 +03:00
adamnsandle
874c66ccbc del deprecated models examples 2023-04-27 12:00:01 +00:00
Alexander Veysov
51fbbcb32e Update README.md 2023-04-27 13:33:03 +03:00
Alexander Veysov
14a0715955 Deprecate lang detector and number detector models 2023-04-27 13:22:00 +03:00
Alexander Veysov
a0d26769e0 Update README.md 2023-04-27 13:19:39 +03:00
mhThomsen
e0c2015193 slices in correct dimension (audio dim), so batch size is not reduced 2023-04-27 07:37:54 +02:00
Alexander Veysov
5872cffd78 Update README.md 2023-03-29 21:56:23 +03:00
Dimitrii Voronin
86400b9a12 Merge pull request #313 from snakers4/adamnsandle
change int2float
2023-03-28 17:31:27 +03:00
adamnsandle
6ef43d1c5d change int2float 2023-03-28 14:30:35 +00:00
Alexander Veysov
540e092276 Merge pull request #308 from zzzacwork/example-parallelization
add a simple parallelization example
2023-03-13 08:18:23 +03:00
Ziyuan Wang
55c41abf46 use a process specific copy of model 2023-03-10 16:09:07 +00:00
Ziyuan Wang
17903cb41d add a initializer 2023-03-09 16:27:36 +00:00
Ziyuan Wang
c39dccc1fd add a initializer 2023-03-09 16:27:18 +00:00
Ziyuan Wang
a6a067de44 add a initializer 2023-03-09 16:16:51 +00:00
Ziyuan Wang
9865b3cb93 fix typos 2023-03-09 04:26:15 +00:00
Ziyuan Wang
3d10c2d950 add parallel example 2023-03-09 04:17:34 +00:00
Dimitrii Voronin
4f57fae3fa Merge pull request #289 from Tomiinek/patch-1
Fixing ONNX init
2023-02-28 15:43:18 +02:00
Tomáš Nekvinda
085d76f08e Update utils_vad.py 2023-02-09 11:13:02 +01:00
Dimitrii Voronin
262bcb4b40 Merge pull request #285 from snakers4/adamnsandle
fx versiontuple bug
2023-01-09 11:09:32 +02:00
adamnsandle
e84eca68d7 fx versiontuple bug 2023-01-09 09:07:59 +00:00
Dimitrii Voronin
e7c4539106 Merge pull request #280 from bclark-videra/test/hubconf-with-ref
Use the path to hubconf.py to find models
2022-12-29 14:02:09 +02:00
Dimitrii Voronin
a480e85aec Merge pull request #282 from saenyakorn/master
Add `progress_tracking_callback` argument to `get_speech_timestamps` function
2022-12-29 13:31:55 +02:00
Saenyakorn Siangsanoh
c69cb6c9c0 fix progress logic 2022-12-28 14:40:55 +07:00
Saenyakorn Siangsanoh
11da69d88b add progress_tracking callback to get_speech_timestamps 2022-12-28 14:18:42 +07:00
Byron Clark
df1d52042d Use the path to hubconf.py to find models
While `torch.hub.load` uses predictable names when downloading from github,
the name referenced only worked when using the `master` branch of the
repo. Using a ref like `snakers4/silero_vad:v4.0` results in a
different directory name and the model files not being found.

Since the model files are stored in a path relative to hubconf.py, use
the path of hubconf.py to find the models instead of assuming the
directory `torch.hub.load` will extract the files to.
2022-12-22 14:47:14 -07:00
Alexander Veysov
d5a944b9f1 Merge pull request #279 from pengzhendong/patch-1
fix sample rate of onnx input
2022-12-21 06:00:10 +03:00
彭震东
d90416e63e fix sample rate of onnx input 2022-12-21 08:36:52 +08:00
Alexander Veysov
91f0aaecef Merge pull request #277 from yuGAN6/yugan6
Yugan6
2022-12-11 16:27:48 +03:00
yuGAN6
015bfc8b21 Update README.md 2022-12-11 21:14:48 +08:00
yuGAN6
5d56b1ea40 Update README.md 2022-12-11 21:13:46 +08:00
yuGAN6
ff3c596cab Update README.md 2022-12-11 21:09:57 +08:00
yuGAN6
63e1be5a22 Update README.md 2022-12-11 21:08:31 +08:00
yuGAN6
1d8f8f38db Move to example 2022-12-11 21:05:11 +08:00
yuGAN6
7198087152 Move into examples 2022-12-11 13:06:21 +08:00
Alexander Veysov
ad57d17f5f Update README.md 2022-12-11 05:57:19 +03:00
yuGAN6
04e87c208a Move directory 2022-12-10 22:50:14 +08:00
yuGAN6
c583fd1e52 Add c++ onnxruntime example 2022-12-10 22:29:34 +08:00
Dimitrii Voronin
5814e548db Merge pull request #270 from kafan1986/master
Resolves Windows inference issue INVALID_ARGUMENT : Unexpected input data type.
2022-12-06 11:38:41 +02:00
Dimitrii Voronin
42565d5baa Merge pull request #271 from b-med/max_speech_duration_v4
Max speech duration + bits per sample
2022-12-06 11:29:15 +02:00
Mohamed Bouaziz
ab7af9745b delete commented lines 2022-11-17 17:52:34 +01:00
Mohamed Bouaziz
83e68c56ea Merge branch 'master' of https://github.com/snakers4/silero-vad into max_speech_duration_v4 2022-11-17 17:32:22 +01:00
kafan1986
d3882c9ebf Merge pull request #1 from kafan1986/kafan1986-patch-1
Solves data type mismatch issue on windows
2022-11-11 23:50:20 +05:30
kafan1986
25f04dda35 Solves data type mismatch issue on windows
Solves windows 11 error regarding mismatched data type. Expected data type is int64 and actual data type is int32
2022-11-11 23:46:51 +05:30
Mohamed Bouaziz
94b4c21874 utils_vad save_audio force bits_per_sample=16 2022-11-04 21:43:28 +01:00
Mohamed Bouaziz
324bc74a58 utils_vad max duration 2022-11-04 21:42:22 +01:00
Dimitrii Voronin
82d199ff22 Merge pull request #256 from snakers4/adamnsandle
Adamnsandle
2022-10-28 13:57:10 +03:00
adamnsandle
5ba388d894 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-10-28 10:55:59 +00:00
adamnsandle
790844ba0f revert to exception 2022-10-28 10:55:46 +00:00
Dimitrii Voronin
51b5245410 Merge pull request #255 from snakers4/adamnsandle
Adamnsandle
2022-10-28 13:33:18 +03:00
adamnsandle
888970e77d Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-10-28 10:32:08 +00:00
adamnsandle
cb6d308335 fx 2022-10-28 10:31:55 +00:00
adamnsandle
1b212c6e95 change exception to warning 2022-10-28 10:26:07 +00:00
adamnsandle
452060ad65 fx 2022-10-28 10:13:00 +00:00
Dimitrii Voronin
c7eab751b5 Merge pull request #253 from snakers4/adamnsandle
add torch version check
2022-10-28 13:09:18 +03:00
adamnsandle
d1714a9ff7 add torch version check 2022-10-28 10:08:07 +00:00
Dimitrii Voronin
94c79d899d Merge pull request #251 from snakers4/adamnsandle
v4 hotfix
2022-10-27 20:26:31 +03:00
adamnsandle
1baf307b35 v4 hotfix 2022-10-27 17:25:31 +00:00
Dimitrii Voronin
e324285cdc Merge pull request #247 from snakers4/adamnsandle
Adamnsandle
2022-10-26 19:17:44 +03:00
adamnsandle
13dce2d067 Merge branch 'MASTER' into adamnsandle 2022-10-26 16:13:37 +00:00
adamnsandle
081e6b9886 VAD v4 2022-10-26 16:10:20 +00:00
Alexander Veysov
572134fdf1 Update README.md 2022-10-25 05:52:53 +03:00
Dimitrii Voronin
a799dea837 Merge pull request #244 from owlsometech-kenyang/feature/support-force-onnx-cpu
Suggesting a new kwarg: force_onnx_cpu
2022-10-14 11:53:46 +03:00
ChiehKai Yang
17209e6c4f add new parameter: force_onnx_cpu 2022-10-12 01:56:43 +08:00
adamnsandle
6661cc9691 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-06-02 10:41:54 +00:00
Dimitrii Voronin
7c671a75c2 Merge pull request #199 from snakers4/adamnsandle
fx end of chunk may exceed audio length
2022-06-02 13:40:42 +03:00
adamnsandle
622016e672 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-06-02 10:40:11 +00:00
adamnsandle
8eba346bc9 fx end of chunk may exceed audio length 2022-06-02 10:39:16 +00:00
Dimitrii Voronin
900c71a109 Merge pull request #198 from snakers4/adamnsandle
fx get_speech ts start of an audio chunk pad
2022-06-02 13:33:36 +03:00
adamnsandle
bf0127e016 fx get_speech ts start of an audio chunk pad 2022-06-02 10:32:32 +00:00
Dimitrii Voronin
ea7af70fe9 Merge pull request #182 from snakers4/adamnsandle
Adamnsandle
2022-04-05 14:36:00 +03:00
adamnsandle
8cdc8d36c9 fx 2022-04-05 11:35:23 +00:00
adamnsandle
6e9fd77500 fx stram imitation example bug 2022-04-05 11:33:34 +00:00
Alexander Veysov
6cc08b1077 Merge pull request #170 from gabrielziegler3/169-fix-min-speech-duration-bug
Fix #169
2022-02-10 12:18:23 +03:00
Gabriel Ziegler
0e8e080894 Remove unnecessary if statement 2022-02-09 19:22:04 -03:00
Gabriel Ziegler
af6931d1de Fix bug where min_speech_duration_ms is not checked in the last speech segment
Signed-off-by: Gabriel Ziegler <gabrielziegler3@gmail.com>
2022-02-09 19:18:48 -03:00
Alexander Veysov
76687cbe25 Update README.md 2021-12-21 14:43:36 +03:00
Dimitrii Voronin
b2329fa5f2 Merge pull request #144 from snakers4/adamnsandle
Update README.md
2021-12-21 14:25:56 +03:00
Dimitrii Voronin
005886e7eb Update README.md 2021-12-21 13:25:14 +02:00
Dimitrii Voronin
f6b1294cb2 Merge pull request #143 from snakers4/adamnsandle
Adamnsandle
2021-12-21 14:02:25 +03:00
adamnsandle
2392ea33f4 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-21 11:01:25 +00:00
adamnsandle
45d72863b6 add multiple of 16k sr support 2021-12-21 11:01:07 +00:00
Alexander Veysov
f40cc128a4 Update utils_vad.py 2021-12-21 08:24:48 +03:00
Alexander Veysov
0d61e4cee1 Update README.md 2021-12-17 22:03:49 +03:00
Alexander Veysov
011268e492 Polish the copy a bit 2021-12-17 22:00:36 +03:00
Dimitrii Voronin
8ebaf139c6 Merge pull request #138 from snakers4/adamnsandle
Update README.md
2021-12-17 18:14:03 +03:00
Dimitrii Voronin
0a90316625 Update README.md 2021-12-17 17:13:33 +02:00
Dimitrii Voronin
35d8969322 Merge pull request #137 from snakers4/adamnsandle
Adamnsandle
2021-12-17 17:50:13 +03:00
adamnsandle
7c3eb8bfb5 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-17 14:48:58 +00:00
adamnsandle
74f759c8f8 add onnx vad 2021-12-17 14:48:32 +00:00
Dimitrii Voronin
5816eb08c4 Merge pull request #135 from snakers4/adamnsandle
Adamnsandle
2021-12-10 14:28:59 +03:00
adamnsandle
0feae6cbbe Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-10 11:28:25 +00:00
adamnsandle
fc0a70f42e imporved model 2021-12-10 11:28:07 +00:00
Dimitrii Voronin
13fd927b84 Merge pull request #134 from snakers4/adamnsandle
Update README.md
2021-12-10 13:57:17 +03:00
Dimitrii Voronin
124d6564a0 Update README.md 2021-12-10 12:56:59 +02:00
Dimitrii Voronin
56fa93a1c9 Merge pull request #133 from snakers4/adamnsandle
Adamnsandle
2021-12-10 13:08:54 +03:00
adamnsandle
1a93276208 fx example 2021-12-10 10:07:38 +00:00
Dimitrii Voronin
9fbd0c4c2d Merge pull request #132 from snakers4/adamnsandle
delete big files from repo
2021-12-10 12:53:52 +03:00
adamnsandle
7b05a183a3 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-10 09:53:37 +00:00
adamnsandle
f67e68efc3 delete big files from repo 2021-12-10 09:52:22 +00:00
Alexander Veysov
51b1365bb0 Merge pull request #131 from snakers4/adamnsandle
add collab record example
2021-12-10 12:20:26 +03:00
adamnsandle
79fdb55f1c add collab record example 2021-12-10 09:18:15 +00:00
Alexander Veysov
b17da75dac Merge pull request #129 from snakers4/adamnsandle
Adamnsandle
2021-12-07 15:40:07 +03:00
adamnsandle
184e384697 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-07 12:32:16 +00:00
adamnsandle
adf5d6d020 fx example 2021-12-07 12:32:04 +00:00
Alexander Veysov
41ee0f6b9f Update README.md 2021-12-07 15:26:13 +03:00
Alexander Veysov
236d250a11 Merge pull request #128 from snakers4/adamnsandle
Adamnsandle
2021-12-07 14:16:28 +03:00
adamnsandle
8794d6f835 fxx 2021-12-07 10:59:30 +00:00
adamnsandle
8f16c14066 Merge branch 'adamnsandle' of github.com:snakers4/silero-vad into adamnsandle 2021-12-07 10:55:31 +00:00
adamnsandle
f638c47595 collab fx 2021-12-07 10:54:50 +00:00
Alexander Veysov
1fad5f4ffb Merge pull request #127 from snakers4/adamnsandle
Adamnsandle
2021-12-07 13:47:01 +03:00
Dimitrii Voronin
7160ce99d3 Merge branch 'master' into adamnsandle 2021-12-07 13:29:31 +03:00
adamnsandle
8af246df49 file lowercase name 2021-12-07 10:27:12 +00:00
Dimitrii Voronin
b1142bcba4 Update README.md 2021-12-07 12:17:42 +02:00
Dimitrii Voronin
a243bd5dc8 Update README.md 2021-12-07 12:14:32 +02:00
Alexander Veysov
d4d2af5833 Update README.md 2021-12-07 13:13:39 +03:00
Alexander Veysov
469ca8a2f6 Improve copy 2021-12-07 13:11:13 +03:00
Dimitrii Voronin
8c1ae73ee7 Update README.md 2021-12-07 12:01:07 +02:00
adamnsandle
aba7862d58 Merge branch 'adamnsandle' of github.com:snakers4/silero-vad into adamnsandle 2021-12-07 09:44:52 +00:00
adamnsandle
b648546a21 get rid of soundifle dependency 2021-12-07 09:44:35 +00:00
Dimitrii Voronin
2e852d7d41 Update README.md 2021-12-07 10:59:09 +02:00
adamnsandle
044278aa12 initial 3.0 commit 2021-12-07 08:49:47 +00:00
78 changed files with 5484 additions and 12168 deletions

40
.github/workflows/python-publish.yml vendored Normal file
View File

@@ -0,0 +1,40 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
push:
tags:
- '*'
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

666
README.md
View File

@@ -1,623 +1,147 @@
[![Mailing list : test](http://img.shields.io/badge/Email-gray.svg?style=for-the-badge&logo=gmail)](mailto:hello@silero.ai) [![Mailing list : test](http://img.shields.io/badge/Telegram-blue.svg?style=for-the-badge&logo=telegram)](https://t.me/silero_speech) [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-MIT-lightgrey.svg?style=for-the-badge)](https://github.com/snakers4/silero-vad/blob/master/LICENSE)
[![Open on Torch Hub](https://img.shields.io/badge/Torch-Hub-red?logo=pytorch&style=for-the-badge)](https://pytorch.org/hub/snakers4_silero-vad_vad/)
[![Mailing list : test](http://img.shields.io/badge/Email-gray.svg?style=for-the-badge&logo=gmail)](mailto:hello@silero.ai) [![Mailing list : test](http://img.shields.io/badge/Telegram-blue.svg?style=for-the-badge&logo=telegram)](https://t.me/silero_speech) [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-MIT-lightgrey.svg?style=for-the-badge)](https://github.com/snakers4/silero-vad/blob/master/LICENSE)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb)
![header](https://user-images.githubusercontent.com/12515440/89997349-b3523080-dc94-11ea-9906-ca2e8bc50535.png)
- [Silero VAD](#silero-vad)
- [TLDR](#tldr)
- [Live Demonstration](#live-demonstration)
- [Getting Started](#getting-started)
- [Pre-trained Models](#pre-trained-models)
- [Version History](#version-history)
- [PyTorch](#pytorch)
- [VAD](#vad)
- [Number Detector](#number-detector)
- [Language Classifier](#language-classifier)
- [ONNX](#onnx)
- [VAD](#vad-1)
- [Number Detector](#number-detector-1)
- [Language Classifier](#language-classifier-1)
- [Metrics](#metrics)
- [Performance Metrics](#performance-metrics)
- [Streaming Latency](#streaming-latency)
- [Full Audio Throughput](#full-audio-throughput)
- [VAD Quality Metrics](#vad-quality-metrics)
- [FAQ](#faq)
- [VAD Parameter Fine Tuning](#vad-parameter-fine-tuning)
- [Classic way](#classic-way)
- [Adaptive way](#adaptive-way)
- [How VAD Works](#how-vad-works)
- [VAD Quality Metrics Methodology](#vad-quality-metrics-methodology)
- [How Number Detector Works](#how-number-detector-works)
- [How Language Classifier Works](#how-language-classifier-works)
- [Contact](#contact)
- [Get in Touch](#get-in-touch)
- [Commercial Inquiries](#commercial-inquiries)
- [Further reading](#further-reading)
- [Citations](#citations)
<br/>
<h1 align="center">Silero VAD</h1>
<br/>
**Silero VAD** - pre-trained enterprise-grade [Voice Activity Detector](https://en.wikipedia.org/wiki/Voice_activity_detection) (also see our [STT models](https://github.com/snakers4/silero-models)).
<br/>
<p align="center">
<img src="https://github.com/snakers4/silero-vad/assets/36505480/300bd062-4da5-4f19-9736-9c144a45d7a7" />
</p>
# Silero VAD
![image](https://user-images.githubusercontent.com/36505480/107667211-06cf2680-6c98-11eb-9ee5-37eb4596260f.png)
<details>
<summary>Real Time Example</summary>
## TLDR
https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4
**Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier.**
Enterprise-grade Speech Products made refreshingly simple (also see our [STT](https://github.com/snakers4/silero-models) models).
</details>
Currently, there are hardly any high quality / modern / free / public voice activity detectors except for WebRTC Voice Activity Detector ([link](https://github.com/wiseman/py-webrtcvad)). WebRTC though starts to show its age and it suffers from many false positives.
<br/>
Also in some cases it is crucial to be able to anonymize large-scale spoken corpora (i.e. remove personal data). Typically personal data is considered to be private / sensitive if it contains (i) a name (ii) some private ID. Name recognition is a highly subjective matter and it depends on locale and business case, but Voice Activity and Number Detection are quite general tasks.
<h2 align="center">Fast start</h2>
<br/>
**Key features:**
<details>
<summary>Dependencies</summary>
- Modern, portable;
- Low memory footprint;
- Superior metrics to WebRTC;
- Trained on huge spoken corpora and noise / sound libraries;
- Slower than WebRTC, but fast enough for IOT / edge / mobile applications;
- Unlike WebRTC (which mostly tells silence from voice), our VAD can tell voice from noise / music / silence;
System requirements to run python examples:
- `python 3.8+`
- 1G+ RAM
- not too outdated cpu
**Typical use cases:**
Dependencies:
- `torch>=1.12.0`
- `torchaudio>=0.12.0` (for I/O functionalities only)
- `onnxruntime>=1.16.1` (for ONNX model usage)
Silero VAD uses torchaudio library for audio file I/O functionalities, which are torchaudio.info, torchaudio.load, and torchaudio.save, so a proper audio backend is required:
- Option №1 - [**FFmpeg**](https://www.ffmpeg.org/) backend. `conda install -c conda-forge 'ffmpeg<7'`
- Option №2 - [**sox_io**](https://pypi.org/project/sox/) backend. `apt-get install sox`, TorchAudio is tested on libsox 14.4.2.
- Option №3 - [**soundfile**](https://pypi.org/project/soundfile/) backend. `pip install soundfile`
- Spoken corpora anonymization;
- Can be used together with WebRTC;
- Voice activity detection for IOT / edge / mobile use cases;
- Data cleaning and preparation, number and voice detection in general;
- PyTorch and ONNX can be used with a wide variety of deployment options and backends in mind;
</details>
### Live Demonstration
**Using pip**:
`pip install silero-vad`
For more information, please see [examples](https://github.com/snakers4/silero-vad/tree/master/examples).
```python3
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
model = load_silero_vad()
wav = read_audio('path_to_audio_file')
speech_timestamps = get_speech_timestamps(wav, model)
```
https://user-images.githubusercontent.com/28188499/116685087-182ff100-a9b2-11eb-927d-ed9f621226ee.mp4
https://user-images.githubusercontent.com/8079748/117580455-4622dd00-b0f8-11eb-858d-e6368ed4eada.mp4
## Getting Started
The models are small enough to be included directly into this repository. Newer models will supersede older models directly.
### Pre-trained Models
**Currently we provide the following endpoints:**
| model= | Params | Model type | Streaming | Languages | PyTorch | ONNX | Colab |
| -------------------------- | ------ | ------------------- | --------- | -------------------------- | ------------------ | ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `'silero_vad'` | 1.1M | VAD | Yes | `ru`, `en`, `de`, `es` (*) | :heavy_check_mark: | :heavy_check_mark: | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
| `'silero_vad_micro'` | 10K | VAD | Yes | `ru`, `en`, `de`, `es` (*) | :heavy_check_mark: | :heavy_check_mark: | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
| `'silero_vad_micro_8k'` | 10K | VAD | Yes | `ru`, `en`, `de`, `es` (*) | :heavy_check_mark: | :heavy_check_mark: | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
| `'silero_vad_mini'` | 100K | VAD | Yes | `ru`, `en`, `de`, `es` (*) | :heavy_check_mark: | :heavy_check_mark: | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
| `'silero_vad_mini_8k'` | 100K | VAD | Yes | `ru`, `en`, `de`, `es` (*) | :heavy_check_mark: | :heavy_check_mark: | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
| `'silero_number_detector'` | 1.1M | Number Detector | No | `ru`, `en`, `de`, `es` | :heavy_check_mark: | :heavy_check_mark: | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
| `'silero_lang_detector'` | 1.1M | Language Classifier | No | `ru`, `en`, `de`, `es` | :heavy_check_mark: | :heavy_check_mark: | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
| ~~`'silero_lang_detector_116'`~~ | ~~1.7M~~ | ~~Language Classifier~~ ||| | ||
| `'silero_lang_detector_95'` | 4.7M | Language Classifier | No | [95 languages](https://github.com/snakers4/silero-vad/blob/master/files/lang_dict_95.json) | :heavy_check_mark: | :heavy_check_mark: | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
(*) Though explicitly trained on these languages, VAD should work on any Germanic, Romance or Slavic Languages out of the box.
What models do:
- VAD - detects speech;
- Number Detector - detects spoken numbers (i.e. thirty five);
- Language Classifier - classifies utterances between language;
- Language Classifier 95 - classifies among 95 languages as well as 58 language groups (mutually intelligible languages -> same group)
### Version History
**Version history:**
| Version | Date | Comment |
| ------- | ---------- | --------------------------------------------------------------------------------------------------------------------------- |
| `v1` | 2020-12-15 | Initial release |
| `v1.1` | 2020-12-24 | better vad models compatible with chunks shorter than 250 ms |
| `v1.2` | 2020-12-30 | Number Detector added |
| `v2` | 2021-01-11 | Add Language Classifier heads (en, ru, de, es) |
| `v2.1` | 2021-02-11 | Add micro (10k params) VAD models |
| `v2.2` | 2021-03-22 | Add micro 8000 sample rate VAD models |
| `v2.3` | 2021-04-12 | Add mini (100k params) VAD models (8k and 16k sample rate) + **new** adaptive utils for full audio and single audio stream |
| `v2.4` | 2021-07-09 | Add 116 languages classifier and group classifier |
| `v2.4` | 2021-07-09 | Deleted 116 language classifier, added 95 language classifier instead (get rid of lowspoken languages for quality improvement)
|
### PyTorch
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb)
We are keeping the colab examples up-to-date, but you can manually manage your dependencies:
- `pytorch` >= 1.7.1 (there were breaking changes in `torch.hub` introduced in 1.7);
- `torchaudio` >= 0.7.2 (used only for IO and resampling, can be easily replaced);
- `soundfile` >= 0.10.3 (used as a default backend for torchaudio, can be replaced);
All of the dependencies except for PyTorch are superficial and for utils / example only. You can use any libraries / pipelines that read files and resample into 16 kHz.
#### VAD
[![Open on Torch Hub](https://img.shields.io/badge/Torch-Hub-red?logo=pytorch&style=for-the-badge)](https://pytorch.org/hub/snakers4_silero-vad_vad/)
```python
**Using torch.hub**:
```python3
import torch
torch.set_num_threads(1)
from pprint import pprint
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=True)
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
(get_speech_timestamps, _, read_audio, _, _) = utils
(get_speech_ts,
get_speech_ts_adaptive,
_, read_audio,
_, _, _) = utils
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
wav = read_audio(f'{files_dir}/en.wav')
# full audio
# get speech timestamps from full audio file
# classic way
speech_timestamps = get_speech_ts(wav, model,
num_steps=4)
pprint(speech_timestamps)
# adaptive way
speech_timestamps = get_speech_ts_adaptive(wav, model)
pprint(speech_timestamps)
wav = read_audio('path_to_audio_file')
speech_timestamps = get_speech_timestamps(wav, model)
```
#### Number Detector
<br/>
[![Open on Torch Hub](https://img.shields.io/badge/Torch-Hub-red?logo=pytorch&style=for-the-badge)](https://pytorch.org/hub/snakers4_silero-vad_number/)
<h2 align="center">Key Features</h2>
<br/>
```python
import torch
torch.set_num_threads(1)
from pprint import pprint
- **Stellar accuracy**
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_number_detector',
force_reload=True)
Silero VAD has [excellent results](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#vs-other-available-solutions) on speech detection tasks.
- **Fast**
(get_number_ts,
_, read_audio,
_, _) = utils
One audio chunk (30+ ms) [takes](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics#silero-vad-performance-metrics) less than **1ms** to be processed on a single CPU thread. Using batching or GPU can also improve performance considerably. Under certain conditions ONNX may even run up to 4-5x faster.
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
- **Lightweight**
wav = read_audio(f'{files_dir}/en_num.wav')
# full audio
# get number timestamps from full audio file
number_timestamps = get_number_ts(wav, model)
JIT model is around two megabytes in size.
pprint(number_timestamps)
```
- **General**
#### Language Classifier
##### 4 languages
Silero VAD was trained on huge corpora that include over **6000** languages and it performs well on audios from different domains with various background noise and quality levels.
[![Open on Torch Hub](https://img.shields.io/badge/Torch-Hub-red?logo=pytorch&style=for-the-badge)](https://pytorch.org/hub/snakers4_silero-vad_language/)
- **Flexible sampling rate**
```python
import torch
torch.set_num_threads(1)
from pprint import pprint
Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate).
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_lang_detector',
force_reload=True)
- **Highly Portable**
get_language, read_audio = utils
Silero VAD reaps benefits from the rich ecosystems built around **PyTorch** and **ONNX** running everywhere where these runtimes are available.
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
- **No Strings Attached**
wav = read_audio(f'{files_dir}/de.wav')
language = get_language(wav, model)
Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
pprint(language)
```
<br/>
##### 95 languages
<h2 align="center">Typical Use Cases</h2>
<br/>
[![Open on Torch Hub](https://img.shields.io/badge/Torch-Hub-red?logo=pytorch&style=for-the-badge)](https://pytorch.org/hub/snakers4_silero-vad_language/)
- Voice activity detection for IOT / edge / mobile use cases
- Data cleaning and preparation, voice detection in general
- Telephony and call-center automation, voice bots
- Voice interfaces
```python
import torch
torch.set_num_threads(1)
from pprint import pprint
<br/>
<h2 align="center">Links</h2>
<br/>
model, lang_dict, lang_group_dict, utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_lang_detector_95',
force_reload=True)
get_language_and_group, read_audio = utils
- [Examples and Dependencies](https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies)
- [Quality Metrics](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics)
- [Performance Metrics](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics)
- [Versions and Available Models](https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)
- [Further reading](https://github.com/snakers4/silero-models#further-reading)
- [FAQ](https://github.com/snakers4/silero-vad/wiki/FAQ)
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
wav = read_audio(f'{files_dir}/de.wav')
languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=2)
for i in languages:
pprint(f'Language: {i[0]} with prob {i[-1]}')
for i in language_groups:
pprint(f'Language group: {i[0]} with prob {i[-1]}')
```
### ONNX
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb)
You can run our models everywhere, where you can import the ONNX model or run ONNX runtime.
#### VAD
```python
import torch
import onnxruntime
from pprint import pprint
_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=True)
(get_speech_ts,
get_speech_ts_adaptive,
_, read_audio,
_, _, _) = utils
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
def init_onnx_model(model_path: str):
return onnxruntime.InferenceSession(model_path)
def validate_onnx(model, inputs):
with torch.no_grad():
ort_inputs = {'input': inputs.cpu().numpy()}
outs = model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
return outs[0]
model = init_onnx_model(f'{files_dir}/model.onnx')
wav = read_audio(f'{files_dir}/en.wav')
# get speech timestamps from full audio file
# classic way
speech_timestamps = get_speech_ts(wav, model, num_steps=4, run_function=validate_onnx)
pprint(speech_timestamps)
# adaptive way
speech_timestamps = get_speech_ts(wav, model, run_function=validate_onnx)
pprint(speech_timestamps)
```
#### Number Detector
```python
import torch
import onnxruntime
from pprint import pprint
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_number_detector',
force_reload=True)
(get_number_ts,
_, read_audio,
_, _) = utils
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
def init_onnx_model(model_path: str):
return onnxruntime.InferenceSession(model_path)
def validate_onnx(model, inputs):
with torch.no_grad():
ort_inputs = {'input': inputs.cpu().numpy()}
outs = model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
return outs
model = init_onnx_model(f'{files_dir}/number_detector.onnx')
wav = read_audio(f'{files_dir}/en_num.wav')
# get speech timestamps from full audio file
number_timestamps = get_number_ts(wav, model, run_function=validate_onnx)
pprint(number_timestamps)
```
#### Language Classifier
##### 4 languages
```python
import torch
import onnxruntime
from pprint import pprint
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_lang_detector',
force_reload=True)
get_language, read_audio = utils
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
def init_onnx_model(model_path: str):
return onnxruntime.InferenceSession(model_path)
def validate_onnx(model, inputs):
with torch.no_grad():
ort_inputs = {'input': inputs.cpu().numpy()}
outs = model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
return outs
model = init_onnx_model(f'{files_dir}/number_detector.onnx')
wav = read_audio(f'{files_dir}/de.wav')
language = get_language(wav, model, run_function=validate_onnx)
print(language)
```
##### 95 languages
```python
import torch
import onnxruntime
from pprint import pprint
model, lang_dict, lang_group_dict, utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_lang_detector_95',
force_reload=True)
get_language_and_group, read_audio = utils
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
def init_onnx_model(model_path: str):
return onnxruntime.InferenceSession(model_path)
def validate_onnx(model, inputs):
with torch.no_grad():
ort_inputs = {'input': inputs.cpu().numpy()}
outs = model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
return outs
model = init_onnx_model(f'{files_dir}/lang_classifier_95.onnx')
wav = read_audio(f'{files_dir}/de.wav')
languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=2, run_function=validate_onnx)
for i in languages:
pprint(f'Language: {i[0]} with prob {i[-1]}')
for i in language_groups:
pprint(f'Language group: {i[0]} with prob {i[-1]}')
```
[![Open on Torch Hub](https://img.shields.io/badge/Torch-Hub-red?logo=pytorch&style=for-the-badge)](https://pytorch.org/hub/snakers4_silero-vad_language/)
## Metrics
### Performance Metrics
All speed test were run on AMD Ryzen Threadripper 3960X using only 1 thread:
```
torch.set_num_threads(1) # pytorch
ort_session.intra_op_num_threads = 1 # onnx
ort_session.inter_op_num_threads = 1 # onnx
```
#### Streaming Latency
Streaming latency depends on 2 variables:
- **num_steps** - number of windows to split each audio chunk into. Our post-processing class keeps previous chunk in memory (250 ms), so new chunk (also 250 ms) is appended to it. The resulting big chunk (500 ms) is split into **num_steps** overlapping windows, each 250 ms long.
- **number of audio streams**
So **batch size** for streaming is **num_steps * number of audio streams**. Time between receiving new audio chunks and getting results is shown in picture:
| Batch size | Pytorch model time, ms | Onnx model time, ms |
| :--------: | :--------------------: | :-----------------: |
| **2** | 9 | 2 |
| **4** | 11 | 4 |
| **8** | 14 | 7 |
| **16** | 19 | 12 |
| **40** | 36 | 29 |
| **80** | 64 | 55 |
| **120** | 96 | 85 |
| **200** | 157 | 137 |
#### Full Audio Throughput
**RTS** (seconds of audio processed per second, real time speed, or 1 / RTF) for full audio processing depends on **num_steps** (see previous paragraph) and **batch size** (bigger is better).
| Batch size | num_steps | Pytorch model RTS | Onnx model RTS |
| :--------: | :-------: | :---------------: | :------------: |
| **40** | **4** | 68 | 86 |
| **40** | **8** | 34 | 43 |
| **80** | **4** | 78 | 91 |
| **80** | **8** | 39 | 45 |
| **120** | **4** | 78 | 88 |
| **120** | **8** | 39 | 44 |
| **200** | **4** | 80 | 91 |
| **200** | **8** | 40 | 46 |
### VAD Quality Metrics
We use random 250 ms audio chunks for validation. Speech to non-speech ratio among chunks is about ~50/50 (i.e. balanced). Speech chunks are sampled from real audios in four different languages (English, Russian, Spanish, German), then random background noise is added to some of them (~40%).
Since our VAD (only VAD, other networks are more flexible) was trained on chunks of the same length, model's output is just one float from 0 to 1 - **speech probability**. We use speech probabilities as thresholds for precision-recall curve. This can be extended to 100 - 150 ms. Less than 100 - 150 ms cannot be distinguished as speech with confidence.
[Webrtc](https://github.com/wiseman/py-webrtcvad) splits audio into frames, each frame has corresponding number (0 **or** 1). We use 30ms frames for webrtc, so each 250 ms chunk is split into 8 frames, their **mean** value is used as a threshold for plot.
[Auditok](https://github.com/amsehili/auditok) - logic same as Webrtc, but we use 50ms frames.
![image](https://user-images.githubusercontent.com/36505480/107667211-06cf2680-6c98-11eb-9ee5-37eb4596260f.png)
## FAQ
### VAD Parameter Fine Tuning
#### Classic way
**This is straightforward classic method `get_speech_ts` where thresholds (`trig_sum` and `neg_trig_sum`) are specified by users**
- Among others, we provide several [utils](https://github.com/snakers4/silero-vad/blob/8b28767292b424e3e505c55f15cd3c4b91e4804b/utils.py#L52-L59) to simplify working with VAD;
- We provide sensible basic hyper-parameters that work for us, but your case can be different;
- `trig_sum` - overlapping windows are used for each audio chunk, trig sum defines average probability among those windows for switching into triggered state (speech state);
- `neg_trig_sum` - same as `trig_sum`, but for switching from triggered to non-triggered state (non-speech)
- `num_steps` - nubmer of overlapping windows to split audio chunk into (we recommend 4 or 8)
- `num_samples_per_window` - number of samples in each window, our models were trained using `4000` samples (250 ms) per window, so this is preferable value (lesser values reduce [quality](https://github.com/snakers4/silero-vad/issues/2#issuecomment-750840434));
- `min_speech_samples` - minimum speech chunk duration in samples
- `min_silence_samples` - minimum silence duration in samples between to separate speech chunks
Optimal parameters may vary per domain, but we provided a tiny tool to learn the best parameters. You can invoke `speech_timestamps` with visualize_probs=True (`pandas` required):
```
speech_timestamps = get_speech_ts(wav, model,
num_samples_per_window=4000,
num_steps=4,
visualize_probs=True)
```
#### Adaptive way
**Adaptive algorithm (`get_speech_ts_adaptive`) automatically selects thresholds (`trig_sum` and `neg_trig_sum`) based on median speech probabilities over the whole audio, SOME ARGUMENTS VARY FROM THE CLASSIC WAY FUNCTION ARGUMENTS**
- `batch_size` - batch size to feed to silero VAD (default - `200`)
- `step` - step size in samples, (default - `500`) (`num_samples_per_window` / `num_steps` from classic method)
- `num_samples_per_window` - number of samples in each window, our models were trained using `4000` samples (250 ms) per window, so this is preferable value (lesser values reduce [quality](https://github.com/snakers4/silero-vad/issues/2#issuecomment-750840434));
- `min_speech_samples` - minimum speech chunk duration in samples (default - `10000`)
- `min_silence_samples` - minimum silence duration in samples between to separate speech chunks (default - `4000`)
- `speech_pad_samples` - widen speech by this amount of samples each side (default - `2000`)
```
speech_timestamps = get_speech_ts_adaptive(wav, model,
num_samples_per_window=4000,
step=500,
visualize_probs=True)
```
The chart should looks something like this:
![image](https://user-images.githubusercontent.com/12515440/106242896-79142580-6219-11eb-9add-fa7195d6fd26.png)
With this particular example you can try shorter chunks (`num_samples_per_window=1600`), but this results in too much noise:
![image](https://user-images.githubusercontent.com/12515440/106243014-a8c32d80-6219-11eb-8374-969f372807f1.png)
### How VAD Works
- Audio is split into 250 ms chunks (you can choose any chunk size, but quality with chunks shorter than 100ms will suffer and there will be more false positives and "unnatural" pauses);
- VAD keeps record of a previous chunk (or zeros at the beginning of the stream);
- Then this 500 ms audio (250 ms + 250 ms) is split into N (typically 4 or 8) windows and the model is applied to this window batch. Each window is 250 ms long (naturally, windows overlap);
- Then probability is averaged across these windows;
- Though typically pauses in speech are 300 ms+ or longer (pauses less than 200-300ms are typically not meaninful), it is hard to confidently classify speech vs noise / music on very short chunks (i.e. 30 - 50ms);
- ~~We are working on lifting this limitation, so that you can use 100 - 125ms windows~~;
### VAD Quality Metrics Methodology
Please see [Quality Metrics](#quality-metrics)
### How Number Detector Works
- It is recommended to split long audio into short ones (< 15s) and apply model on each of them;
- Number Detector can classify if the whole audio contains a number, or if each audio frame contains a number;
- Audio is splitted into frames in a certain way, so, having a per-frame output, we can restore timing bounds for a numbers with an accuracy of about 0.2s;
### How Language Classifier Works
- **99%** validation accuracy
- Language classifier was trained using audio samples in 4 languages: **Russian**, **English**, **Spanish**, **German**
- More languages TBD
- Arbitrary audio length can be used, although network was trained using audio shorter than 15 seconds
### How Language Classifier 95 Works
- **85%** validation accuracy among 95 languages, **90%** validation accuracy among [58 language groups](https://github.com/snakers4/silero-vad/blob/master/files/lang_group_dict_95.json)
- Language classifier 95 was trained using audio samples in [95 languages](https://github.com/snakers4/silero-vad/blob/master/files/lang_dict_95.json)
- Arbitrary audio length can be used, although network was trained using audio shorter than 20 seconds
## Contact
### Get in Touch
<br/>
<h2 align="center">Get In Touch</h2>
<br/>
Try our models, create an [issue](https://github.com/snakers4/silero-vad/issues/new), start a [discussion](https://github.com/snakers4/silero-vad/discussions/new), join our telegram [chat](https://t.me/silero_speech), [email](mailto:hello@silero.ai) us, read our [news](https://t.me/silero_news).
### Commercial Inquiries
Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for relevant information and [email](mailto:hello@silero.ai) us directly.
Please see our [wiki](https://github.com/snakers4/silero-models/wiki) and [tiers](https://github.com/snakers4/silero-models/wiki/Licensing-and-Tiers) for relevant information and [email](mailto:hello@silero.ai) us directly.
## Further reading
### General
- Silero-models - https://github.com/snakers4/silero-models
- Nice [thread](https://github.com/snakers4/silero-vad/discussions/16#discussioncomment-305830) in discussions
### English
- STT:
- Towards an Imagenet Moment For Speech-To-Text - [link](https://thegradient.pub/towards-an-imagenet-moment-for-speech-to-text/)
- A Speech-To-Text Practitioners Criticisms of Industry and Academia - [link](https://thegradient.pub/a-speech-to-text-practitioners-criticisms-of-industry-and-academia/)
- Modern Google-level STT Models Released - [link](https://habr.com/ru/post/519562/)
- TTS:
- High-Quality Text-to-Speech Made Accessible, Simple and Fast - [link](https://habr.com/ru/post/549482/)
- VAD:
- Modern Portable Voice Activity Detector Released - [link](https://habr.com/ru/post/537276/)
- Text Enhancement:
- We have published a model for text repunctuation and recapitalization for four languages - [link](https://habr.com/ru/post/581960/)
### Chinese
- STT:
- 迈向语音识别领域的 ImageNet 时刻 - [link](https://www.infoq.cn/article/4u58WcFCs0RdpoXev1E2)
- 语音领域学术界和工业界的七宗罪 - [link](https://www.infoq.cn/article/lEe6GCRjF1CNToVITvNw)
### Russian
- STT
- Последние обновления моделей распознавания речи из Silero Models - [link](https://habr.com/ru/post/577630/)
- Сжимаем трансформеры: простые, универсальные и прикладные способы елать их компактными и быстрыми - [link](https://habr.com/ru/post/563778/)
- Ультимативное сравнение систем распознавания речи: Ashmanov, Google, Sber, Silero, Tinkoff, Yandex - [link](https://habr.com/ru/post/559640/)
- Мы опубликовали современные STT модели сравнимые по качеству с Google - [link](https://habr.com/ru/post/519564/)
- Понижаем барьеры на вход в распознавание речи - [link](https://habr.com/ru/post/494006/)
- Огромный открытый датасет русской речи версия 1.0 - [link](https://habr.com/ru/post/474462/)
- Насколько Быстрой Можно Сделать Систему STT? - [link](https://habr.com/ru/post/531524/)
- Наша система Speech-To-Text - [link](https://www.silero.ai/tag/our-speech-to-text/)
- Speech To Text - [link](https://www.silero.ai/tag/speech-to-text/)
- TTS:
- Мы сделали наш публичный синтез речи еще лучше - [link](https://habr.com/ru/post/563484/)
- Мы Опубликовали Качественный, Простой, Доступный и Быстрый Синтез Речи - [link](https://habr.com/ru/post/549480/)
- VAD:
- Модели для Детекции Речи, Чисел и Распознавания Языков - [link](https://www.silero.ai/vad-lang-classifier-number-detector/)
- Мы опубликовали современный Voice Activity Detector и не только -[link](https://habr.com/ru/post/537274/)
- Text Enhancement:
- Мы опубликовали модель, расставляющую знаки препинания и заглавные буквы в тексте на четырех языках - [link](https://habr.com/ru/post/581946/)
## Citations
**Citations**
```
@misc{Silero VAD,
author = {Silero Team},
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
year = {2021},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/snakers4/silero-vad}},
@@ -625,3 +149,13 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) and [tiers
email = {hello@silero.ai}
}
```
<br/>
<h2 align="center">Examples and VAD-based Community Apps</h2>
<br/>
- Example of VAD ONNX Runtime model usage in [C++](https://github.com/snakers4/silero-vad/tree/master/examples/cpp)
- Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web
- [Rust](https://github.com/snakers4/silero-vad/tree/master/examples/rust-example), [Go](https://github.com/snakers4/silero-vad/tree/master/examples/go), [Java](https://github.com/snakers4/silero-vad/tree/master/examples/java-example) and [other](https://github.com/snakers4/silero-vad/tree/master/examples) examples

84
datasets/README.md Normal file
View File

@@ -0,0 +1,84 @@
# Датасет Silero-VAD
> Датасет создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный
интеллект» национальной программы «Цифровая экономика Российской Федерации».
По ссылкам ниже представлены `.feather` файлы, содержащие размеченные с помощью Silero VAD открытые наборы аудиоданных, а также короткое описание каждого набора данных с примерами загрузки. `.feather` файлы можно открыть с помощью библиотеки `pandas`:
```python3
import pandas as pd
dataframe = pd.read_feather(PATH_TO_FEATHER_FILE)
```
Каждый `.feather` файл с разметкой содержит следующие колонки:
- `speech_timings` - разметка данного аудио. Это список, содержащий словари вида `{'start': START_SECOND, 'end': END_SECOND}`, где `START_SECOND` и `END_SECOND` - время начала и конца речи в секундах. Количество данных словарей равно количеству речевых аудио отрывков, найденных в данном аудио;
- `language` - ISO код языка данного аудио.
Колонки, содержащие информацию о загрузке аудио файла различаются и описаны для каждого набора данных ниже.
**Все данные размечены при временной дискретизации в ~30 миллисекунд (`num_samples` - 512)**
| Название | Число часов | Число языков | Ссылка | Лицензия | md5sum |
|----------------------|-------------|-------------|--------|----------|----------|
| **Bible.is** | 53,138 | 1,596 | [URL](https://live.bible.is/) | [Уникальная](https://live.bible.is/terms) | ea404eeaf2cd283b8223f63002be11f9 |
| **globalrecordings.net** | 9,743 | 6,171[^1] | [URL](https://globalrecordings.net/en) | CC BY-NC-SA 4.0 | 3c5c0f31b0abd9fe94ddbe8b1e2eb326 |
| **VoxLingua107** | 6,628 | 107 | [URL](https://bark.phon.ioc.ee/voxlingua107/) | CC BY 4.0 | 5dfef33b4d091b6d399cfaf3d05f2140 |
| **Common Voice** | 30,329 | 120 | [URL](https://commonvoice.mozilla.org/en/datasets) | CC0 | 5e30a85126adf74a5fd1496e6ac8695d |
| **MLS** | 50,709 | 8 | [URL](https://www.openslr.org/94/) | CC BY 4.0 | a339d0e94bdf41bba3c003756254ac4e |
| **Итого** | **150,547** | **6,171+** | | | |
## Bible.is
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/BibleIs.feather)
- Колонка `audio_link` содержит ссылки на конкретные аудио файлы.
## globalrecordings.net
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/globalrecordings.feather)
- Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
- Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
``Количество уникальных ISO кодов данного датасета не совпадает с фактическим количеством представленных языков, т.к некоторые близкие языки могут кодироваться одним и тем же ISO кодом.``
## VoxLingua107
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/VoxLingua107.feather)
- Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
- Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
## Common Voice
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/common_voice.feather)
Этот датасет невозможно скачать по статичным ссылкам. Для загрузки необходимо перейти по [ссылке](https://commonvoice.mozilla.org/en/datasets) и, получив доступ в соответствующей форме, скачать архивы для каждого доступного языка. Внимание! Представленная разметка актуальна для версии исходного датасета `Common Voice Corpus 16.1`.
- Колонка `audio_path` содержит уникальные названия `.mp3` файлов, полученных после скачивания соответствующего датасета.
## MLS
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/MLS.feather)
- Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
- Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
## Лицензия
Данный датасет распространяется под [лицензией](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) `CC BY-NC-SA 4.0`.
## Цитирование
```
@misc{Silero VAD Dataset,
author = {Silero Team},
title = {Silero-VAD Dataset: a large public Internet-scale dataset for voice activity detection for 6000+ languages},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/snakers4/silero-vad/datasets/README.md}},
email = {hello@silero.ai}
}
```
[^1]: ``Количество уникальных ISO кодов данного датасета не совпадает с фактическим количеством представленных языков, т.к некоторые близкие языки могут кодироваться одним и тем же ISO кодом.``

View File

@@ -0,0 +1,241 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "bccAucKjnPHm"
},
"source": [
"### Dependencies and inputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cSih95WFmwgi"
},
"outputs": [],
"source": [
"!pip -q install pydub\n",
"from google.colab import output\n",
"from base64 import b64decode, b64encode\n",
"from io import BytesIO\n",
"import numpy as np\n",
"from pydub import AudioSegment\n",
"from IPython.display import HTML, display\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import moviepy.editor as mpe\n",
"from matplotlib.animation import FuncAnimation, FFMpegWriter\n",
"import matplotlib\n",
"matplotlib.use('Agg')\n",
"\n",
"torch.set_num_threads(1)\n",
"\n",
"model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=True)\n",
"\n",
"def int2float(sound):\n",
" abs_max = np.abs(sound).max()\n",
" sound = sound.astype('float32')\n",
" if abs_max > 0:\n",
" sound *= 1/32768\n",
" sound = sound.squeeze()\n",
" return sound\n",
"\n",
"AUDIO_HTML = \"\"\"\n",
"<script>\n",
"var my_div = document.createElement(\"DIV\");\n",
"var my_p = document.createElement(\"P\");\n",
"var my_btn = document.createElement(\"BUTTON\");\n",
"var t = document.createTextNode(\"Press to start recording\");\n",
"\n",
"my_btn.appendChild(t);\n",
"//my_p.appendChild(my_btn);\n",
"my_div.appendChild(my_btn);\n",
"document.body.appendChild(my_div);\n",
"\n",
"var base64data = 0;\n",
"var reader;\n",
"var recorder, gumStream;\n",
"var recordButton = my_btn;\n",
"\n",
"var handleSuccess = function(stream) {\n",
" gumStream = stream;\n",
" var options = {\n",
" //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n",
" mimeType : 'audio/webm;codecs=opus'\n",
" //mimeType : 'audio/webm;codecs=pcm'\n",
" }; \n",
" //recorder = new MediaRecorder(stream, options);\n",
" recorder = new MediaRecorder(stream);\n",
" recorder.ondataavailable = function(e) { \n",
" var url = URL.createObjectURL(e.data);\n",
" // var preview = document.createElement('audio');\n",
" // preview.controls = true;\n",
" // preview.src = url;\n",
" // document.body.appendChild(preview);\n",
"\n",
" reader = new FileReader();\n",
" reader.readAsDataURL(e.data); \n",
" reader.onloadend = function() {\n",
" base64data = reader.result;\n",
" //console.log(\"Inside FileReader:\" + base64data);\n",
" }\n",
" };\n",
" recorder.start();\n",
" };\n",
"\n",
"recordButton.innerText = \"Recording... press to stop\";\n",
"\n",
"navigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);\n",
"\n",
"\n",
"function toggleRecording() {\n",
" if (recorder && recorder.state == \"recording\") {\n",
" recorder.stop();\n",
" gumStream.getAudioTracks()[0].stop();\n",
" recordButton.innerText = \"Saving recording...\"\n",
" }\n",
"}\n",
"\n",
"// https://stackoverflow.com/a/951057\n",
"function sleep(ms) {\n",
" return new Promise(resolve => setTimeout(resolve, ms));\n",
"}\n",
"\n",
"var data = new Promise(resolve=>{\n",
"//recordButton.addEventListener(\"click\", toggleRecording);\n",
"recordButton.onclick = ()=>{\n",
"toggleRecording()\n",
"\n",
"sleep(2000).then(() => {\n",
" // wait 2000ms for the data to be available...\n",
" // ideally this should use something like await...\n",
" //console.log(\"Inside data:\" + base64data)\n",
" resolve(base64data.toString())\n",
"\n",
"});\n",
"\n",
"}\n",
"});\n",
" \n",
"</script>\n",
"\"\"\"\n",
"\n",
"def record(sec=10):\n",
" display(HTML(AUDIO_HTML))\n",
" s = output.eval_js(\"data\")\n",
" b = b64decode(s.split(',')[1])\n",
" audio = AudioSegment.from_file(BytesIO(b))\n",
" audio.export('test.mp3', format='mp3')\n",
" audio = audio.set_channels(1)\n",
" audio = audio.set_frame_rate(16000)\n",
" audio_float = int2float(np.array(audio.get_array_of_samples()))\n",
" audio_tens = torch.tensor(audio_float )\n",
" return audio_tens\n",
"\n",
"def make_animation(probs, audio_duration, interval=40):\n",
" fig = plt.figure(figsize=(16, 9))\n",
" ax = plt.axes(xlim=(0, audio_duration), ylim=(0, 1.02))\n",
" line, = ax.plot([], [], lw=2)\n",
" x = [i / 16000 * 512 for i in range(len(probs))]\n",
" plt.xlabel('Time, seconds', fontsize=16)\n",
" plt.ylabel('Speech Probability', fontsize=16)\n",
"\n",
" def init():\n",
" plt.fill_between(x, probs, color='#064273')\n",
" line.set_data([], [])\n",
" line.set_color('#990000')\n",
" return line,\n",
"\n",
" def animate(i):\n",
" x = i * interval / 1000 - 0.04\n",
" y = np.linspace(0, 1.02, 2)\n",
" \n",
" line.set_data(x, y)\n",
" line.set_color('#990000')\n",
" return line,\n",
"\n",
" anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=audio_duration / (interval / 1000))\n",
"\n",
" f = r\"animation.mp4\" \n",
" writervideo = FFMpegWriter(fps=1000/interval) \n",
" anim.save(f, writer=writervideo)\n",
" plt.close('all')\n",
"\n",
"def combine_audio(vidname, audname, outname, fps=25): \n",
" my_clip = mpe.VideoFileClip(vidname, verbose=False)\n",
" audio_background = mpe.AudioFileClip(audname)\n",
" final_clip = my_clip.set_audio(audio_background)\n",
" final_clip.write_videofile(outname,fps=fps,verbose=False)\n",
"\n",
"def record_make_animation():\n",
" tensor = record()\n",
"\n",
" print('Calculating probabilities...')\n",
" speech_probs = []\n",
" window_size_samples = 512\n",
" for i in range(0, len(tensor), window_size_samples):\n",
" if len(tensor[i: i+ window_size_samples]) < window_size_samples:\n",
" break\n",
" speech_prob = model(tensor[i: i+ window_size_samples], 16000).item()\n",
" speech_probs.append(speech_prob)\n",
" model.reset_states()\n",
" print('Making animation...')\n",
" make_animation(speech_probs, len(tensor) / 16000)\n",
"\n",
" print('Merging your voice with animation...')\n",
" combine_audio('animation.mp4', 'test.mp3', 'merged.mp4')\n",
" print('Done!')\n",
" mp4 = open('merged.mp4','rb').read()\n",
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
" display(HTML(\"\"\"\n",
" <video width=800 controls>\n",
" <source src=\"%s\" type=\"video/mp4\">\n",
" </video>\n",
" \"\"\" % data_url))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IFVs3GvTnpB1"
},
"source": [
"## Record example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5EBjrTwiqAaQ"
},
"outputs": [],
"source": [
"record_make_animation()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"bccAucKjnPHm"
],
"name": "Untitled2.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

43
examples/cpp/README.md Normal file
View File

@@ -0,0 +1,43 @@
# Stream example in C++
Here's a simple example of the vad model in c++ onnxruntime.
## Requirements
Code are tested in the environments bellow, feel free to try others.
- WSL2 + Debian-bullseye (docker)
- gcc 12.2.0
- onnxruntime-linux-x64-1.12.1
## Usage
1. Install gcc 12.2.0, or just pull the docker image with `docker pull gcc:12.2.0-bullseye`
2. Install onnxruntime-linux-x64-1.12.1
- Download lib onnxruntime:
`wget https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz`
- Unzip. Assume the path is `/root/onnxruntime-linux-x64-1.12.1`
3. Modify wav path & Test configs in main function
`wav::WavReader wav_reader("${path_to_your_wav_file}");`
test sample rate, frame per ms, threshold...
4. Build with gcc and run
```bash
# Build
g++ silero-vad-onnx.cpp -I /root/onnxruntime-linux-x64-1.12.1/include/ -L /root/onnxruntime-linux-x64-1.12.1/lib/ -lonnxruntime -Wl,-rpath,/root/onnxruntime-linux-x64-1.12.1/lib/ -o test
# Run
./test
```

View File

@@ -0,0 +1,478 @@
#include <iostream>
#include <vector>
#include <sstream>
#include <cstring>
#include <limits>
#include <chrono>
#include <memory>
#include <string>
#include <stdexcept>
#include <iostream>
#include <string>
#include "onnxruntime_cxx_api.h"
#include "wav.h"
#include <cstdio>
#include <cstdarg>
#if __cplusplus < 201703L
#include <memory>
#endif
//#define __DEBUG_SPEECH_PROB___
class timestamp_t
{
public:
int start;
int end;
// default + parameterized constructor
timestamp_t(int start = -1, int end = -1)
: start(start), end(end)
{
};
// assignment operator modifies object, therefore non-const
timestamp_t& operator=(const timestamp_t& a)
{
start = a.start;
end = a.end;
return *this;
};
// equality comparison. doesn't modify object. therefore const.
bool operator==(const timestamp_t& a) const
{
return (start == a.start && end == a.end);
};
std::string c_str()
{
//return std::format("timestamp {:08d}, {:08d}", start, end);
return format("{start:%08d,end:%08d}", start, end);
};
private:
std::string format(const char* fmt, ...)
{
char buf[256];
va_list args;
va_start(args, fmt);
const auto r = std::vsnprintf(buf, sizeof buf, fmt, args);
va_end(args);
if (r < 0)
// conversion failed
return {};
const size_t len = r;
if (len < sizeof buf)
// we fit in the buffer
return { buf, len };
#if __cplusplus >= 201703L
// C++17: Create a string and write to its underlying array
std::string s(len, '\0');
va_start(args, fmt);
std::vsnprintf(s.data(), len + 1, fmt, args);
va_end(args);
return s;
#else
// C++11 or C++14: We need to allocate scratch memory
auto vbuf = std::unique_ptr<char[]>(new char[len + 1]);
va_start(args, fmt);
std::vsnprintf(vbuf.get(), len + 1, fmt, args);
va_end(args);
return { vbuf.get(), len };
#endif
};
};
class VadIterator
{
private:
// OnnxRuntime resources
Ort::Env env;
Ort::SessionOptions session_options;
std::shared_ptr<Ort::Session> session = nullptr;
Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
private:
void init_engine_threads(int inter_threads, int intra_threads)
{
// The method should be called in each thread/proc in multi-thread/proc work
session_options.SetIntraOpNumThreads(intra_threads);
session_options.SetInterOpNumThreads(inter_threads);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
};
void init_onnx_model(const std::wstring& model_path)
{
// Init threads = 1 for
init_engine_threads(1, 1);
// Load model
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
};
void reset_states()
{
// Call reset before each audio start
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
triggered = false;
temp_end = 0;
current_sample = 0;
prev_end = next_start = 0;
speeches.clear();
current_speech = timestamp_t();
};
void predict(const std::vector<float> &data)
{
// Infer
// Create ort tensors
input.assign(data.begin(), data.end());
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
memory_info, input.data(), input.size(), input_node_dims, 2);
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
memory_info, _state.data(), _state.size(), state_node_dims, 3);
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
// Clear and add inputs
ort_inputs.clear();
ort_inputs.emplace_back(std::move(input_ort));
ort_inputs.emplace_back(std::move(state_ort));
ort_inputs.emplace_back(std::move(sr_ort));
// Infer
ort_outputs = session->Run(
Ort::RunOptions{nullptr},
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
output_node_names.data(), output_node_names.size());
// Output probability & update h,c recursively
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
// Push forward sample index
current_sample += window_size_samples;
// Reset temp_end when > threshold
if ((speech_prob >= threshold))
{
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample- window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
if (temp_end != 0)
{
temp_end = 0;
if (next_start < prev_end)
next_start = current_sample - window_size_samples;
}
if (triggered == false)
{
triggered = true;
current_speech.start = current_sample - window_size_samples;
}
return;
}
if (
(triggered == true)
&& ((current_sample - current_speech.start) > max_speech_samples)
) {
if (prev_end > 0) {
current_speech.end = prev_end;
speeches.push_back(current_speech);
current_speech = timestamp_t();
// previously reached silence(< neg_thres) and is still not speech(< thres)
if (next_start < prev_end)
triggered = false;
else{
current_speech.start = next_start;
}
prev_end = 0;
next_start = 0;
temp_end = 0;
}
else{
current_speech.end = current_sample;
speeches.push_back(current_speech);
current_speech = timestamp_t();
prev_end = 0;
next_start = 0;
temp_end = 0;
triggered = false;
}
return;
}
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold))
{
if (triggered) {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
}
else {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
}
return;
}
// 4) End
if ((speech_prob < (threshold - 0.15)))
{
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point.
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
if (triggered == true)
{
if (temp_end == 0)
{
temp_end = current_sample;
}
if (current_sample - temp_end > min_silence_samples_at_max_speech)
prev_end = temp_end;
// a. silence < min_slience_samples, continue speaking
if ((current_sample - temp_end) < min_silence_samples)
{
}
// b. silence >= min_slience_samples, end speaking
else
{
current_speech.end = temp_end;
if (current_speech.end - current_speech.start > min_speech_samples)
{
speeches.push_back(current_speech);
current_speech = timestamp_t();
prev_end = 0;
next_start = 0;
temp_end = 0;
triggered = false;
}
}
}
else {
// may first windows see end state.
}
return;
}
};
public:
void process(const std::vector<float>& input_wav)
{
reset_states();
audio_length_samples = input_wav.size();
for (int j = 0; j < audio_length_samples; j += window_size_samples)
{
if (j + window_size_samples > audio_length_samples)
break;
std::vector<float> r{ &input_wav[0] + j, &input_wav[0] + j + window_size_samples };
predict(r);
}
if (current_speech.start >= 0) {
current_speech.end = audio_length_samples;
speeches.push_back(current_speech);
current_speech = timestamp_t();
prev_end = 0;
next_start = 0;
temp_end = 0;
triggered = false;
}
};
void process(const std::vector<float>& input_wav, std::vector<float>& output_wav)
{
process(input_wav);
collect_chunks(input_wav, output_wav);
}
void collect_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
{
output_wav.clear();
for (int i = 0; i < speeches.size(); i++) {
#ifdef __DEBUG_SPEECH_PROB___
std::cout << speeches[i].c_str() << std::endl;
#endif //#ifdef __DEBUG_SPEECH_PROB___
std::vector<float> slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]);
output_wav.insert(output_wav.end(),slice.begin(),slice.end());
}
};
const std::vector<timestamp_t> get_speech_timestamps() const
{
return speeches;
}
void drop_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
{
output_wav.clear();
int current_start = 0;
for (int i = 0; i < speeches.size(); i++) {
std::vector<float> slice(&input_wav[current_start],&input_wav[speeches[i].start]);
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
current_start = speeches[i].end;
}
std::vector<float> slice(&input_wav[current_start], &input_wav[input_wav.size()]);
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
};
private:
// model config
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
int sample_rate; //Assign when init support 16000 or 8000
int sr_per_ms; // Assign when init, support 8 or 16
float threshold;
int min_silence_samples; // sr_per_ms * #ms
int min_silence_samples_at_max_speech; // sr_per_ms * #98
int min_speech_samples; // sr_per_ms * #ms
float max_speech_samples;
int speech_pad_samples; // usually a
int audio_length_samples;
// model states
bool triggered = false;
unsigned int temp_end = 0;
unsigned int current_sample = 0;
// MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
int prev_end;
int next_start = 0;
//Output timestamp
std::vector<timestamp_t> speeches;
timestamp_t current_speech;
// Onnx model
// Inputs
std::vector<Ort::Value> ort_inputs;
std::vector<const char *> input_node_names = {"input", "state", "sr"};
std::vector<float> input;
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
std::vector<float> _state;
std::vector<int64_t> sr;
int64_t input_node_dims[2] = {};
const int64_t state_node_dims[3] = {2, 1, 128};
const int64_t sr_node_dims[1] = {1};
// Outputs
std::vector<Ort::Value> ort_outputs;
std::vector<const char *> output_node_names = {"output", "stateN"};
public:
// Construction
VadIterator(const std::wstring ModelPath,
int Sample_rate = 16000, int windows_frame_size = 32,
float Threshold = 0.5, int min_silence_duration_ms = 0,
int speech_pad_ms = 32, int min_speech_duration_ms = 32,
float max_speech_duration_s = std::numeric_limits<float>::infinity())
{
init_onnx_model(ModelPath);
threshold = Threshold;
sample_rate = Sample_rate;
sr_per_ms = sample_rate / 1000;
window_size_samples = windows_frame_size * sr_per_ms;
min_speech_samples = sr_per_ms * min_speech_duration_ms;
speech_pad_samples = sr_per_ms * speech_pad_ms;
max_speech_samples = (
sample_rate * max_speech_duration_s
- window_size_samples
- 2 * speech_pad_samples
);
min_silence_samples = sr_per_ms * min_silence_duration_ms;
min_silence_samples_at_max_speech = sr_per_ms * 98;
input.resize(window_size_samples);
input_node_dims[0] = 1;
input_node_dims[1] = window_size_samples;
_state.resize(size_state);
sr.resize(1);
sr[0] = sample_rate;
};
};
int main()
{
std::vector<timestamp_t> stamps;
// Read wav
wav::WavReader wav_reader("recorder.wav"); //16000,1,32float
std::vector<float> input_wav(wav_reader.num_samples());
std::vector<float> output_wav;
for (int i = 0; i < wav_reader.num_samples(); i++)
{
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
}
// ===== Test configs =====
std::wstring path = L"silero_vad.onnx";
VadIterator vad(path);
// ==============================================
// ==== = Example 1 of full function =====
// ==============================================
vad.process(input_wav);
// 1.a get_speech_timestamps
stamps = vad.get_speech_timestamps();
for (int i = 0; i < stamps.size(); i++) {
std::cout << stamps[i].c_str() << std::endl;
}
// 1.b collect_chunks output wav
vad.collect_chunks(input_wav, output_wav);
// 1.c drop_chunks output wav
vad.drop_chunks(input_wav, output_wav);
// ==============================================
// ===== Example 2 of simple full function =====
// ==============================================
vad.process(input_wav, output_wav);
stamps = vad.get_speech_timestamps();
for (int i = 0; i < stamps.size(); i++) {
std::cout << stamps[i].c_str() << std::endl;
}
// ==============================================
// ===== Example 3 of full function =====
// ==============================================
for(int i = 0; i<2; i++)
vad.process(input_wav, output_wav);
}

235
examples/cpp/wav.h Normal file
View File

@@ -0,0 +1,235 @@
// Copyright (c) 2016 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
// #include "utils/log.h"
namespace wav {
struct WavHeader {
char riff[4]; // "riff"
unsigned int size;
char wav[4]; // "WAVE"
char fmt[4]; // "fmt "
unsigned int fmt_size;
uint16_t format;
uint16_t channels;
unsigned int sample_rate;
unsigned int bytes_per_second;
uint16_t block_size;
uint16_t bit;
char data[4]; // "data"
unsigned int data_size;
};
class WavReader {
public:
WavReader() : data_(nullptr) {}
explicit WavReader(const std::string& filename) { Open(filename); }
bool Open(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
if (NULL == fp) {
std::cout << "Error in read " << filename;
return false;
}
WavHeader header;
fread(&header, 1, sizeof(header), fp);
if (header.fmt_size < 16) {
printf("WaveData: expect PCM format data "
"to have fmt chunk of at least size 16.\n");
return false;
} else if (header.fmt_size > 16) {
int offset = 44 - 8 + header.fmt_size - 16;
fseek(fp, offset, SEEK_SET);
fread(header.data, 8, sizeof(char), fp);
}
// check "riff" "WAVE" "fmt " "data"
// Skip any sub-chunks between "fmt" and "data". Usually there will
// be a single "fact" sub chunk, but on Windows there can also be a
// "list" sub chunk.
while (0 != strncmp(header.data, "data", 4)) {
// We will just ignore the data in these chunks.
fseek(fp, header.data_size, SEEK_CUR);
// read next sub chunk
fread(header.data, 8, sizeof(char), fp);
}
if (header.data_size == 0) {
int offset = ftell(fp);
fseek(fp, 0, SEEK_END);
header.data_size = ftell(fp) - offset;
fseek(fp, offset, SEEK_SET);
}
num_channel_ = header.channels;
sample_rate_ = header.sample_rate;
bits_per_sample_ = header.bit;
int num_data = header.data_size / (bits_per_sample_ / 8);
data_ = new float[num_data]; // Create 1-dim array
num_samples_ = num_data / num_channel_;
std::cout << "num_channel_ :" << num_channel_ << std::endl;
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
std::cout << "num_samples :" << num_data << std::endl;
std::cout << "num_data_size :" << header.data_size << std::endl;
switch (bits_per_sample_) {
case 8: {
char sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(char), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
break;
}
case 16: {
int16_t sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(int16_t), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
break;
}
case 32:
{
if (header.format == 1) //S32
{
int sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(int), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
}
else if (header.format == 3) // IEEE-float
{
float sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(float), fp);
data_[i] = static_cast<float>(sample);
}
}
else {
printf("unsupported quantization bits\n");
}
break;
}
default:
printf("unsupported quantization bits\n");
break;
}
fclose(fp);
return true;
}
int num_channel() const { return num_channel_; }
int sample_rate() const { return sample_rate_; }
int bits_per_sample() const { return bits_per_sample_; }
int num_samples() const { return num_samples_; }
~WavReader() {
delete[] data_;
}
const float* data() const { return data_; }
private:
int num_channel_;
int sample_rate_;
int bits_per_sample_;
int num_samples_; // sample points per channel
float* data_;
};
class WavWriter {
public:
WavWriter(const float* data, int num_samples, int num_channel,
int sample_rate, int bits_per_sample)
: data_(data),
num_samples_(num_samples),
num_channel_(num_channel),
sample_rate_(sample_rate),
bits_per_sample_(bits_per_sample) {}
void Write(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "w");
// init char 'riff' 'WAVE' 'fmt ' 'data'
WavHeader header;
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
memcpy(&header, wav_header, sizeof(header));
header.channels = num_channel_;
header.bit = bits_per_sample_;
header.sample_rate = sample_rate_;
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
header.size = sizeof(header) - 8 + header.data_size;
header.bytes_per_second =
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
header.block_size = num_channel_ * (bits_per_sample_ / 8);
fwrite(&header, 1, sizeof(header), fp);
for (int i = 0; i < num_samples_; ++i) {
for (int j = 0; j < num_channel_; ++j) {
switch (bits_per_sample_) {
case 8: {
char sample = static_cast<char>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 16: {
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 32: {
int sample = static_cast<int>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
}
}
}
fclose(fp);
}
private:
const float* data_;
int num_samples_; // total float points in data_
int num_channel_;
int sample_rate_;
int bits_per_sample_;
};
} // namespace wenet
#endif // FRONTEND_WAV_H_

View File

@@ -0,0 +1,35 @@
using System.Text;
namespace VadDotNet;
class Program
{
private const string MODEL_PATH = "./resources/silero_vad.onnx";
private const string EXAMPLE_WAV_FILE = "./resources/example.wav";
private const int SAMPLE_RATE = 16000;
private const float THRESHOLD = 0.5f;
private const int MIN_SPEECH_DURATION_MS = 250;
private const float MAX_SPEECH_DURATION_SECONDS = float.PositiveInfinity;
private const int MIN_SILENCE_DURATION_MS = 100;
private const int SPEECH_PAD_MS = 30;
public static void Main(string[] args)
{
var vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
List<SileroSpeechSegment> speechTimeList = vadDetector.GetSpeechSegmentList(new FileInfo(EXAMPLE_WAV_FILE));
//Console.WriteLine(speechTimeList.ToJson());
StringBuilder sb = new StringBuilder();
foreach (var speechSegment in speechTimeList)
{
sb.Append($"start second: {speechSegment.StartSecond}, end second: {speechSegment.EndSecond}\n");
}
Console.WriteLine(sb.ToString());
}
}

View File

@@ -0,0 +1,21 @@
namespace VadDotNet;
public class SileroSpeechSegment
{
public int? StartOffset { get; set; }
public int? EndOffset { get; set; }
public float? StartSecond { get; set; }
public float? EndSecond { get; set; }
public SileroSpeechSegment()
{
}
public SileroSpeechSegment(int startOffset, int? endOffset, float? startSecond, float? endSecond)
{
StartOffset = startOffset;
EndOffset = endOffset;
StartSecond = startSecond;
EndSecond = endSecond;
}
}

View File

@@ -0,0 +1,250 @@
using NAudio.Wave;
using VADdotnet;
namespace VadDotNet;
public class SileroVadDetector
{
private readonly SileroVadOnnxModel _model;
private readonly float _threshold;
private readonly float _negThreshold;
private readonly int _samplingRate;
private readonly int _windowSizeSample;
private readonly float _minSpeechSamples;
private readonly float _speechPadSamples;
private readonly float _maxSpeechSamples;
private readonly float _minSilenceSamples;
private readonly float _minSilenceSamplesAtMaxSpeech;
private int _audioLengthSamples;
private const float THRESHOLD_GAP = 0.15f;
// ReSharper disable once InconsistentNaming
private const int SAMPLING_RATE_8K = 8000;
// ReSharper disable once InconsistentNaming
private const int SAMPLING_RATE_16K = 16000;
public SileroVadDetector(string onnxModelPath, float threshold, int samplingRate,
int minSpeechDurationMs, float maxSpeechDurationSeconds,
int minSilenceDurationMs, int speechPadMs)
{
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K)
{
throw new ArgumentException("Sampling rate not support, only available for [8000, 16000]");
}
this._model = new SileroVadOnnxModel(onnxModelPath);
this._samplingRate = samplingRate;
this._threshold = threshold;
this._negThreshold = threshold - THRESHOLD_GAP;
this._windowSizeSample = samplingRate == SAMPLING_RATE_16K ? 512 : 256;
this._minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
this._speechPadSamples = samplingRate * speechPadMs / 1000f;
this._maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - _windowSizeSample - 2 * _speechPadSamples;
this._minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
this._minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
this.Reset();
}
public void Reset()
{
_model.ResetStates();
}
public List<SileroSpeechSegment> GetSpeechSegmentList(FileInfo wavFile)
{
Reset();
using (var audioFile = new AudioFileReader(wavFile.FullName))
{
List<float> speechProbList = new List<float>();
this._audioLengthSamples = (int)(audioFile.Length / 2);
float[] buffer = new float[this._windowSizeSample];
while (audioFile.Read(buffer, 0, buffer.Length) > 0)
{
float speechProb = _model.Call(new[] { buffer }, _samplingRate)[0];
speechProbList.Add(speechProb);
}
return CalculateProb(speechProbList);
}
}
private List<SileroSpeechSegment> CalculateProb(List<float> speechProbList)
{
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
bool triggered = false;
int tempEnd = 0, prevEnd = 0, nextStart = 0;
SileroSpeechSegment segment = new SileroSpeechSegment();
for (int i = 0; i < speechProbList.Count; i++)
{
float speechProb = speechProbList[i];
if (speechProb >= _threshold && (tempEnd != 0))
{
tempEnd = 0;
if (nextStart < prevEnd)
{
nextStart = _windowSizeSample * i;
}
}
if (speechProb >= _threshold && !triggered)
{
triggered = true;
segment.StartOffset = _windowSizeSample * i;
continue;
}
if (triggered && (_windowSizeSample * i) - segment.StartOffset > _maxSpeechSamples)
{
if (prevEnd != 0)
{
segment.EndOffset = prevEnd;
result.Add(segment);
segment = new SileroSpeechSegment();
if (nextStart < prevEnd)
{
triggered = false;
}
else
{
segment.StartOffset = nextStart;
}
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
}
else
{
segment.EndOffset = _windowSizeSample * i;
result.Add(segment);
segment = new SileroSpeechSegment();
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
triggered = false;
continue;
}
}
if (speechProb < _negThreshold && triggered)
{
if (tempEnd == 0)
{
tempEnd = _windowSizeSample * i;
}
if (((_windowSizeSample * i) - tempEnd) > _minSilenceSamplesAtMaxSpeech)
{
prevEnd = tempEnd;
}
if ((_windowSizeSample * i) - tempEnd < _minSilenceSamples)
{
continue;
}
else
{
segment.EndOffset = tempEnd;
if ((segment.EndOffset - segment.StartOffset) > _minSpeechSamples)
{
result.Add(segment);
}
segment = new SileroSpeechSegment();
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
triggered = false;
continue;
}
}
}
if (segment.StartOffset != null && (_audioLengthSamples - segment.StartOffset) > _minSpeechSamples)
{
segment.EndOffset = _audioLengthSamples;
result.Add(segment);
}
for (int i = 0; i < result.Count; i++)
{
SileroSpeechSegment item = result[i];
if (i == 0)
{
item.StartOffset = (int)Math.Max(0, item.StartOffset.Value - _speechPadSamples);
}
if (i != result.Count - 1)
{
SileroSpeechSegment nextItem = result[i + 1];
int silenceDuration = nextItem.StartOffset.Value - item.EndOffset.Value;
if (silenceDuration < 2 * _speechPadSamples)
{
item.EndOffset = item.EndOffset + (silenceDuration / 2);
nextItem.StartOffset = Math.Max(0, nextItem.StartOffset.Value - (silenceDuration / 2));
}
else
{
item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
nextItem.StartOffset = (int)Math.Max(0, nextItem.StartOffset.Value - _speechPadSamples);
}
}
else
{
item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
}
}
return MergeListAndCalculateSecond(result, _samplingRate);
}
private List<SileroSpeechSegment> MergeListAndCalculateSecond(List<SileroSpeechSegment> original, int samplingRate)
{
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
if (original == null || original.Count == 0)
{
return result;
}
int left = original[0].StartOffset.Value;
int right = original[0].EndOffset.Value;
if (original.Count > 1)
{
original.Sort((a, b) => a.StartOffset.Value.CompareTo(b.StartOffset.Value));
for (int i = 1; i < original.Count; i++)
{
SileroSpeechSegment segment = original[i];
if (segment.StartOffset > right)
{
result.Add(new SileroSpeechSegment(left, right,
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
left = segment.StartOffset.Value;
right = segment.EndOffset.Value;
}
else
{
right = Math.Max(right, segment.EndOffset.Value);
}
}
result.Add(new SileroSpeechSegment(left, right,
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
}
else
{
result.Add(new SileroSpeechSegment(left, right,
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
}
return result;
}
private float CalculateSecondByOffset(int offset, int samplingRate)
{
float secondValue = offset * 1.0f / samplingRate;
return (float)Math.Floor(secondValue * 1000.0f) / 1000.0f;
}
}

View File

@@ -0,0 +1,220 @@
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.Linq;
namespace VADdotnet;
public class SileroVadOnnxModel : IDisposable
{
private readonly InferenceSession session;
private float[][][] state;
private float[][] context;
private int lastSr = 0;
private int lastBatchSize = 0;
private static readonly List<int> SAMPLE_RATES = new List<int> { 8000, 16000 };
public SileroVadOnnxModel(string modelPath)
{
var sessionOptions = new SessionOptions();
sessionOptions.InterOpNumThreads = 1;
sessionOptions.IntraOpNumThreads = 1;
sessionOptions.EnableCpuMemArena = true;
session = new InferenceSession(modelPath, sessionOptions);
ResetStates();
}
public void ResetStates()
{
state = new float[2][][];
state[0] = new float[1][];
state[1] = new float[1][];
state[0][0] = new float[128];
state[1][0] = new float[128];
context = Array.Empty<float[]>();
lastSr = 0;
lastBatchSize = 0;
}
public void Dispose()
{
session?.Dispose();
}
public class ValidationResult
{
public float[][] X { get; }
public int Sr { get; }
public ValidationResult(float[][] x, int sr)
{
X = x;
Sr = sr;
}
}
private ValidationResult ValidateInput(float[][] x, int sr)
{
if (x.Length == 1)
{
x = new float[][] { x[0] };
}
if (x.Length > 2)
{
throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}");
}
if (sr != 16000 && (sr % 16000 == 0))
{
int step = sr / 16000;
float[][] reducedX = new float[x.Length][];
for (int i = 0; i < x.Length; i++)
{
float[] current = x[i];
float[] newArr = new float[(current.Length + step - 1) / step];
for (int j = 0, index = 0; j < current.Length; j += step, index++)
{
newArr[index] = current[j];
}
reducedX[i] = newArr;
}
x = reducedX;
sr = 16000;
}
if (!SAMPLE_RATES.Contains(sr))
{
throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)");
}
if (((float)sr) / x[0].Length > 31.25)
{
throw new ArgumentException("Input audio is too short");
}
return new ValidationResult(x, sr);
}
private static float[][] Concatenate(float[][] a, float[][] b)
{
if (a.Length != b.Length)
{
throw new ArgumentException("The number of rows in both arrays must be the same.");
}
int rows = a.Length;
int colsA = a[0].Length;
int colsB = b[0].Length;
float[][] result = new float[rows][];
for (int i = 0; i < rows; i++)
{
result[i] = new float[colsA + colsB];
Array.Copy(a[i], 0, result[i], 0, colsA);
Array.Copy(b[i], 0, result[i], colsA, colsB);
}
return result;
}
private static float[][] GetLastColumns(float[][] array, int contextSize)
{
int rows = array.Length;
int cols = array[0].Length;
if (contextSize > cols)
{
throw new ArgumentException("contextSize cannot be greater than the number of columns in the array.");
}
float[][] result = new float[rows][];
for (int i = 0; i < rows; i++)
{
result[i] = new float[contextSize];
Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize);
}
return result;
}
public float[] Call(float[][] x, int sr)
{
var result = ValidateInput(x, sr);
x = result.X;
sr = result.Sr;
int numberSamples = sr == 16000 ? 512 : 256;
if (x[0].Length != numberSamples)
{
throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)");
}
int batchSize = x.Length;
int contextSize = sr == 16000 ? 64 : 32;
if (lastBatchSize == 0)
{
ResetStates();
}
if (lastSr != 0 && lastSr != sr)
{
ResetStates();
}
if (lastBatchSize != 0 && lastBatchSize != batchSize)
{
ResetStates();
}
if (context.Length == 0)
{
context = new float[batchSize][];
for (int i = 0; i < batchSize; i++)
{
context[i] = new float[contextSize];
}
}
x = Concatenate(context, x);
var inputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("input", new DenseTensor<float>(x.SelectMany(a => a).ToArray(), new[] { x.Length, x[0].Length })),
NamedOnnxValue.CreateFromTensor("sr", new DenseTensor<long>(new[] { (long)sr }, new[] { 1 })),
NamedOnnxValue.CreateFromTensor("state", new DenseTensor<float>(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), new[] { state.Length, state[0].Length, state[0][0].Length }))
};
using (var outputs = session.Run(inputs))
{
var output = outputs.First(o => o.Name == "output").AsTensor<float>();
var newState = outputs.First(o => o.Name == "stateN").AsTensor<float>();
context = GetLastColumns(x, contextSize);
lastSr = sr;
lastBatchSize = batchSize;
state = new float[newState.Dimensions[0]][][];
for (int i = 0; i < newState.Dimensions[0]; i++)
{
state[i] = new float[newState.Dimensions[1]][];
for (int j = 0; j < newState.Dimensions[1]; j++)
{
state[i][j] = new float[newState.Dimensions[2]];
for (int k = 0; k < newState.Dimensions[2]; k++)
{
state[i][j][k] = newState[i, j, k];
}
}
}
return output.ToArray();
}
}
}

View File

@@ -0,0 +1,25 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.18.1" />
<PackageReference Include="NAudio" Version="2.2.1" />
</ItemGroup>
<ItemGroup>
<Folder Include="resources\" />
</ItemGroup>
<ItemGroup>
<Content Include="resources\**">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
</ItemGroup>
</Project>

View File

@@ -0,0 +1 @@
place onnx model file and example.wav file in this folder

19
examples/go/README.md Normal file
View File

@@ -0,0 +1,19 @@
## Golang Example
This is a sample program of how to run speech detection using `silero-vad` from Golang (CGO + ONNX Runtime).
### Requirements
- Golang >= v1.21
- ONNX Runtime
### Usage
```sh
go run ./cmd/main.go test.wav
```
> **_Note_**
>
> Make sure you have the ONNX Runtime library and C headers installed in your path.

63
examples/go/cmd/main.go Normal file
View File

@@ -0,0 +1,63 @@
package main
import (
"log"
"os"
"github.com/streamer45/silero-vad-go/speech"
"github.com/go-audio/wav"
)
func main() {
sd, err := speech.NewDetector(speech.DetectorConfig{
ModelPath: "../../src/silero_vad/data/silero_vad.onnx",
SampleRate: 16000,
Threshold: 0.5,
MinSilenceDurationMs: 100,
SpeechPadMs: 30,
})
if err != nil {
log.Fatalf("failed to create speech detector: %s", err)
}
if len(os.Args) != 2 {
log.Fatalf("invalid arguments provided: expecting one file path")
}
f, err := os.Open(os.Args[1])
if err != nil {
log.Fatalf("failed to open sample audio file: %s", err)
}
defer f.Close()
dec := wav.NewDecoder(f)
if ok := dec.IsValidFile(); !ok {
log.Fatalf("invalid WAV file")
}
buf, err := dec.FullPCMBuffer()
if err != nil {
log.Fatalf("failed to get PCM buffer")
}
pcmBuf := buf.AsFloat32Buffer()
segments, err := sd.Detect(pcmBuf.Data)
if err != nil {
log.Fatalf("Detect failed: %s", err)
}
for _, s := range segments {
log.Printf("speech starts at %0.2fs", s.SpeechStartAt)
if s.SpeechEndAt > 0 {
log.Printf("speech ends at %0.2fs", s.SpeechEndAt)
}
}
err = sd.Destroy()
if err != nil {
log.Fatalf("failed to destroy detector: %s", err)
}
}

13
examples/go/go.mod Normal file
View File

@@ -0,0 +1,13 @@
module silero
go 1.21.4
require (
github.com/go-audio/wav v1.1.0
github.com/streamer45/silero-vad-go v0.2.1
)
require (
github.com/go-audio/audio v1.0.0 // indirect
github.com/go-audio/riff v1.0.0 // indirect
)

18
examples/go/go.sum Normal file
View File

@@ -0,0 +1,18 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/streamer45/silero-vad-go v0.2.0 h1:bbRTa6cQuc7VI88y0qicx375UyWoxE6wlVOF+mUg0+g=
github.com/streamer45/silero-vad-go v0.2.0/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk=
github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,30 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>java-example</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<name>sliero-vad-example</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.16.0-rc1</version>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,69 @@
package org.example;
import ai.onnxruntime.OrtException;
import javax.sound.sampled.*;
import java.util.Map;
public class App {
private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx";
private static final int SAMPLE_RATE = 16000;
private static final float START_THRESHOLD = 0.6f;
private static final float END_THRESHOLD = 0.45f;
private static final int MIN_SILENCE_DURATION_MS = 600;
private static final int SPEECH_PAD_MS = 500;
private static final int WINDOW_SIZE_SAMPLES = 2048;
public static void main(String[] args) {
// Initialize the Voice Activity Detector
SlieroVadDetector vadDetector;
try {
vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
} catch (OrtException e) {
System.err.println("Error initializing the VAD detector: " + e.getMessage());
return;
}
// Set audio format
AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false);
DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
// Get the target data line and open it with the specified format
TargetDataLine targetDataLine;
try {
targetDataLine = (TargetDataLine) AudioSystem.getLine(info);
targetDataLine.open(format);
targetDataLine.start();
} catch (LineUnavailableException e) {
System.err.println("Error opening target data line: " + e.getMessage());
return;
}
// Main loop to continuously read data and apply Voice Activity Detection
while (targetDataLine.isOpen()) {
byte[] data = new byte[WINDOW_SIZE_SAMPLES];
int numBytesRead = targetDataLine.read(data, 0, data.length);
if (numBytesRead <= 0) {
System.err.println("Error reading data from target data line.");
continue;
}
// Apply the Voice Activity Detector to the data and get the result
Map<String, Double> detectResult;
try {
detectResult = vadDetector.apply(data, true);
} catch (Exception e) {
System.err.println("Error applying VAD detector: " + e.getMessage());
continue;
}
if (!detectResult.isEmpty()) {
System.out.println(detectResult);
}
}
// Close the target data line to release audio resources
targetDataLine.close();
}
}

View File

@@ -0,0 +1,145 @@
package org.example;
import ai.onnxruntime.OrtException;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
public class SlieroVadDetector {
// OnnxModel model used for speech processing
private final SlieroVadOnnxModel model;
// Threshold for speech start
private final float startThreshold;
// Threshold for speech end
private final float endThreshold;
// Sampling rate
private final int samplingRate;
// Minimum number of silence samples to determine the end threshold of speech
private final float minSilenceSamples;
// Additional number of samples for speech start or end to calculate speech start or end time
private final float speechPadSamples;
// Whether in the triggered state (i.e. whether speech is being detected)
private boolean triggered;
// Temporarily stored number of speech end samples
private int tempEnd;
// Number of samples currently being processed
private int currentSample;
public SlieroVadDetector(String modelPath,
float startThreshold,
float endThreshold,
int samplingRate,
int minSilenceDurationMs,
int speechPadMs) throws OrtException {
// Check if the sampling rate is 8000 or 16000, if not, throw an exception
if (samplingRate != 8000 && samplingRate != 16000) {
throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]");
}
// Initialize the parameters
this.model = new SlieroVadOnnxModel(modelPath);
this.startThreshold = startThreshold;
this.endThreshold = endThreshold;
this.samplingRate = samplingRate;
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
// Reset the state
reset();
}
// Method to reset the state, including the model state, trigger state, temporary end time, and current sample count
public void reset() {
model.resetStates();
triggered = false;
tempEnd = 0;
currentSample = 0;
}
// apply method for processing the audio array, returning possible speech start or end times
public Map<String, Double> apply(byte[] data, boolean returnSeconds) {
// Convert the byte array to a float array
float[] audioData = new float[data.length / 2];
for (int i = 0; i < audioData.length; i++) {
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
}
// Get the length of the audio array as the window size
int windowSizeSamples = audioData.length;
// Update the current sample count
currentSample += windowSizeSamples;
// Call the model to get the prediction probability of speech
float speechProb = 0;
try {
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
} catch (OrtException e) {
throw new RuntimeException(e);
}
// If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time
// This indicates that the speech duration has exceeded expectations and needs to recalculate the end time
if (speechProb >= startThreshold && tempEnd != 0) {
tempEnd = 0;
}
// If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time
if (speechProb >= startThreshold && !triggered) {
triggered = true;
int speechStart = (int) (currentSample - speechPadSamples);
speechStart = Math.max(speechStart, 0);
Map<String, Double> result = new HashMap<>();
// Decide whether to return the result in seconds or sample count based on the returnSeconds parameter
if (returnSeconds) {
double speechStartSeconds = speechStart / (double) samplingRate;
double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
result.put("start", roundedSpeechStart);
} else {
result.put("start", (double) speechStart);
}
return result;
}
// If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time
if (speechProb < endThreshold && triggered) {
// Initialize or update the temporary end time
if (tempEnd == 0) {
tempEnd = currentSample;
}
// If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null
// This indicates that it is not yet possible to determine whether the speech has ended
if (currentSample - tempEnd < minSilenceSamples) {
return Collections.emptyMap();
} else {
// Calculate the speech end time, reset the trigger state and temporary end time
int speechEnd = (int) (tempEnd + speechPadSamples);
tempEnd = 0;
triggered = false;
Map<String, Double> result = new HashMap<>();
if (returnSeconds) {
double speechEndSeconds = speechEnd / (double) samplingRate;
double roundedSpeechEnd = BigDecimal.valueOf(speechEndSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
result.put("end", roundedSpeechEnd);
} else {
result.put("end", (double) speechEnd);
}
return result;
}
}
// If the above conditions are not met, return null by default
return Collections.emptyMap();
}
public void close() throws OrtException {
reset();
model.close();
}
}

View File

@@ -0,0 +1,180 @@
package org.example;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class SlieroVadOnnxModel {
// Define private variable OrtSession
private final OrtSession session;
private float[][][] h;
private float[][][] c;
// Define the last sample rate
private int lastSr = 0;
// Define the last batch size
private int lastBatchSize = 0;
// Define a list of supported sample rates
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
// Constructor
public SlieroVadOnnxModel(String modelPath) throws OrtException {
// Get the ONNX runtime environment
OrtEnvironment env = OrtEnvironment.getEnvironment();
// Create an ONNX session options object
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
opts.setInterOpNumThreads(1);
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
opts.setIntraOpNumThreads(1);
// Add a CPU device, setting to false disables CPU execution optimization
opts.addCPU(true);
// Create an ONNX session using the environment, model path, and options
session = env.createSession(modelPath, opts);
// Reset states
resetStates();
}
/**
* Reset states
*/
void resetStates() {
h = new float[2][1][64];
c = new float[2][1][64];
lastSr = 0;
lastBatchSize = 0;
}
public void close() throws OrtException {
session.close();
}
/**
* Define inner class ValidationResult
*/
public static class ValidationResult {
public final float[][] x;
public final int sr;
// Constructor
public ValidationResult(float[][] x, int sr) {
this.x = x;
this.sr = sr;
}
}
/**
* Function to validate input data
*/
private ValidationResult validateInput(float[][] x, int sr) {
// Process the input data with dimension 1
if (x.length == 1) {
x = new float[][]{x[0]};
}
// Throw an exception when the input data dimension is greater than 2
if (x.length > 2) {
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
}
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
if (sr != 16000 && (sr % 16000 == 0)) {
int step = sr / 16000;
float[][] reducedX = new float[x.length][];
for (int i = 0; i < x.length; i++) {
float[] current = x[i];
float[] newArr = new float[(current.length + step - 1) / step];
for (int j = 0, index = 0; j < current.length; j += step, index++) {
newArr[index] = current[j];
}
reducedX[i] = newArr;
}
x = reducedX;
sr = 16000;
}
// If the sample rate is not in the list of supported sample rates, throw an exception
if (!SAMPLE_RATES.contains(sr)) {
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
}
// If the input audio block is too short, throw an exception
if (((float) sr) / x[0].length > 31.25) {
throw new IllegalArgumentException("Input audio is too short");
}
// Return the validated result
return new ValidationResult(x, sr);
}
/**
* Method to call the ONNX model
*/
public float[] call(float[][] x, int sr) throws OrtException {
ValidationResult result = validateInput(x, sr);
x = result.x;
sr = result.sr;
int batchSize = x.length;
if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) {
resetStates();
}
OrtEnvironment env = OrtEnvironment.getEnvironment();
OnnxTensor inputTensor = null;
OnnxTensor hTensor = null;
OnnxTensor cTensor = null;
OnnxTensor srTensor = null;
OrtSession.Result ortOutputs = null;
try {
// Create input tensors
inputTensor = OnnxTensor.createTensor(env, x);
hTensor = OnnxTensor.createTensor(env, h);
cTensor = OnnxTensor.createTensor(env, c);
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input", inputTensor);
inputs.put("sr", srTensor);
inputs.put("h", hTensor);
inputs.put("c", cTensor);
// Call the ONNX model for calculation
ortOutputs = session.run(inputs);
// Get the output results
float[][] output = (float[][]) ortOutputs.get(0).getValue();
h = (float[][][]) ortOutputs.get(1).getValue();
c = (float[][][]) ortOutputs.get(2).getValue();
lastSr = sr;
lastBatchSize = batchSize;
return output[0];
} finally {
if (inputTensor != null) {
inputTensor.close();
}
if (hTensor != null) {
hTensor.close();
}
if (cTensor != null) {
cTensor.close();
}
if (srTensor != null) {
srTensor.close();
}
if (ortOutputs != null) {
ortOutputs.close();
}
}
}
}

View File

@@ -0,0 +1,37 @@
package org.example;
import ai.onnxruntime.OrtException;
import java.io.File;
import java.util.List;
public class App {
private static final String MODEL_PATH = "/path/silero_vad.onnx";
private static final String EXAMPLE_WAV_FILE = "/path/example.wav";
private static final int SAMPLE_RATE = 16000;
private static final float THRESHOLD = 0.5f;
private static final int MIN_SPEECH_DURATION_MS = 250;
private static final float MAX_SPEECH_DURATION_SECONDS = Float.POSITIVE_INFINITY;
private static final int MIN_SILENCE_DURATION_MS = 100;
private static final int SPEECH_PAD_MS = 30;
public static void main(String[] args) {
// Initialize the Voice Activity Detector
SileroVadDetector vadDetector;
try {
vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
fromWavFile(vadDetector, new File(EXAMPLE_WAV_FILE));
} catch (OrtException e) {
System.err.println("Error initializing the VAD detector: " + e.getMessage());
}
}
public static void fromWavFile(SileroVadDetector vadDetector, File wavFile) {
List<SileroSpeechSegment> speechTimeList = vadDetector.getSpeechSegmentList(wavFile);
for (SileroSpeechSegment speechSegment : speechTimeList) {
System.out.println(String.format("start second: %f, end second: %f",
speechSegment.getStartSecond(), speechSegment.getEndSecond()));
}
}
}

View File

@@ -0,0 +1,51 @@
package org.example;
public class SileroSpeechSegment {
private Integer startOffset;
private Integer endOffset;
private Float startSecond;
private Float endSecond;
public SileroSpeechSegment() {
}
public SileroSpeechSegment(Integer startOffset, Integer endOffset, Float startSecond, Float endSecond) {
this.startOffset = startOffset;
this.endOffset = endOffset;
this.startSecond = startSecond;
this.endSecond = endSecond;
}
public Integer getStartOffset() {
return startOffset;
}
public Integer getEndOffset() {
return endOffset;
}
public Float getStartSecond() {
return startSecond;
}
public Float getEndSecond() {
return endSecond;
}
public void setStartOffset(Integer startOffset) {
this.startOffset = startOffset;
}
public void setEndOffset(Integer endOffset) {
this.endOffset = endOffset;
}
public void setStartSecond(Float startSecond) {
this.startSecond = startSecond;
}
public void setEndSecond(Float endSecond) {
this.endSecond = endSecond;
}
}

View File

@@ -0,0 +1,244 @@
package org.example;
import ai.onnxruntime.OrtException;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.File;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
public class SileroVadDetector {
private final SileroVadOnnxModel model;
private final float threshold;
private final float negThreshold;
private final int samplingRate;
private final int windowSizeSample;
private final float minSpeechSamples;
private final float speechPadSamples;
private final float maxSpeechSamples;
private final float minSilenceSamples;
private final float minSilenceSamplesAtMaxSpeech;
private int audioLengthSamples;
private static final float THRESHOLD_GAP = 0.15f;
private static final Integer SAMPLING_RATE_8K = 8000;
private static final Integer SAMPLING_RATE_16K = 16000;
/**
* Constructor
* @param onnxModelPath the path of silero-vad onnx model
* @param threshold threshold for speech start
* @param samplingRate audio sampling rate, only available for [8k, 16k]
* @param minSpeechDurationMs Minimum speech length in millis, any speech duration that smaller than this value would not be considered as speech
* @param maxSpeechDurationSeconds Maximum speech length in millis, recommend to be set as Float.POSITIVE_INFINITY
* @param minSilenceDurationMs Minimum silence length in millis, any silence duration that smaller than this value would not be considered as silence
* @param speechPadMs Additional pad millis for speech start and end
* @throws OrtException
*/
public SileroVadDetector(String onnxModelPath, float threshold, int samplingRate,
int minSpeechDurationMs, float maxSpeechDurationSeconds,
int minSilenceDurationMs, int speechPadMs) throws OrtException {
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K) {
throw new IllegalArgumentException("Sampling rate not support, only available for [8000, 16000]");
}
this.model = new SileroVadOnnxModel(onnxModelPath);
this.samplingRate = samplingRate;
this.threshold = threshold;
this.negThreshold = threshold - THRESHOLD_GAP;
if (samplingRate == SAMPLING_RATE_16K) {
this.windowSizeSample = 512;
} else {
this.windowSizeSample = 256;
}
this.minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
this.maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - windowSizeSample - 2 * speechPadSamples;
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
this.minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
this.reset();
}
/**
* Method to reset the state
*/
public void reset() {
model.resetStates();
}
/**
* Get speech segment list by given wav-format file
* @param wavFile wav file
* @return list of speech segment
*/
public List<SileroSpeechSegment> getSpeechSegmentList(File wavFile) {
reset();
try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(wavFile)){
List<Float> speechProbList = new ArrayList<>();
this.audioLengthSamples = audioInputStream.available() / 2;
byte[] data = new byte[this.windowSizeSample * 2];
int numBytesRead = 0;
while ((numBytesRead = audioInputStream.read(data)) != -1) {
if (numBytesRead <= 0) {
break;
}
// Convert the byte array to a float array
float[] audioData = new float[data.length / 2];
for (int i = 0; i < audioData.length; i++) {
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
}
float speechProb = 0;
try {
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
speechProbList.add(speechProb);
} catch (OrtException e) {
throw e;
}
}
return calculateProb(speechProbList);
} catch (Exception e) {
throw new RuntimeException("SileroVadDetector getSpeechTimeList with error", e);
}
}
/**
* Calculate speech segement by probability
* @param speechProbList speech probability list
* @return list of speech segment
*/
private List<SileroSpeechSegment> calculateProb(List<Float> speechProbList) {
List<SileroSpeechSegment> result = new ArrayList<>();
boolean triggered = false;
int tempEnd = 0, prevEnd = 0, nextStart = 0;
SileroSpeechSegment segment = new SileroSpeechSegment();
for (int i = 0; i < speechProbList.size(); i++) {
Float speechProb = speechProbList.get(i);
if (speechProb >= threshold && (tempEnd != 0)) {
tempEnd = 0;
if (nextStart < prevEnd) {
nextStart = windowSizeSample * i;
}
}
if (speechProb >= threshold && !triggered) {
triggered = true;
segment.setStartOffset(windowSizeSample * i);
continue;
}
if (triggered && (windowSizeSample * i) - segment.getStartOffset() > maxSpeechSamples) {
if (prevEnd != 0) {
segment.setEndOffset(prevEnd);
result.add(segment);
segment = new SileroSpeechSegment();
if (nextStart < prevEnd) {
triggered = false;
}else {
segment.setStartOffset(nextStart);
}
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
}else {
segment.setEndOffset(windowSizeSample * i);
result.add(segment);
segment = new SileroSpeechSegment();
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
triggered = false;
continue;
}
}
if (speechProb < negThreshold && triggered) {
if (tempEnd == 0) {
tempEnd = windowSizeSample * i;
}
if (((windowSizeSample * i) - tempEnd) > minSilenceSamplesAtMaxSpeech) {
prevEnd = tempEnd;
}
if ((windowSizeSample * i) - tempEnd < minSilenceSamples) {
continue;
}else {
segment.setEndOffset(tempEnd);
if ((segment.getEndOffset() - segment.getStartOffset()) > minSpeechSamples) {
result.add(segment);
}
segment = new SileroSpeechSegment();
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
triggered = false;
continue;
}
}
}
if (segment.getStartOffset() != null && (audioLengthSamples - segment.getStartOffset()) > minSpeechSamples) {
segment.setEndOffset(audioLengthSamples);
result.add(segment);
}
for (int i = 0; i < result.size(); i++) {
SileroSpeechSegment item = result.get(i);
if (i == 0) {
item.setStartOffset((int)(Math.max(0,item.getStartOffset() - speechPadSamples)));
}
if (i != result.size() - 1) {
SileroSpeechSegment nextItem = result.get(i + 1);
Integer silenceDuration = nextItem.getStartOffset() - item.getEndOffset();
if(silenceDuration < 2 * speechPadSamples){
item.setEndOffset(item.getEndOffset() + (silenceDuration / 2 ));
nextItem.setStartOffset(Math.max(0, nextItem.getStartOffset() - (silenceDuration / 2)));
} else {
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
nextItem.setStartOffset((int)(Math.max(0,nextItem.getStartOffset() - speechPadSamples)));
}
}else {
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
}
}
return mergeListAndCalculateSecond(result, samplingRate);
}
private List<SileroSpeechSegment> mergeListAndCalculateSecond(List<SileroSpeechSegment> original, Integer samplingRate) {
List<SileroSpeechSegment> result = new ArrayList<>();
if (original == null || original.size() == 0) {
return result;
}
Integer left = original.get(0).getStartOffset();
Integer right = original.get(0).getEndOffset();
if (original.size() > 1) {
original.sort(Comparator.comparingLong(SileroSpeechSegment::getStartOffset));
for (int i = 1; i < original.size(); i++) {
SileroSpeechSegment segment = original.get(i);
if (segment.getStartOffset() > right) {
result.add(new SileroSpeechSegment(left, right,
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
left = segment.getStartOffset();
right = segment.getEndOffset();
} else {
right = Math.max(right, segment.getEndOffset());
}
}
result.add(new SileroSpeechSegment(left, right,
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
}else {
result.add(new SileroSpeechSegment(left, right,
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
}
return result;
}
private Float calculateSecondByOffset(Integer offset, Integer samplingRate) {
float secondValue = offset * 1.0f / samplingRate;
return (float) Math.floor(secondValue * 1000.0f) / 1000.0f;
}
}

View File

@@ -0,0 +1,234 @@
package org.example;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class SileroVadOnnxModel {
// Define private variable OrtSession
private final OrtSession session;
private float[][][] state;
private float[][] context;
// Define the last sample rate
private int lastSr = 0;
// Define the last batch size
private int lastBatchSize = 0;
// Define a list of supported sample rates
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
// Constructor
public SileroVadOnnxModel(String modelPath) throws OrtException {
// Get the ONNX runtime environment
OrtEnvironment env = OrtEnvironment.getEnvironment();
// Create an ONNX session options object
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
opts.setInterOpNumThreads(1);
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
opts.setIntraOpNumThreads(1);
// Add a CPU device, setting to false disables CPU execution optimization
opts.addCPU(true);
// Create an ONNX session using the environment, model path, and options
session = env.createSession(modelPath, opts);
// Reset states
resetStates();
}
/**
* Reset states
*/
void resetStates() {
state = new float[2][1][128];
context = new float[0][];
lastSr = 0;
lastBatchSize = 0;
}
public void close() throws OrtException {
session.close();
}
/**
* Define inner class ValidationResult
*/
public static class ValidationResult {
public final float[][] x;
public final int sr;
// Constructor
public ValidationResult(float[][] x, int sr) {
this.x = x;
this.sr = sr;
}
}
/**
* Function to validate input data
*/
private ValidationResult validateInput(float[][] x, int sr) {
// Process the input data with dimension 1
if (x.length == 1) {
x = new float[][]{x[0]};
}
// Throw an exception when the input data dimension is greater than 2
if (x.length > 2) {
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
}
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
if (sr != 16000 && (sr % 16000 == 0)) {
int step = sr / 16000;
float[][] reducedX = new float[x.length][];
for (int i = 0; i < x.length; i++) {
float[] current = x[i];
float[] newArr = new float[(current.length + step - 1) / step];
for (int j = 0, index = 0; j < current.length; j += step, index++) {
newArr[index] = current[j];
}
reducedX[i] = newArr;
}
x = reducedX;
sr = 16000;
}
// If the sample rate is not in the list of supported sample rates, throw an exception
if (!SAMPLE_RATES.contains(sr)) {
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
}
// If the input audio block is too short, throw an exception
if (((float) sr) / x[0].length > 31.25) {
throw new IllegalArgumentException("Input audio is too short");
}
// Return the validated result
return new ValidationResult(x, sr);
}
private static float[][] concatenate(float[][] a, float[][] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("The number of rows in both arrays must be the same.");
}
int rows = a.length;
int colsA = a[0].length;
int colsB = b[0].length;
float[][] result = new float[rows][colsA + colsB];
for (int i = 0; i < rows; i++) {
System.arraycopy(a[i], 0, result[i], 0, colsA);
System.arraycopy(b[i], 0, result[i], colsA, colsB);
}
return result;
}
private static float[][] getLastColumns(float[][] array, int contextSize) {
int rows = array.length;
int cols = array[0].length;
if (contextSize > cols) {
throw new IllegalArgumentException("contextSize cannot be greater than the number of columns in the array.");
}
float[][] result = new float[rows][contextSize];
for (int i = 0; i < rows; i++) {
System.arraycopy(array[i], cols - contextSize, result[i], 0, contextSize);
}
return result;
}
/**
* Method to call the ONNX model
*/
public float[] call(float[][] x, int sr) throws OrtException {
ValidationResult result = validateInput(x, sr);
x = result.x;
sr = result.sr;
int numberSamples = 256;
if (sr == 16000) {
numberSamples = 512;
}
if (x[0].length != numberSamples) {
throw new IllegalArgumentException("Provided number of samples is " + x[0].length + " (Supported values: 256 for 8000 sample rate, 512 for 16000)");
}
int batchSize = x.length;
int contextSize = 32;
if (sr == 16000) {
contextSize = 64;
}
if (lastBatchSize == 0) {
resetStates();
}
if (lastSr != 0 && lastSr != sr) {
resetStates();
}
if (lastBatchSize != 0 && lastBatchSize != batchSize) {
resetStates();
}
if (context.length == 0) {
context = new float[batchSize][contextSize];
}
x = concatenate(context, x);
OrtEnvironment env = OrtEnvironment.getEnvironment();
OnnxTensor inputTensor = null;
OnnxTensor stateTensor = null;
OnnxTensor srTensor = null;
OrtSession.Result ortOutputs = null;
try {
// Create input tensors
inputTensor = OnnxTensor.createTensor(env, x);
stateTensor = OnnxTensor.createTensor(env, state);
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input", inputTensor);
inputs.put("sr", srTensor);
inputs.put("state", stateTensor);
// Call the ONNX model for calculation
ortOutputs = session.run(inputs);
// Get the output results
float[][] output = (float[][]) ortOutputs.get(0).getValue();
state = (float[][][]) ortOutputs.get(1).getValue();
context = getLastColumns(x, contextSize);
lastSr = sr;
lastBatchSize = batchSize;
return output[0];
} finally {
if (inputTensor != null) {
inputTensor.close();
}
if (stateTensor != null) {
stateTensor.close();
}
if (srTensor != null) {
srTensor.close();
}
if (ortOutputs != null) {
ortOutputs.close();
}
}
}
}

View File

@@ -186,7 +186,7 @@ if __name__ == '__main__':
help="same as trig_sum, but for switching from triggered to non-triggered state (non-speech)")
parser.add_argument('-N', '--num_steps', type=int, default=8,
help="nubmer of overlapping windows to split audio chunk into (we recommend 4 or 8)")
help="number of overlapping windows to split audio chunk into (we recommend 4 or 8)")
parser.add_argument('-nspw', '--num_samples_per_window', type=int, default=4000,
help="number of samples in each window, our models were trained using 4000 samples (250 ms) per window, so this is preferable value (lesser values reduce quality)")
@@ -198,4 +198,4 @@ if __name__ == '__main__':
help=" minimum silence duration in samples between to separate speech chunks")
ARGS = parser.parse_args()
ARGS.rate=DEFAULT_SAMPLE_RATE
main(ARGS)
main(ARGS)

View File

@@ -0,0 +1,149 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !pip install -q torchaudio\n",
"SAMPLING_RATE = 16000\n",
"import torch\n",
"from pprint import pprint\n",
"\n",
"torch.set_num_threads(1)\n",
"NUM_PROCESS=4 # set to the number of CPU cores in the machine\n",
"NUM_COPIES=8\n",
"# download wav files, make multiple copies\n",
"for idx in range(NUM_COPIES):\n",
" torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example{idx}.wav\")\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load VAD model from torch hub"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=True,\n",
" onnx=False)\n",
"\n",
"(get_speech_timestamps,\n",
"save_audio,\n",
"read_audio,\n",
"VADIterator,\n",
"collect_chunks) = utils"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define a vad process function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import multiprocessing\n",
"\n",
"vad_models = dict()\n",
"\n",
"def init_model(model):\n",
" pid = multiprocessing.current_process().pid\n",
" model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=False,\n",
" onnx=False)\n",
" vad_models[pid] = model\n",
"\n",
"def vad_process(audio_file: str):\n",
" \n",
" pid = multiprocessing.current_process().pid\n",
" \n",
" with torch.no_grad():\n",
" wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n",
" return get_speech_timestamps(\n",
" wav,\n",
" vad_models[pid],\n",
" 0.46, # speech prob threshold\n",
" 16000, # sample rate\n",
" 300, # min speech duration in ms\n",
" 20, # max speech duration in seconds\n",
" 600, # min silence duration\n",
" 512, # window size\n",
" 200, # spech pad ms\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Parallelization"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from concurrent.futures import ProcessPoolExecutor, as_completed\n",
"\n",
"futures = []\n",
"\n",
"with ProcessPoolExecutor(max_workers=NUM_PROCESS, initializer=init_model, initargs=(model,)) as ex:\n",
" for i in range(NUM_COPIES):\n",
" futures.append(ex.submit(vad_process, f\"en_example{idx}.wav\"))\n",
"\n",
"for finished in as_completed(futures):\n",
" pprint(finished.result())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "diarization",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

2
examples/rust-example/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
target/
recorder.wav

781
examples/rust-example/Cargo.lock generated Normal file
View File

@@ -0,0 +1,781 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "adler"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "autocfg"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "base64"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1"
[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
]
[[package]]
name = "bumpalo"
version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]]
name = "cc"
version = "1.0.98"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cpufeatures"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504"
dependencies = [
"libc",
]
[[package]]
name = "crc32fast"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
dependencies = [
"cfg-if",
]
[[package]]
name = "crunchy"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "crypto-common"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"typenum",
]
[[package]]
name = "digest"
version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
]
[[package]]
name = "errno"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
dependencies = [
"libc",
"windows-sys",
]
[[package]]
name = "filetime"
version = "0.2.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"windows-sys",
]
[[package]]
name = "flate2"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "form_urlencoded"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
dependencies = [
"percent-encoding",
]
[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "getrandom"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "half"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
dependencies = [
"cfg-if",
"crunchy",
]
[[package]]
name = "hound"
version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f"
[[package]]
name = "idna"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
dependencies = [
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "js-sys"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d"
dependencies = [
"wasm-bindgen",
]
[[package]]
name = "libc"
version = "0.2.155"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]]
name = "libloading"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
dependencies = [
"cfg-if",
"windows-targets",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
[[package]]
name = "log"
version = "0.4.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
[[package]]
name = "matrixmultiply"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]]
name = "miniz_oxide"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae"
dependencies = [
"adler",
]
[[package]]
name = "ndarray"
version = "0.15.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"rawpointer",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "once_cell"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "ort"
version = "2.0.0-rc.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14"
dependencies = [
"half",
"js-sys",
"libloading",
"ndarray",
"ort-sys",
"thiserror",
"tracing",
"web-sys",
]
[[package]]
name = "ort-sys"
version = "2.0.0-rc.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe"
dependencies = [
"flate2",
"sha2",
"tar",
"ureq",
]
[[package]]
name = "percent-encoding"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "pin-project-lite"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
[[package]]
name = "proc-macro2"
version = "1.0.84"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "redox_syscall"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa"
dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "ring"
version = "0.17.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
dependencies = [
"cc",
"cfg-if",
"getrandom",
"libc",
"spin",
"untrusted",
"windows-sys",
]
[[package]]
name = "rust-example"
version = "0.1.0"
dependencies = [
"hound",
"ndarray",
"ort",
]
[[package]]
name = "rustix"
version = "0.38.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f"
dependencies = [
"bitflags 2.5.0",
"errno",
"libc",
"linux-raw-sys",
"windows-sys",
]
[[package]]
name = "rustls"
version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
dependencies = [
"log",
"ring",
"rustls-pki-types",
"rustls-webpki",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-pki-types"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d"
[[package]]
name = "rustls-webpki"
version = "0.102.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
]
[[package]]
name = "sha2"
version = "0.10.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]]
name = "subtle"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc"
[[package]]
name = "syn"
version = "2.0.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "tar"
version = "0.4.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb"
dependencies = [
"filetime",
"libc",
"xattr",
]
[[package]]
name = "thiserror"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tinyvec"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tracing"
version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tracing-core"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [
"once_cell",
]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "unicode-bidi"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75"
[[package]]
name = "unicode-ident"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "unicode-normalization"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5"
dependencies = [
"tinyvec",
]
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "ureq"
version = "2.9.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd"
dependencies = [
"base64",
"log",
"once_cell",
"rustls",
"rustls-pki-types",
"rustls-webpki",
"url",
"webpki-roots",
]
[[package]]
name = "url"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
]
[[package]]
name = "version_check"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasm-bindgen"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8"
dependencies = [
"cfg-if",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da"
dependencies = [
"bumpalo",
"log",
"once_cell",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96"
[[package]]
name = "web-sys"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-targets"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
[[package]]
name = "windows_i686_gnu"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
[[package]]
name = "windows_i686_msvc"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
[[package]]
name = "xattr"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f"
dependencies = [
"libc",
"linux-raw-sys",
"rustix",
]
[[package]]
name = "zeroize"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"

View File

@@ -0,0 +1,9 @@
[package]
name = "rust-example"
version = "0.1.0"
edition = "2021"
[dependencies]
ort = { version = "2.0.0-rc.2", features = ["load-dynamic", "ndarray"] }
ndarray = "0.15"
hound = "3"

View File

@@ -0,0 +1,19 @@
# Stream example in Rust
Made after [C++ stream example](https://github.com/snakers4/silero-vad/tree/master/examples/cpp)
## Dependencies
- To build Rust crate `ort` you need `cc` installed.
## Usage
Just
```
cargo run
```
If you run example outside of this repo adjust environment variable
```
SILERO_MODEL_PATH=/path/to/silero_vad.onnx cargo run
```
If you need to test against other wav file, not `recorder.wav`, specify it as the first argument
```
cargo run -- /path/to/audio/file.wav
```

View File

@@ -0,0 +1,36 @@
mod silero;
mod utils;
mod vad_iter;
fn main() {
let model_path = std::env::var("SILERO_MODEL_PATH")
.unwrap_or_else(|_| String::from("../../files/silero_vad.onnx"));
let audio_path = std::env::args()
.nth(1)
.unwrap_or_else(|| String::from("recorder.wav"));
let mut wav_reader = hound::WavReader::open(audio_path).unwrap();
let sample_rate = match wav_reader.spec().sample_rate {
8000 => utils::SampleRate::EightkHz,
16000 => utils::SampleRate::SixteenkHz,
_ => panic!("Unsupported sample rate. Expect 8 kHz or 16 kHz."),
};
if wav_reader.spec().sample_format != hound::SampleFormat::Int {
panic!("Unsupported sample format. Expect Int.");
}
let content = wav_reader
.samples()
.filter_map(|x| x.ok())
.collect::<Vec<i16>>();
assert!(!content.is_empty());
let silero = silero::Silero::new(sample_rate, model_path).unwrap();
let vad_params = utils::VadParams {
sample_rate: sample_rate.into(),
..Default::default()
};
let mut vad_iterator = vad_iter::VadIter::new(silero, vad_params);
vad_iterator.process(&content).unwrap();
for timestamp in vad_iterator.speeches() {
println!("{}", timestamp);
}
println!("Finished.");
}

View File

@@ -0,0 +1,54 @@
use crate::utils;
use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
use std::path::Path;
#[derive(Debug)]
pub struct Silero {
session: ort::Session,
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
}
impl Silero {
pub fn new(
sample_rate: utils::SampleRate,
model_path: impl AsRef<Path>,
) -> Result<Self, ort::Error> {
let session = ort::Session::builder()?.commit_from_file(model_path)?;
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
Ok(Self {
session,
sample_rate,
state,
})
}
pub fn reset(&mut self) {
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
}
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
let data = audio_frame
.iter()
.map(|x| (*x as f32) / (i16::MAX as f32))
.collect::<Vec<_>>();
let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
frame = frame.slice(s![.., ..480]).to_owned();
let inps = ort::inputs![
frame,
std::mem::take(&mut self.state),
self.sample_rate.clone(),
]?;
let res = self
.session
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
Ok(*res["output"]
.try_extract_raw_tensor::<f32>()
.unwrap()
.1
.first()
.unwrap())
}
}

View File

@@ -0,0 +1,60 @@
#[derive(Debug, Clone, Copy)]
pub enum SampleRate {
EightkHz,
SixteenkHz,
}
impl From<SampleRate> for i64 {
fn from(value: SampleRate) -> Self {
match value {
SampleRate::EightkHz => 8000,
SampleRate::SixteenkHz => 16000,
}
}
}
impl From<SampleRate> for usize {
fn from(value: SampleRate) -> Self {
match value {
SampleRate::EightkHz => 8000,
SampleRate::SixteenkHz => 16000,
}
}
}
#[derive(Debug)]
pub struct VadParams {
pub frame_size: usize,
pub threshold: f32,
pub min_silence_duration_ms: usize,
pub speech_pad_ms: usize,
pub min_speech_duration_ms: usize,
pub max_speech_duration_s: f32,
pub sample_rate: usize,
}
impl Default for VadParams {
fn default() -> Self {
Self {
frame_size: 64,
threshold: 0.5,
min_silence_duration_ms: 0,
speech_pad_ms: 64,
min_speech_duration_ms: 64,
max_speech_duration_s: f32::INFINITY,
sample_rate: 16000,
}
}
}
#[derive(Debug, Default)]
pub struct TimeStamp {
pub start: i64,
pub end: i64,
}
impl std::fmt::Display for TimeStamp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[start:{:08}, end:{:08}]", self.start, self.end)
}
}

View File

@@ -0,0 +1,223 @@
use crate::{silero, utils};
const DEBUG_SPEECH_PROB: bool = true;
#[derive(Debug)]
pub struct VadIter {
silero: silero::Silero,
params: Params,
state: State,
}
impl VadIter {
pub fn new(silero: silero::Silero, params: utils::VadParams) -> Self {
Self {
silero,
params: Params::from(params),
state: State::new(),
}
}
pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> {
self.reset_states();
for audio_frame in samples.chunks_exact(self.params.frame_size_samples) {
let speech_prob: f32 = self.silero.calc_level(audio_frame)?;
self.state.update(&self.params, speech_prob);
}
self.state.check_for_last_speech(samples.len());
Ok(())
}
pub fn speeches(&self) -> &[utils::TimeStamp] {
&self.state.speeches
}
}
impl VadIter {
fn reset_states(&mut self) {
self.silero.reset();
self.state = State::new()
}
}
#[allow(unused)]
#[derive(Debug)]
struct Params {
frame_size: usize,
threshold: f32,
min_silence_duration_ms: usize,
speech_pad_ms: usize,
min_speech_duration_ms: usize,
max_speech_duration_s: f32,
sample_rate: usize,
sr_per_ms: usize,
frame_size_samples: usize,
min_speech_samples: usize,
speech_pad_samples: usize,
max_speech_samples: f32,
min_silence_samples: usize,
min_silence_samples_at_max_speech: usize,
}
impl From<utils::VadParams> for Params {
fn from(value: utils::VadParams) -> Self {
let frame_size = value.frame_size;
let threshold = value.threshold;
let min_silence_duration_ms = value.min_silence_duration_ms;
let speech_pad_ms = value.speech_pad_ms;
let min_speech_duration_ms = value.min_speech_duration_ms;
let max_speech_duration_s = value.max_speech_duration_s;
let sample_rate = value.sample_rate;
let sr_per_ms = sample_rate / 1000;
let frame_size_samples = frame_size * sr_per_ms;
let min_speech_samples = sr_per_ms * min_speech_duration_ms;
let speech_pad_samples = sr_per_ms * speech_pad_ms;
let max_speech_samples = sample_rate as f32 * max_speech_duration_s
- frame_size_samples as f32
- 2.0 * speech_pad_samples as f32;
let min_silence_samples = sr_per_ms * min_silence_duration_ms;
let min_silence_samples_at_max_speech = sr_per_ms * 98;
Self {
frame_size,
threshold,
min_silence_duration_ms,
speech_pad_ms,
min_speech_duration_ms,
max_speech_duration_s,
sample_rate,
sr_per_ms,
frame_size_samples,
min_speech_samples,
speech_pad_samples,
max_speech_samples,
min_silence_samples,
min_silence_samples_at_max_speech,
}
}
}
#[derive(Debug, Default)]
struct State {
current_sample: usize,
temp_end: usize,
next_start: usize,
prev_end: usize,
triggered: bool,
current_speech: utils::TimeStamp,
speeches: Vec<utils::TimeStamp>,
}
impl State {
fn new() -> Self {
Default::default()
}
fn update(&mut self, params: &Params, speech_prob: f32) {
self.current_sample += params.frame_size_samples;
if speech_prob > params.threshold {
if self.temp_end != 0 {
self.temp_end = 0;
if self.next_start < self.prev_end {
self.next_start = self
.current_sample
.saturating_sub(params.frame_size_samples)
}
}
if !self.triggered {
self.debug(speech_prob, params, "start");
self.triggered = true;
self.current_speech.start =
self.current_sample as i64 - params.frame_size_samples as i64;
}
return;
}
if self.triggered
&& (self.current_sample as i64 - self.current_speech.start) as f32
> params.max_speech_samples
{
if self.prev_end > 0 {
self.current_speech.end = self.prev_end as _;
self.take_speech();
if self.next_start < self.prev_end {
self.triggered = false
} else {
self.current_speech.start = self.next_start as _;
}
self.prev_end = 0;
self.next_start = 0;
self.temp_end = 0;
} else {
self.current_speech.end = self.current_sample as _;
self.take_speech();
self.prev_end = 0;
self.next_start = 0;
self.temp_end = 0;
self.triggered = false;
}
return;
}
if speech_prob >= (params.threshold - 0.15) && (speech_prob < params.threshold) {
if self.triggered {
self.debug(speech_prob, params, "speaking")
} else {
self.debug(speech_prob, params, "silence")
}
}
if self.triggered && speech_prob < (params.threshold - 0.15) {
self.debug(speech_prob, params, "end");
if self.temp_end == 0 {
self.temp_end = self.current_sample;
}
if self.current_sample.saturating_sub(self.temp_end)
> params.min_silence_samples_at_max_speech
{
self.prev_end = self.temp_end;
}
if self.current_sample.saturating_sub(self.temp_end) >= params.min_silence_samples {
self.current_speech.end = self.temp_end as _;
if self.current_speech.end - self.current_speech.start
> params.min_speech_samples as _
{
self.take_speech();
self.prev_end = 0;
self.next_start = 0;
self.temp_end = 0;
self.triggered = false;
}
}
}
}
fn take_speech(&mut self) {
self.speeches.push(std::mem::take(&mut self.current_speech)); // current speech becomes TimeStamp::default() due to take()
}
fn check_for_last_speech(&mut self, last_sample: usize) {
if self.current_speech.start > 0 {
self.current_speech.end = last_sample as _;
self.take_speech();
self.prev_end = 0;
self.next_start = 0;
self.temp_end = 0;
self.triggered = false;
}
}
fn debug(&self, speech_prob: f32, params: &Params, title: &str) {
if DEBUG_SPEECH_PROB {
let speech = self.current_sample as f32
- params.frame_size_samples as f32
- if title == "end" {
params.speech_pad_samples
} else {
0
} as f32; // minus window_size_samples to get precise start time point.
println!(
"[{:10}: {:.3} s ({:.3}) {:8}]",
title,
speech / params.sample_rate as f32,
speech_prob,
self.current_sample - params.frame_size_samples,
);
}
}
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1 +0,0 @@
{"59": "mg, Malagasy", "76": "tk, Turkmen", "20": "lb, Luxembourgish, Letzeburgesch", "62": "or, Oriya", "30": "en, English", "26": "oc, Occitan", "69": "no, Norwegian", "77": "sr, Serbian", "90": "bs, Bosnian", "71": "el, Greek, Modern (1453\u2013)", "15": "az, Azerbaijani", "12": "lo, Lao", "85": "zh-HK, Chinese", "79": "cs, Czech", "43": "sv, Swedish", "37": "mn, Mongolian", "32": "fi, Finnish", "51": "tg, Tajik", "46": "am, Amharic", "17": "nn, Norwegian Nynorsk", "40": "ja, Japanese", "8": "it, Italian", "21": "ha, Hausa", "11": "as, Assamese", "29": "fa, Persian", "82": "bn, Bengali", "54": "mk, Macedonian", "31": "sw, Swahili", "45": "vi, Vietnamese", "41": "ur, Urdu", "74": "bo, Tibetan", "4": "hi, Hindi", "86": "mr, Marathi", "3": "fy-NL, Western Frisian", "65": "sk, Slovak", "2": "ln, Lingala", "92": "gl, Galician", "53": "sn, Shona", "87": "su, Sundanese", "35": "tt, Tatar", "93": "kn, Kannada", "6": "yo, Yoruba", "27": "ps, Pashto, Pushto", "34": "hy, Armenian", "25": "pa-IN, Punjabi, Panjabi", "23": "nl, Dutch, Flemish", "48": "th, Thai", "73": "mt, Maltese", "55": "ar, Arabic", "89": "ba, Bashkir", "78": "bg, Bulgarian", "42": "yi, Yiddish", "5": "ru, Russian", "84": "sv-SE, Swedish", "80": "tr, Turkish", "33": "sq, Albanian", "38": "kk, Kazakh", "50": "pl, Polish", "9": "hr, Croatian", "66": "ky, Kirghiz, Kyrgyz", "49": "hu, Hungarian", "10": "si, Sinhala, Sinhalese", "56": "la, Latin", "75": "de, German", "14": "ko, Korean", "22": "id, Indonesian", "47": "sl, Slovenian", "57": "be, Belarusian", "36": "ta, Tamil", "7": "da, Danish", "91": "sd, Sindhi", "28": "et, Estonian", "63": "pt, Portuguese", "60": "ne, Nepali", "94": "zh-TW, Chinese", "18": "zh-CN, Chinese", "88": "rw, Kinyarwanda", "19": "es, Spanish, Castilian", "39": "ht, Haitian, Haitian Creole", "64": "tl, Tagalog", "83": "ms, Malay", "70": "ro, Romanian, Moldavian, Moldovan", "68": "pa, Punjabi, Panjabi", "52": "uz, Uzbek", "58": "km, Central Khmer", "67": "my, Burmese", "0": "fr, French", "24": "af, Afrikaans", "16": "gu, Gujarati", "81": "so, Somali", "13": "uk, Ukrainian", "44": "ca, Catalan, Valencian", "72": "ml, Malayalam", "61": "te, Telugu", "1": "zh, Chinese"}

View File

@@ -1 +0,0 @@
{"0": ["Afrikaans", "Dutch, Flemish", "Western Frisian"], "1": ["Turkish", "Azerbaijani"], "2": ["Russian", "Slovak", "Ukrainian", "Czech", "Polish", "Belarusian"], "3": ["Bulgarian", "Macedonian", "Serbian", "Croatian", "Bosnian", "Slovenian"], "4": ["Norwegian Nynorsk", "Swedish", "Danish", "Norwegian"], "5": ["English"], "6": ["Finnish", "Estonian"], "7": ["Yiddish", "Luxembourgish, Letzeburgesch", "German"], "8": ["Spanish", "Occitan", "Portuguese", "Catalan, Valencian", "Galician", "Spanish, Castilian", "Italian"], "9": ["Maltese", "Arabic"], "10": ["Marathi"], "11": ["Hindi", "Urdu"], "12": ["Lao", "Thai"], "13": ["Malay", "Indonesian"], "14": ["Romanian, Moldavian, Moldovan"], "15": ["Tagalog"], "16": ["Tajik", "Persian"], "17": ["Kazakh", "Uzbek", "Kirghiz, Kyrgyz"], "18": ["Kinyarwanda"], "19": ["Tatar", "Bashkir"], "20": ["French"], "21": ["Chinese"], "22": ["Lingala"], "23": ["Yoruba"], "24": ["Sinhala, Sinhalese"], "25": ["Assamese"], "26": ["Korean"], "27": ["Gujarati"], "28": ["Hausa"], "29": ["Punjabi, Panjabi"], "30": ["Pashto, Pushto"], "31": ["Swahili"], "32": ["Albanian"], "33": ["Armenian"], "34": ["Mongolian"], "35": ["Tamil"], "36": ["Haitian, Haitian Creole"], "37": ["Japanese"], "38": ["Vietnamese"], "39": ["Amharic"], "40": ["Hungarian"], "41": ["Shona"], "42": ["Latin"], "43": ["Central Khmer"], "44": ["Malagasy"], "45": ["Nepali"], "46": ["Telugu"], "47": ["Oriya"], "48": ["Burmese"], "49": ["Greek, Modern (1453\u2013)"], "50": ["Malayalam"], "51": ["Tibetan"], "52": ["Turkmen"], "53": ["Somali"], "54": ["Bengali"], "55": ["Sundanese"], "56": ["Sindhi"], "57": ["Kannada"]}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,154 +1,49 @@
dependencies = ['torch', 'torchaudio']
import torch
import json
from utils_vad import (init_jit_model,
get_speech_ts,
get_speech_ts_adaptive,
get_number_ts,
get_language,
get_language_and_group,
save_audio,
read_audio,
state_generator,
single_audio_stream,
collect_chunks,
drop_chunks)
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from silero_vad.utils_vad import (init_jit_model,
get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks,
OnnxWrapper)
def silero_vad(**kwargs):
def versiontuple(v):
splitted = v.split('+')[0].split(".")
version_list = []
for i in splitted:
try:
version_list.append(int(i))
except:
version_list.append(0)
return tuple(version_list)
def silero_vad(onnx=False, force_onnx_cpu=False):
"""Silero Voice Activity Detector
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
hub_dir = torch.hub.get_dir()
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/model.jit')
utils = (get_speech_ts,
get_speech_ts_adaptive,
if not onnx:
installed_version = torch.__version__
supported_version = '1.12.0'
if versiontuple(installed_version) < versiontuple(supported_version):
raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)')
model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data')
if onnx:
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
else:
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
utils = (get_speech_timestamps,
save_audio,
read_audio,
state_generator,
single_audio_stream,
VADIterator,
collect_chunks)
return model, utils
def silero_vad_micro(**kwargs):
"""Silero Voice Activity Detector
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
hub_dir = torch.hub.get_dir()
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/model_micro.jit')
utils = (get_speech_ts,
get_speech_ts_adaptive,
save_audio,
read_audio,
state_generator,
single_audio_stream,
collect_chunks)
return model, utils
def silero_vad_micro_8k(**kwargs):
"""Silero Voice Activity Detector
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
hub_dir = torch.hub.get_dir()
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/model_micro_8k.jit')
utils = (get_speech_ts,
get_speech_ts_adaptive,
save_audio,
read_audio,
state_generator,
single_audio_stream,
collect_chunks)
return model, utils
def silero_vad_mini(**kwargs):
"""Silero Voice Activity Detector
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
hub_dir = torch.hub.get_dir()
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/model_mini.jit')
utils = (get_speech_ts,
get_speech_ts_adaptive,
save_audio,
read_audio,
state_generator,
single_audio_stream,
collect_chunks)
return model, utils
def silero_vad_mini_8k(**kwargs):
"""Silero Voice Activity Detector
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
hub_dir = torch.hub.get_dir()
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/model_mini_8k.jit')
utils = (get_speech_ts,
get_speech_ts_adaptive,
save_audio,
read_audio,
state_generator,
single_audio_stream,
collect_chunks)
return model, utils
def silero_number_detector(**kwargs):
"""Silero Number Detector
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
hub_dir = torch.hub.get_dir()
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/number_detector.jit')
utils = (get_number_ts,
save_audio,
read_audio,
collect_chunks,
drop_chunks)
return model, utils
def silero_lang_detector(**kwargs):
"""Silero Language Classifier
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
hub_dir = torch.hub.get_dir()
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/number_detector.jit')
utils = (get_language,
read_audio)
return model, utils
def silero_lang_detector_95(**kwargs):
"""Silero Language Classifier (95 languages)
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
hub_dir = torch.hub.get_dir()
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/lang_classifier_95.jit')
with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_dict_95.json', 'r') as f:
lang_dict = json.load(f)
with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_group_dict_95.json', 'r') as f:
lang_group_dict = json.load(f)
utils = (get_language_and_group, read_audio)
return model, lang_dict, lang_group_dict, utils

35
pyproject.toml Normal file
View File

@@ -0,0 +1,35 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "silero-vad"
version = "5.1"
authors = [
{name="Silero Team", email="hello@silero.ai"},
]
description = "Voice Activity Detector (VAD) by Silero"
readme = "README.md"
requires-python = ">=3.8"
classifiers = [
"Development Status :: 5 - Production/Stable",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering",
]
dependencies = [
"torch>=1.12.0",
"torchaudio>=0.12.0",
"onnxruntime>=1.16.1",
]
[project.urls]
Homepage = "https://github.com/snakers4/silero-vad"
Issues = "https://github.com/snakers4/silero-vad/issues"

View File

@@ -1,23 +1,5 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "sVNOuHQQjsrp"
},
"source": [
"# PyTorch Examples"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FpMplOCA2Fwp"
},
"source": [
"## VAD"
]
},
{
"cell_type": "markdown",
"metadata": {
@@ -25,7 +7,7 @@
"id": "62A6F_072Fwq"
},
"source": [
"### Install Dependencies"
"## Install Dependencies"
]
},
{
@@ -40,28 +22,51 @@
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile\n",
"!pip install -q torchaudio\n",
"\n",
"SAMPLING_RATE = 16000\n",
"\n",
"import glob\n",
"import torch\n",
"torch.set_num_threads(1)\n",
"\n",
"from IPython.display import Audio\n",
"from pprint import pprint\n",
"# download example\n",
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pSifus5IilRp"
},
"outputs": [],
"source": [
"USE_PIP = True # download model using pip package or torch.hub\n",
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
"if USE_PIP:\n",
" !pip install -q silero-vad\n",
" from silero_vad import (load_silero_vad,\n",
" read_audio,\n",
" get_speech_timestamps,\n",
" save_audio,\n",
" VADIterator,\n",
" collect_chunks)\n",
" model = load_silero_vad(onnx=USE_ONNX)\n",
"else:\n",
" model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=True)\n",
"\n",
"(get_speech_ts,\n",
" get_speech_ts_adaptive,\n",
" save_audio,\n",
" read_audio,\n",
" state_generator,\n",
" single_audio_stream,\n",
" collect_chunks) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
" (get_speech_timestamps,\n",
" save_audio,\n",
" read_audio,\n",
" VADIterator,\n",
" collect_chunks) = utils"
]
},
{
@@ -70,16 +75,7 @@
"id": "fXbbaUO3jsrw"
},
"source": [
"### Full Audio"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dY2Us3_Q2Fws"
},
"source": [
"**Classic way of getting speech chunks, you may need to select the thresholds yourself**"
"## Speech timestapms from full audio"
]
},
{
@@ -90,10 +86,9 @@
},
"outputs": [],
"source": [
"wav = read_audio(f'{files_dir}/en.wav')\n",
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"# get speech timestamps from full audio file\n",
"speech_timestamps = get_speech_ts(wav, model,\n",
" num_steps=4)\n",
"speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLING_RATE)\n",
"pprint(speech_timestamps)"
]
},
@@ -107,45 +102,31 @@
"source": [
"# merge all speech chunks to one audio\n",
"save_audio('only_speech.wav',\n",
" collect_chunks(speech_timestamps, wav), 16000) \n",
" collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE)\n",
"Audio('only_speech.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n8plzbJU2Fws"
"id": "zeO1xCqxUC6w"
},
"source": [
"**Experimental Adaptive method, algorithm selects thresholds itself (see readme for more information)**"
"## Entire audio inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SQOtu2Vl2Fwt"
"id": "LjZBcsaTT7Mk"
},
"outputs": [],
"source": [
"wav = read_audio(f'{files_dir}/en.wav')\n",
"# get speech timestamps from full audio file\n",
"speech_timestamps = get_speech_ts_adaptive(wav, model, step=500, num_samples_per_window=4000)\n",
"pprint(speech_timestamps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Lr6zCGXh2Fwt"
},
"outputs": [],
"source": [
"# merge all speech chunks to one audio\n",
"save_audio('only_speech.wav',\n",
" collect_chunks(speech_timestamps, wav), 16000) \n",
"Audio('only_speech.wav')"
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"# audio is being splitted into 31.25 ms long pieces\n",
"# so output length equals ceil(input_length * 31.25 / SAMPLING_RATE)\n",
"predicts = model.audio_forward(wav, sr=SAMPLING_RATE)"
]
},
{
@@ -154,16 +135,7 @@
"id": "iDKQbVr8jsry"
},
"source": [
"### Single Audio Stream"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xCM-HrUR2Fwu"
},
"source": [
"**Classic way of getting speech chunks, you may need to select the thresholds yourself**"
"## Stream imitation example"
]
},
{
@@ -174,20 +146,20 @@
},
"outputs": [],
"source": [
"wav = f'{files_dir}/en.wav'\n",
"## using VADIterator class\n",
"\n",
"for batch in single_audio_stream(model, wav):\n",
" if batch:\n",
" print(batch)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t8TXtnvk2Fwv"
},
"source": [
"**Experimental Adaptive method, algorithm selects thresholds itself (see readme for more information)**"
"vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE)\n",
"wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"\n",
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
"for i in range(0, len(wav), window_size_samples):\n",
" chunk = wav[i: i+ window_size_samples]\n",
" if len(chunk) < window_size_samples:\n",
" break\n",
" speech_dict = vad_iterator(chunk, return_seconds=True)\n",
" if speech_dict:\n",
" print(speech_dict, end=' ')\n",
"vad_iterator.reset_states() # reset model states after each audio"
]
},
{
@@ -198,755 +170,20 @@
},
"outputs": [],
"source": [
"wav = f'{files_dir}/en.wav'\n",
"## just probabilities\n",
"\n",
"for batch in single_audio_stream(model, wav, iterator_type='adaptive'):\n",
" if batch:\n",
" print(batch)"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "KBDVybJCjsrz"
},
"source": [
"### Multiple Audio Streams"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "BK4tGfWgjsrz"
},
"outputs": [],
"source": [
"audios_for_stream = glob.glob(f'{files_dir}/*.wav')\n",
"len(audios_for_stream) # total 4 audios"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "v1l8sam1jsrz"
},
"outputs": [],
"source": [
"for batch in state_generator(model, audios_for_stream, audios_in_stream=2): # 2 audio stream\n",
" if batch:\n",
" pprint(batch)"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "36jY0niD2Fww"
},
"source": [
"## Number detector"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "scd1DlS42Fwx"
},
"source": [
"### Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "Kq5gQuYq2Fwx"
},
"outputs": [],
"source": [
"#@title Install and Import Dependencies\n",
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"speech_probs = []\n",
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
"for i in range(0, len(wav), window_size_samples):\n",
" chunk = wav[i: i+ window_size_samples]\n",
" if len(chunk) < window_size_samples:\n",
" break\n",
" speech_prob = model(chunk, SAMPLING_RATE).item()\n",
" speech_probs.append(speech_prob)\n",
"vad_iterator.reset_states() # reset model states after each audio\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile\n",
"\n",
"import glob\n",
"import torch\n",
"torch.set_num_threads(1)\n",
"\n",
"from IPython.display import Audio\n",
"from pprint import pprint\n",
"\n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_number_detector',\n",
" force_reload=True)\n",
"\n",
"(get_number_ts,\n",
" save_audio,\n",
" read_audio,\n",
" collect_chunks,\n",
" drop_chunks) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "qhPa30ij2Fwy"
},
"source": [
"### Full audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "EXpau6xq2Fwy"
},
"outputs": [],
"source": [
"wav = read_audio(f'{files_dir}/en_num.wav')\n",
"# get number timestamps from full audio file\n",
"number_timestamps = get_number_ts(wav, model)\n",
"pprint(number_timestamps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "u-KfXRhZ2Fwy"
},
"outputs": [],
"source": [
"sample_rate = 16000\n",
"# convert ms in timestamps to samples\n",
"for timestamp in number_timestamps:\n",
" timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n",
" timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "iwYEC4aZ2Fwy"
},
"outputs": [],
"source": [
"# merge all number chunks to one audio\n",
"save_audio('only_numbers.wav',\n",
" collect_chunks(number_timestamps, wav), sample_rate) \n",
"Audio('only_numbers.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "fHaYejX12Fwy"
},
"outputs": [],
"source": [
"# drop all number chunks from audio\n",
"save_audio('no_numbers.wav',\n",
" drop_chunks(number_timestamps, wav), sample_rate) \n",
"Audio('no_numbers.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "PnKtJKbq2Fwz"
},
"source": [
"## Language detector"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "F5cAmMbP2Fwz"
},
"source": [
"### Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "Zu9D0t6n2Fwz"
},
"outputs": [],
"source": [
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile\n",
"\n",
"import glob\n",
"import torch\n",
"torch.set_num_threads(1)\n",
"\n",
"from IPython.display import Audio\n",
"from pprint import pprint\n",
"\n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_lang_detector',\n",
" force_reload=True)\n",
"\n",
"(get_language,\n",
" read_audio) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "iC696eMX2Fwz"
},
"source": [
"### Full audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "c8UYnYBF2Fw0"
},
"outputs": [],
"source": [
"wav = read_audio(f'{files_dir}/en.wav')\n",
"lang = get_language(wav, model)\n",
"print(lang)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "57avIBd6jsrz"
},
"source": [
"# ONNX Example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hEhnfORV2Fw0"
},
"source": [
"## VAD"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "bL4kn4KJrlyL"
},
"source": [
"### Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"hidden": true,
"id": "Q4QIfSpprnkI"
},
"outputs": [],
"source": [
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile onnxruntime\n",
"\n",
"import glob\n",
"import onnxruntime\n",
"from pprint import pprint\n",
"\n",
"from IPython.display import Audio\n",
"\n",
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=True)\n",
"\n",
"(get_speech_ts,\n",
" get_speech_ts_adaptive,\n",
" save_audio,\n",
" read_audio,\n",
" state_generator,\n",
" single_audio_stream,\n",
" collect_speeches) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
"\n",
"def init_onnx_model(model_path: str):\n",
" return onnxruntime.InferenceSession(model_path)\n",
"\n",
"def validate_onnx(model, inputs):\n",
" with torch.no_grad():\n",
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
" outs = model.run(None, ort_inputs)\n",
" outs = [torch.Tensor(x) for x in outs]\n",
" return outs[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5JHErdB7jsr0"
},
"source": [
"### Full Audio"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNEtK5zi2Fw2"
},
"source": [
"**Classic way of getting speech chunks, you may need to select the thresholds yourself**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "krnGoA6Kjsr0"
},
"outputs": [],
"source": [
"model = init_onnx_model(f'{files_dir}/model.onnx')\n",
"wav = read_audio(f'{files_dir}/en.wav')\n",
"\n",
"# get speech timestamps from full audio file\n",
"speech_timestamps = get_speech_ts(wav, model, num_steps=4, run_function=validate_onnx) \n",
"pprint(speech_timestamps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B176Lzfnjsr1"
},
"outputs": [],
"source": [
"# merge all speech chunks to one audio\n",
"save_audio('only_speech.wav', collect_chunks(speech_timestamps, wav), 16000)\n",
"Audio('only_speech.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "21RE8KEC2Fw2"
},
"source": [
"**Experimental Adaptive method, algorithm selects thresholds itself (see readme for more information)**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uIVs56rb2Fw2"
},
"outputs": [],
"source": [
"model = init_onnx_model(f'{files_dir}/model.onnx')\n",
"wav = read_audio(f'{files_dir}/en.wav')\n",
"\n",
"# get speech timestamps from full audio file\n",
"speech_timestamps = get_speech_ts_adaptive(wav, model, run_function=validate_onnx) \n",
"pprint(speech_timestamps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cox6oumC2Fw3"
},
"outputs": [],
"source": [
"# merge all speech chunks to one audio\n",
"save_audio('only_speech.wav', collect_chunks(speech_timestamps, wav), 16000)\n",
"Audio('only_speech.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rio9W50gjsr1"
},
"source": [
"### Single Audio Stream"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i8EZwtaA2Fw3"
},
"source": [
"**Classic way of getting speech chunks, you may need to select the thresholds yourself**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IPkl8Yy1jsr1"
},
"outputs": [],
"source": [
"model = init_onnx_model(f'{files_dir}/model.onnx')\n",
"wav = f'{files_dir}/en.wav'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NC6Jim0hjsr1"
},
"outputs": [],
"source": [
"for batch in single_audio_stream(model, wav, run_function=validate_onnx):\n",
" if batch:\n",
" pprint(batch)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0pSKslpz2Fw3"
},
"source": [
"**Experimental Adaptive method, algorithm selects thresholds itself (see readme for more information)**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RZwc-Khk2Fw4"
},
"outputs": [],
"source": [
"model = init_onnx_model(f'{files_dir}/model.onnx')\n",
"wav = f'{files_dir}/en.wav'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Z4lzFPs02Fw4"
},
"outputs": [],
"source": [
"for batch in single_audio_stream(model, wav, iterator_type='adaptive', run_function=validate_onnx):\n",
" if batch:\n",
" pprint(batch)"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "WNZ42u0ajsr1"
},
"source": [
"### Multiple Audio Streams"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "XjhGQGppjsr1"
},
"outputs": [],
"source": [
"model = init_onnx_model(f'{files_dir}/model.onnx')\n",
"audios_for_stream = glob.glob(f'{files_dir}/*.wav')\n",
"pprint(len(audios_for_stream)) # total 4 audios"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "QI7-arlqjsr2"
},
"outputs": [],
"source": [
"for batch in state_generator(model, audios_for_stream, audios_in_stream=2, run_function=validate_onnx): # 2 audio stream\n",
" if batch:\n",
" pprint(batch)"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "7QMvUvpg2Fw4"
},
"source": [
"## Number detector"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "tBPDkpHr2Fw4"
},
"source": [
"### Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"hidden": true,
"id": "PdjGd56R2Fw5"
},
"outputs": [],
"source": [
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile onnxruntime\n",
"\n",
"import glob\n",
"import torch\n",
"import onnxruntime\n",
"from pprint import pprint\n",
"\n",
"from IPython.display import Audio\n",
"\n",
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_number_detector',\n",
" force_reload=True)\n",
"\n",
"(get_number_ts,\n",
" save_audio,\n",
" read_audio,\n",
" collect_chunks,\n",
" drop_chunks) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
"\n",
"def init_onnx_model(model_path: str):\n",
" return onnxruntime.InferenceSession(model_path)\n",
"\n",
"def validate_onnx(model, inputs):\n",
" with torch.no_grad():\n",
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
" outs = model.run(None, ort_inputs)\n",
" outs = [torch.Tensor(x) for x in outs]\n",
" return outs"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "I9QWSFZh2Fw5"
},
"source": [
"### Full Audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "_r6QZiwu2Fw5"
},
"outputs": [],
"source": [
"model = init_onnx_model(f'{files_dir}/number_detector.onnx')\n",
"wav = read_audio(f'{files_dir}/en_num.wav')\n",
"\n",
"# get number timestamps from full audio file\n",
"number_timestamps = get_number_ts(wav, model, run_function=validate_onnx)\n",
"pprint(number_timestamps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "FN4aDwLV2Fw5"
},
"outputs": [],
"source": [
"sample_rate = 16000\n",
"# convert ms in timestamps to samples\n",
"for timestamp in number_timestamps:\n",
" timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n",
" timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "JnvS6WTK2Fw5"
},
"outputs": [],
"source": [
"# merge all number chunks to one audio\n",
"save_audio('only_numbers.wav',\n",
" collect_chunks(number_timestamps, wav), 16000) \n",
"Audio('only_numbers.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "yUxOcOFG2Fw6"
},
"outputs": [],
"source": [
"# drop all number chunks from audio\n",
"save_audio('no_numbers.wav',\n",
" drop_chunks(number_timestamps, wav), 16000) \n",
"Audio('no_numbers.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "SR8Bgcd52Fw6"
},
"source": [
"## Language detector"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "PBnXPtKo2Fw6"
},
"source": [
"### Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"hidden": true,
"id": "iNkDWJ3H2Fw6"
},
"outputs": [],
"source": [
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile onnxruntime\n",
"\n",
"import glob\n",
"import torch\n",
"import onnxruntime\n",
"from pprint import pprint\n",
"\n",
"from IPython.display import Audio\n",
"\n",
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_lang_detector',\n",
" force_reload=True)\n",
"\n",
"(get_language,\n",
" read_audio) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
"\n",
"def init_onnx_model(model_path: str):\n",
" return onnxruntime.InferenceSession(model_path)\n",
"\n",
"def validate_onnx(model, inputs):\n",
" with torch.no_grad():\n",
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
" outs = model.run(None, ort_inputs)\n",
" outs = [torch.Tensor(x) for x in outs]\n",
" return outs"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true,
"id": "G8N8oP4q2Fw6"
},
"source": [
"### Full Audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "WHXnh9IV2Fw6"
},
"outputs": [],
"source": [
"model = init_onnx_model(f'{files_dir}/number_detector.onnx')\n",
"wav = read_audio(f'{files_dir}/en.wav')\n",
"\n",
"lang = get_language(wav, model, run_function=validate_onnx)\n",
"print(lang)"
"print(speech_probs[:10]) # first 10 chunks predicts"
]
}
],

View File

@@ -0,0 +1,12 @@
from importlib.metadata import version
try:
__version__ = version(__name__)
except:
pass
from silero_vad.model import load_silero_vad
from silero_vad.utils_vad import (get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks)

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

25
src/silero_vad/model.py Normal file
View File

@@ -0,0 +1,25 @@
from .utils_vad import init_jit_model, OnnxWrapper
import torch
torch.set_num_threads(1)
def load_silero_vad(onnx=False):
model_name = 'silero_vad.onnx' if onnx else 'silero_vad.jit'
package_path = "silero_vad.data"
try:
import importlib_resources as impresources
model_file_path = str(impresources.files(package_path).joinpath(model_name))
except:
from importlib import resources as impresources
try:
with impresources.path(package_path, model_name) as f:
model_file_path = f
except:
model_file_path = str(impresources.files(package_path).joinpath(model_name))
if onnx:
model = OnnxWrapper(model_file_path, force_onnx_cpu=True)
else:
model = init_jit_model(model_file_path)
return model

495
src/silero_vad/utils_vad.py Normal file
View File

@@ -0,0 +1,495 @@
import torch
import torchaudio
from typing import Callable, List
import warnings
languages = ['ru', 'en', 'de', 'es']
class OnnxWrapper():
def __init__(self, path, force_onnx_cpu=False):
import numpy as np
global np
import onnxruntime
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
else:
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.reset_states()
self.sample_rates = [8000, 16000]
def _validate_input(self, x, sr: int):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() > 2:
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
if sr != 16000 and (sr % 16000 == 0):
step = sr // 16000
x = x[:,::step]
sr = 16000
if sr not in self.sample_rates:
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
return x, sr
def reset_states(self, batch_size=1):
self._state = torch.zeros((2, batch_size, 128)).float()
self._context = torch.zeros(0)
self._last_sr = 0
self._last_batch_size = 0
def __call__(self, x, sr: int):
x, sr = self._validate_input(x, sr)
num_samples = 512 if sr == 16000 else 256
if x.shape[-1] != num_samples:
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
batch_size = x.shape[0]
context_size = 64 if sr == 16000 else 32
if not self._last_batch_size:
self.reset_states(batch_size)
if (self._last_sr) and (self._last_sr != sr):
self.reset_states(batch_size)
if (self._last_batch_size) and (self._last_batch_size != batch_size):
self.reset_states(batch_size)
if not len(self._context):
self._context = torch.zeros(batch_size, context_size)
x = torch.cat([self._context, x], dim=1)
if sr in [8000, 16000]:
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
ort_outs = self.session.run(None, ort_inputs)
out, state = ort_outs
self._state = torch.from_numpy(state)
else:
raise ValueError()
self._context = x[..., -context_size:]
self._last_sr = sr
self._last_batch_size = batch_size
out = torch.from_numpy(out)
return out
def audio_forward(self, x, sr: int):
outs = []
x, sr = self._validate_input(x, sr)
self.reset_states()
num_samples = 512 if sr == 16000 else 256
if x.shape[1] % num_samples:
pad_num = num_samples - (x.shape[1] % num_samples)
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
for i in range(0, x.shape[1], num_samples):
wavs_batch = x[:, i:i+num_samples]
out_chunk = self.__call__(wavs_batch, sr)
outs.append(out_chunk)
stacked = torch.cat(outs, dim=1)
return stacked.cpu()
class Validator():
def __init__(self, url, force_onnx_cpu):
self.onnx = True if url.endswith('.onnx') else False
torch.hub.download_url_to_file(url, 'inf.model')
if self.onnx:
import onnxruntime
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider'])
else:
self.model = onnxruntime.InferenceSession('inf.model')
else:
self.model = init_jit_model(model_path='inf.model')
def __call__(self, inputs: torch.Tensor):
with torch.no_grad():
if self.onnx:
ort_inputs = {'input': inputs.cpu().numpy()}
outs = self.model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
else:
outs = self.model(inputs)
return outs
def read_audio(path: str,
sampling_rate: int = 16000):
list_backends = torchaudio.list_audio_backends()
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
try:
effects = [
['channels', '1'],
['rate', str(sampling_rate)]
]
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
except:
wav, sr = torchaudio.load(path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != sampling_rate:
transform = torchaudio.transforms.Resample(orig_freq=sr,
new_freq=sampling_rate)
wav = transform(wav)
sr = sampling_rate
assert sr == sampling_rate
return wav.squeeze(0)
def save_audio(path: str,
tensor: torch.Tensor,
sampling_rate: int = 16000):
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16)
def init_jit_model(model_path: str,
device=torch.device('cpu')):
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
def make_visualization(probs, step):
import pandas as pd
pd.DataFrame({'probs': probs},
index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8),
kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step],
xlabel='seconds',
ylabel='speech probability',
colormap='tab20')
@torch.no_grad()
def get_speech_timestamps(audio: torch.Tensor,
model,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_speech_duration_ms: int = 250,
max_speech_duration_s: float = float('inf'),
min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30,
return_seconds: bool = False,
visualize_probs: bool = False,
progress_tracking_callback: Callable[[float], None] = None,
neg_threshold: float = None,
window_size_samples: int = 512,):
"""
This method is used for splitting long audios into speech chunks using silero VAD
Parameters
----------
audio: torch.Tensor, one dimensional
One dimensional float torch.Tensor, other types are casted to torch if possible
model: preloaded .jit/.onnx silero VAD model
threshold: float (default - 0.5)
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
sampling_rate: int (default - 16000)
Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates
min_speech_duration_ms: int (default - 250 milliseconds)
Final speech chunks shorter min_speech_duration_ms are thrown out
max_speech_duration_s: int (default - inf)
Maximum duration of speech chunks in seconds
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting.
Otherwise, they will be split aggressively just before max_speech_duration_s.
min_silence_duration_ms: int (default - 100 milliseconds)
In the end of each speech chunk wait for min_silence_duration_ms before separating it
speech_pad_ms: int (default - 30 milliseconds)
Final speech chunks are padded by speech_pad_ms each side
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
visualize_probs: bool (default - False)
whether draw prob hist or not
progress_tracking_callback: Callable[[float], None] (default - None)
callback function taking progress in percents as an argument
neg_threshold: float (default = threshold - 0.15)
Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH.
window_size_samples: int (default - 512 samples)
!!! DEPRECATED, DOES NOTHING !!!
Returns
----------
speeches: list of dicts
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
"""
if not torch.is_tensor(audio):
try:
audio = torch.Tensor(audio)
except:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
if len(audio.shape) > 1:
for i in range(len(audio.shape)): # trying to squeeze empty dimensions
audio = audio.squeeze(0)
if len(audio.shape) > 1:
raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?")
if sampling_rate > 16000 and (sampling_rate % 16000 == 0):
step = sampling_rate // 16000
sampling_rate = 16000
audio = audio[::step]
warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!')
else:
step = 1
if sampling_rate not in [8000, 16000]:
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
window_size_samples = 512 if sampling_rate == 16000 else 256
model.reset_states()
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
audio_length_samples = len(audio)
speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
if len(chunk) < window_size_samples:
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob = model(chunk, sampling_rate).item()
speech_probs.append(speech_prob)
# caculate progress and seng it to callback function
progress = current_start_sample + window_size_samples
if progress > audio_length_samples:
progress = audio_length_samples
progress_percent = (progress / audio_length_samples) * 100
if progress_tracking_callback:
progress_tracking_callback(progress_percent)
triggered = False
speeches = []
current_speech = {}
if neg_threshold is None:
neg_threshold = threshold - 0.15
temp_end = 0 # to save potential segment end (and tolerate some silence)
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
for i, speech_prob in enumerate(speech_probs):
if (speech_prob >= threshold) and temp_end:
temp_end = 0
if next_start < prev_end:
next_start = window_size_samples * i
if (speech_prob >= threshold) and not triggered:
triggered = True
current_speech['start'] = window_size_samples * i
continue
if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples:
if prev_end:
current_speech['end'] = prev_end
speeches.append(current_speech)
current_speech = {}
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
triggered = False
else:
current_speech['start'] = next_start
prev_end = next_start = temp_end = 0
else:
current_speech['end'] = window_size_samples * i
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if (speech_prob < neg_threshold) and triggered:
if not temp_end:
temp_end = window_size_samples * i
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
prev_end = temp_end
if (window_size_samples * i) - temp_end < min_silence_samples:
continue
else:
current_speech['end'] = temp_end
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
current_speech['end'] = audio_length_samples
speeches.append(current_speech)
for i, speech in enumerate(speeches):
if i == 0:
speech['start'] = int(max(0, speech['start'] - speech_pad_samples))
if i != len(speeches) - 1:
silence_duration = speeches[i+1]['start'] - speech['end']
if silence_duration < 2 * speech_pad_samples:
speech['end'] += int(silence_duration // 2)
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2))
else:
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples))
else:
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
if return_seconds:
for speech_dict in speeches:
speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1)
speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1)
elif step > 1:
for speech_dict in speeches:
speech_dict['start'] *= step
speech_dict['end'] *= step
if visualize_probs:
make_visualization(speech_probs, window_size_samples / sampling_rate)
return speeches
class VADIterator:
def __init__(self,
model,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30
):
"""
Class for stream imitation
Parameters
----------
model: preloaded .jit/.onnx silero VAD model
threshold: float (default - 0.5)
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
sampling_rate: int (default - 16000)
Currently silero VAD models support 8000 and 16000 sample rates
min_silence_duration_ms: int (default - 100 milliseconds)
In the end of each speech chunk wait for min_silence_duration_ms before separating it
speech_pad_ms: int (default - 30 milliseconds)
Final speech chunks are padded by speech_pad_ms each side
"""
self.model = model
self.threshold = threshold
self.sampling_rate = sampling_rate
if sampling_rate not in [8000, 16000]:
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
self.reset_states()
def reset_states(self):
self.model.reset_states()
self.triggered = False
self.temp_end = 0
self.current_sample = 0
@torch.no_grad()
def __call__(self, x, return_seconds=False):
"""
x: torch.Tensor
audio chunk (see examples in repo)
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
"""
if not torch.is_tensor(x):
try:
x = torch.Tensor(x)
except:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
self.current_sample += window_size_samples
speech_prob = self.model(x, self.sampling_rate).item()
if (speech_prob >= self.threshold) and self.temp_end:
self.temp_end = 0
if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end:
self.temp_end = self.current_sample
if self.current_sample - self.temp_end < self.min_silence_samples:
return None
else:
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
self.temp_end = 0
self.triggered = False
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
return None
def collect_chunks(tss: List[dict],
wav: torch.Tensor):
chunks = []
for i in tss:
chunks.append(wav[i['start']: i['end']])
return torch.cat(chunks)
def drop_chunks(tss: List[dict],
wav: torch.Tensor):
chunks = []
cur_start = 0
for i in tss:
chunks.append((wav[cur_start: i['start']]))
cur_start = i['end']
return torch.cat(chunks)

74
tuning/README.md Normal file
View File

@@ -0,0 +1,74 @@
# Тюнинг Silero-VAD модели
> Код тюнинга создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный
интеллект» национальной программы «Цифровая экономика Российской Федерации».
Тюнинг используется для улучшения качества детекции речи Silero-VAD модели на кастомных данных.
## Зависимости
Следующие зависимости используются при тюнинге VAD модели:
- `torchaudio>=0.12.0`
- `omegaconf>=2.3.0`
- `sklearn>=1.2.0`
- `torch>=1.12.0`
- `pandas>=2.2.2`
- `tqdm`
## Подготовка данных
Датафреймы для тюнинга должны быть подготовлены и сохранены в формате `.feather`. Следующие колонки в `.feather` файлах тренировки и валидации являются обязательными:
- **audio_path** - абсолютный путь до аудиофайла в дисковой системе. Аудиофайлы должны представлять собой `PCM` данные, предпочтительно в форматах `.wav` или `.opus` (иные популярные форматы аудио тоже поддерживаются). Для ускорения темпа дообучения рекомендуется предварительно выполнить ресемплинг аудиофайлов (изменить частоту дискретизации) до 16000 Гц;
- **speech_ts** - разметка для соответствующего аудиофайла. Список, состоящий из словарей формата `{'start': START_SEC, 'end': 'END_SEC'}`, где `START_SEC` и `END_SEC` - время начало и конца речевого отрезка в секундах соответственно. Для качественного дообучения рекомендуется использовать разметку с точностью до 30 миллисекунд.
Чем больше данных используется на этапе дообучения, тем эффективнее показывает себя адаптированная модель на целевом домене. Длина аудио не ограничена, т.к. каждое аудио будет обрезано до `max_train_length_sec` секунд перед подачей в нейросеть. Длинные аудио лучше предварительно порезать на кусочки длины `max_train_length_sec`.
Пример `.feather` датафрейма можно посмотреть в файле `example_dataframe.feather`
## Файл конфигурации `config.yml`
Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения:
- `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
- `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
- `jit_model_path` - абсолютный путь до Silero-VAD модели в формате `.jit`. Если оставить это поле пустым, то модель будет загружена из репозитория в зависимости от значения поля `use_torchhub`
- `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`);
- `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц;
- `model_save_path` - путь сохранения добученной модели;
- `noise_loss` - коэффициент лосса, применяемый для неречевых окон аудио;
- `max_train_length_sec` - максимальная длина аудио в секундах на этапе дообучения. Более длительные аудио будут обрезаны до этого показателя;
- `aug_prob` - вероятность применения аугментаций к аудиофайлу на этапе дообучения;
- `learning_rate` - темп дообучения;
- `batch_size` - размер батча при дообучении и валидации;
- `num_workers` - количество потоков, используемых для загрузки данных;
- `num_epochs` - количество эпох дообучения. За одну эпоху прогоняются все тренировочные данные;
- `device` - `cpu` или `cuda`.
## Дообучение
Дообучение запускается командой
`python tune.py`
Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit.
## Поиск пороговых значений
Порог на вход и порог на выход можно подобрать, используя команду
`python search_thresholds`
Данный скрипт использует файл конфигурации, описанный выше. Указанная в конфигурации модель будет использована для поиска оптимальных порогов на валидационном датасете.
## Цитирование
```
@misc{Silero VAD,
author = {Silero Team},
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/snakers4/silero-vad}},
commit = {insert_some_commit_here},
email = {hello@silero.ai}
}
```

0
tuning/__init__.py Normal file
View File

17
tuning/config.yml Normal file
View File

@@ -0,0 +1,17 @@
jit_model_path: '' # путь до Silero-VAD модели в формате jit, эта модель будет использована для дообучения. Если оставить поле пустым, то модель будет загружена автоматически
use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False
tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True
train_dataset_path: 'train_dataset_path.feather' # путь до датасета в формате feather для дообучения, подробности в README
val_dataset_path: 'val_dataset_path.feather' # путь до датасета в формате feather для валидации, подробности в README
model_save_path: 'model_save_path.jit' # путь сохранения дообученной модели
noise_loss: 0.5 # коэффициент, применяемый к лоссу на неречевых окнах
max_train_length_sec: 8 # во время тюнинга аудио длиннее будут обрезаны до данного значения
aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения
learning_rate: 5e-4 # темп дообучения модели
batch_size: 128 # размер батча при дообучении и валидации
num_workers: 4 # количество потоков, используемых для даталоадеров
num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных
device: 'cuda' # cpu или cuda, на чем будет производится дообучение

Binary file not shown.

View File

@@ -0,0 +1,36 @@
from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder
from omegaconf import OmegaConf
import torch
torch.set_num_threads(1)
if __name__ == '__main__':
config = OmegaConf.load('config.yml')
loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'),
batch_size=config.batch_size,
collate_fn=SileroVadPadder,
num_workers=config.num_workers)
if config.jit_model_path:
print(f'Loading model from the local folder: {config.jit_model_path}')
model = init_jit_model(config.jit_model_path, device=config.device)
else:
if config.use_torchhub:
print('Loading model using torch.hub')
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=False,
force_reload=True)
else:
print('Loading model using silero-vad library')
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=False)
print('Model loaded')
model.to(config.device)
print('Making predicts...')
all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000)
print('Calculating thresholds...')
best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts)
print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')

65
tuning/tune.py Normal file
View File

@@ -0,0 +1,65 @@
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
from omegaconf import OmegaConf
import torch.nn as nn
import torch
if __name__ == '__main__':
config = OmegaConf.load('config.yml')
train_dataset = SileroVadDataset(config, mode='train')
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=config.batch_size,
collate_fn=SileroVadPadder,
num_workers=config.num_workers)
val_dataset = SileroVadDataset(config, mode='val')
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=config.batch_size,
collate_fn=SileroVadPadder,
num_workers=config.num_workers)
if config.jit_model_path:
print(f'Loading model from the local folder: {config.jit_model_path}')
model = init_jit_model(config.jit_model_path, device=config.device)
else:
if config.use_torchhub:
print('Loading model using torch.hub')
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=False,
force_reload=True)
else:
print('Loading model using silero-vad library')
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=False)
print('Model loaded')
model.to(config.device)
decoder = VADDecoderRNNJIT().to(config.device)
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())
decoder.train()
params = decoder.parameters()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params),
lr=config.learning_rate)
criterion = nn.BCELoss(reduction='none')
best_val_roc = 0
for i in range(config.num_epochs):
print(f'Starting epoch {i + 1}')
train_loss = train(config, train_loader, model, decoder, criterion, optimizer, config.device)
val_loss, val_roc = validate(config, val_loader, model, decoder, criterion, config.device)
print(f'Metrics after epoch {i + 1}:\n'
f'\tTrain loss: {round(train_loss, 3)}\n',
f'\tValidation loss: {round(val_loss, 3)}\n'
f'\tValidation ROC-AUC: {round(val_roc, 3)}')
if val_roc > best_val_roc:
print('New best ROC-AUC, saving model')
best_val_roc = val_roc
if config.tune_8k:
model._model_8k.decoder.load_state_dict(decoder.state_dict())
else:
model._model.decoder.load_state_dict(decoder.state_dict())
torch.jit.save(model, config.model_save_path)
print('Done')

357
tuning/utils.py Normal file
View File

@@ -0,0 +1,357 @@
from sklearn.metrics import roc_auc_score, accuracy_score
from torch.utils.data import Dataset
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
import numpy as np
import torchaudio
import warnings
import random
import torch
import gc
warnings.filterwarnings('ignore')
def read_audio(path: str,
sampling_rate: int = 16000,
normalize=False):
wav, sr = torchaudio.load(path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sampling_rate:
if sr != sampling_rate:
transform = torchaudio.transforms.Resample(orig_freq=sr,
new_freq=sampling_rate)
wav = transform(wav)
sr = sampling_rate
if normalize and wav.abs().max() != 0:
wav = wav / wav.abs().max()
return wav.squeeze(0)
def build_audiomentations_augs(p):
from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \
LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \
Aliasing, AddGaussianNoise
transforms = [Aliasing(p=1),
AddGaussianNoise(p=1),
AirAbsorption(p=1),
BandPassFilter(p=1),
BandStopFilter(p=1),
ClippingDistortion(p=1),
HighPassFilter(p=1),
HighShelfFilter(p=1),
LowPassFilter(p=1),
LowShelfFilter(p=1),
Mp3Compression(p=1),
PeakingFilter(p=1),
PitchShift(p=1),
RoomSimulator(p=1, leave_length_unchanged=True),
SevenBandParametricEQ(p=1)]
tr = SomeOf((1, 3), transforms=transforms, p=p)
return tr
class SileroVadDataset(Dataset):
def __init__(self,
config,
mode='train'):
self.num_samples = 512 # constant, do not change
self.sr = 16000 # constant, do not change
self.resample_to_8k = config.tune_8k
self.noise_loss = config.noise_loss
self.max_train_length_sec = config.max_train_length_sec
self.max_train_length_samples = config.max_train_length_sec * self.sr
assert self.max_train_length_samples % self.num_samples == 0
assert mode in ['train', 'val']
dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path
self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True)
self.index_dict = self.dataframe.to_dict('index')
self.mode = mode
print(f'DATASET SIZE : {len(self.dataframe)}')
if mode == 'train':
self.augs = build_audiomentations_augs(p=config.aug_prob)
else:
self.augs = None
def __getitem__(self, idx):
idx = None if self.mode == 'train' else idx
wav, gt, mask = self.load_speech_sample(idx)
if self.mode == 'train':
wav = self.add_augs(wav)
if len(wav) > self.max_train_length_samples:
wav = wav[:self.max_train_length_samples]
gt = gt[:int(self.max_train_length_samples / self.num_samples)]
mask = mask[:int(self.max_train_length_samples / self.num_samples)]
wav = torch.FloatTensor(wav)
if self.resample_to_8k:
transform = torchaudio.transforms.Resample(orig_freq=self.sr,
new_freq=8000)
wav = transform(wav)
return wav, torch.FloatTensor(gt), torch.from_numpy(mask)
def __len__(self):
return len(self.index_dict)
def load_speech_sample(self, idx=None):
if idx is None:
idx = random.randint(0, len(self.index_dict) - 1)
wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy()
if len(wav) % self.num_samples != 0:
pad_num = self.num_samples - (len(wav) % (self.num_samples))
wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0)
gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav))
assert len(gt) == len(wav) / self.num_samples
mask[gt == 0]
return wav, gt, mask
def get_ground_truth_annotated(self, annotation, audio_length_samples):
gt = np.zeros(audio_length_samples)
for i in annotation:
gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1
squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1)
squeezed_predicts = (squeezed_predicts > 0.5).astype(int)
mask = np.ones(len(squeezed_predicts))
mask[squeezed_predicts == 0] = self.noise_loss
return squeezed_predicts, mask
def add_augs(self, wav):
while True:
try:
wav_aug = self.augs(wav, self.sr)
if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()):
return wav
return wav_aug
except Exception as e:
continue
def SileroVadPadder(batch):
wavs = [batch[i][0] for i in range(len(batch))]
labels = [batch[i][1] for i in range(len(batch))]
masks = [batch[i][2] for i in range(len(batch))]
wavs = torch.nn.utils.rnn.pad_sequence(
wavs, batch_first=True, padding_value=0)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=0)
masks = torch.nn.utils.rnn.pad_sequence(
masks, batch_first=True, padding_value=0)
return wavs, labels, masks
class VADDecoderRNNJIT(nn.Module):
def __init__(self):
super(VADDecoderRNNJIT, self).__init__()
self.rnn = nn.LSTMCell(128, 128)
self.decoder = nn.Sequential(nn.Dropout(0.1),
nn.ReLU(),
nn.Conv1d(128, 1, kernel_size=1),
nn.Sigmoid())
def forward(self, x, state=torch.zeros(0)):
x = x.squeeze(-1)
if len(state):
h, c = self.rnn(x, (state[0], state[1]))
else:
h, c = self.rnn(x)
x = h.unsqueeze(-1).float()
state = torch.stack([h, c])
x = self.decoder(x)
return x, state
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def train(config,
loader,
jit_model,
decoder,
criterion,
optimizer,
device):
losses = AverageMeter()
decoder.train()
context_size = 32 if config.tune_8k else 64
num_samples = 256 if config.tune_8k else 512
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
with torch.enable_grad():
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
targets = targets.to(device)
x = x.to(device)
masks = masks.to(device)
x = torch.nn.functional.pad(x, (context_size, 0))
outs = []
state = torch.zeros(0)
for i in range(context_size, x.shape[1], num_samples):
input_ = x[:, i-context_size:i+num_samples]
out = stft_layer(input_)
out = encoder_layer(out)
out, state = decoder(out, state)
outs.append(out)
stacked = torch.cat(outs, dim=2).squeeze(1)
loss = criterion(stacked, targets)
loss = (loss * masks).mean()
loss.backward()
optimizer.step()
losses.update(loss.item(), masks.numel())
torch.cuda.empty_cache()
gc.collect()
return losses.avg
def validate(config,
loader,
jit_model,
decoder,
criterion,
device):
losses = AverageMeter()
decoder.eval()
predicts = []
gts = []
context_size = 32 if config.tune_8k else 64
num_samples = 256 if config.tune_8k else 512
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
with torch.no_grad():
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
targets = targets.to(device)
x = x.to(device)
masks = masks.to(device)
x = torch.nn.functional.pad(x, (context_size, 0))
outs = []
state = torch.zeros(0)
for i in range(context_size, x.shape[1], num_samples):
input_ = x[:, i-context_size:i+num_samples]
out = stft_layer(input_)
out = encoder_layer(out)
out, state = decoder(out, state)
outs.append(out)
stacked = torch.cat(outs, dim=2).squeeze(1)
predicts.extend(stacked[masks != 0].tolist())
gts.extend(targets[masks != 0].tolist())
loss = criterion(stacked, targets)
loss = (loss * masks).mean()
losses.update(loss.item(), masks.numel())
score = roc_auc_score(gts, predicts)
torch.cuda.empty_cache()
gc.collect()
return losses.avg, round(score, 3)
def init_jit_model(model_path: str,
device=torch.device('cpu')):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
def predict(model, loader, device, sr):
with torch.no_grad():
all_predicts = []
all_gts = []
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
x = x.to(device)
out = model.audio_forward(x, sr=sr)
for i, out_chunk in enumerate(out):
predict = out_chunk[masks[i] != 0].cpu().tolist()
gt = targets[i, masks[i] != 0].cpu().tolist()
all_predicts.append(predict)
all_gts.append(gt)
return all_predicts, all_gts
def calculate_best_thresholds(all_predicts, all_gts):
best_acc = 0
for ths_enter in tqdm(np.linspace(0, 1, 20)):
for ths_exit in np.linspace(0, 1, 20):
if ths_exit >= ths_enter:
continue
accs = []
for j, predict in enumerate(all_predicts):
predict_bool = []
is_speech = False
for i in predict:
if i >= ths_enter:
is_speech = True
predict_bool.append(1)
elif i <= ths_exit:
is_speech = False
predict_bool.append(0)
else:
val = 1 if is_speech else 0
predict_bool.append(val)
score = round(accuracy_score(all_gts[j], predict_bool), 4)
accs.append(score)
mean_acc = round(np.mean(accs), 3)
if mean_acc > best_acc:
best_acc = mean_acc
best_ths_enter = round(ths_enter, 2)
best_ths_exit = round(ths_exit, 2)
return best_ths_enter, best_ths_exit, best_acc

View File

@@ -1,632 +0,0 @@
import torch
import torchaudio
from typing import List
from itertools import repeat
from collections import deque
import torch.nn.functional as F
torchaudio.set_audio_backend("soundfile") # switch backend
languages = ['ru', 'en', 'de', 'es']
class IterativeMedianMeter():
def __init__(self):
self.reset()
def reset(self):
self.median = 0
self.counts = {}
for i in range(0, 101, 1):
self.counts[i / 100] = 0
self.total_values = 0
def __call__(self, val):
self.total_values += 1
rounded = round(abs(val), 2)
self.counts[rounded] += 1
bin_sum = 0
for j in self.counts:
bin_sum += self.counts[j]
if bin_sum >= self.total_values / 2:
self.median = j
break
return self.median
def validate(model,
inputs: torch.Tensor):
with torch.no_grad():
outs = model(inputs)
return outs
def read_audio(path: str,
target_sr: int = 16000):
assert torchaudio.get_audio_backend() == 'soundfile'
wav, sr = torchaudio.load(path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != target_sr:
transform = torchaudio.transforms.Resample(orig_freq=sr,
new_freq=target_sr)
wav = transform(wav)
sr = target_sr
assert sr == target_sr
return wav.squeeze(0)
def save_audio(path: str,
tensor: torch.Tensor,
sr: int = 16000):
torchaudio.save(path, tensor.unsqueeze(0), sr)
def init_jit_model(model_path: str,
device=torch.device('cpu')):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
def get_speech_ts(wav: torch.Tensor,
model,
trig_sum: float = 0.25,
neg_trig_sum: float = 0.07,
num_steps: int = 8,
batch_size: int = 200,
num_samples_per_window: int = 4000,
min_speech_samples: int = 10000, #samples
min_silence_samples: int = 500,
run_function=validate,
visualize_probs=False,
smoothed_prob_func='mean',
device='cpu'):
assert smoothed_prob_func in ['mean', 'max'], 'smoothed_prob_func not in ["max", "mean"]'
num_samples = num_samples_per_window
assert num_samples % num_steps == 0
step = int(num_samples / num_steps) # stride / hop
outs = []
to_concat = []
for i in range(0, len(wav), step):
chunk = wav[i: i+num_samples]
if len(chunk) < num_samples:
chunk = F.pad(chunk, (0, num_samples - len(chunk)))
to_concat.append(chunk.unsqueeze(0))
if len(to_concat) >= batch_size:
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
out = run_function(model, chunks)
outs.append(out)
to_concat = []
if to_concat:
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
out = run_function(model, chunks)
outs.append(out)
outs = torch.cat(outs, dim=0)
buffer = deque(maxlen=num_steps) # maxlen reached => first element dropped
triggered = False
speeches = []
current_speech = {}
if visualize_probs:
import pandas as pd
smoothed_probs = []
speech_probs = outs[:, 1] # this is very misleading
temp_end = 0
for i, predict in enumerate(speech_probs): # add name
buffer.append(predict)
if smoothed_prob_func == 'mean':
smoothed_prob = (sum(buffer) / len(buffer))
elif smoothed_prob_func == 'max':
smoothed_prob = max(buffer)
if visualize_probs:
smoothed_probs.append(float(smoothed_prob))
if (smoothed_prob >= trig_sum) and temp_end:
temp_end=0
if (smoothed_prob >= trig_sum) and not triggered:
triggered = True
current_speech['start'] = step * max(0, i-num_steps)
continue
if (smoothed_prob < neg_trig_sum) and triggered:
if not temp_end:
temp_end = step * i
if step * i - temp_end < min_silence_samples:
continue
else:
current_speech['end'] = temp_end
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
speeches.append(current_speech)
temp_end = 0
current_speech = {}
triggered = False
continue
if current_speech:
current_speech['end'] = len(wav)
speeches.append(current_speech)
if visualize_probs:
pd.DataFrame({'probs':smoothed_probs}).plot(figsize=(16,8))
return speeches
def get_speech_ts_adaptive(wav: torch.Tensor,
model,
batch_size: int = 200,
step: int = 500,
num_samples_per_window: int = 4000, # Number of samples per audio chunk to feed to NN (4000 for 16k SR, 2000 for 8k SR is optimal)
min_speech_samples: int = 10000, # samples
min_silence_samples: int = 4000,
speech_pad_samples: int = 2000,
run_function=validate,
visualize_probs=False,
device='cpu'):
"""
This function is used for splitting long audios into speech chunks using silero VAD
Attention! All default sample rate values are optimal for 16000 sample rate model, if you are using 8000 sample rate model optimal values are half as much!
Parameters
----------
batch_size: int
batch size to feed to silero VAD (default - 200)
step: int
step size in samples, (default - 500)
num_samples_per_window: int
window size in samples (chunk length in samples to feed to NN, default - 4000)
min_speech_samples: int
if speech duration is shorter than this value, do not consider it speech (default - 10000)
min_silence_samples: int
number of samples to wait before considering as the end of speech (default - 4000)
speech_pad_samples: int
widen speech by this amount of samples each side (default - 2000)
run_function: function
function to use for the model call
visualize_probs: bool
whether draw prob hist or not (default: False)
device: string
torch device to use for the model call (default - "cpu")
Returns
----------
speeches: list
list containing ends and beginnings of speech chunks (in samples)
"""
if visualize_probs:
import pandas as pd
num_samples = num_samples_per_window
num_steps = int(num_samples / step)
assert min_silence_samples >= step
outs = []
to_concat = []
for i in range(0, len(wav), step):
chunk = wav[i: i+num_samples]
if len(chunk) < num_samples:
chunk = F.pad(chunk, (0, num_samples - len(chunk)))
to_concat.append(chunk.unsqueeze(0))
if len(to_concat) >= batch_size:
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
out = run_function(model, chunks)
outs.append(out)
to_concat = []
if to_concat:
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
out = run_function(model, chunks)
outs.append(out)
outs = torch.cat(outs, dim=0).cpu()
buffer = deque(maxlen=num_steps)
triggered = False
speeches = []
smoothed_probs = []
current_speech = {}
speech_probs = outs[:, 1] # 0 index for silence probs, 1 index for speech probs
median_probs = speech_probs.median()
trig_sum = 0.89 * median_probs + 0.08 # 0.08 when median is zero, 0.97 when median is 1
temp_end = 0
for i, predict in enumerate(speech_probs):
buffer.append(predict)
smoothed_prob = max(buffer)
if visualize_probs:
smoothed_probs.append(float(smoothed_prob))
if (smoothed_prob >= trig_sum) and temp_end:
temp_end = 0
if (smoothed_prob >= trig_sum) and not triggered:
triggered = True
current_speech['start'] = step * max(0, i-num_steps)
continue
if (smoothed_prob < trig_sum) and triggered:
if not temp_end:
temp_end = step * i
if step * i - temp_end < min_silence_samples:
continue
else:
current_speech['end'] = temp_end
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
speeches.append(current_speech)
temp_end = 0
current_speech = {}
triggered = False
continue
if current_speech:
current_speech['end'] = len(wav)
speeches.append(current_speech)
if visualize_probs:
pd.DataFrame({'probs': smoothed_probs}).plot(figsize=(16, 8))
for i, ts in enumerate(speeches):
if i == 0:
ts['start'] = max(0, ts['start'] - speech_pad_samples)
if i != len(speeches) - 1:
silence_duration = speeches[i+1]['start'] - ts['end']
if silence_duration < 2 * speech_pad_samples:
ts['end'] += silence_duration // 2
speeches[i+1]['start'] = max(0, speeches[i+1]['start'] - silence_duration // 2)
else:
ts['end'] += speech_pad_samples
else:
ts['end'] = min(len(wav), ts['end'] + speech_pad_samples)
return speeches
def get_number_ts(wav: torch.Tensor,
model,
model_stride=8,
hop_length=160,
sample_rate=16000,
run_function=validate):
wav = torch.unsqueeze(wav, dim=0)
perframe_logits = run_function(model, wav)[0]
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
extended_preds = []
for i in perframe_preds:
extended_preds.extend([i.item()] * model_stride)
# len(extended_preds) is *num_frames_real*; for each frame of audio we know if it has a number in it.
triggered = False
timings = []
cur_timing = {}
for i, pred in enumerate(extended_preds):
if pred == 1:
if not triggered:
cur_timing['start'] = int((i * hop_length) / (sample_rate / 1000))
triggered = True
elif pred == 0:
if triggered:
cur_timing['end'] = int((i * hop_length) / (sample_rate / 1000))
timings.append(cur_timing)
cur_timing = {}
triggered = False
if cur_timing:
cur_timing['end'] = int(len(wav) / (sample_rate / 1000))
timings.append(cur_timing)
return timings
def get_language(wav: torch.Tensor,
model,
run_function=validate):
wav = torch.unsqueeze(wav, dim=0)
lang_logits = run_function(model, wav)[2]
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
assert lang_pred < len(languages)
return languages[lang_pred]
def get_language_and_group(wav: torch.Tensor,
model,
lang_dict: dict,
lang_group_dict: dict,
top_n=1,
run_function=validate):
wav = torch.unsqueeze(wav, dim=0)
lang_logits, lang_group_logits = run_function(model, wav)
softm = torch.softmax(lang_logits, dim=1).squeeze()
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
srtd = torch.argsort(softm, descending=True)
srtd_group = torch.argsort(softm_group, descending=True)
outs = []
outs_group = []
for i in range(top_n):
prob = round(softm[srtd[i]].item(), 2)
prob_group = round(softm_group[srtd_group[i]].item(), 2)
outs.append((lang_dict[str(srtd[i].item())], prob))
outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group))
return outs, outs_group
class VADiterator:
def __init__(self,
trig_sum: float = 0.26,
neg_trig_sum: float = 0.07,
num_steps: int = 8,
num_samples_per_window: int = 4000):
self.num_samples = num_samples_per_window
self.num_steps = num_steps
assert self.num_samples % num_steps == 0
self.step = int(self.num_samples / num_steps) # 500 samples is good enough
self.prev = torch.zeros(self.num_samples)
self.last = False
self.triggered = False
self.buffer = deque(maxlen=num_steps)
self.num_frames = 0
self.trig_sum = trig_sum
self.neg_trig_sum = neg_trig_sum
self.current_name = ''
def refresh(self):
self.prev = torch.zeros(self.num_samples)
self.last = False
self.triggered = False
self.buffer = deque(maxlen=self.num_steps)
self.num_frames = 0
def prepare_batch(self, wav_chunk, name=None):
if (name is not None) and (name != self.current_name):
self.refresh()
self.current_name = name
assert len(wav_chunk) <= self.num_samples
self.num_frames += len(wav_chunk)
if len(wav_chunk) < self.num_samples:
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # short chunk => eof audio
self.last = True
stacked = torch.cat([self.prev, wav_chunk])
self.prev = wav_chunk
overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0)
for i in range(self.step, self.num_samples+1, self.step)]
return torch.cat(overlap_chunks, dim=0)
def state(self, model_out):
current_speech = {}
speech_probs = model_out[:, 1] # this is very misleading
for i, predict in enumerate(speech_probs):
self.buffer.append(predict)
if ((sum(self.buffer) / len(self.buffer)) >= self.trig_sum) and not self.triggered:
self.triggered = True
current_speech[self.num_frames - (self.num_steps-i) * self.step] = 'start'
if ((sum(self.buffer) / len(self.buffer)) < self.neg_trig_sum) and self.triggered:
current_speech[self.num_frames - (self.num_steps-i) * self.step] = 'end'
self.triggered = False
if self.triggered and self.last:
current_speech[self.num_frames] = 'end'
if self.last:
self.refresh()
return current_speech, self.current_name
class VADiteratorAdaptive:
def __init__(self,
trig_sum: float = 0.26,
neg_trig_sum: float = 0.06,
step: int = 500,
num_samples_per_window: int = 4000,
speech_pad_samples: int = 1000,
accum_period: int = 50):
"""
This class is used for streaming silero VAD usage
Parameters
----------
trig_sum: float
trigger value for speech probability, probs above this value are considered speech, switch to TRIGGERED state (default - 0.26)
neg_trig_sum: float
in triggered state probabilites below this value are considered nonspeech, switch to NONTRIGGERED state (default - 0.06)
step: int
step size in samples, (default - 500)
num_samples_per_window: int
window size in samples (chunk length in samples to feed to NN, default - 4000)
speech_pad_samples: int
widen speech by this amount of samples each side (default - 1000)
accum_period: int
number of chunks / iterations to wait before switching from constant (initial) trig and neg_trig coeffs to adaptive median coeffs (default - 50)
"""
self.num_samples = num_samples_per_window
self.num_steps = int(num_samples_per_window / step)
self.step = step
self.prev = torch.zeros(self.num_samples)
self.last = False
self.triggered = False
self.buffer = deque(maxlen=self.num_steps)
self.num_frames = 0
self.trig_sum = trig_sum
self.neg_trig_sum = neg_trig_sum
self.current_name = ''
self.median_meter = IterativeMedianMeter()
self.median = 0
self.total_steps = 0
self.accum_period = accum_period
self.speech_pad_samples = speech_pad_samples
def refresh(self):
self.prev = torch.zeros(self.num_samples)
self.last = False
self.triggered = False
self.buffer = deque(maxlen=self.num_steps)
self.num_frames = 0
self.median_meter.reset()
self.median = 0
self.total_steps = 0
def prepare_batch(self, wav_chunk, name=None):
if (name is not None) and (name != self.current_name):
self.refresh()
self.current_name = name
assert len(wav_chunk) <= self.num_samples
self.num_frames += len(wav_chunk)
if len(wav_chunk) < self.num_samples:
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # short chunk => eof audio
self.last = True
stacked = torch.cat([self.prev, wav_chunk])
self.prev = wav_chunk
overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0)
for i in range(self.step, self.num_samples+1, self.step)]
return torch.cat(overlap_chunks, dim=0)
def state(self, model_out):
current_speech = {}
speech_probs = model_out[:, 1] # 0 index for silence probs, 1 index for speech probs
for i, predict in enumerate(speech_probs):
self.median = self.median_meter(predict.item())
if self.total_steps < self.accum_period:
trig_sum = self.trig_sum
neg_trig_sum = self.neg_trig_sum
else:
trig_sum = 0.89 * self.median + 0.08 # 0.08 when median is zero, 0.97 when median is 1
neg_trig_sum = 0.6 * self.median
self.total_steps += 1
self.buffer.append(predict)
smoothed_prob = max(self.buffer)
if (smoothed_prob >= trig_sum) and not self.triggered:
self.triggered = True
current_speech[max(0, self.num_frames - (self.num_steps-i) * self.step - self.speech_pad_samples)] = 'start'
if (smoothed_prob < neg_trig_sum) and self.triggered:
current_speech[self.num_frames - (self.num_steps-i) * self.step + self.speech_pad_samples] = 'end'
self.triggered = False
if self.triggered and self.last:
current_speech[self.num_frames] = 'end'
if self.last:
self.refresh()
return current_speech, self.current_name
def state_generator(model,
audios: List[str],
onnx: bool = False,
trig_sum: float = 0.26,
neg_trig_sum: float = 0.07,
num_steps: int = 8,
num_samples_per_window: int = 4000,
audios_in_stream: int = 2,
run_function=validate):
VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps, num_samples_per_window) for i in range(audios_in_stream)]
for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream, num_samples_per_window)):
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
batch = torch.cat(for_batch)
outs = run_function(model, batch)
vad_outs = torch.split(outs, num_steps)
states = []
for x, y in zip(VADiters, vad_outs):
cur_st = x.state(y)
if cur_st[0]:
states.append(cur_st)
yield states
def stream_imitator(audios: List[str],
audios_in_stream: int,
num_samples_per_window: int = 4000):
audio_iter = iter(audios)
iterators = []
num_samples = num_samples_per_window
# initial wavs
for i in range(audios_in_stream):
next_wav = next(audio_iter)
wav = read_audio(next_wav)
wav_chunks = iter([(wav[i:i+num_samples], next_wav) for i in range(0, len(wav), num_samples)])
iterators.append(wav_chunks)
print('Done initial Loading')
good_iters = audios_in_stream
while True:
values = []
for i, it in enumerate(iterators):
try:
out, wav_name = next(it)
except StopIteration:
try:
next_wav = next(audio_iter)
print('Loading next wav: ', next_wav)
wav = read_audio(next_wav)
iterators[i] = iter([(wav[i:i+num_samples], next_wav) for i in range(0, len(wav), num_samples)])
out, wav_name = next(iterators[i])
except StopIteration:
good_iters -= 1
iterators[i] = repeat((torch.zeros(num_samples), 'junk'))
out, wav_name = next(iterators[i])
if good_iters == 0:
return
values.append((out, wav_name))
yield values
def single_audio_stream(model,
audio: torch.Tensor,
num_samples_per_window:int = 4000,
run_function=validate,
iterator_type='basic',
**kwargs):
num_samples = num_samples_per_window
if iterator_type == 'basic':
VADiter = VADiterator(num_samples_per_window=num_samples_per_window, **kwargs)
elif iterator_type == 'adaptive':
VADiter = VADiteratorAdaptive(num_samples_per_window=num_samples_per_window, **kwargs)
wav = read_audio(audio)
wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)])
for chunk in wav_chunks:
batch = VADiter.prepare_batch(chunk)
outs = run_function(model, batch)
states = []
state = VADiter.state(outs)
if state[0]:
states.append(state[0])
yield states
def collect_chunks(tss: List[dict],
wav: torch.Tensor):
chunks = []
for i in tss:
chunks.append(wav[i['start']: i['end']])
return torch.cat(chunks)
def drop_chunks(tss: List[dict],
wav: torch.Tensor):
chunks = []
cur_start = 0
for i in tss:
chunks.append((wav[cur_start: i['start']]))
cur_start = i['end']
return torch.cat(chunks)

View File

@@ -1,56 +0,0 @@
from utils_vad import *
import sys
import os
from pathlib import Path
sys.path.append('/home/keras/notebook/nvme_raid/adamnsandle/silero_mono/pipelines/align/bin/')
from align_utils import load_audio_norm
import torch
import pandas as pd
import numpy as np
sys.path.append('/home/keras/notebook/nvme_raid/adamnsandle/silero_mono/utils/')
from open_stt import soundfile_opus as sf
def split_save_audio_chunks(audio_path, model_path, save_path=None, device='cpu', absolute=True, max_duration=10, adaptive=False, **kwargs):
if not save_path:
save_path = str(Path(audio_path).with_name('after_vad'))
print(f'No save path specified! Using {save_path} to save audio chunks!')
SAMPLE_RATE = 16000
if type(model_path) == str:
#print('Loading model...')
model = init_jit_model(model_path, device)
else:
#print('Using loaded model')
model = model_path
save_name = Path(audio_path).stem
audio, sr = load_audio_norm(audio_path)
wav = torch.tensor(audio)
if adaptive:
speech_timestamps = get_speech_ts_adaptive(wav, model, device=device, **kwargs)
else:
speech_timestamps = get_speech_ts(wav, model, device=device, **kwargs)
full_save_path = Path(save_path, save_name)
if not os.path.exists(full_save_path):
os.makedirs(full_save_path, exist_ok=True)
chunks = []
if not speech_timestamps:
return pd.DataFrame()
for ts in speech_timestamps:
start_ts = int(ts['start'])
end_ts = int(ts['end'])
for i in range(start_ts, end_ts, max_duration * SAMPLE_RATE):
new_start = i
new_end = min(end_ts, i + max_duration * SAMPLE_RATE)
duration = round((new_end - new_start) / SAMPLE_RATE, 2)
chunk_path = Path(full_save_path, f'{save_name}_{new_start}-{new_end}.opus')
chunk_path = chunk_path.absolute() if absolute else chunk_path
sf.write(str(chunk_path), audio[new_start: new_end], 16000, format='OGG', subtype='OPUS')
chunks.append({'audio_path': chunk_path,
'text': '',
'duration': duration,
'domain': ''})
return pd.DataFrame(chunks)