153 Commits

Author SHA1 Message Date
Dimitrii Voronin
be95df9152 Merge pull request #719 from snakers4/adamnsandle
Adamnsandle
2025-11-06 11:25:49 +03:00
adamnsandle
ec56fe50a5 fx workflow 2025-11-06 08:18:46 +00:00
adamnsandle
dea5980320 fx workflow 2025-11-06 08:04:02 +00:00
adamnsandle
90d9ce7695 fx workflow 2025-11-06 07:49:44 +00:00
adamnsandle
c56dbb11ac Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2025-11-06 07:36:38 +00:00
adamnsandle
9b686893ad fx test workflow 2025-11-06 07:36:23 +00:00
Dimitrii Voronin
6979fbd535 Merge pull request #717 from snakers4/adamnsandle
v6.2.0 release
2025-11-06 10:28:00 +03:00
adamnsandle
1cff663de5 fix version to 6.2.0 2025-11-06 07:27:07 +00:00
adamnsandle
bfdc019302 add v6.2 model 2025-11-06 07:23:43 +00:00
Alexander Veysov
c0c0ffa0c5 Merge pull request #714 from Purfview/patch-4
Fix type hint for min_silence_at_max_speech (float -> int)
2025-11-05 08:44:00 +03:00
Alexander Veysov
3f0c9ead54 Update pyproject.toml 2025-11-05 08:38:07 +03:00
Purfview
556a442942 Fix type hint for min_silence_at_max_speech (float -> int) 2025-11-04 08:30:01 +00:00
Dimitrii Voronin
9623ce72da Merge pull request #710 from Purfview/patch-3
Fixes and refines - use_max_poss_sil_at_max_speech arg
2025-10-29 12:36:58 +03:00
Dimitrii Voronin
b6dd0599fc Merge pull request #712 from snakers4/adamnsandle
drop_chunks fix
2025-10-29 12:16:10 +03:00
adamnsandle
d8f88c9157 drop_chunks fix 2025-10-29 09:14:45 +00:00
Purfview
b15a216b47 Reword a comment 2025-10-24 10:30:34 +01:00
Purfview
2389039408 Fixes and refines - use_max_poss_sil_at_max_speech arg
Removed redundant "if temp_end != 0:" check.
Multiple "window_size_samples * i" - assigned to a variable.
Restored the previous functionality (which was broken) when use_max_poss_sil_at_max_speech=False.

@shashank14k was your https://github.com/snakers4/silero-vad/pull/664 PR still WIP when it was merged?
Anyway, please test if use_max_poss_sil_at_max_speech=True behaviour is same, and "False" is same as before your PR.
2025-10-24 07:46:41 +01:00
Alexander Veysov
df22fcaec8 Merge pull request #708 from Purfview/patch-2
Removes redundant hop_size_samples variable
2025-10-23 15:58:00 +03:00
Purfview
81e8a48e25 Removes redundant hop_size_samples variable
Remove redundant hop_size_samples variable
2025-10-23 05:23:18 +01:00
Alexander Veysov
a14a23faa7 Merge pull request #707 from Purfview/patch-1
Fixes few typos
2025-10-23 06:35:58 +03:00
Purfview
a30b5843c1 Fixes various typos 2025-10-23 04:02:13 +01:00
Dimitrii Voronin
a66c890188 Merge pull request #704 from snakers4/adamnsandle
resolve torchaudio 2.9 utils
2025-10-17 15:50:20 +03:00
adamnsandle
77c91a91fa resolve torchaudio 2.9 utils 2025-10-17 12:35:40 +00:00
Alexander Veysov
33093c6f1b Update utils.py 2025-10-14 14:51:23 +03:00
Alexander Veysov
dc0b62e1e4 Merge pull request #699 from JiJiJiang/master
fix bug in tuning/utils.py: add optimizer.zero_grad() before loss.bac…
2025-10-14 14:50:58 +03:00
Hongji Wang
64fb49e1c8 fix bug in tuning/utils.py: add optimizer.zero_grad() before loss.backward() 2025-10-13 20:50:29 +08:00
Alexander Veysov
55ba6e2825 Merge pull request #697 from VvvvvGH/java-example-v6
Update java example for v6
2025-10-11 11:41:15 +03:00
GH
b90f8c012f Update SlieroVadOnnxModel.java 2025-10-11 16:21:57 +08:00
GH
25a778c798 Update SlieroVadDetector.java 2025-10-11 16:21:45 +08:00
GH
3d860e6ace Update App.java 2025-10-11 16:21:32 +08:00
GH
f5ea01bfda Update pom.xml 2025-10-11 16:21:03 +08:00
Alexander Veysov
dd651a54a5 Merge pull request #695 from mpariente/master
Remove ipdb and raise error directly in get_speech_timestamps
2025-10-11 08:07:18 +03:00
Manuel Pariente
f1175c902f Remove ipdb and raise error directly 2025-10-10 10:46:44 +02:00
Alexander Veysov
7819fd911b Update README.md 2025-10-09 17:34:33 +03:00
Dimitrii Voronin
fba061dc55 Merge pull request #677 from snakers4/adamnsandle
get rid of hop_size_ratio
2025-08-26 09:54:35 +03:00
adamnsandle
11631356a2 get rid of hop_size_ratio 2025-08-26 06:53:53 +00:00
Dimitrii Voronin
34dea51680 Merge pull request #664 from shashank14k/master
Adding additional params to get_speech_timestamps
2025-08-26 09:50:44 +03:00
Dimitrii Voronin
51fd43130a Update README.md 2025-08-25 19:30:20 +03:00
Dimitrii Voronin
3080062489 Update README.md 2025-08-25 18:07:06 +03:00
Dimitrii Voronin
f974f2d6bc Merge pull request #676 from snakers4/adamnsandle
Adamnsandle
2025-08-25 17:59:19 +03:00
adamnsandle
f1886d9088 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2025-08-25 14:57:11 +00:00
adamnsandle
4c00cd14be add v6 models 2025-08-25 14:56:50 +00:00
Dimitrii Voronin
5d70880844 Merge pull request #675 from snakers4/adamnsandle
Adamnsandle
2025-08-25 17:28:38 +03:00
adamnsandle
a16f3ed079 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2025-08-25 14:27:26 +00:00
adamnsandle
b0fbf4bec6 fx 2025-08-25 14:27:15 +00:00
Dimitrii Voronin
ab02267584 Merge pull request #674 from snakers4/adamnsandle
Adamnsandle
2025-08-25 17:09:07 +03:00
adamnsandle
485a7d91b0 git push Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2025-08-25 14:08:15 +00:00
adamnsandle
1da76acfc3 fx 2025-08-25 14:07:32 +00:00
Dimitrii Voronin
3c70b587e8 Merge pull request #673 from snakers4/adamnsandle
Adamnsandle
2025-08-25 16:56:19 +03:00
adamnsandle
7aff370d68 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2025-08-25 13:55:30 +00:00
adamnsandle
931eddfdab fx 2025-08-25 13:55:24 +00:00
Dimitrii Voronin
6143b9a5d9 Merge pull request #672 from snakers4/adamnsandle
fx
2025-08-25 16:46:24 +03:00
adamnsandle
8ca8cf7d9b fx 2025-08-25 13:45:36 +00:00
Dimitrii Voronin
ad0fdbe4ac Merge pull request #671 from snakers4/adamnsandle
Adamnsandle
2025-08-25 16:40:10 +03:00
adamnsandle
06806eb70b Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2025-08-25 13:39:32 +00:00
adamnsandle
c90e1603c5 fx 2025-08-25 13:39:15 +00:00
Dimitrii Voronin
023d3a36f0 Merge pull request #670 from snakers4/adamnsandle
fx
2025-08-25 16:25:39 +03:00
adamnsandle
aa2a66cf46 fx 2025-08-25 13:24:43 +00:00
Dimitrii Voronin
b1cd34aae2 Merge pull request #669 from snakers4/adamnsandle
Adamnsandle
2025-08-25 16:17:17 +03:00
adamnsandle
50be3744fe fix 2025-08-25 13:08:02 +00:00
adamnsandle
fce776f872 fix workflow 2025-08-25 12:59:58 +00:00
adamnsandle
fbddc91a5d initial autotest commit 2025-08-25 12:54:47 +00:00
shashank14k
bbf22a0064 Added params for hop_size, and min_silence_at_max speech to cut at a possible silence when max_dur reached to avoid abrupt cuts 2025-07-25 20:51:40 +05:30
Alexander Veysov
94811cbe12 Merge pull request #656 from davidrs/patch-1
Surface drop_chunks in init
2025-06-11 07:45:36 +03:00
David Rust-Smith
22a2362b4c Surface drop_chunks in init 2025-06-10 11:36:10 -07:00
Dimitrii Voronin
0dd45f0bcd Merge pull request #626 from b3by/feature/process_chunks_in_seconds
Use second coordinates for audio concatenation in collect_chunks and drop_chunks
2025-03-24 19:02:56 +03:00
Dimitrii Voronin
feba8cd5c4 Merge pull request #627 from b3by/feature/time_coordinates_resolution
Specify time resolution when returning speech coordinates in seconds
2025-03-24 18:59:25 +03:00
Antonio Bevilacqua
6622e562e4 time resolution can be specified when coordinates are returned in seconds 2025-03-24 08:53:28 +01:00
Antonio Bevilacqua
d5625d5c38 added audio concatenation for collect_chunks and drop_chunks based on second coordinates 2025-03-21 13:06:59 +01:00
Alexander Veysov
cd92290a15 Merge pull request #605 from OJRYK/fix/cpp-vad-context
Fix/cpp vad context
2025-02-17 11:01:04 +03:00
Ojuro Yokoyama
33a9d190fe Update wav.h 2025-02-17 16:03:42 +09:00
Ojuro Yokoyama
7440bc4689 Update silero-vad-onnx.cpp
I fixed bug of silero-vad-onnx.cpp
2025-02-17 16:02:24 +09:00
Alexander Veysov
10e7e8a8bc Merge pull request #601 from kiwamizamurai/master
Add CITATION.cff file for proper citation
2025-02-11 08:42:10 +03:00
きわみざむらい
5a5b662496 Create CITATION.cff 2025-02-11 08:54:16 +09:00
Alexander Veysov
9060f664f2 Merge pull request #591 from qwbarch/master
Add haskell example
2024-12-26 19:05:13 +03:00
qwbarch
94271e9096 Add haskell example 2024-12-26 11:18:10 -05:00
Dimitrii Voronin
3f9fffc261 Merge pull request #581 from snakers4/adamnsandle
fx negative ths bug
2024-11-25 16:55:38 +03:00
adamnsandle
eaf633ec9d fx negative ths bug 2024-11-25 13:54:46 +00:00
Alexander Veysov
cff5eb2980 Merge pull request #578 from NathanJHLee/add-torch-cpp
Add cpp source based on libtorch
2024-11-22 11:26:49 +03:00
Dimitrii Voronin
f356a8081a Merge pull request #579 from snakers4/adamnsandle
fx https://github.com/snakers4/silero-vad/issues/576
2024-11-22 11:18:26 +03:00
adamnsandle
782e30d28f fx https://github.com/snakers4/silero-vad/issues/576 2024-11-22 08:17:25 +00:00
Nathan Lee
caee535cf6 ReadMe v4 2024-11-22 06:48:27 +00:00
Nathan Lee
8ab5be005f ReadMe v3 2024-11-22 06:46:28 +00:00
Nathan Lee
9f67a54e87 ReadMe v2 2024-11-22 06:42:20 +00:00
Nathan Lee
c8df1dee3f modified Readme 2024-11-22 06:35:16 +00:00
Nathan Lee
0189ebd8af Changed some source. 2024-11-22 06:21:49 +00:00
Nathan Lee
05e380c1de add c++ inference based on libtorch 2024-11-22 00:10:13 +00:00
Alexander Veysov
93b9782f28 Merge pull request #573 from snakers4/adamnsandle
Adamnsandle
2024-11-13 12:32:55 +03:00
adamnsandle
d2ab7c254e add just 16k model 2024-11-13 08:53:27 +00:00
adamnsandle
6217b08bbb add other opsets 2024-11-12 08:25:06 +00:00
adamnsandle
d53ba1ea11 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2024-11-12 08:19:54 +00:00
Alexander Veysov
102e6d0962 Add downloads shield 2024-11-07 14:40:33 +03:00
Alexander Veysov
e531cd3462 Update README.md 2024-10-21 10:22:02 +03:00
Alexander Veysov
fd41da0b15 Merge pull request #553 from EarningsCall/master
Improve documentation.
2024-10-12 18:25:46 +03:00
EarningsCall
9db72c35bd Update README.md
update again
2024-10-12 09:23:29 -05:00
EarningsCall
867a067bee Update README.md
I assume most people want seconds, so it's useful to show example to return seconds in README file.
2024-10-12 09:22:39 -05:00
Alexander Veysov
2c43391b17 Update README.md 2024-10-09 12:56:22 +03:00
Alexander Veysov
6478567951 Update pyproject.toml 2024-10-09 12:49:27 +03:00
adamnsandle
add6e3028e Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2024-10-09 09:48:51 +00:00
adamnsandle
e7025ed8c5 5.1.1 tag 2024-10-09 09:48:37 +00:00
Alexander Veysov
35d601adc6 Update pyproject.toml 2024-10-09 12:47:08 +03:00
Dimitrii Voronin
032ca21a70 Merge pull request #549 from snakers4/adamnsandle
Adamnsandle
2024-10-09 12:32:09 +03:00
adamnsandle
001d57d6ff fx dependencies 2024-10-09 09:26:39 +00:00
adamnsandle
6e6da04e7a fix pyaudio streaming example 2024-10-09 08:49:39 +00:00
Alexander Veysov
9c1eff9169 Delete files/real_time_example.mp4 2024-10-09 10:10:03 +03:00
Alexander Veysov
36b759d053 Add files via upload 2024-10-09 10:02:04 +03:00
Dimitrii Voronin
1a7499607a Merge pull request #543 from snakers4/adamnsandle
Adamnsandle
2024-09-24 15:19:30 +03:00
Alexander Veysov
87451b059f Update README.md 2024-09-24 15:16:18 +03:00
Alexander Veysov
becc7770c7 Update README.md 2024-09-24 15:15:10 +03:00
Alexander Veysov
3f2eff0303 Merge pull request #542 from snakers4/snakers4-patch-1
Update README.md
2024-09-24 15:14:18 +03:00
Alexander Veysov
3a25110cf9 Update README.md 2024-09-24 15:13:34 +03:00
adamnsandle
d23867da10 fx parallel example 2024-09-24 12:03:07 +00:00
adamnsandle
2043282182 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2024-09-24 12:02:00 +00:00
adamnsandle
fa8036ae1c fx old examples 2024-09-24 12:01:47 +00:00
Dimitrii Voronin
2fff4b8ce8 Merge pull request #541 from snakers4/adamnsandle-1
Update README.md
2024-09-24 14:48:51 +03:00
Dimitrii Voronin
64b863d2ff Update README.md 2024-09-24 14:48:35 +03:00
Dimitrii Voronin
8a3600665b Merge pull request #540 from snakers4/adamnsandle-patch-2
Update README.md
2024-09-24 13:45:31 +03:00
Dimitrii Voronin
9c2c90aa1c Update README.md 2024-09-24 13:45:16 +03:00
Dimitrii Voronin
1d48167271 Merge pull request #539 from gengyuchao/update/python_pyaudio_example
Fixed the pyaudio example can not run issue.
2024-09-11 12:27:15 +03:00
GengYuchao
d0139d94d9 Fixed the pyaudio example can not run issue.
Update the related packages.
2024-09-11 00:45:49 +08:00
Dimitrii Voronin
46f94b7d60 Merge pull request #529 from snakers4/adamnsandle
Adamnsandle
2024-08-22 17:31:42 +03:00
adamnsandle
3de3ee3abe Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2024-08-22 14:30:27 +00:00
adamnsandle
e680ea6633 add half onnx model 2024-08-22 14:30:13 +00:00
Dimitrii Voronin
199de226e5 Merge pull request #528 from snakers4/adamnsandle
add neg_threshold parameter explicitly
2024-08-22 16:39:33 +03:00
adamnsandle
4109b107c1 add neg_threshold parameter explicitly 2024-08-20 08:53:15 +00:00
Alexander Veysov
36854a90db Merge pull request #526 from snakers4/adamnsandle
код для тюнинга
2024-08-19 20:01:21 +03:00
adamnsandle
827e86e685 добавлен поиск порогов 2024-08-19 16:53:28 +00:00
Dimitrii Voronin
e706ec6fee Update README.md 2024-08-19 18:31:11 +03:00
adamnsandle
88df0ce1dd код для тюнинга 2024-08-19 14:36:45 +00:00
Dimitrii Voronin
d18b91e037 Merge pull request #521 from snakers4/adamnsandle
downgrade onnxruntime dependency
2024-08-09 14:23:16 +03:00
adamnsandle
1e3f343767 downgrade onnxruntime dependency 2024-08-09 11:15:22 +00:00
Alexander Veysov
6a8ee81ee0 Merge pull request #507 from nganju98/master
add csharp example
2024-07-21 09:03:38 +03:00
nick.ganju
cb25c0c047 add csharp example 2024-07-20 22:59:18 -04:00
Alexander Veysov
7af8628a27 Merge pull request #506 from yuguanqin/master
Add java example for wav file & support V5 model
2024-07-18 07:34:40 +03:00
yuguanqin
3682cb189c java example for whole wav file & compatible with V5 model 2024-07-18 10:34:02 +08:00
Dimitrii Voronin
57c0b51f9b Merge pull request #505 from snakers4/adamnsandle
VadIterator first chunk bag fx
2024-07-15 13:42:36 +03:00
adamnsandle
dd0b143803 VadIterator first chunk bag fx 2024-07-15 10:37:46 +00:00
Alexander Veysov
181cdf92b6 Merge pull request #497 from rumbleFTW/fix/rust-example-v5
fix: rust example for v5 checkpoint
2024-07-11 17:48:58 +03:00
rumbleFTW
a7bd2dd38f fix: rust example 2024-07-11 20:06:54 +05:30
Alexander Veysov
df7de797a5 Merge pull request #496 from streamer45/update-golang-example
Fix Golang example
2024-07-10 21:31:15 +03:00
streamer45
87ed11b508 Fix Golang example 2024-07-10 20:26:41 +02:00
Alexander Veysov
84768cefdf Merge pull request #493 from snakers4/adamnsandle
Adamnsandle
2024-07-09 16:16:40 +03:00
adamnsandle
6de3660f25 fx version 2024-07-09 10:27:00 +00:00
adamnsandle
d9a6941852 add pip examples to collab 2024-07-09 10:20:50 +00:00
adamnsandle
dfdc9a484e Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2024-07-09 09:51:42 +00:00
adamnsandle
f2e3a23d96 fx version 2024-07-09 09:45:10 +00:00
Dimitrii Voronin
2b97f61160 Merge pull request #492 from snakers4/adamnsandle-patch-1
Create python-publish.yml
2024-07-09 12:42:23 +03:00
adamnsandle
657dac8736 add pyproject.toml 2024-07-09 09:31:18 +00:00
Dimitrii Voronin
412a478e29 Update README.md 2024-07-09 12:25:06 +03:00
adamnsandle
9adf6d2192 add abs import path 2024-07-09 09:06:05 +00:00
adamnsandle
8a2a73c14f fx package import 2024-07-09 09:02:33 +00:00
adamnsandle
3e0305559d fx hubconf 2024-07-09 08:32:18 +00:00
adamnsandle
f0d880d79c make package structure 2024-07-09 08:26:17 +00:00
66 changed files with 3727 additions and 645 deletions

40
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,40 @@
name: Test Package
on:
workflow_dispatch: # запуск вручную
jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.8","3.9","3.10","3.11","3.12","3.13"]
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build hatchling pytest soundfile
pip install .[test]
- name: Build package
run: python -m build --wheel --outdir dist
- name: Install package
run: |
import glob, subprocess, sys
whl = glob.glob("dist/*.whl")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", whl])
shell: python
- name: Run tests
run: pytest tests

20
CITATION.cff Normal file
View File

@@ -0,0 +1,20 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
title: "Silero VAD"
authors:
- family-names: "Silero Team"
email: "hello@silero.ai"
type: software
repository-code: "https://github.com/snakers4/silero-vad"
license: MIT
abstract: "Pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier"
preferred-citation:
type: software
authors:
- family-names: "Silero Team"
email: "hello@silero.ai"
title: "Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier"
year: 2024
publisher: "GitHub"
journal: "GitHub repository"
howpublished: "https://github.com/snakers4/silero-vad"

View File

@@ -1,6 +1,6 @@
[![Mailing list : test](http://img.shields.io/badge/Email-gray.svg?style=for-the-badge&logo=gmail)](mailto:hello@silero.ai) [![Mailing list : test](http://img.shields.io/badge/Telegram-blue.svg?style=for-the-badge&logo=telegram)](https://t.me/silero_speech) [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-MIT-lightgrey.svg?style=for-the-badge)](https://github.com/snakers4/silero-vad/blob/master/LICENSE)
[![Mailing list : test](http://img.shields.io/badge/Email-gray.svg?style=for-the-badge&logo=gmail)](mailto:hello@silero.ai) [![Mailing list : test](http://img.shields.io/badge/Telegram-blue.svg?style=for-the-badge&logo=telegram)](https://t.me/silero_speech) [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-MIT-lightgrey.svg?style=for-the-badge)](https://github.com/snakers4/silero-vad/blob/master/LICENSE) [![downloads](https://img.shields.io/pypi/dm/silero-vad?style=for-the-badge)](https://pypi.org/project/silero-vad/)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) [![Test Package](https://github.com/snakers4/silero-vad/actions/workflows/test.yml/badge.svg)](https://github.com/snakers4/silero-vad/actions/workflows/test.yml) [![Pypi version](https://img.shields.io/pypi/v/silero-vad)](https://pypi.org/project/silero-vad/) [![Python version](https://img.shields.io/pypi/pyversions/silero-vad)](https://pypi.org/project/silero-vad)
![header](https://user-images.githubusercontent.com/12515440/89997349-b3523080-dc94-11ea-9906-ca2e8bc50535.png)
@@ -13,7 +13,7 @@
<br/>
<p align="center">
<img src="https://github.com/snakers4/silero-vad/assets/36505480/300bd062-4da5-4f19-9736-9c144a45d7a7" />
<img src="https://github.com/user-attachments/assets/f2940867-0a51-4bdb-8c14-1129d3c44e64" />
</p>
@@ -22,9 +22,75 @@
https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4
Please note, that video loads only if you are logged in your GitHub account.
</details>
<br/>
<h2 align="center">Fast start</h2>
<br/>
<details>
<summary>Dependencies</summary>
System requirements to run python examples on `x86-64` systems:
- `python 3.8+`;
- 1G+ RAM;
- A modern CPU with AVX, AVX2, AVX-512 or AMX instruction sets.
Dependencies:
- `torch>=1.12.0`;
- `torchaudio>=0.12.0` (for I/O only);
- `onnxruntime>=1.16.1` (for ONNX model usage).
Silero VAD uses torchaudio library for audio I/O (`torchaudio.info`, `torchaudio.load`, and `torchaudio.save`), so a proper audio backend is required:
- Option №1 - [**FFmpeg**](https://www.ffmpeg.org/) backend. `conda install -c conda-forge 'ffmpeg<7'`;
- Option №2 - [**sox_io**](https://pypi.org/project/sox/) backend. `apt-get install sox`, TorchAudio is tested on libsox 14.4.2;
- Option №3 - [**soundfile**](https://pypi.org/project/soundfile/) backend. `pip install soundfile`.
If you are planning to run the VAD using solely the `onnx-runtime`, it will run on any other system architectures where onnx-runtume is [supported](https://onnxruntime.ai/getting-started). In this case please note that:
- You will have to implement the I/O;
- You will have to adapt the existing wrappers / examples / post-processing for your use-case.
</details>
**Using pip**:
`pip install silero-vad`
```python3
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
model = load_silero_vad()
wav = read_audio('path_to_audio_file')
speech_timestamps = get_speech_timestamps(
wav,
model,
return_seconds=True, # Return speech timestamps in seconds (default is samples)
)
```
**Using torch.hub**:
```python3
import torch
torch.set_num_threads(1)
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
(get_speech_timestamps, _, read_audio, _, _) = utils
wav = read_audio('path_to_audio_file')
speech_timestamps = get_speech_timestamps(
wav,
model,
return_seconds=True, # Return speech timestamps in seconds (default is samples)
)
```
<br/>
<h2 align="center">Key Features</h2>
<br/>
@@ -57,21 +123,7 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
<br/>
<h2 align="center">Fast start</h2>
<br/>
```python3
import torch
torch.set_num_threads(1)
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
(get_speech_timestamps, _, read_audio, _, _) = utils
wav = read_audio('path_to_audio_file')
speech_timestamps = get_speech_timestamps(wav, model)
```
<br/>
<h2 align="center">Typical Use Cases</h2>
<br/>
@@ -106,7 +158,7 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for releva
@misc{Silero VAD,
author = {Silero Team},
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
year = {2021},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/snakers4/silero-vad}},
@@ -123,4 +175,4 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for releva
- Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web
- [Rust](https://github.com/snakers4/silero-vad/tree/master/examples/rust-example), [Go](https://github.com/snakers4/silero-vad/tree/master/examples/go), [Java](https://github.com/snakers4/silero-vad/tree/master/examples/java-example) and [other](https://github.com/snakers4/silero-vad/tree/master/examples) examples
- [Rust](https://github.com/snakers4/silero-vad/tree/master/examples/rust-example), [Go](https://github.com/snakers4/silero-vad/tree/master/examples/go), [Java](https://github.com/snakers4/silero-vad/tree/master/examples/java-example), [C++](https://github.com/snakers4/silero-vad/tree/master/examples/cpp), [C#](https://github.com/snakers4/silero-vad/tree/master/examples/csharp) and [other](https://github.com/snakers4/silero-vad/tree/master/examples) community examples

View File

@@ -17,6 +17,7 @@
},
"outputs": [],
"source": [
"#!apt install ffmpeg\n",
"!pip -q install pydub\n",
"from google.colab import output\n",
"from base64 import b64decode, b64encode\n",
@@ -37,13 +38,12 @@
" model='silero_vad',\n",
" force_reload=True)\n",
"\n",
"def int2float(sound):\n",
" abs_max = np.abs(sound).max()\n",
" sound = sound.astype('float32')\n",
" if abs_max > 0:\n",
" sound *= 1/32768\n",
" sound = sound.squeeze()\n",
" return sound\n",
"def int2float(audio):\n",
" samples = audio.get_array_of_samples()\n",
" new_sound = audio._spawn(samples)\n",
" arr = np.array(samples).astype(np.float32)\n",
" arr = arr / np.abs(arr).max()\n",
" return arr\n",
"\n",
"AUDIO_HTML = \"\"\"\n",
"<script>\n",
@@ -68,10 +68,10 @@
" //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n",
" mimeType : 'audio/webm;codecs=opus'\n",
" //mimeType : 'audio/webm;codecs=pcm'\n",
" }; \n",
" };\n",
" //recorder = new MediaRecorder(stream, options);\n",
" recorder = new MediaRecorder(stream);\n",
" recorder.ondataavailable = function(e) { \n",
" recorder.ondataavailable = function(e) {\n",
" var url = URL.createObjectURL(e.data);\n",
" // var preview = document.createElement('audio');\n",
" // preview.controls = true;\n",
@@ -79,7 +79,7 @@
" // document.body.appendChild(preview);\n",
"\n",
" reader = new FileReader();\n",
" reader.readAsDataURL(e.data); \n",
" reader.readAsDataURL(e.data);\n",
" reader.onloadend = function() {\n",
" base64data = reader.result;\n",
" //console.log(\"Inside FileReader:\" + base64data);\n",
@@ -121,7 +121,7 @@
"\n",
"}\n",
"});\n",
" \n",
"\n",
"</script>\n",
"\"\"\"\n",
"\n",
@@ -133,8 +133,8 @@
" audio.export('test.mp3', format='mp3')\n",
" audio = audio.set_channels(1)\n",
" audio = audio.set_frame_rate(16000)\n",
" audio_float = int2float(np.array(audio.get_array_of_samples()))\n",
" audio_tens = torch.tensor(audio_float )\n",
" audio_float = int2float(audio)\n",
" audio_tens = torch.tensor(audio_float)\n",
" return audio_tens\n",
"\n",
"def make_animation(probs, audio_duration, interval=40):\n",
@@ -154,19 +154,18 @@
" def animate(i):\n",
" x = i * interval / 1000 - 0.04\n",
" y = np.linspace(0, 1.02, 2)\n",
" \n",
"\n",
" line.set_data(x, y)\n",
" line.set_color('#990000')\n",
" return line,\n",
" anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=int(audio_duration / (interval / 1000)))\n",
"\n",
" anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=audio_duration / (interval / 1000))\n",
"\n",
" f = r\"animation.mp4\" \n",
" writervideo = FFMpegWriter(fps=1000/interval) \n",
" f = r\"animation.mp4\"\n",
" writervideo = FFMpegWriter(fps=1000/interval)\n",
" anim.save(f, writer=writervideo)\n",
" plt.close('all')\n",
"\n",
"def combine_audio(vidname, audname, outname, fps=25): \n",
"def combine_audio(vidname, audname, outname, fps=25):\n",
" my_clip = mpe.VideoFileClip(vidname, verbose=False)\n",
" audio_background = mpe.AudioFileClip(audname)\n",
" final_clip = my_clip.set_audio(audio_background)\n",
@@ -174,15 +173,10 @@
"\n",
"def record_make_animation():\n",
" tensor = record()\n",
"\n",
" print('Calculating probabilities...')\n",
" speech_probs = []\n",
" window_size_samples = 512\n",
" for i in range(0, len(tensor), window_size_samples):\n",
" if len(tensor[i: i+ window_size_samples]) < window_size_samples:\n",
" break\n",
" speech_prob = model(tensor[i: i+ window_size_samples], 16000).item()\n",
" speech_probs.append(speech_prob)\n",
" speech_probs = model.audio_forward(tensor, sr=16000)[0].tolist()\n",
" model.reset_states()\n",
" print('Making animation...')\n",
" make_animation(speech_probs, len(tensor) / 16000)\n",
@@ -196,7 +190,9 @@
" <video width=800 controls>\n",
" <source src=\"%s\" type=\"video/mp4\">\n",
" </video>\n",
" \"\"\" % data_url))"
" \"\"\" % data_url))\n",
"\n",
" return speech_probs"
]
},
{
@@ -216,7 +212,7 @@
},
"outputs": [],
"source": [
"record_make_animation()"
"speech_probs = record_make_animation()"
]
}
],

View File

@@ -1,211 +1,227 @@
#ifndef _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_WARNINGS
#endif
#include <iostream>
#include <vector>
#include <sstream>
#include <cstring>
#include <limits>
#include <chrono>
#include <iomanip>
#include <memory>
#include <string>
#include <stdexcept>
#include <iostream>
#include <string>
#include "onnxruntime_cxx_api.h"
#include "wav.h"
#include <cstdio>
#include <cstdarg>
#include <cmath> // for std::rint
#if __cplusplus < 201703L
#include <memory>
#endif
//#define __DEBUG_SPEECH_PROB___
class timestamp_t
{
#include "onnxruntime_cxx_api.h"
#include "wav.h" // For reading WAV files
// timestamp_t class: stores the start and end (in samples) of a speech segment.
class timestamp_t {
public:
int start;
int end;
// default + parameterized constructor
timestamp_t(int start = -1, int end = -1)
: start(start), end(end)
{
};
: start(start), end(end) { }
// assignment operator modifies object, therefore non-const
timestamp_t& operator=(const timestamp_t& a)
{
timestamp_t& operator=(const timestamp_t& a) {
start = a.start;
end = a.end;
return *this;
};
}
// equality comparison. doesn't modify object. therefore const.
bool operator==(const timestamp_t& a) const
{
bool operator==(const timestamp_t& a) const {
return (start == a.start && end == a.end);
};
std::string c_str()
{
//return std::format("timestamp {:08d}, {:08d}", start, end);
return format("{start:%08d,end:%08d}", start, end);
};
}
// Returns a formatted string of the timestamp.
std::string c_str() const {
return format("{start:%08d, end:%08d}", start, end);
}
private:
std::string format(const char* fmt, ...)
{
// Helper function for formatting.
std::string format(const char* fmt, ...) const {
char buf[256];
va_list args;
va_start(args, fmt);
const auto r = std::vsnprintf(buf, sizeof buf, fmt, args);
const auto r = std::vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
if (r < 0)
// conversion failed
return {};
const size_t len = r;
if (len < sizeof buf)
// we fit in the buffer
return { buf, len };
if (len < sizeof(buf))
return std::string(buf, len);
#if __cplusplus >= 201703L
// C++17: Create a string and write to its underlying array
std::string s(len, '\0');
va_start(args, fmt);
std::vsnprintf(s.data(), len + 1, fmt, args);
va_end(args);
return s;
#else
// C++11 or C++14: We need to allocate scratch memory
auto vbuf = std::unique_ptr<char[]>(new char[len + 1]);
va_start(args, fmt);
std::vsnprintf(vbuf.get(), len + 1, fmt, args);
va_end(args);
return { vbuf.get(), len };
return std::string(vbuf.get(), len);
#endif
};
}
};
class VadIterator
{
// VadIterator class: uses ONNX Runtime to detect speech segments.
class VadIterator {
private:
// OnnxRuntime resources
// ONNX Runtime resources
Ort::Env env;
Ort::SessionOptions session_options;
std::shared_ptr<Ort::Session> session = nullptr;
Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
private:
void init_engine_threads(int inter_threads, int intra_threads)
{
// The method should be called in each thread/proc in multi-thread/proc work
// ----- Context-related additions -----
const int context_samples = 64; // For 16kHz, 64 samples are added as context.
std::vector<float> _context; // Holds the last 64 samples from the previous chunk (initialized to zero).
// Original window size (e.g., 32ms corresponds to 512 samples)
int window_size_samples;
// Effective window size = window_size_samples + context_samples
int effective_window_size;
// Additional declaration: samples per millisecond
int sr_per_ms;
// ONNX Runtime input/output buffers
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_node_names = { "input", "state", "sr" };
std::vector<float> input;
unsigned int size_state = 2 * 1 * 128;
std::vector<float> _state;
std::vector<int64_t> sr;
int64_t input_node_dims[2] = {};
const int64_t state_node_dims[3] = { 2, 1, 128 };
const int64_t sr_node_dims[1] = { 1 };
std::vector<Ort::Value> ort_outputs;
std::vector<const char*> output_node_names = { "output", "stateN" };
// Model configuration parameters
int sample_rate;
float threshold;
int min_silence_samples;
int min_silence_samples_at_max_speech;
int min_speech_samples;
float max_speech_samples;
int speech_pad_samples;
int audio_length_samples;
// State management
bool triggered = false;
unsigned int temp_end = 0;
unsigned int current_sample = 0;
int prev_end;
int next_start = 0;
std::vector<timestamp_t> speeches;
timestamp_t current_speech;
// Loads the ONNX model.
void init_onnx_model(const std::wstring& model_path) {
init_engine_threads(1, 1);
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
}
// Initializes threading settings.
void init_engine_threads(int inter_threads, int intra_threads) {
session_options.SetIntraOpNumThreads(intra_threads);
session_options.SetInterOpNumThreads(inter_threads);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
};
}
void init_onnx_model(const std::wstring& model_path)
{
// Init threads = 1 for
init_engine_threads(1, 1);
// Load model
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
};
void reset_states()
{
// Call reset before each audio start
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
// Resets internal state (_state, _context, etc.)
void reset_states() {
std::memset(_state.data(), 0, _state.size() * sizeof(float));
triggered = false;
temp_end = 0;
current_sample = 0;
prev_end = next_start = 0;
speeches.clear();
current_speech = timestamp_t();
};
std::fill(_context.begin(), _context.end(), 0.0f);
}
void predict(const std::vector<float> &data)
{
// Infer
// Create ort tensors
input.assign(data.begin(), data.end());
// Inference: runs inference on one chunk of input data.
// data_chunk is expected to have window_size_samples samples.
void predict(const std::vector<float>& data_chunk) {
// Build new input: first context_samples from _context, followed by the current chunk (window_size_samples).
std::vector<float> new_data(effective_window_size, 0.0f);
std::copy(_context.begin(), _context.end(), new_data.begin());
std::copy(data_chunk.begin(), data_chunk.end(), new_data.begin() + context_samples);
input = new_data;
// Create input tensor (input_node_dims[1] is already set to effective_window_size).
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
memory_info, input.data(), input.size(), input_node_dims, 2);
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
memory_info, _state.data(), _state.size(), state_node_dims, 3);
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
// Clear and add inputs
ort_inputs.clear();
ort_inputs.emplace_back(std::move(input_ort));
ort_inputs.emplace_back(std::move(state_ort));
ort_inputs.emplace_back(std::move(sr_ort));
// Infer
// Run inference.
ort_outputs = session->Run(
Ort::RunOptions{nullptr},
Ort::RunOptions{ nullptr },
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
output_node_names.data(), output_node_names.size());
// Output probability & update h,c recursively
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
float* stateN = ort_outputs[1].GetTensorMutableData<float>();
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
current_sample += static_cast<unsigned int>(window_size_samples); // Advance by the original window size.
// Push forward sample index
current_sample += window_size_samples;
// Reset temp_end when > threshold
if ((speech_prob >= threshold))
{
// If speech is detected (probability >= threshold)
if (speech_prob >= threshold) {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample- window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
if (temp_end != 0)
{
float speech = current_sample - window_size_samples;
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0f * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif
if (temp_end != 0) {
temp_end = 0;
if (next_start < prev_end)
next_start = current_sample - window_size_samples;
}
if (triggered == false)
{
if (!triggered) {
triggered = true;
current_speech.start = current_sample - window_size_samples;
}
// Update context: copy the last context_samples from new_data.
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
return;
}
if (
(triggered == true)
&& ((current_sample - current_speech.start) > max_speech_samples)
) {
// If the speech segment becomes too long.
if (triggered && ((current_sample - current_speech.start) > max_speech_samples)) {
if (prev_end > 0) {
current_speech.end = prev_end;
speeches.push_back(current_speech);
current_speech = timestamp_t();
// previously reached silence(< neg_thres) and is still not speech(< thres)
if (next_start < prev_end)
triggered = false;
else{
else
current_speech.start = next_start;
}
prev_end = 0;
next_start = 0;
temp_end = 0;
}
else{
else {
current_speech.end = current_sample;
speeches.push_back(current_speech);
current_speech = timestamp_t();
@@ -214,53 +230,29 @@ private:
temp_end = 0;
triggered = false;
}
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
return;
}
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold))
{
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) {
// When the speech probability temporarily drops but is still in speech, update context without changing state.
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
return;
}
if (speech_prob < (threshold - 0.15)) {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples - speech_pad_samples;
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0f * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif
if (triggered) {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
}
else {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
}
return;
}
// 4) End
if ((speech_prob < (threshold - 0.15)))
{
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point.
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
if (triggered == true)
{
if (temp_end == 0)
{
temp_end = current_sample;
}
if (current_sample - temp_end > min_silence_samples_at_max_speech)
prev_end = temp_end;
// a. silence < min_slience_samples, continue speaking
if ((current_sample - temp_end) < min_silence_samples)
{
}
// b. silence >= min_slience_samples, end speaking
else
{
if ((current_sample - temp_end) >= min_silence_samples) {
current_speech.end = temp_end;
if (current_speech.end - current_speech.start > min_speech_samples)
{
if (current_speech.end - current_speech.start > min_speech_samples) {
speeches.push_back(current_speech);
current_speech = timestamp_t();
prev_end = 0;
@@ -270,27 +262,23 @@ private:
}
}
}
else {
// may first windows see end state.
}
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
return;
}
};
}
public:
void process(const std::vector<float>& input_wav)
{
// Process the entire audio input.
void process(const std::vector<float>& input_wav) {
reset_states();
audio_length_samples = input_wav.size();
for (int j = 0; j < audio_length_samples; j += window_size_samples)
{
if (j + window_size_samples > audio_length_samples)
audio_length_samples = static_cast<int>(input_wav.size());
// Process audio in chunks of window_size_samples (e.g., 512 samples)
for (size_t j = 0; j < static_cast<size_t>(audio_length_samples); j += static_cast<size_t>(window_size_samples)) {
if (j + static_cast<size_t>(window_size_samples) > static_cast<size_t>(audio_length_samples))
break;
std::vector<float> r{ &input_wav[0] + j, &input_wav[0] + j + window_size_samples };
predict(r);
std::vector<float> chunk(&input_wav[j], &input_wav[j] + window_size_samples);
predict(chunk);
}
if (current_speech.start >= 0) {
current_speech.end = audio_length_samples;
speeches.push_back(current_speech);
@@ -300,179 +288,80 @@ public:
temp_end = 0;
triggered = false;
}
};
void process(const std::vector<float>& input_wav, std::vector<float>& output_wav)
{
process(input_wav);
collect_chunks(input_wav, output_wav);
}
void collect_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
{
output_wav.clear();
for (int i = 0; i < speeches.size(); i++) {
#ifdef __DEBUG_SPEECH_PROB___
std::cout << speeches[i].c_str() << std::endl;
#endif //#ifdef __DEBUG_SPEECH_PROB___
std::vector<float> slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]);
output_wav.insert(output_wav.end(),slice.begin(),slice.end());
}
};
const std::vector<timestamp_t> get_speech_timestamps() const
{
// Returns the detected speech timestamps.
const std::vector<timestamp_t> get_speech_timestamps() const {
return speeches;
}
void drop_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
{
output_wav.clear();
int current_start = 0;
for (int i = 0; i < speeches.size(); i++) {
std::vector<float> slice(&input_wav[current_start],&input_wav[speeches[i].start]);
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
current_start = speeches[i].end;
}
std::vector<float> slice(&input_wav[current_start], &input_wav[input_wav.size()]);
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
};
private:
// model config
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
int sample_rate; //Assign when init support 16000 or 8000
int sr_per_ms; // Assign when init, support 8 or 16
float threshold;
int min_silence_samples; // sr_per_ms * #ms
int min_silence_samples_at_max_speech; // sr_per_ms * #98
int min_speech_samples; // sr_per_ms * #ms
float max_speech_samples;
int speech_pad_samples; // usually a
int audio_length_samples;
// model states
bool triggered = false;
unsigned int temp_end = 0;
unsigned int current_sample = 0;
// MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
int prev_end;
int next_start = 0;
//Output timestamp
std::vector<timestamp_t> speeches;
timestamp_t current_speech;
// Onnx model
// Inputs
std::vector<Ort::Value> ort_inputs;
std::vector<const char *> input_node_names = {"input", "state", "sr"};
std::vector<float> input;
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
std::vector<float> _state;
std::vector<int64_t> sr;
int64_t input_node_dims[2] = {};
const int64_t state_node_dims[3] = {2, 1, 128};
const int64_t sr_node_dims[1] = {1};
// Outputs
std::vector<Ort::Value> ort_outputs;
std::vector<const char *> output_node_names = {"output", "stateN"};
// Public method to reset the internal state.
void reset() {
reset_states();
}
public:
// Construction
// Constructor: sets model path, sample rate, window size (ms), and other parameters.
// The parameters are set to match the Python version.
VadIterator(const std::wstring ModelPath,
int Sample_rate = 16000, int windows_frame_size = 32,
float Threshold = 0.5, int min_silence_duration_ms = 0,
int speech_pad_ms = 32, int min_speech_duration_ms = 32,
float Threshold = 0.5, int min_silence_duration_ms = 100,
int speech_pad_ms = 30, int min_speech_duration_ms = 250,
float max_speech_duration_s = std::numeric_limits<float>::infinity())
: sample_rate(Sample_rate), threshold(Threshold), speech_pad_samples(speech_pad_ms), prev_end(0)
{
init_onnx_model(ModelPath);
threshold = Threshold;
sample_rate = Sample_rate;
sr_per_ms = sample_rate / 1000;
window_size_samples = windows_frame_size * sr_per_ms;
min_speech_samples = sr_per_ms * min_speech_duration_ms;
speech_pad_samples = sr_per_ms * speech_pad_ms;
max_speech_samples = (
sample_rate * max_speech_duration_s
- window_size_samples
- 2 * speech_pad_samples
);
min_silence_samples = sr_per_ms * min_silence_duration_ms;
min_silence_samples_at_max_speech = sr_per_ms * 98;
input.resize(window_size_samples);
sr_per_ms = sample_rate / 1000; // e.g., 16000 / 1000 = 16
window_size_samples = windows_frame_size * sr_per_ms; // e.g., 32ms * 16 = 512 samples
effective_window_size = window_size_samples + context_samples; // e.g., 512 + 64 = 576 samples
input_node_dims[0] = 1;
input_node_dims[1] = window_size_samples;
input_node_dims[1] = effective_window_size;
_state.resize(size_state);
sr.resize(1);
sr[0] = sample_rate;
};
_context.assign(context_samples, 0.0f);
min_speech_samples = sr_per_ms * min_speech_duration_ms;
max_speech_samples = (sample_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples);
min_silence_samples = sr_per_ms * min_silence_duration_ms;
min_silence_samples_at_max_speech = sr_per_ms * 98;
init_onnx_model(ModelPath);
}
};
int main()
{
std::vector<timestamp_t> stamps;
// Read wav
wav::WavReader wav_reader("recorder.wav"); //16000,1,32float
std::vector<float> input_wav(wav_reader.num_samples());
std::vector<float> output_wav;
for (int i = 0; i < wav_reader.num_samples(); i++)
{
int main() {
// Read the WAV file (expects 16000 Hz, mono, PCM).
wav::WavReader wav_reader("audio/recorder.wav"); // File located in the "audio" folder.
int numSamples = wav_reader.num_samples();
std::vector<float> input_wav(static_cast<size_t>(numSamples));
for (size_t i = 0; i < static_cast<size_t>(numSamples); i++) {
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
}
// Set the ONNX model path (file located in the "model" folder).
std::wstring model_path = L"model/silero_vad.onnx";
// Initialize the VadIterator.
VadIterator vad(model_path);
// ===== Test configs =====
std::wstring path = L"silero_vad.onnx";
VadIterator vad(path);
// ==============================================
// ==== = Example 1 of full function =====
// ==============================================
// Process the audio.
vad.process(input_wav);
// 1.a get_speech_timestamps
stamps = vad.get_speech_timestamps();
for (int i = 0; i < stamps.size(); i++) {
// Retrieve the speech timestamps (in samples).
std::vector<timestamp_t> stamps = vad.get_speech_timestamps();
std::cout << stamps[i].c_str() << std::endl;
// Convert timestamps to seconds and round to one decimal place (for 16000 Hz).
const float sample_rate_float = 16000.0f;
for (size_t i = 0; i < stamps.size(); i++) {
float start_sec = std::rint((stamps[i].start / sample_rate_float) * 10.0f) / 10.0f;
float end_sec = std::rint((stamps[i].end / sample_rate_float) * 10.0f) / 10.0f;
std::cout << "Speech detected from "
<< std::fixed << std::setprecision(1) << start_sec
<< " s to "
<< std::fixed << std::setprecision(1) << end_sec
<< " s" << std::endl;
}
// 1.b collect_chunks output wav
vad.collect_chunks(input_wav, output_wav);
// Optionally, reset the internal state.
vad.reset();
// 1.c drop_chunks output wav
vad.drop_chunks(input_wav, output_wav);
// ==============================================
// ===== Example 2 of simple full function =====
// ==============================================
vad.process(input_wav, output_wav);
stamps = vad.get_speech_timestamps();
for (int i = 0; i < stamps.size(); i++) {
std::cout << stamps[i].c_str() << std::endl;
}
// ==============================================
// ===== Example 3 of full function =====
// ==============================================
for(int i = 0; i<2; i++)
vad.process(input_wav, output_wav);
return 0;
}

View File

@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
@@ -24,6 +24,8 @@
#include <string>
#include <iostream>
// #include "utils/log.h"
namespace wav {
@@ -230,6 +232,6 @@ class WavWriter {
int bits_per_sample_;
};
} // namespace wenet
} // namespace wav
#endif // FRONTEND_WAV_H_

View File

@@ -0,0 +1,45 @@
# Silero-VAD V5 in C++ (based on LibTorch)
This is the source code for Silero-VAD V5 in C++, utilizing LibTorch. The primary implementation is CPU-based, and you should compare its results with the Python version. Only results at 16kHz have been tested.
Additionally, batch and CUDA inference options are available if you want to explore further. Note that when using batch inference, the speech probabilities may slightly differ from the standard version, likely due to differences in caching. Unlike individual input processing, batch inference may not use the cache from previous chunks. Despite this, batch inference offers significantly faster processing. For optimal performance, consider adjusting the threshold when using batch inference.
## Requirements
- GCC 11.4.0 (GCC >= 5.1)
- LibTorch 1.13.0 (other versions are also acceptable)
## Download LibTorch
```bash
-CPU Version
wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip
unzip libtorch-shared-with-deps-1.13.0+cpu.zip'
-CUDA Version
wget https://download.pytorch.org/libtorch/cu116/libtorch-shared-with-deps-1.13.0%2Bcu116.zip
unzip libtorch-shared-with-deps-1.13.0+cu116.zip
```
## Compilation
```bash
-CPU Version
g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0
-CUDA Version
g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cuda -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_GPU
```
## Optional Compilation Flags
-DUSE_BATCH: Enable batch inference
-DUSE_GPU: Use GPU for inference
## Run the Program
To run the program, use the following command:
`./silero aepyx.wav 16000 0.5`
The sample file aepyx.wav is part of the Voxconverse dataset.
File details: aepyx.wav is a 16kHz, 16-bit audio file.

Binary file not shown.

View File

@@ -0,0 +1,54 @@
#include <iostream>
#include "silero_torch.h"
#include "wav.h"
int main(int argc, char* argv[]) {
if(argc != 4){
std::cerr<<"Usage : "<<argv[0]<<" <wav.path> <SampleRate> <Threshold>"<<std::endl;
std::cerr<<"Usage : "<<argv[0]<<" sample.wav 16000 0.5"<<std::endl;
return 1;
}
std::string wav_path = argv[1];
float sample_rate = std::stof(argv[2]);
float threshold = std::stof(argv[3]);
//Load Model
std::string model_path = "../../src/silero_vad/data/silero_vad.jit";
silero::VadIterator vad(model_path);
vad.threshold=threshold; //(Default:0.5)
vad.sample_rate=sample_rate; //16000Hz,8000Hz. (Default:16000)
vad.print_as_samples=true; //if true, it prints time-stamp with samples. otherwise, in seconds
//(Default:false)
vad.SetVariables();
// Read wav
wav::WavReader wav_reader(wav_path);
std::vector<float> input_wav(wav_reader.num_samples());
for (int i = 0; i < wav_reader.num_samples(); i++)
{
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
}
vad.SpeechProbs(input_wav);
std::vector<silero::SpeechSegment> speeches = vad.GetSpeechTimestamps();
for(const auto& speech : speeches){
if(vad.print_as_samples){
std::cout<<"{'start': "<<static_cast<int>(speech.start)<<", 'end': "<<static_cast<int>(speech.end)<<"}"<<std::endl;
}
else{
std::cout<<"{'start': "<<speech.start<<", 'end': "<<speech.end<<"}"<<std::endl;
}
}
return 0;
}

BIN
examples/cpp_libtorch/silero Executable file

Binary file not shown.

View File

@@ -0,0 +1,285 @@
//Author : Nathan Lee
//Created On : 2024-11-18
//Description : silero 5.1 system for torch-script(c++).
//Version : 1.0
#include "silero_torch.h"
namespace silero {
VadIterator::VadIterator(const std::string &model_path, float threshold, int sample_rate, int window_size_ms, int speech_pad_ms, int min_silence_duration_ms, int min_speech_duration_ms, int max_duration_merge_ms, bool print_as_samples)
:sample_rate(sample_rate), threshold(threshold), window_size_ms(window_size_ms), speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms), min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms), print_as_samples(print_as_samples)
{
init_torch_model(model_path);
//init_engine(window_size_ms);
}
VadIterator::~VadIterator(){
}
void VadIterator::SpeechProbs(std::vector<float>& input_wav){
// Set the sample rate (must match the model's expected sample rate)
// Process the waveform in chunks of 512 samples
int num_samples = input_wav.size();
int num_chunks = num_samples / window_size_samples;
int remainder_samples = num_samples % window_size_samples;
total_sample_size += num_samples;
torch::Tensor output;
std::vector<torch::Tensor> chunks;
for (int i = 0; i < num_chunks; i++) {
float* chunk_start = input_wav.data() + i *window_size_samples;
torch::Tensor chunk = torch::from_blob(chunk_start, {1,window_size_samples}, torch::kFloat32);
//std::cout<<"chunk size : "<<chunk.sizes()<<std::endl;
chunks.push_back(chunk);
if(i==num_chunks-1 && remainder_samples>0){//마지막 chunk && 나머지가 존재
int remaining_samples = num_samples - num_chunks * window_size_samples;
//std::cout<<"Remainder size : "<<remaining_samples;
float* chunk_start_remainder = input_wav.data() + num_chunks *window_size_samples;
torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1,remaining_samples},
torch::kFloat32);
// Pad the remainder chunk to match window_size_samples
torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples
- remaining_samples}, torch::kFloat32)}, 1);
//std::cout<<", padded_chunk size : "<<padded_chunk.size(1)<<std::endl;
chunks.push_back(padded_chunk);
}
}
if (!chunks.empty()) {
#ifdef USE_BATCH
torch::Tensor batched_chunks = torch::stack(chunks); // Stack all chunks into a single tensor
//batched_chunks = batched_chunks.squeeze(1);
batched_chunks = torch::cat({batched_chunks.squeeze(1)});
#ifdef USE_GPU
batched_chunks = batched_chunks.to(at::kCUDA); // Move the entire batch to GPU once
#endif
// Prepare input for model
std::vector<torch::jit::IValue> inputs;
inputs.push_back(batched_chunks); // Batch of chunks
inputs.push_back(sample_rate); // Assuming sample_rate is a valid input for the model
// Run inference on the batch
torch::NoGradGuard no_grad;
torch::Tensor output = model.forward(inputs).toTensor();
#ifdef USE_GPU
output = output.to(at::kCPU); // Move the output back to CPU once
#endif
// Collect output probabilities
for (int i = 0; i < chunks.size(); i++) {
float output_f = output[i].item<float>();
outputs_prob.push_back(output_f);
//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
}
#else
std::vector<torch::Tensor> outputs;
torch::Tensor batched_chunks = torch::stack(chunks);
#ifdef USE_GPU
batched_chunks = batched_chunks.to(at::kCUDA);
#endif
for (int i = 0; i < chunks.size(); i++) {
torch::NoGradGuard no_grad;
std::vector<torch::jit::IValue> inputs;
inputs.push_back(batched_chunks[i]);
inputs.push_back(sample_rate);
torch::Tensor output = model.forward(inputs).toTensor();
outputs.push_back(output);
}
torch::Tensor all_outputs = torch::stack(outputs);
#ifdef USE_GPU
all_outputs = all_outputs.to(at::kCPU);
#endif
for (int i = 0; i < chunks.size(); i++) {
float output_f = all_outputs[i].item<float>();
outputs_prob.push_back(output_f);
}
#endif
}
}
std::vector<SpeechSegment> VadIterator::GetSpeechTimestamps() {
std::vector<SpeechSegment> speeches = DoVad();
#ifdef USE_BATCH
//When you use BATCH inference. You would better use 'mergeSpeeches' function to arrage time stamp.
//It could be better get reasonable output because of distorted probs.
duration_merge_samples = sample_rate * max_duration_merge_ms / 1000;
std::vector<SpeechSegment> speeches_merge = mergeSpeeches(speeches, duration_merge_samples);
if(!print_as_samples){
for (auto& speech : speeches_merge) { //samples to second
speech.start /= sample_rate;
speech.end /= sample_rate;
}
}
return speeches_merge;
#else
if(!print_as_samples){
for (auto& speech : speeches) { //samples to second
speech.start /= sample_rate;
speech.end /= sample_rate;
}
}
return speeches;
#endif
}
void VadIterator::SetVariables(){
init_engine(window_size_ms);
}
void VadIterator::init_engine(int window_size_ms) {
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
speech_pad_samples = sample_rate * speech_pad_ms / 1000;
window_size_samples = sample_rate / 1000 * window_size_ms;
min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
}
void VadIterator::init_torch_model(const std::string& model_path) {
at::set_num_threads(1);
model = torch::jit::load(model_path);
#ifdef USE_GPU
if (!torch::cuda::is_available()) {
std::cout<<"CUDA is not available! Please check your GPU settings"<<std::endl;
throw std::runtime_error("CUDA is not available!");
model.to(at::Device(at::kCPU));
} else {
std::cout<<"CUDA available! Running on '0'th GPU"<<std::endl;
model.to(at::Device(at::kCUDA, 0)); //select 0'th machine
}
#endif
model.eval();
torch::NoGradGuard no_grad;
std::cout << "Model loaded successfully"<<std::endl;
}
void VadIterator::reset_states() {
triggered = false;
current_sample = 0;
temp_end = 0;
outputs_prob.clear();
model.run_method("reset_states");
total_sample_size = 0;
}
std::vector<SpeechSegment> VadIterator::DoVad() {
std::vector<SpeechSegment> speeches;
for (size_t i = 0; i < outputs_prob.size(); ++i) {
float speech_prob = outputs_prob[i];
//std::cout << speech_prob << std::endl;
//std::cout << "Chunk " << i << " Prob: " << speech_prob << "\n";
//std::cout << speech_prob << " ";
current_sample += window_size_samples;
if (speech_prob >= threshold && temp_end != 0) {
temp_end = 0;
}
if (speech_prob >= threshold && !triggered) {
triggered = true;
SpeechSegment segment;
segment.start = std::max(static_cast<int>(0), current_sample - speech_pad_samples - window_size_samples);
speeches.push_back(segment);
continue;
}
if (speech_prob < threshold - 0.15f && triggered) {
if (temp_end == 0) {
temp_end = current_sample;
}
if (current_sample - temp_end < min_silence_samples) {
continue;
} else {
SpeechSegment& segment = speeches.back();
segment.end = temp_end + speech_pad_samples - window_size_samples;
temp_end = 0;
triggered = false;
}
}
}
if (triggered) { //만약 낮은 확률을 보이다가 마지막프레임 prbos만 딱 확률이 높게 나오면 위에서 triggerd = true 메핑과 동시에 segment start가 돼서 문제가 될것 같은데? start = end 같은값? 후처리가 있으니 문제가 없으려나?
std::cout<<"when last triggered is keep working until last Probs"<<std::endl;
SpeechSegment& segment = speeches.back();
segment.end = total_sample_size; // 현재 샘플을 마지막 구간의 종료 시간으로 설정
triggered = false; // VAD 상태 초기화
}
speeches.erase(
std::remove_if(
speeches.begin(),
speeches.end(),
[this](const SpeechSegment& speech) {
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
//min_speech_samples is 4000samples(0.25sec)
//여기서 포인트!! 계산 할때는 start,end sample에'speech_pad_samples' 사이즈를 추가한후 길이를 측정함.
}
),
speeches.end()
);
//std::cout<<std::endl;
//std::cout<<"outputs_prob.size : "<<outputs_prob.size()<<std::endl;
reset_states();
return speeches;
}
std::vector<SpeechSegment> VadIterator::mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples) {
std::vector<SpeechSegment> mergedSpeeches;
if (speeches.empty()) {
return mergedSpeeches; // 빈 벡터 반환
}
// 첫 번째 구간으로 초기화
SpeechSegment currentSegment = speeches[0];
for (size_t i = 1; i < speeches.size(); ++i) { //첫번째 start,end 정보 건너뛰기. 그래서 i=1부터
// 두 구간의 차이가 threshold(duration_merge_samples)보다 작은 경우, 합침
if (speeches[i].start - currentSegment.end < duration_merge_samples) {
// 현재 구간의 끝점을 업데이트
currentSegment.end = speeches[i].end;
} else {
// 차이가 threshold(duration_merge_samples) 이상이면 현재 구간을 저장하고 새로운 구간 시작
mergedSpeeches.push_back(currentSegment);
currentSegment = speeches[i];
}
}
// 마지막 구간 추가
mergedSpeeches.push_back(currentSegment);
return mergedSpeeches;
}
}

View File

@@ -0,0 +1,75 @@
//Author : Nathan Lee
//Created On : 2024-11-18
//Description : silero 5.1 system for torch-script(c++).
//Version : 1.0
#ifndef SILERO_TORCH_H
#define SILERO_TORCH_H
#include <string>
#include <memory>
#include <stdexcept>
#include <iostream>
#include <memory>
#include <vector>
#include <fstream>
#include <chrono>
#include <torch/torch.h>
#include <torch/script.h>
namespace silero{
struct SpeechSegment{
int start;
int end;
};
class VadIterator{
public:
VadIterator(const std::string &model_path, float threshold = 0.5, int sample_rate = 16000,
int window_size_ms = 32, int speech_pad_ms = 30, int min_silence_duration_ms = 100,
int min_speech_duration_ms = 250, int max_duration_merge_ms = 300, bool print_as_samples = false);
~VadIterator();
void SpeechProbs(std::vector<float>& input_wav);
std::vector<silero::SpeechSegment> GetSpeechTimestamps();
void SetVariables();
float threshold;
int sample_rate;
int window_size_ms;
int min_speech_duration_ms;
int max_duration_merge_ms;
bool print_as_samples;
private:
torch::jit::script::Module model;
std::vector<float> outputs_prob;
int min_silence_samples;
int min_speech_samples;
int speech_pad_samples;
int window_size_samples;
int duration_merge_samples;
int current_sample = 0;
int total_sample_size=0;
int min_silence_duration_ms;
int speech_pad_ms;
bool triggered = false;
int temp_end = 0;
void init_engine(int window_size_ms);
void init_torch_model(const std::string& model_path);
void reset_states();
std::vector<SpeechSegment> DoVad();
std::vector<SpeechSegment> mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples);
};
}
#endif // SILERO_TORCH_H

235
examples/cpp_libtorch/wav.h Normal file
View File

@@ -0,0 +1,235 @@
// Copyright (c) 2016 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
// #include "utils/log.h"
namespace wav {
struct WavHeader {
char riff[4]; // "riff"
unsigned int size;
char wav[4]; // "WAVE"
char fmt[4]; // "fmt "
unsigned int fmt_size;
uint16_t format;
uint16_t channels;
unsigned int sample_rate;
unsigned int bytes_per_second;
uint16_t block_size;
uint16_t bit;
char data[4]; // "data"
unsigned int data_size;
};
class WavReader {
public:
WavReader() : data_(nullptr) {}
explicit WavReader(const std::string& filename) { Open(filename); }
bool Open(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
if (NULL == fp) {
std::cout << "Error in read " << filename;
return false;
}
WavHeader header;
fread(&header, 1, sizeof(header), fp);
if (header.fmt_size < 16) {
printf("WaveData: expect PCM format data "
"to have fmt chunk of at least size 16.\n");
return false;
} else if (header.fmt_size > 16) {
int offset = 44 - 8 + header.fmt_size - 16;
fseek(fp, offset, SEEK_SET);
fread(header.data, 8, sizeof(char), fp);
}
// check "riff" "WAVE" "fmt " "data"
// Skip any sub-chunks between "fmt" and "data". Usually there will
// be a single "fact" sub chunk, but on Windows there can also be a
// "list" sub chunk.
while (0 != strncmp(header.data, "data", 4)) {
// We will just ignore the data in these chunks.
fseek(fp, header.data_size, SEEK_CUR);
// read next sub chunk
fread(header.data, 8, sizeof(char), fp);
}
if (header.data_size == 0) {
int offset = ftell(fp);
fseek(fp, 0, SEEK_END);
header.data_size = ftell(fp) - offset;
fseek(fp, offset, SEEK_SET);
}
num_channel_ = header.channels;
sample_rate_ = header.sample_rate;
bits_per_sample_ = header.bit;
int num_data = header.data_size / (bits_per_sample_ / 8);
data_ = new float[num_data]; // Create 1-dim array
num_samples_ = num_data / num_channel_;
std::cout << "num_channel_ :" << num_channel_ << std::endl;
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
std::cout << "num_samples :" << num_data << std::endl;
std::cout << "num_data_size :" << header.data_size << std::endl;
switch (bits_per_sample_) {
case 8: {
char sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(char), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
break;
}
case 16: {
int16_t sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(int16_t), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
break;
}
case 32:
{
if (header.format == 1) //S32
{
int sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(int), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
}
else if (header.format == 3) // IEEE-float
{
float sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(float), fp);
data_[i] = static_cast<float>(sample);
}
}
else {
printf("unsupported quantization bits\n");
}
break;
}
default:
printf("unsupported quantization bits\n");
break;
}
fclose(fp);
return true;
}
int num_channel() const { return num_channel_; }
int sample_rate() const { return sample_rate_; }
int bits_per_sample() const { return bits_per_sample_; }
int num_samples() const { return num_samples_; }
~WavReader() {
delete[] data_;
}
const float* data() const { return data_; }
private:
int num_channel_;
int sample_rate_;
int bits_per_sample_;
int num_samples_; // sample points per channel
float* data_;
};
class WavWriter {
public:
WavWriter(const float* data, int num_samples, int num_channel,
int sample_rate, int bits_per_sample)
: data_(data),
num_samples_(num_samples),
num_channel_(num_channel),
sample_rate_(sample_rate),
bits_per_sample_(bits_per_sample) {}
void Write(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "w");
// init char 'riff' 'WAVE' 'fmt ' 'data'
WavHeader header;
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
memcpy(&header, wav_header, sizeof(header));
header.channels = num_channel_;
header.bit = bits_per_sample_;
header.sample_rate = sample_rate_;
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
header.size = sizeof(header) - 8 + header.data_size;
header.bytes_per_second =
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
header.block_size = num_channel_ * (bits_per_sample_ / 8);
fwrite(&header, 1, sizeof(header), fp);
for (int i = 0; i < num_samples_; ++i) {
for (int j = 0; j < num_channel_; ++j) {
switch (bits_per_sample_) {
case 8: {
char sample = static_cast<char>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 16: {
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 32: {
int sample = static_cast<int>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
}
}
}
fclose(fp);
}
private:
const float* data_;
int num_samples_; // total float points in data_
int num_channel_;
int sample_rate_;
int bits_per_sample_;
};
} // namespace wenet
#endif // FRONTEND_WAV_H_

View File

@@ -0,0 +1,35 @@
using System.Text;
namespace VadDotNet;
class Program
{
private const string MODEL_PATH = "./resources/silero_vad.onnx";
private const string EXAMPLE_WAV_FILE = "./resources/example.wav";
private const int SAMPLE_RATE = 16000;
private const float THRESHOLD = 0.5f;
private const int MIN_SPEECH_DURATION_MS = 250;
private const float MAX_SPEECH_DURATION_SECONDS = float.PositiveInfinity;
private const int MIN_SILENCE_DURATION_MS = 100;
private const int SPEECH_PAD_MS = 30;
public static void Main(string[] args)
{
var vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
List<SileroSpeechSegment> speechTimeList = vadDetector.GetSpeechSegmentList(new FileInfo(EXAMPLE_WAV_FILE));
//Console.WriteLine(speechTimeList.ToJson());
StringBuilder sb = new StringBuilder();
foreach (var speechSegment in speechTimeList)
{
sb.Append($"start second: {speechSegment.StartSecond}, end second: {speechSegment.EndSecond}\n");
}
Console.WriteLine(sb.ToString());
}
}

View File

@@ -0,0 +1,21 @@
namespace VadDotNet;
public class SileroSpeechSegment
{
public int? StartOffset { get; set; }
public int? EndOffset { get; set; }
public float? StartSecond { get; set; }
public float? EndSecond { get; set; }
public SileroSpeechSegment()
{
}
public SileroSpeechSegment(int startOffset, int? endOffset, float? startSecond, float? endSecond)
{
StartOffset = startOffset;
EndOffset = endOffset;
StartSecond = startSecond;
EndSecond = endSecond;
}
}

View File

@@ -0,0 +1,250 @@
using NAudio.Wave;
using VADdotnet;
namespace VadDotNet;
public class SileroVadDetector
{
private readonly SileroVadOnnxModel _model;
private readonly float _threshold;
private readonly float _negThreshold;
private readonly int _samplingRate;
private readonly int _windowSizeSample;
private readonly float _minSpeechSamples;
private readonly float _speechPadSamples;
private readonly float _maxSpeechSamples;
private readonly float _minSilenceSamples;
private readonly float _minSilenceSamplesAtMaxSpeech;
private int _audioLengthSamples;
private const float THRESHOLD_GAP = 0.15f;
// ReSharper disable once InconsistentNaming
private const int SAMPLING_RATE_8K = 8000;
// ReSharper disable once InconsistentNaming
private const int SAMPLING_RATE_16K = 16000;
public SileroVadDetector(string onnxModelPath, float threshold, int samplingRate,
int minSpeechDurationMs, float maxSpeechDurationSeconds,
int minSilenceDurationMs, int speechPadMs)
{
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K)
{
throw new ArgumentException("Sampling rate not support, only available for [8000, 16000]");
}
this._model = new SileroVadOnnxModel(onnxModelPath);
this._samplingRate = samplingRate;
this._threshold = threshold;
this._negThreshold = threshold - THRESHOLD_GAP;
this._windowSizeSample = samplingRate == SAMPLING_RATE_16K ? 512 : 256;
this._minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
this._speechPadSamples = samplingRate * speechPadMs / 1000f;
this._maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - _windowSizeSample - 2 * _speechPadSamples;
this._minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
this._minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
this.Reset();
}
public void Reset()
{
_model.ResetStates();
}
public List<SileroSpeechSegment> GetSpeechSegmentList(FileInfo wavFile)
{
Reset();
using (var audioFile = new AudioFileReader(wavFile.FullName))
{
List<float> speechProbList = new List<float>();
this._audioLengthSamples = (int)(audioFile.Length / 2);
float[] buffer = new float[this._windowSizeSample];
while (audioFile.Read(buffer, 0, buffer.Length) > 0)
{
float speechProb = _model.Call(new[] { buffer }, _samplingRate)[0];
speechProbList.Add(speechProb);
}
return CalculateProb(speechProbList);
}
}
private List<SileroSpeechSegment> CalculateProb(List<float> speechProbList)
{
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
bool triggered = false;
int tempEnd = 0, prevEnd = 0, nextStart = 0;
SileroSpeechSegment segment = new SileroSpeechSegment();
for (int i = 0; i < speechProbList.Count; i++)
{
float speechProb = speechProbList[i];
if (speechProb >= _threshold && (tempEnd != 0))
{
tempEnd = 0;
if (nextStart < prevEnd)
{
nextStart = _windowSizeSample * i;
}
}
if (speechProb >= _threshold && !triggered)
{
triggered = true;
segment.StartOffset = _windowSizeSample * i;
continue;
}
if (triggered && (_windowSizeSample * i) - segment.StartOffset > _maxSpeechSamples)
{
if (prevEnd != 0)
{
segment.EndOffset = prevEnd;
result.Add(segment);
segment = new SileroSpeechSegment();
if (nextStart < prevEnd)
{
triggered = false;
}
else
{
segment.StartOffset = nextStart;
}
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
}
else
{
segment.EndOffset = _windowSizeSample * i;
result.Add(segment);
segment = new SileroSpeechSegment();
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
triggered = false;
continue;
}
}
if (speechProb < _negThreshold && triggered)
{
if (tempEnd == 0)
{
tempEnd = _windowSizeSample * i;
}
if (((_windowSizeSample * i) - tempEnd) > _minSilenceSamplesAtMaxSpeech)
{
prevEnd = tempEnd;
}
if ((_windowSizeSample * i) - tempEnd < _minSilenceSamples)
{
continue;
}
else
{
segment.EndOffset = tempEnd;
if ((segment.EndOffset - segment.StartOffset) > _minSpeechSamples)
{
result.Add(segment);
}
segment = new SileroSpeechSegment();
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
triggered = false;
continue;
}
}
}
if (segment.StartOffset != null && (_audioLengthSamples - segment.StartOffset) > _minSpeechSamples)
{
segment.EndOffset = _audioLengthSamples;
result.Add(segment);
}
for (int i = 0; i < result.Count; i++)
{
SileroSpeechSegment item = result[i];
if (i == 0)
{
item.StartOffset = (int)Math.Max(0, item.StartOffset.Value - _speechPadSamples);
}
if (i != result.Count - 1)
{
SileroSpeechSegment nextItem = result[i + 1];
int silenceDuration = nextItem.StartOffset.Value - item.EndOffset.Value;
if (silenceDuration < 2 * _speechPadSamples)
{
item.EndOffset = item.EndOffset + (silenceDuration / 2);
nextItem.StartOffset = Math.Max(0, nextItem.StartOffset.Value - (silenceDuration / 2));
}
else
{
item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
nextItem.StartOffset = (int)Math.Max(0, nextItem.StartOffset.Value - _speechPadSamples);
}
}
else
{
item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
}
}
return MergeListAndCalculateSecond(result, _samplingRate);
}
private List<SileroSpeechSegment> MergeListAndCalculateSecond(List<SileroSpeechSegment> original, int samplingRate)
{
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
if (original == null || original.Count == 0)
{
return result;
}
int left = original[0].StartOffset.Value;
int right = original[0].EndOffset.Value;
if (original.Count > 1)
{
original.Sort((a, b) => a.StartOffset.Value.CompareTo(b.StartOffset.Value));
for (int i = 1; i < original.Count; i++)
{
SileroSpeechSegment segment = original[i];
if (segment.StartOffset > right)
{
result.Add(new SileroSpeechSegment(left, right,
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
left = segment.StartOffset.Value;
right = segment.EndOffset.Value;
}
else
{
right = Math.Max(right, segment.EndOffset.Value);
}
}
result.Add(new SileroSpeechSegment(left, right,
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
}
else
{
result.Add(new SileroSpeechSegment(left, right,
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
}
return result;
}
private float CalculateSecondByOffset(int offset, int samplingRate)
{
float secondValue = offset * 1.0f / samplingRate;
return (float)Math.Floor(secondValue * 1000.0f) / 1000.0f;
}
}

View File

@@ -0,0 +1,220 @@
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.Linq;
namespace VADdotnet;
public class SileroVadOnnxModel : IDisposable
{
private readonly InferenceSession session;
private float[][][] state;
private float[][] context;
private int lastSr = 0;
private int lastBatchSize = 0;
private static readonly List<int> SAMPLE_RATES = new List<int> { 8000, 16000 };
public SileroVadOnnxModel(string modelPath)
{
var sessionOptions = new SessionOptions();
sessionOptions.InterOpNumThreads = 1;
sessionOptions.IntraOpNumThreads = 1;
sessionOptions.EnableCpuMemArena = true;
session = new InferenceSession(modelPath, sessionOptions);
ResetStates();
}
public void ResetStates()
{
state = new float[2][][];
state[0] = new float[1][];
state[1] = new float[1][];
state[0][0] = new float[128];
state[1][0] = new float[128];
context = Array.Empty<float[]>();
lastSr = 0;
lastBatchSize = 0;
}
public void Dispose()
{
session?.Dispose();
}
public class ValidationResult
{
public float[][] X { get; }
public int Sr { get; }
public ValidationResult(float[][] x, int sr)
{
X = x;
Sr = sr;
}
}
private ValidationResult ValidateInput(float[][] x, int sr)
{
if (x.Length == 1)
{
x = new float[][] { x[0] };
}
if (x.Length > 2)
{
throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}");
}
if (sr != 16000 && (sr % 16000 == 0))
{
int step = sr / 16000;
float[][] reducedX = new float[x.Length][];
for (int i = 0; i < x.Length; i++)
{
float[] current = x[i];
float[] newArr = new float[(current.Length + step - 1) / step];
for (int j = 0, index = 0; j < current.Length; j += step, index++)
{
newArr[index] = current[j];
}
reducedX[i] = newArr;
}
x = reducedX;
sr = 16000;
}
if (!SAMPLE_RATES.Contains(sr))
{
throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)");
}
if (((float)sr) / x[0].Length > 31.25)
{
throw new ArgumentException("Input audio is too short");
}
return new ValidationResult(x, sr);
}
private static float[][] Concatenate(float[][] a, float[][] b)
{
if (a.Length != b.Length)
{
throw new ArgumentException("The number of rows in both arrays must be the same.");
}
int rows = a.Length;
int colsA = a[0].Length;
int colsB = b[0].Length;
float[][] result = new float[rows][];
for (int i = 0; i < rows; i++)
{
result[i] = new float[colsA + colsB];
Array.Copy(a[i], 0, result[i], 0, colsA);
Array.Copy(b[i], 0, result[i], colsA, colsB);
}
return result;
}
private static float[][] GetLastColumns(float[][] array, int contextSize)
{
int rows = array.Length;
int cols = array[0].Length;
if (contextSize > cols)
{
throw new ArgumentException("contextSize cannot be greater than the number of columns in the array.");
}
float[][] result = new float[rows][];
for (int i = 0; i < rows; i++)
{
result[i] = new float[contextSize];
Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize);
}
return result;
}
public float[] Call(float[][] x, int sr)
{
var result = ValidateInput(x, sr);
x = result.X;
sr = result.Sr;
int numberSamples = sr == 16000 ? 512 : 256;
if (x[0].Length != numberSamples)
{
throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)");
}
int batchSize = x.Length;
int contextSize = sr == 16000 ? 64 : 32;
if (lastBatchSize == 0)
{
ResetStates();
}
if (lastSr != 0 && lastSr != sr)
{
ResetStates();
}
if (lastBatchSize != 0 && lastBatchSize != batchSize)
{
ResetStates();
}
if (context.Length == 0)
{
context = new float[batchSize][];
for (int i = 0; i < batchSize; i++)
{
context[i] = new float[contextSize];
}
}
x = Concatenate(context, x);
var inputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("input", new DenseTensor<float>(x.SelectMany(a => a).ToArray(), new[] { x.Length, x[0].Length })),
NamedOnnxValue.CreateFromTensor("sr", new DenseTensor<long>(new[] { (long)sr }, new[] { 1 })),
NamedOnnxValue.CreateFromTensor("state", new DenseTensor<float>(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), new[] { state.Length, state[0].Length, state[0][0].Length }))
};
using (var outputs = session.Run(inputs))
{
var output = outputs.First(o => o.Name == "output").AsTensor<float>();
var newState = outputs.First(o => o.Name == "stateN").AsTensor<float>();
context = GetLastColumns(x, contextSize);
lastSr = sr;
lastBatchSize = batchSize;
state = new float[newState.Dimensions[0]][][];
for (int i = 0; i < newState.Dimensions[0]; i++)
{
state[i] = new float[newState.Dimensions[1]][];
for (int j = 0; j < newState.Dimensions[1]; j++)
{
state[i][j] = new float[newState.Dimensions[2]];
for (int k = 0; k < newState.Dimensions[2]; k++)
{
state[i][j][k] = newState[i, j, k];
}
}
}
return output.ToArray();
}
}
}

View File

@@ -0,0 +1,25 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.18.1" />
<PackageReference Include="NAudio" Version="2.2.1" />
</ItemGroup>
<ItemGroup>
<Folder Include="resources\" />
</ItemGroup>
<ItemGroup>
<Content Include="resources\**">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
</ItemGroup>
</Project>

View File

@@ -0,0 +1 @@
place onnx model file and example.wav file in this folder

View File

@@ -11,11 +11,11 @@ import (
func main() {
sd, err := speech.NewDetector(speech.DetectorConfig{
ModelPath: "../../files/silero_vad.onnx",
ModelPath: "../../src/silero_vad/data/silero_vad.onnx",
SampleRate: 16000,
Threshold: 0.5,
MinSilenceDurationMs: 0,
SpeechPadMs: 0,
MinSilenceDurationMs: 100,
SpeechPadMs: 30,
})
if err != nil {
log.Fatalf("failed to create speech detector: %s", err)

View File

@@ -4,7 +4,7 @@ go 1.21.4
require (
github.com/go-audio/wav v1.1.0
github.com/streamer45/silero-vad-go v0.2.0
github.com/streamer45/silero-vad-go v0.2.1
)
require (

View File

@@ -10,6 +10,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/streamer45/silero-vad-go v0.2.0 h1:bbRTa6cQuc7VI88y0qicx375UyWoxE6wlVOF+mUg0+g=
github.com/streamer45/silero-vad-go v0.2.0/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk=
github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1,13 @@
# Haskell example
To run the example, make sure you put an ``example.wav`` in this directory, and then run the following:
```bash
stack run
```
The ``example.wav`` file must have the following requirements:
- Must be 16khz sample rate.
- Must be mono channel.
- Must be 16-bit audio.
This uses the [silero-vad](https://hackage.haskell.org/package/silero-vad) package, a haskell implementation based on the C# example.

View File

@@ -0,0 +1,22 @@
module Main (main) where
import qualified Data.Vector.Storable as Vector
import Data.WAVE
import Data.Function
import Silero
main :: IO ()
main =
withModel $ \model -> do
wav <- getWAVEFile "example.wav"
let samples =
concat (waveSamples wav)
& Vector.fromList
& Vector.map (realToFrac . sampleToDouble)
let vad =
(defaultVad model)
{ startThreshold = 0.5
, endThreshold = 0.35
}
segments <- detectSegments vad samples
print segments

View File

@@ -0,0 +1,23 @@
cabal-version: 1.12
-- This file has been generated from package.yaml by hpack version 0.37.0.
--
-- see: https://github.com/sol/hpack
name: example
version: 0.1.0.0
build-type: Simple
executable example-exe
main-is: Main.hs
other-modules:
Paths_example
hs-source-dirs:
app
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
build-depends:
WAVE
, base >=4.7 && <5
, silero-vad
, vector
default-language: Haskell2010

View File

@@ -0,0 +1,28 @@
name: example
version: 0.1.0.0
dependencies:
- base >= 4.7 && < 5
- silero-vad
- WAVE
- vector
ghc-options:
- -Wall
- -Wcompat
- -Widentities
- -Wincomplete-record-updates
- -Wincomplete-uni-patterns
- -Wmissing-export-lists
- -Wmissing-home-modules
- -Wpartial-fields
- -Wredundant-constraints
executables:
example-exe:
main: Main.hs
source-dirs: app
ghc-options:
- -threaded
- -rtsopts
- -with-rtsopts=-N

View File

@@ -0,0 +1,11 @@
snapshot:
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml
packages:
- .
extra-deps:
- silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
- WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
- derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
- vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481

View File

@@ -0,0 +1,41 @@
# This file was autogenerated by Stack.
# You should not edit this file by hand.
# For more information, please see the documentation at:
# https://docs.haskellstack.org/en/stable/lock_files
packages:
- completed:
hackage: silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
pantry-tree:
sha256: a62e813f978d32c87769796fded981d25fcf2875bb2afdf60ed6279f931ccd7f
size: 1391
original:
hackage: silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
- completed:
hackage: WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
pantry-tree:
sha256: ee5ccd70fa7fe6ffc360ebd762b2e3f44ae10406aa27f3842d55b8cbd1a19498
size: 405
original:
hackage: WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
- completed:
hackage: derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
pantry-tree:
sha256: 48e35a72d1bb593173890616c8d7efd636a650a306a50bb3e1513e679939d27e
size: 902
original:
hackage: derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
- completed:
hackage: vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481
pantry-tree:
sha256: 2176fd677a02a4c47337f7dca5aeca2745dbb821a6ea5c7099b3a991ecd7f4f0
size: 4478
original:
hackage: vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481
snapshots:
- completed:
sha256: 5a59b2a405b3aba3c00188453be172b85893cab8ebc352b1ef58b0eae5d248a2
size: 650475
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml
original:
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml

View File

@@ -1,30 +1,31 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>java-example</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<groupId>org.example</groupId>
<artifactId>java-example</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<name>sliero-vad-example</name>
<url>http://maven.apache.org</url>
<name>sliero-vad-example</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.16.0-rc1</version>
</dependency>
</dependencies>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.1</version>
<scope>test</scope>
</dependency>
<!-- https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.23.1</version>
</dependency>
</dependencies>
</project>

View File

@@ -2,68 +2,263 @@ package org.example;
import ai.onnxruntime.OrtException;
import javax.sound.sampled.*;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Silero VAD Java Example
* Voice Activity Detection using ONNX model
*
* @author VvvvvGH
*/
public class App {
private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx";
// ONNX model path - using the model file from the project
private static final String MODEL_PATH = "../../src/silero_vad/data/silero_vad.onnx";
// Test audio file path
private static final String AUDIO_FILE_PATH = "../../en_example.wav";
// Sampling rate
private static final int SAMPLE_RATE = 16000;
private static final float START_THRESHOLD = 0.6f;
private static final float END_THRESHOLD = 0.45f;
private static final int MIN_SILENCE_DURATION_MS = 600;
private static final int SPEECH_PAD_MS = 500;
private static final int WINDOW_SIZE_SAMPLES = 2048;
// Speech threshold (consistent with Python default)
private static final float THRESHOLD = 0.5f;
// Negative threshold (used to determine speech end)
private static final float NEG_THRESHOLD = 0.35f; // threshold - 0.15
// Minimum speech duration (milliseconds)
private static final int MIN_SPEECH_DURATION_MS = 250;
// Minimum silence duration (milliseconds)
private static final int MIN_SILENCE_DURATION_MS = 100;
// Speech padding (milliseconds)
private static final int SPEECH_PAD_MS = 30;
// Window size (samples) - 512 samples for 16kHz
private static final int WINDOW_SIZE_SAMPLES = 512;
public static void main(String[] args) {
// Initialize the Voice Activity Detector
SlieroVadDetector vadDetector;
System.out.println("=".repeat(60));
System.out.println("Silero VAD Java ONNX Example");
System.out.println("=".repeat(60));
// Load ONNX model
SlieroVadOnnxModel model;
try {
vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
System.out.println("Loading ONNX model: " + MODEL_PATH);
model = new SlieroVadOnnxModel(MODEL_PATH);
System.out.println("Model loaded successfully!");
} catch (OrtException e) {
System.err.println("Error initializing the VAD detector: " + e.getMessage());
System.err.println("Failed to load model: " + e.getMessage());
e.printStackTrace();
return;
}
// Set audio format
AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false);
DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
// Get the target data line and open it with the specified format
TargetDataLine targetDataLine;
// Read WAV file
float[] audioData;
try {
targetDataLine = (TargetDataLine) AudioSystem.getLine(info);
targetDataLine.open(format);
targetDataLine.start();
} catch (LineUnavailableException e) {
System.err.println("Error opening target data line: " + e.getMessage());
System.out.println("\nReading audio file: " + AUDIO_FILE_PATH);
audioData = readWavFileAsFloatArray(AUDIO_FILE_PATH);
System.out.println("Audio file read successfully, samples: " + audioData.length);
System.out.println("Audio duration: " + String.format("%.2f", (audioData.length / (float) SAMPLE_RATE)) + " seconds");
} catch (Exception e) {
System.err.println("Failed to read audio file: " + e.getMessage());
e.printStackTrace();
return;
}
// Main loop to continuously read data and apply Voice Activity Detection
while (targetDataLine.isOpen()) {
byte[] data = new byte[WINDOW_SIZE_SAMPLES];
int numBytesRead = targetDataLine.read(data, 0, data.length);
if (numBytesRead <= 0) {
System.err.println("Error reading data from target data line.");
continue;
}
// Apply the Voice Activity Detector to the data and get the result
Map<String, Double> detectResult;
try {
detectResult = vadDetector.apply(data, true);
} catch (Exception e) {
System.err.println("Error applying VAD detector: " + e.getMessage());
continue;
}
if (!detectResult.isEmpty()) {
System.out.println(detectResult);
}
// Get speech timestamps (batch mode, consistent with Python's get_speech_timestamps)
System.out.println("\nDetecting speech segments...");
List<Map<String, Integer>> speechTimestamps;
try {
speechTimestamps = getSpeechTimestamps(
audioData,
model,
THRESHOLD,
SAMPLE_RATE,
MIN_SPEECH_DURATION_MS,
MIN_SILENCE_DURATION_MS,
SPEECH_PAD_MS,
NEG_THRESHOLD
);
} catch (OrtException e) {
System.err.println("Failed to detect speech timestamps: " + e.getMessage());
e.printStackTrace();
return;
}
// Close the target data line to release audio resources
targetDataLine.close();
// Output detection results
System.out.println("\nDetected speech timestamps (in samples):");
for (Map<String, Integer> timestamp : speechTimestamps) {
System.out.println(timestamp);
}
// Output summary
System.out.println("\n" + "=".repeat(60));
System.out.println("Detection completed!");
System.out.println("Total detected " + speechTimestamps.size() + " speech segments");
System.out.println("=".repeat(60));
// Close model
try {
model.close();
} catch (OrtException e) {
System.err.println("Error closing model: " + e.getMessage());
}
}
/**
* Get speech timestamps
* Implements the same logic as Python's get_speech_timestamps
*
* @param audio Audio data (float array)
* @param model ONNX model
* @param threshold Speech threshold
* @param samplingRate Sampling rate
* @param minSpeechDurationMs Minimum speech duration (milliseconds)
* @param minSilenceDurationMs Minimum silence duration (milliseconds)
* @param speechPadMs Speech padding (milliseconds)
* @param negThreshold Negative threshold (used to determine speech end)
* @return List of speech timestamps
*/
private static List<Map<String, Integer>> getSpeechTimestamps(
float[] audio,
SlieroVadOnnxModel model,
float threshold,
int samplingRate,
int minSpeechDurationMs,
int minSilenceDurationMs,
int speechPadMs,
float negThreshold) throws OrtException {
// Reset model states
model.resetStates();
// Calculate parameters
int minSpeechSamples = samplingRate * minSpeechDurationMs / 1000;
int speechPadSamples = samplingRate * speechPadMs / 1000;
int minSilenceSamples = samplingRate * minSilenceDurationMs / 1000;
int windowSizeSamples = samplingRate == 16000 ? 512 : 256;
int audioLengthSamples = audio.length;
// Calculate speech probabilities for all audio chunks
List<Float> speechProbs = new ArrayList<>();
for (int currentStart = 0; currentStart < audioLengthSamples; currentStart += windowSizeSamples) {
float[] chunk = new float[windowSizeSamples];
int chunkLength = Math.min(windowSizeSamples, audioLengthSamples - currentStart);
System.arraycopy(audio, currentStart, chunk, 0, chunkLength);
// Pad with zeros if chunk is shorter than window size
if (chunkLength < windowSizeSamples) {
for (int i = chunkLength; i < windowSizeSamples; i++) {
chunk[i] = 0.0f;
}
}
float speechProb = model.call(new float[][]{chunk}, samplingRate)[0];
speechProbs.add(speechProb);
}
// Detect speech segments using the same algorithm as Python
boolean triggered = false;
List<Map<String, Integer>> speeches = new ArrayList<>();
Map<String, Integer> currentSpeech = null;
int tempEnd = 0;
for (int i = 0; i < speechProbs.size(); i++) {
float speechProb = speechProbs.get(i);
// Reset temporary end if speech probability exceeds threshold
if (speechProb >= threshold && tempEnd != 0) {
tempEnd = 0;
}
// Detect speech start
if (speechProb >= threshold && !triggered) {
triggered = true;
currentSpeech = new HashMap<>();
currentSpeech.put("start", windowSizeSamples * i);
continue;
}
// Detect speech end
if (speechProb < negThreshold && triggered) {
if (tempEnd == 0) {
tempEnd = windowSizeSamples * i;
}
if (windowSizeSamples * i - tempEnd < minSilenceSamples) {
continue;
} else {
currentSpeech.put("end", tempEnd);
if (currentSpeech.get("end") - currentSpeech.get("start") > minSpeechSamples) {
speeches.add(currentSpeech);
}
currentSpeech = null;
tempEnd = 0;
triggered = false;
}
}
}
// Handle the last speech segment
if (currentSpeech != null &&
(audioLengthSamples - currentSpeech.get("start")) > minSpeechSamples) {
currentSpeech.put("end", audioLengthSamples);
speeches.add(currentSpeech);
}
// Add speech padding - same logic as Python
for (int i = 0; i < speeches.size(); i++) {
Map<String, Integer> speech = speeches.get(i);
if (i == 0) {
speech.put("start", Math.max(0, speech.get("start") - speechPadSamples));
}
if (i != speeches.size() - 1) {
int silenceDuration = speeches.get(i + 1).get("start") - speech.get("end");
if (silenceDuration < 2 * speechPadSamples) {
speech.put("end", speech.get("end") + silenceDuration / 2);
speeches.get(i + 1).put("start",
Math.max(0, speeches.get(i + 1).get("start") - silenceDuration / 2));
} else {
speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples));
speeches.get(i + 1).put("start",
Math.max(0, speeches.get(i + 1).get("start") - speechPadSamples));
}
} else {
speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples));
}
}
return speeches;
}
/**
* Read WAV file and return as float array
*
* @param filePath WAV file path
* @return Audio data as float array (normalized to -1.0 to 1.0)
*/
private static float[] readWavFileAsFloatArray(String filePath)
throws UnsupportedAudioFileException, IOException {
File audioFile = new File(filePath);
AudioInputStream audioStream = AudioSystem.getAudioInputStream(audioFile);
// Get audio format information
AudioFormat format = audioStream.getFormat();
System.out.println("Audio format: " + format);
// Read all audio data
byte[] audioBytes = audioStream.readAllBytes();
audioStream.close();
// Convert to float array
float[] audioData = new float[audioBytes.length / 2];
for (int i = 0; i < audioData.length; i++) {
// 16-bit PCM: two bytes per sample (little-endian)
short sample = (short) ((audioBytes[i * 2] & 0xff) | (audioBytes[i * 2 + 1] << 8));
audioData[i] = sample / 32768.0f; // Normalize to -1.0 to 1.0
}
return audioData;
}
}

View File

@@ -8,25 +8,30 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
/**
* Silero VAD Detector
* Real-time voice activity detection
*
* @author VvvvvGH
*/
public class SlieroVadDetector {
// OnnxModel model used for speech processing
// ONNX model for speech processing
private final SlieroVadOnnxModel model;
// Threshold for speech start
// Speech start threshold
private final float startThreshold;
// Threshold for speech end
// Speech end threshold
private final float endThreshold;
// Sampling rate
private final int samplingRate;
// Minimum number of silence samples to determine the end threshold of speech
// Minimum silence samples to determine speech end
private final float minSilenceSamples;
// Additional number of samples for speech start or end to calculate speech start or end time
// Speech padding samples for calculating speech boundaries
private final float speechPadSamples;
// Whether in the triggered state (i.e. whether speech is being detected)
// Triggered state (whether speech is being detected)
private boolean triggered;
// Temporarily stored number of speech end samples
// Temporary speech end sample position
private int tempEnd;
// Number of samples currently being processed
// Current sample position
private int currentSample;
@@ -36,23 +41,25 @@ public class SlieroVadDetector {
int samplingRate,
int minSilenceDurationMs,
int speechPadMs) throws OrtException {
// Check if the sampling rate is 8000 or 16000, if not, throw an exception
// Validate sampling rate
if (samplingRate != 8000 && samplingRate != 16000) {
throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]");
throw new IllegalArgumentException("Does not support sampling rates other than [8000, 16000]");
}
// Initialize the parameters
// Initialize parameters
this.model = new SlieroVadOnnxModel(modelPath);
this.startThreshold = startThreshold;
this.endThreshold = endThreshold;
this.samplingRate = samplingRate;
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
// Reset the state
// Reset state
reset();
}
// Method to reset the state, including the model state, trigger state, temporary end time, and current sample count
/**
* Reset detector state
*/
public void reset() {
model.resetStates();
triggered = false;
@@ -60,21 +67,27 @@ public class SlieroVadDetector {
currentSample = 0;
}
// apply method for processing the audio array, returning possible speech start or end times
/**
* Process audio data and detect speech events
*
* @param data Audio data as byte array
* @param returnSeconds Whether to return timestamps in seconds
* @return Speech event (start or end) or empty map if no event
*/
public Map<String, Double> apply(byte[] data, boolean returnSeconds) {
// Convert the byte array to a float array
// Convert byte array to float array
float[] audioData = new float[data.length / 2];
for (int i = 0; i < audioData.length; i++) {
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
}
// Get the length of the audio array as the window size
// Get window size from audio data length
int windowSizeSamples = audioData.length;
// Update the current sample count
// Update current sample position
currentSample += windowSizeSamples;
// Call the model to get the prediction probability of speech
// Get speech probability from model
float speechProb = 0;
try {
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
@@ -82,19 +95,18 @@ public class SlieroVadDetector {
throw new RuntimeException(e);
}
// If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time
// This indicates that the speech duration has exceeded expectations and needs to recalculate the end time
// Reset temporary end if speech probability exceeds threshold
if (speechProb >= startThreshold && tempEnd != 0) {
tempEnd = 0;
}
// If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time
// Detect speech start
if (speechProb >= startThreshold && !triggered) {
triggered = true;
int speechStart = (int) (currentSample - speechPadSamples);
speechStart = Math.max(speechStart, 0);
Map<String, Double> result = new HashMap<>();
// Decide whether to return the result in seconds or sample count based on the returnSeconds parameter
// Return in seconds or samples based on returnSeconds parameter
if (returnSeconds) {
double speechStartSeconds = speechStart / (double) samplingRate;
double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
@@ -106,18 +118,17 @@ public class SlieroVadDetector {
return result;
}
// If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time
// Detect speech end
if (speechProb < endThreshold && triggered) {
// Initialize or update the temporary end time
// Initialize or update temporary end position
if (tempEnd == 0) {
tempEnd = currentSample;
}
// If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null
// This indicates that it is not yet possible to determine whether the speech has ended
// Wait for minimum silence duration before confirming speech end
if (currentSample - tempEnd < minSilenceSamples) {
return Collections.emptyMap();
} else {
// Calculate the speech end time, reset the trigger state and temporary end time
// Calculate speech end time and reset state
int speechEnd = (int) (tempEnd + speechPadSamples);
tempEnd = 0;
triggered = false;
@@ -134,7 +145,7 @@ public class SlieroVadDetector {
}
}
// If the above conditions are not met, return null by default
// No speech event detected
return Collections.emptyMap();
}

View File

@@ -9,42 +9,58 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Silero VAD ONNX Model Wrapper
*
* @author VvvvvGH
*/
public class SlieroVadOnnxModel {
// Define private variable OrtSession
// ONNX runtime session
private final OrtSession session;
private float[][][] h;
private float[][][] c;
// Define the last sample rate
// Model state - dimensions: [2, batch_size, 128]
private float[][][] state;
// Context - stores the tail of the previous audio chunk
private float[][] context;
// Last sample rate
private int lastSr = 0;
// Define the last batch size
// Last batch size
private int lastBatchSize = 0;
// Define a list of supported sample rates
// Supported sample rates
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
// Constructor
public SlieroVadOnnxModel(String modelPath) throws OrtException {
// Get the ONNX runtime environment
OrtEnvironment env = OrtEnvironment.getEnvironment();
// Create an ONNX session options object
// Create ONNX session options
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
// Set InterOp thread count to 1 (for parallel processing of different graph operations)
opts.setInterOpNumThreads(1);
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
// Set IntraOp thread count to 1 (for parallel processing within a single operation)
opts.setIntraOpNumThreads(1);
// Add a CPU device, setting to false disables CPU execution optimization
// Enable CPU execution optimization
opts.addCPU(true);
// Create an ONNX session using the environment, model path, and options
// Create ONNX session with the environment, model path, and options
session = env.createSession(modelPath, opts);
// Reset states
resetStates();
}
/**
* Reset states
* Reset states with default batch size
*/
void resetStates() {
h = new float[2][1][64];
c = new float[2][1][64];
resetStates(1);
}
/**
* Reset states with specific batch size
*
* @param batchSize Batch size for state initialization
*/
void resetStates(int batchSize) {
state = new float[2][batchSize][128];
context = new float[0][]; // Empty context
lastSr = 0;
lastBatchSize = 0;
}
@@ -54,13 +70,12 @@ public class SlieroVadOnnxModel {
}
/**
* Define inner class ValidationResult
* Inner class for validation result
*/
public static class ValidationResult {
public final float[][] x;
public final int sr;
// Constructor
public ValidationResult(float[][] x, int sr) {
this.x = x;
this.sr = sr;
@@ -68,19 +83,23 @@ public class SlieroVadOnnxModel {
}
/**
* Function to validate input data
* Validate input data
*
* @param x Audio data array
* @param sr Sample rate
* @return Validated input data and sample rate
*/
private ValidationResult validateInput(float[][] x, int sr) {
// Process the input data with dimension 1
// Ensure input is at least 2D
if (x.length == 1) {
x = new float[][]{x[0]};
}
// Throw an exception when the input data dimension is greater than 2
// Check if input dimension is valid
if (x.length > 2) {
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
}
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
// Downsample if sample rate is a multiple of 16000
if (sr != 16000 && (sr % 16000 == 0)) {
int step = sr / 16000;
float[][] reducedX = new float[x.length][];
@@ -100,22 +119,26 @@ public class SlieroVadOnnxModel {
sr = 16000;
}
// If the sample rate is not in the list of supported sample rates, throw an exception
// Validate sample rate
if (!SAMPLE_RATES.contains(sr)) {
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
}
// If the input audio block is too short, throw an exception
// Check if audio chunk is too short
if (((float) sr) / x[0].length > 31.25) {
throw new IllegalArgumentException("Input audio is too short");
}
// Return the validated result
return new ValidationResult(x, sr);
}
/**
* Method to call the ONNX model
* Call the ONNX model for inference
*
* @param x Audio data array
* @param sr Sample rate
* @return Speech probability output
* @throws OrtException If ONNX runtime error occurs
*/
public float[] call(float[][] x, int sr) throws OrtException {
ValidationResult result = validateInput(x, sr);
@@ -123,38 +146,62 @@ public class SlieroVadOnnxModel {
sr = result.sr;
int batchSize = x.length;
int numSamples = sr == 16000 ? 512 : 256;
int contextSize = sr == 16000 ? 64 : 32;
if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) {
resetStates();
// Reset states only when sample rate or batch size changes
if (lastSr != 0 && lastSr != sr) {
resetStates(batchSize);
} else if (lastBatchSize != 0 && lastBatchSize != batchSize) {
resetStates(batchSize);
} else if (lastBatchSize == 0) {
// First call - state is already initialized, just set batch size
lastBatchSize = batchSize;
}
// Initialize context if needed
if (context.length == 0) {
context = new float[batchSize][contextSize];
}
// Concatenate context and input
float[][] xWithContext = new float[batchSize][contextSize + numSamples];
for (int i = 0; i < batchSize; i++) {
// Copy context
System.arraycopy(context[i], 0, xWithContext[i], 0, contextSize);
// Copy input
System.arraycopy(x[i], 0, xWithContext[i], contextSize, numSamples);
}
OrtEnvironment env = OrtEnvironment.getEnvironment();
OnnxTensor inputTensor = null;
OnnxTensor hTensor = null;
OnnxTensor cTensor = null;
OnnxTensor stateTensor = null;
OnnxTensor srTensor = null;
OrtSession.Result ortOutputs = null;
try {
// Create input tensors
inputTensor = OnnxTensor.createTensor(env, x);
hTensor = OnnxTensor.createTensor(env, h);
cTensor = OnnxTensor.createTensor(env, c);
inputTensor = OnnxTensor.createTensor(env, xWithContext);
stateTensor = OnnxTensor.createTensor(env, state);
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input", inputTensor);
inputs.put("sr", srTensor);
inputs.put("h", hTensor);
inputs.put("c", cTensor);
inputs.put("state", stateTensor);
// Call the ONNX model for calculation
// Run ONNX model inference
ortOutputs = session.run(inputs);
// Get the output results
// Get output results
float[][] output = (float[][]) ortOutputs.get(0).getValue();
h = (float[][][]) ortOutputs.get(1).getValue();
c = (float[][][]) ortOutputs.get(2).getValue();
state = (float[][][]) ortOutputs.get(1).getValue();
// Update context - save the last contextSize samples from input
for (int i = 0; i < batchSize; i++) {
System.arraycopy(xWithContext[i], xWithContext[i].length - contextSize,
context[i], 0, contextSize);
}
lastSr = sr;
lastBatchSize = batchSize;
@@ -163,11 +210,8 @@ public class SlieroVadOnnxModel {
if (inputTensor != null) {
inputTensor.close();
}
if (hTensor != null) {
hTensor.close();
}
if (cTensor != null) {
cTensor.close();
if (stateTensor != null) {
stateTensor.close();
}
if (srTensor != null) {
srTensor.close();

View File

@@ -0,0 +1,37 @@
package org.example;
import ai.onnxruntime.OrtException;
import java.io.File;
import java.util.List;
public class App {
private static final String MODEL_PATH = "/path/silero_vad.onnx";
private static final String EXAMPLE_WAV_FILE = "/path/example.wav";
private static final int SAMPLE_RATE = 16000;
private static final float THRESHOLD = 0.5f;
private static final int MIN_SPEECH_DURATION_MS = 250;
private static final float MAX_SPEECH_DURATION_SECONDS = Float.POSITIVE_INFINITY;
private static final int MIN_SILENCE_DURATION_MS = 100;
private static final int SPEECH_PAD_MS = 30;
public static void main(String[] args) {
// Initialize the Voice Activity Detector
SileroVadDetector vadDetector;
try {
vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
fromWavFile(vadDetector, new File(EXAMPLE_WAV_FILE));
} catch (OrtException e) {
System.err.println("Error initializing the VAD detector: " + e.getMessage());
}
}
public static void fromWavFile(SileroVadDetector vadDetector, File wavFile) {
List<SileroSpeechSegment> speechTimeList = vadDetector.getSpeechSegmentList(wavFile);
for (SileroSpeechSegment speechSegment : speechTimeList) {
System.out.println(String.format("start second: %f, end second: %f",
speechSegment.getStartSecond(), speechSegment.getEndSecond()));
}
}
}

View File

@@ -0,0 +1,51 @@
package org.example;
public class SileroSpeechSegment {
private Integer startOffset;
private Integer endOffset;
private Float startSecond;
private Float endSecond;
public SileroSpeechSegment() {
}
public SileroSpeechSegment(Integer startOffset, Integer endOffset, Float startSecond, Float endSecond) {
this.startOffset = startOffset;
this.endOffset = endOffset;
this.startSecond = startSecond;
this.endSecond = endSecond;
}
public Integer getStartOffset() {
return startOffset;
}
public Integer getEndOffset() {
return endOffset;
}
public Float getStartSecond() {
return startSecond;
}
public Float getEndSecond() {
return endSecond;
}
public void setStartOffset(Integer startOffset) {
this.startOffset = startOffset;
}
public void setEndOffset(Integer endOffset) {
this.endOffset = endOffset;
}
public void setStartSecond(Float startSecond) {
this.startSecond = startSecond;
}
public void setEndSecond(Float endSecond) {
this.endSecond = endSecond;
}
}

View File

@@ -0,0 +1,244 @@
package org.example;
import ai.onnxruntime.OrtException;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.File;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
public class SileroVadDetector {
private final SileroVadOnnxModel model;
private final float threshold;
private final float negThreshold;
private final int samplingRate;
private final int windowSizeSample;
private final float minSpeechSamples;
private final float speechPadSamples;
private final float maxSpeechSamples;
private final float minSilenceSamples;
private final float minSilenceSamplesAtMaxSpeech;
private int audioLengthSamples;
private static final float THRESHOLD_GAP = 0.15f;
private static final Integer SAMPLING_RATE_8K = 8000;
private static final Integer SAMPLING_RATE_16K = 16000;
/**
* Constructor
* @param onnxModelPath the path of silero-vad onnx model
* @param threshold threshold for speech start
* @param samplingRate audio sampling rate, only available for [8k, 16k]
* @param minSpeechDurationMs Minimum speech length in millis, any speech duration that smaller than this value would not be considered as speech
* @param maxSpeechDurationSeconds Maximum speech length in millis, recommend to be set as Float.POSITIVE_INFINITY
* @param minSilenceDurationMs Minimum silence length in millis, any silence duration that smaller than this value would not be considered as silence
* @param speechPadMs Additional pad millis for speech start and end
* @throws OrtException
*/
public SileroVadDetector(String onnxModelPath, float threshold, int samplingRate,
int minSpeechDurationMs, float maxSpeechDurationSeconds,
int minSilenceDurationMs, int speechPadMs) throws OrtException {
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K) {
throw new IllegalArgumentException("Sampling rate not support, only available for [8000, 16000]");
}
this.model = new SileroVadOnnxModel(onnxModelPath);
this.samplingRate = samplingRate;
this.threshold = threshold;
this.negThreshold = threshold - THRESHOLD_GAP;
if (samplingRate == SAMPLING_RATE_16K) {
this.windowSizeSample = 512;
} else {
this.windowSizeSample = 256;
}
this.minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
this.maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - windowSizeSample - 2 * speechPadSamples;
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
this.minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
this.reset();
}
/**
* Method to reset the state
*/
public void reset() {
model.resetStates();
}
/**
* Get speech segment list by given wav-format file
* @param wavFile wav file
* @return list of speech segment
*/
public List<SileroSpeechSegment> getSpeechSegmentList(File wavFile) {
reset();
try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(wavFile)){
List<Float> speechProbList = new ArrayList<>();
this.audioLengthSamples = audioInputStream.available() / 2;
byte[] data = new byte[this.windowSizeSample * 2];
int numBytesRead = 0;
while ((numBytesRead = audioInputStream.read(data)) != -1) {
if (numBytesRead <= 0) {
break;
}
// Convert the byte array to a float array
float[] audioData = new float[data.length / 2];
for (int i = 0; i < audioData.length; i++) {
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
}
float speechProb = 0;
try {
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
speechProbList.add(speechProb);
} catch (OrtException e) {
throw e;
}
}
return calculateProb(speechProbList);
} catch (Exception e) {
throw new RuntimeException("SileroVadDetector getSpeechTimeList with error", e);
}
}
/**
* Calculate speech segement by probability
* @param speechProbList speech probability list
* @return list of speech segment
*/
private List<SileroSpeechSegment> calculateProb(List<Float> speechProbList) {
List<SileroSpeechSegment> result = new ArrayList<>();
boolean triggered = false;
int tempEnd = 0, prevEnd = 0, nextStart = 0;
SileroSpeechSegment segment = new SileroSpeechSegment();
for (int i = 0; i < speechProbList.size(); i++) {
Float speechProb = speechProbList.get(i);
if (speechProb >= threshold && (tempEnd != 0)) {
tempEnd = 0;
if (nextStart < prevEnd) {
nextStart = windowSizeSample * i;
}
}
if (speechProb >= threshold && !triggered) {
triggered = true;
segment.setStartOffset(windowSizeSample * i);
continue;
}
if (triggered && (windowSizeSample * i) - segment.getStartOffset() > maxSpeechSamples) {
if (prevEnd != 0) {
segment.setEndOffset(prevEnd);
result.add(segment);
segment = new SileroSpeechSegment();
if (nextStart < prevEnd) {
triggered = false;
}else {
segment.setStartOffset(nextStart);
}
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
}else {
segment.setEndOffset(windowSizeSample * i);
result.add(segment);
segment = new SileroSpeechSegment();
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
triggered = false;
continue;
}
}
if (speechProb < negThreshold && triggered) {
if (tempEnd == 0) {
tempEnd = windowSizeSample * i;
}
if (((windowSizeSample * i) - tempEnd) > minSilenceSamplesAtMaxSpeech) {
prevEnd = tempEnd;
}
if ((windowSizeSample * i) - tempEnd < minSilenceSamples) {
continue;
}else {
segment.setEndOffset(tempEnd);
if ((segment.getEndOffset() - segment.getStartOffset()) > minSpeechSamples) {
result.add(segment);
}
segment = new SileroSpeechSegment();
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
triggered = false;
continue;
}
}
}
if (segment.getStartOffset() != null && (audioLengthSamples - segment.getStartOffset()) > minSpeechSamples) {
segment.setEndOffset(audioLengthSamples);
result.add(segment);
}
for (int i = 0; i < result.size(); i++) {
SileroSpeechSegment item = result.get(i);
if (i == 0) {
item.setStartOffset((int)(Math.max(0,item.getStartOffset() - speechPadSamples)));
}
if (i != result.size() - 1) {
SileroSpeechSegment nextItem = result.get(i + 1);
Integer silenceDuration = nextItem.getStartOffset() - item.getEndOffset();
if(silenceDuration < 2 * speechPadSamples){
item.setEndOffset(item.getEndOffset() + (silenceDuration / 2 ));
nextItem.setStartOffset(Math.max(0, nextItem.getStartOffset() - (silenceDuration / 2)));
} else {
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
nextItem.setStartOffset((int)(Math.max(0,nextItem.getStartOffset() - speechPadSamples)));
}
}else {
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
}
}
return mergeListAndCalculateSecond(result, samplingRate);
}
private List<SileroSpeechSegment> mergeListAndCalculateSecond(List<SileroSpeechSegment> original, Integer samplingRate) {
List<SileroSpeechSegment> result = new ArrayList<>();
if (original == null || original.size() == 0) {
return result;
}
Integer left = original.get(0).getStartOffset();
Integer right = original.get(0).getEndOffset();
if (original.size() > 1) {
original.sort(Comparator.comparingLong(SileroSpeechSegment::getStartOffset));
for (int i = 1; i < original.size(); i++) {
SileroSpeechSegment segment = original.get(i);
if (segment.getStartOffset() > right) {
result.add(new SileroSpeechSegment(left, right,
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
left = segment.getStartOffset();
right = segment.getEndOffset();
} else {
right = Math.max(right, segment.getEndOffset());
}
}
result.add(new SileroSpeechSegment(left, right,
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
}else {
result.add(new SileroSpeechSegment(left, right,
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
}
return result;
}
private Float calculateSecondByOffset(Integer offset, Integer samplingRate) {
float secondValue = offset * 1.0f / samplingRate;
return (float) Math.floor(secondValue * 1000.0f) / 1000.0f;
}
}

View File

@@ -0,0 +1,234 @@
package org.example;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class SileroVadOnnxModel {
// Define private variable OrtSession
private final OrtSession session;
private float[][][] state;
private float[][] context;
// Define the last sample rate
private int lastSr = 0;
// Define the last batch size
private int lastBatchSize = 0;
// Define a list of supported sample rates
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
// Constructor
public SileroVadOnnxModel(String modelPath) throws OrtException {
// Get the ONNX runtime environment
OrtEnvironment env = OrtEnvironment.getEnvironment();
// Create an ONNX session options object
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
opts.setInterOpNumThreads(1);
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
opts.setIntraOpNumThreads(1);
// Add a CPU device, setting to false disables CPU execution optimization
opts.addCPU(true);
// Create an ONNX session using the environment, model path, and options
session = env.createSession(modelPath, opts);
// Reset states
resetStates();
}
/**
* Reset states
*/
void resetStates() {
state = new float[2][1][128];
context = new float[0][];
lastSr = 0;
lastBatchSize = 0;
}
public void close() throws OrtException {
session.close();
}
/**
* Define inner class ValidationResult
*/
public static class ValidationResult {
public final float[][] x;
public final int sr;
// Constructor
public ValidationResult(float[][] x, int sr) {
this.x = x;
this.sr = sr;
}
}
/**
* Function to validate input data
*/
private ValidationResult validateInput(float[][] x, int sr) {
// Process the input data with dimension 1
if (x.length == 1) {
x = new float[][]{x[0]};
}
// Throw an exception when the input data dimension is greater than 2
if (x.length > 2) {
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
}
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
if (sr != 16000 && (sr % 16000 == 0)) {
int step = sr / 16000;
float[][] reducedX = new float[x.length][];
for (int i = 0; i < x.length; i++) {
float[] current = x[i];
float[] newArr = new float[(current.length + step - 1) / step];
for (int j = 0, index = 0; j < current.length; j += step, index++) {
newArr[index] = current[j];
}
reducedX[i] = newArr;
}
x = reducedX;
sr = 16000;
}
// If the sample rate is not in the list of supported sample rates, throw an exception
if (!SAMPLE_RATES.contains(sr)) {
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
}
// If the input audio block is too short, throw an exception
if (((float) sr) / x[0].length > 31.25) {
throw new IllegalArgumentException("Input audio is too short");
}
// Return the validated result
return new ValidationResult(x, sr);
}
private static float[][] concatenate(float[][] a, float[][] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("The number of rows in both arrays must be the same.");
}
int rows = a.length;
int colsA = a[0].length;
int colsB = b[0].length;
float[][] result = new float[rows][colsA + colsB];
for (int i = 0; i < rows; i++) {
System.arraycopy(a[i], 0, result[i], 0, colsA);
System.arraycopy(b[i], 0, result[i], colsA, colsB);
}
return result;
}
private static float[][] getLastColumns(float[][] array, int contextSize) {
int rows = array.length;
int cols = array[0].length;
if (contextSize > cols) {
throw new IllegalArgumentException("contextSize cannot be greater than the number of columns in the array.");
}
float[][] result = new float[rows][contextSize];
for (int i = 0; i < rows; i++) {
System.arraycopy(array[i], cols - contextSize, result[i], 0, contextSize);
}
return result;
}
/**
* Method to call the ONNX model
*/
public float[] call(float[][] x, int sr) throws OrtException {
ValidationResult result = validateInput(x, sr);
x = result.x;
sr = result.sr;
int numberSamples = 256;
if (sr == 16000) {
numberSamples = 512;
}
if (x[0].length != numberSamples) {
throw new IllegalArgumentException("Provided number of samples is " + x[0].length + " (Supported values: 256 for 8000 sample rate, 512 for 16000)");
}
int batchSize = x.length;
int contextSize = 32;
if (sr == 16000) {
contextSize = 64;
}
if (lastBatchSize == 0) {
resetStates();
}
if (lastSr != 0 && lastSr != sr) {
resetStates();
}
if (lastBatchSize != 0 && lastBatchSize != batchSize) {
resetStates();
}
if (context.length == 0) {
context = new float[batchSize][contextSize];
}
x = concatenate(context, x);
OrtEnvironment env = OrtEnvironment.getEnvironment();
OnnxTensor inputTensor = null;
OnnxTensor stateTensor = null;
OnnxTensor srTensor = null;
OrtSession.Result ortOutputs = null;
try {
// Create input tensors
inputTensor = OnnxTensor.createTensor(env, x);
stateTensor = OnnxTensor.createTensor(env, state);
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input", inputTensor);
inputs.put("sr", srTensor);
inputs.put("state", stateTensor);
// Call the ONNX model for calculation
ortOutputs = session.run(inputs);
// Get the output results
float[][] output = (float[][]) ortOutputs.get(0).getValue();
state = (float[][][]) ortOutputs.get(1).getValue();
context = getLastColumns(x, contextSize);
lastSr = sr;
lastBatchSize = batchSize;
return output[0];
} finally {
if (inputTensor != null) {
inputTensor.close();
}
if (stateTensor != null) {
stateTensor.close();
}
if (srTensor != null) {
srTensor.close();
}
if (ortOutputs != null) {
ortOutputs.close();
}
}
}
}

View File

@@ -1,7 +1,6 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -18,17 +17,19 @@
"SAMPLING_RATE = 16000\n",
"import torch\n",
"from pprint import pprint\n",
"import time\n",
"import shutil\n",
"\n",
"torch.set_num_threads(1)\n",
"NUM_PROCESS=4 # set to the number of CPU cores in the machine\n",
"NUM_COPIES=8\n",
"# download wav files, make multiple copies\n",
"for idx in range(NUM_COPIES):\n",
" torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example{idx}.wav\")\n"
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example0.wav\")\n",
"for idx in range(NUM_COPIES-1):\n",
" shutil.copy(f\"en_example0.wav\", f\"en_example{idx+1}.wav\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -54,7 +55,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -99,7 +99,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -127,7 +126,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "diarization",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -141,7 +140,20 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.10.14"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,

View File

@@ -7,6 +7,8 @@ It has been designed as a low-level example for binary real-time streaming using
Currently, the notebook consits of two examples:
- One that records audio of a predefined length from the microphone, process it with Silero-VAD, and plots it afterwards.
- The other one plots the speech probabilities in real-time (using jupyterplot) and records the audio until you press enter.
This example does not work in google colab! For local usage only.
## Example Video for the Real-Time Visualization

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "62a0cccb",
"id": "76aa55ba",
"metadata": {},
"source": [
"# Pyaudio Microphone Streaming Examples\n",
@@ -12,12 +12,14 @@
"I created it as an example on how binary data from a stream could be feed into Silero VAD.\n",
"\n",
"\n",
"Has been tested on Ubuntu 21.04 (x86). After you installed the dependencies below, no additional setup is required."
"Has been tested on Ubuntu 21.04 (x86). After you installed the dependencies below, no additional setup is required.\n",
"\n",
"This notebook does not work in google colab! For local usage only."
]
},
{
"cell_type": "markdown",
"id": "64cbe1eb",
"id": "4a4e15c2",
"metadata": {},
"source": [
"## Dependencies\n",
@@ -26,22 +28,27 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "57bc2aac",
"metadata": {},
"execution_count": 1,
"id": "24205cce",
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-09T08:47:34.056898Z",
"start_time": "2024-10-09T08:47:34.053418Z"
}
},
"outputs": [],
"source": [
"#!pip install numpy==1.20.2\n",
"#!pip install torch==1.9.0\n",
"#!pip install matplotlib==3.4.2\n",
"#!pip install torchaudio==0.9.0\n",
"#!pip install soundfile==0.10.3.post1\n",
"#!pip install pyaudio==0.2.11"
"#!pip install numpy>=1.24.0\n",
"#!pip install torch>=1.12.0\n",
"#!pip install matplotlib>=3.6.0\n",
"#!pip install torchaudio>=0.12.0\n",
"#!pip install soundfile==0.12.1\n",
"#!apt install python3-pyaudio (linux) or pip install pyaudio (windows)"
]
},
{
"cell_type": "markdown",
"id": "110de761",
"id": "cd22818f",
"metadata": {},
"source": [
"## Imports"
@@ -49,10 +56,27 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a647d8d",
"metadata": {},
"outputs": [],
"execution_count": 2,
"id": "994d7f3a",
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-09T08:47:39.005032Z",
"start_time": "2024-10-09T08:47:36.489952Z"
}
},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'pyaudio'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpylab\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpyaudio\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pyaudio'"
]
}
],
"source": [
"import io\n",
"import numpy as np\n",
@@ -61,14 +85,13 @@
"import torchaudio\n",
"import matplotlib\n",
"import matplotlib.pylab as plt\n",
"torchaudio.set_audio_backend(\"soundfile\")\n",
"import pyaudio"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "725d7066",
"id": "ac5c52f7",
"metadata": {},
"outputs": [],
"source": [
@@ -80,7 +103,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "1c0b2ea7",
"id": "ad5919dc",
"metadata": {},
"outputs": [],
"source": [
@@ -93,7 +116,7 @@
},
{
"cell_type": "markdown",
"id": "f9112603",
"id": "784d1ab6",
"metadata": {},
"source": [
"### Helper Methods"
@@ -102,7 +125,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "5abc6330",
"id": "af4bca64",
"metadata": {},
"outputs": [],
"source": [
@@ -125,7 +148,7 @@
},
{
"cell_type": "markdown",
"id": "5124095e",
"id": "ca13e514",
"metadata": {},
"source": [
"## Pyaudio Set-up"
@@ -134,7 +157,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a845356e",
"id": "75f99022",
"metadata": {},
"outputs": [],
"source": [
@@ -148,7 +171,7 @@
},
{
"cell_type": "markdown",
"id": "0b910c99",
"id": "4da7d2ef",
"metadata": {},
"source": [
"## Simple Example\n",
@@ -158,17 +181,17 @@
{
"cell_type": "code",
"execution_count": null,
"id": "9d3d2c10",
"id": "6fe77661",
"metadata": {},
"outputs": [],
"source": [
"num_samples = 1536"
"num_samples = 512"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3cb44a4a",
"id": "23f4da3e",
"metadata": {},
"outputs": [],
"source": [
@@ -180,6 +203,8 @@
"data = []\n",
"voiced_confidences = []\n",
"\n",
"frames_to_record = 50\n",
"\n",
"print(\"Started Recording\")\n",
"for i in range(0, frames_to_record):\n",
" \n",
@@ -206,7 +231,7 @@
},
{
"cell_type": "markdown",
"id": "a3dda982",
"id": "fd243e8f",
"metadata": {},
"source": [
"## Real Time Visualization\n",
@@ -219,7 +244,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "05ef4100",
"id": "d36980c2",
"metadata": {},
"outputs": [],
"source": [
@@ -229,7 +254,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d1d4cdd6",
"id": "5607b616",
"metadata": {},
"outputs": [],
"source": [
@@ -286,7 +311,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "1e398009",
"id": "dc4f0108",
"metadata": {},
"outputs": [],
"source": [
@@ -296,7 +321,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -310,7 +335,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
"version": "3.10.14"
},
"toc": {
"base_numbering": 1,

View File

@@ -1,13 +1,12 @@
use crate::utils;
use ndarray::{Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
use std::path::Path;
#[derive(Debug)]
pub struct Silero {
session: ort::Session,
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
h: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
c: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
}
impl Silero {
@@ -16,20 +15,17 @@ impl Silero {
model_path: impl AsRef<Path>,
) -> Result<Self, ort::Error> {
let session = ort::Session::builder()?.commit_from_file(model_path)?;
let h = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
let c = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
Ok(Self {
session,
sample_rate,
h,
c,
state,
})
}
pub fn reset(&mut self) {
self.h = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
self.c = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
}
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
@@ -37,18 +33,17 @@ impl Silero {
.iter()
.map(|x| (*x as f32) / (i16::MAX as f32))
.collect::<Vec<_>>();
let frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
frame = frame.slice(s![.., ..480]).to_owned();
let inps = ort::inputs![
frame,
std::mem::take(&mut self.state),
self.sample_rate.clone(),
std::mem::take(&mut self.h),
std::mem::take(&mut self.c)
]?;
let res = self
.session
.run(ort::SessionInputs::ValueSlice::<4>(&inps))?;
self.h = res["hn"].try_extract_tensor().unwrap().to_owned();
self.c = res["cn"].try_extract_tensor().unwrap().to_owned();
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
Ok(*res["output"]
.try_extract_raw_tensor::<f32>()
.unwrap()

View File

@@ -20,7 +20,7 @@ impl VadIter {
pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> {
self.reset_states();
for audio_frame in samples.chunks_exact(self.params.frame_size_samples) {
let speech_prob = self.silero.calc_level(audio_frame)?;
let speech_prob: f32 = self.silero.calc_level(audio_frame)?;
self.state.update(&self.params, speech_prob);
}
self.state.check_for_last_speech(samples.len());

View File

@@ -1 +0,0 @@
{"59": "mg, Malagasy", "76": "tk, Turkmen", "20": "lb, Luxembourgish, Letzeburgesch", "62": "or, Oriya", "30": "en, English", "26": "oc, Occitan", "69": "no, Norwegian", "77": "sr, Serbian", "90": "bs, Bosnian", "71": "el, Greek, Modern (1453\u2013)", "15": "az, Azerbaijani", "12": "lo, Lao", "85": "zh-HK, Chinese", "79": "cs, Czech", "43": "sv, Swedish", "37": "mn, Mongolian", "32": "fi, Finnish", "51": "tg, Tajik", "46": "am, Amharic", "17": "nn, Norwegian Nynorsk", "40": "ja, Japanese", "8": "it, Italian", "21": "ha, Hausa", "11": "as, Assamese", "29": "fa, Persian", "82": "bn, Bengali", "54": "mk, Macedonian", "31": "sw, Swahili", "45": "vi, Vietnamese", "41": "ur, Urdu", "74": "bo, Tibetan", "4": "hi, Hindi", "86": "mr, Marathi", "3": "fy-NL, Western Frisian", "65": "sk, Slovak", "2": "ln, Lingala", "92": "gl, Galician", "53": "sn, Shona", "87": "su, Sundanese", "35": "tt, Tatar", "93": "kn, Kannada", "6": "yo, Yoruba", "27": "ps, Pashto, Pushto", "34": "hy, Armenian", "25": "pa-IN, Punjabi, Panjabi", "23": "nl, Dutch, Flemish", "48": "th, Thai", "73": "mt, Maltese", "55": "ar, Arabic", "89": "ba, Bashkir", "78": "bg, Bulgarian", "42": "yi, Yiddish", "5": "ru, Russian", "84": "sv-SE, Swedish", "80": "tr, Turkish", "33": "sq, Albanian", "38": "kk, Kazakh", "50": "pl, Polish", "9": "hr, Croatian", "66": "ky, Kirghiz, Kyrgyz", "49": "hu, Hungarian", "10": "si, Sinhala, Sinhalese", "56": "la, Latin", "75": "de, German", "14": "ko, Korean", "22": "id, Indonesian", "47": "sl, Slovenian", "57": "be, Belarusian", "36": "ta, Tamil", "7": "da, Danish", "91": "sd, Sindhi", "28": "et, Estonian", "63": "pt, Portuguese", "60": "ne, Nepali", "94": "zh-TW, Chinese", "18": "zh-CN, Chinese", "88": "rw, Kinyarwanda", "19": "es, Spanish, Castilian", "39": "ht, Haitian, Haitian Creole", "64": "tl, Tagalog", "83": "ms, Malay", "70": "ro, Romanian, Moldavian, Moldovan", "68": "pa, Punjabi, Panjabi", "52": "uz, Uzbek", "58": "km, Central Khmer", "67": "my, Burmese", "0": "fr, French", "24": "af, Afrikaans", "16": "gu, Gujarati", "81": "so, Somali", "13": "uk, Ukrainian", "44": "ca, Catalan, Valencian", "72": "ml, Malayalam", "61": "te, Telugu", "1": "zh, Chinese"}

View File

@@ -1 +0,0 @@
{"0": ["Afrikaans", "Dutch, Flemish", "Western Frisian"], "1": ["Turkish", "Azerbaijani"], "2": ["Russian", "Slovak", "Ukrainian", "Czech", "Polish", "Belarusian"], "3": ["Bulgarian", "Macedonian", "Serbian", "Croatian", "Bosnian", "Slovenian"], "4": ["Norwegian Nynorsk", "Swedish", "Danish", "Norwegian"], "5": ["English"], "6": ["Finnish", "Estonian"], "7": ["Yiddish", "Luxembourgish, Letzeburgesch", "German"], "8": ["Spanish", "Occitan", "Portuguese", "Catalan, Valencian", "Galician", "Spanish, Castilian", "Italian"], "9": ["Maltese", "Arabic"], "10": ["Marathi"], "11": ["Hindi", "Urdu"], "12": ["Lao", "Thai"], "13": ["Malay", "Indonesian"], "14": ["Romanian, Moldavian, Moldovan"], "15": ["Tagalog"], "16": ["Tajik", "Persian"], "17": ["Kazakh", "Uzbek", "Kirghiz, Kyrgyz"], "18": ["Kinyarwanda"], "19": ["Tatar", "Bashkir"], "20": ["French"], "21": ["Chinese"], "22": ["Lingala"], "23": ["Yoruba"], "24": ["Sinhala, Sinhalese"], "25": ["Assamese"], "26": ["Korean"], "27": ["Gujarati"], "28": ["Hausa"], "29": ["Punjabi, Panjabi"], "30": ["Pashto, Pushto"], "31": ["Swahili"], "32": ["Albanian"], "33": ["Armenian"], "34": ["Mongolian"], "35": ["Tamil"], "36": ["Haitian, Haitian Creole"], "37": ["Japanese"], "38": ["Vietnamese"], "39": ["Amharic"], "40": ["Hungarian"], "41": ["Shona"], "42": ["Latin"], "43": ["Central Khmer"], "44": ["Malagasy"], "45": ["Nepali"], "46": ["Telugu"], "47": ["Oriya"], "48": ["Burmese"], "49": ["Greek, Modern (1453\u2013)"], "50": ["Malayalam"], "51": ["Tibetan"], "52": ["Turkmen"], "53": ["Somali"], "54": ["Bengali"], "55": ["Sundanese"], "56": ["Sindhi"], "57": ["Kannada"]}

Binary file not shown.

View File

@@ -1,16 +1,15 @@
dependencies = ['torch', 'torchaudio']
import torch
import json
import os
from utils_vad import (init_jit_model,
get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks,
drop_chunks,
Validator,
OnnxWrapper)
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from silero_vad.utils_vad import (init_jit_model,
get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks,
OnnxWrapper)
def versiontuple(v):
@@ -24,11 +23,14 @@ def versiontuple(v):
return tuple(version_list)
def silero_vad(onnx=False, force_onnx_cpu=False):
def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16):
"""Silero Voice Activity Detector
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
available_ops = [15, 16]
if onnx and opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}')
if not onnx:
installed_version = torch.__version__
@@ -36,9 +38,13 @@ def silero_vad(onnx=False, force_onnx_cpu=False):
if versiontuple(installed_version) < versiontuple(supported_version):
raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)')
model_dir = os.path.join(os.path.dirname(__file__), 'files')
model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data')
if onnx:
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
if opset_version == 16:
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
model = OnnxWrapper(os.path.join(model_dir, model_name), force_onnx_cpu)
else:
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
utils = (get_speech_timestamps,

46
pyproject.toml Normal file
View File

@@ -0,0 +1,46 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "silero-vad"
version = "6.2.0"
authors = [
{name="Silero Team", email="hello@silero.ai"},
]
description = "Voice Activity Detector (VAD) by Silero"
readme = "README.md"
requires-python = ">=3.8"
classifiers = [
"Development Status :: 5 - Production/Stable",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering",
]
dependencies = [
"packaging",
"torch>=1.12.0",
"torchaudio>=0.12.0",
"onnxruntime>=1.16.1",
]
[project.urls]
Homepage = "https://github.com/snakers4/silero-vad"
Issues = "https://github.com/snakers4/silero-vad/issues"
[project.optional-dependencies]
test = [
"pytest",
"soundfile",
"torch<2.9",
]

View File

@@ -43,20 +43,30 @@
},
"outputs": [],
"source": [
"USE_PIP = True # download model using pip package or torch.hub\n",
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
"if USE_PIP:\n",
" !pip install -q silero-vad\n",
" from silero_vad import (load_silero_vad,\n",
" read_audio,\n",
" get_speech_timestamps,\n",
" save_audio,\n",
" VADIterator,\n",
" collect_chunks)\n",
" model = load_silero_vad(onnx=USE_ONNX)\n",
"else:\n",
" model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n",
"(get_speech_timestamps,\n",
" save_audio,\n",
" read_audio,\n",
" VADIterator,\n",
" collect_chunks) = utils"
" (get_speech_timestamps,\n",
" save_audio,\n",
" read_audio,\n",
" VADIterator,\n",
" collect_chunks) = utils"
]
},
{

View File

@@ -0,0 +1,13 @@
from importlib.metadata import version
try:
__version__ = version(__name__)
except:
pass
from silero_vad.model import load_silero_vad
from silero_vad.utils_vad import (get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks,
drop_chunks)

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

36
src/silero_vad/model.py Normal file
View File

@@ -0,0 +1,36 @@
from .utils_vad import init_jit_model, OnnxWrapper
import torch
torch.set_num_threads(1)
def load_silero_vad(onnx=False, opset_version=16):
available_ops = [15, 16]
if onnx and opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}')
if onnx:
if opset_version == 16:
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
else:
model_name = 'silero_vad.jit'
package_path = "silero_vad.data"
try:
import importlib_resources as impresources
model_file_path = str(impresources.files(package_path).joinpath(model_name))
except:
from importlib import resources as impresources
try:
with impresources.path(package_path, model_name) as f:
model_file_path = f
except:
model_file_path = str(impresources.files(package_path).joinpath(model_name))
if onnx:
model = OnnxWrapper(str(model_file_path), force_onnx_cpu=True)
else:
model = init_jit_model(model_file_path)
return model

View File

@@ -2,6 +2,7 @@ import torch
import torchaudio
from typing import Callable, List
import warnings
from packaging import version
languages = ['ru', 'en', 'de', 'es']
@@ -23,7 +24,11 @@ class OnnxWrapper():
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.reset_states()
self.sample_rates = [8000, 16000]
if '16k' in path:
warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000]
else:
self.sample_rates = [8000, 16000]
def _validate_input(self, x, sr: int):
if x.dim() == 1:
@@ -53,10 +58,10 @@ class OnnxWrapper():
x, sr = self._validate_input(x, sr)
num_samples = 512 if sr == 16000 else 256
if x.shape[-1] != num_samples:
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
batch_size = x.shape[0]
context_size = 64 if sr == 16000 else 32
@@ -130,39 +135,60 @@ class Validator():
return outs
def read_audio(path: str,
sampling_rate: int = 16000):
sox_backends = set(['sox', 'sox_io'])
audio_backends = torchaudio.list_audio_backends()
if len(sox_backends.intersection(audio_backends)) > 0:
effects = [
['channels', '1'],
['rate', str(sampling_rate)]
]
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def read_audio(path: str, sampling_rate: int = 16000) -> torch.Tensor:
ta_ver = version.parse(torchaudio.__version__)
if ta_ver < version.parse("2.9"):
try:
effects = [['channels', '1'],['rate', str(sampling_rate)]]
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
except:
wav, sr = torchaudio.load(path)
else:
wav, sr = torchaudio.load(path)
try:
wav, sr = torchaudio.load(path)
except:
try:
from torchcodec.decoders import AudioDecoder
samples = AudioDecoder(path).get_all_samples()
wav = samples.data
sr = samples.sample_rate
except ImportError:
raise RuntimeError(
f"torchaudio version {torchaudio.__version__} requires torchcodec for audio I/O. "
+ "Install torchcodec or pin torchaudio < 2.9"
)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if wav.ndim > 1 and wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != sampling_rate:
transform = torchaudio.transforms.Resample(orig_freq=sr,
new_freq=sampling_rate)
wav = transform(wav)
sr = sampling_rate
if sr != sampling_rate:
wav = torchaudio.transforms.Resample(sr, sampling_rate)(wav)
assert sr == sampling_rate
return wav.squeeze(0)
def save_audio(path: str,
tensor: torch.Tensor,
sampling_rate: int = 16000):
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16)
def save_audio(path: str, tensor: torch.Tensor, sampling_rate: int = 16000):
tensor = tensor.detach().cpu()
if tensor.ndim == 1:
tensor = tensor.unsqueeze(0)
ta_ver = version.parse(torchaudio.__version__)
try:
torchaudio.save(path, tensor, sampling_rate, bits_per_sample=16)
except Exception:
if ta_ver >= version.parse("2.9"):
try:
from torchcodec.encoders import AudioEncoder
encoder = AudioEncoder(tensor, sample_rate=16000)
encoder.to_file(path)
except ImportError:
raise RuntimeError(
f"torchaudio version {torchaudio.__version__} requires torchcodec for saving. "
+ "Install torchcodec or pin torchaudio < 2.9"
)
else:
raise
def init_jit_model(model_path: str,
@@ -192,9 +218,13 @@ def get_speech_timestamps(audio: torch.Tensor,
min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30,
return_seconds: bool = False,
time_resolution: int = 1,
visualize_probs: bool = False,
progress_tracking_callback: Callable[[float], None] = None,
window_size_samples: int = 512,):
neg_threshold: float = None,
window_size_samples: int = 512,
min_silence_at_max_speech: int = 98,
use_max_poss_sil_at_max_speech: bool = True):
"""
This method is used for splitting long audios into speech chunks using silero VAD
@@ -218,7 +248,7 @@ def get_speech_timestamps(audio: torch.Tensor,
max_speech_duration_s: int (default - inf)
Maximum duration of speech chunks in seconds
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting.
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent aggressive cutting.
Otherwise, they will be split aggressively just before max_speech_duration_s.
min_silence_duration_ms: int (default - 100 milliseconds)
@@ -230,12 +260,24 @@ def get_speech_timestamps(audio: torch.Tensor,
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
time_resolution: bool (default - 1)
time resolution of speech coordinates when requested as seconds
visualize_probs: bool (default - False)
whether draw prob hist or not
progress_tracking_callback: Callable[[float], None] (default - None)
callback function taking progress in percents as an argument
neg_threshold: float (default = threshold - 0.15)
Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH.
min_silence_at_max_speech: int (default - 98ms)
Minimum silence duration in ms which is used to avoid abrupt cuts when max_speech_duration_s is reached
use_max_poss_sil_at_max_speech: bool (default - True)
Whether to use the maximum possible silence at max_speech_duration_s or not. If not, the last silence is used.
window_size_samples: int (default - 512 samples)
!!! DEPRECATED, DOES NOTHING !!!
@@ -244,7 +286,6 @@ def get_speech_timestamps(audio: torch.Tensor,
speeches: list of dicts
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
"""
if not torch.is_tensor(audio):
try:
audio = torch.Tensor(audio)
@@ -275,7 +316,7 @@ def get_speech_timestamps(audio: torch.Tensor,
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
min_silence_samples_at_max_speech = sampling_rate * min_silence_at_max_speech / 1000
audio_length_samples = len(audio)
@@ -286,7 +327,7 @@ def get_speech_timestamps(audio: torch.Tensor,
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob = model(chunk, sampling_rate).item()
speech_probs.append(speech_prob)
# caculate progress and seng it to callback function
# calculate progress and send it to callback function
progress = current_start_sample + window_size_samples
if progress > audio_length_samples:
progress = audio_length_samples
@@ -297,45 +338,78 @@ def get_speech_timestamps(audio: torch.Tensor,
triggered = False
speeches = []
current_speech = {}
neg_threshold = threshold - 0.15
temp_end = 0 # to save potential segment end (and tolerate some silence)
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
if neg_threshold is None:
neg_threshold = max(threshold - 0.15, 0.01)
temp_end = 0 # to save potential segment end (and tolerate some silence)
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
possible_ends = []
for i, speech_prob in enumerate(speech_probs):
cur_sample = window_size_samples * i
# If speech returns after a temp_end, record candidate silence if long enough and clear temp_end
if (speech_prob >= threshold) and temp_end:
sil_dur = cur_sample - temp_end
if sil_dur > min_silence_samples_at_max_speech:
possible_ends.append((temp_end, sil_dur))
temp_end = 0
if next_start < prev_end:
next_start = window_size_samples * i
next_start = cur_sample
# Start of speech
if (speech_prob >= threshold) and not triggered:
triggered = True
current_speech['start'] = window_size_samples * i
current_speech['start'] = cur_sample
continue
if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples:
if prev_end:
# Max speech length reached: decide where to cut
if triggered and (cur_sample - current_speech['start'] > max_speech_samples):
if use_max_poss_sil_at_max_speech and possible_ends:
prev_end, dur = max(possible_ends, key=lambda x: x[1]) # use the longest possible silence segment in the current speech chunk
current_speech['end'] = prev_end
speeches.append(current_speech)
current_speech = {}
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
triggered = False
else:
current_speech['start'] = next_start
prev_end = next_start = temp_end = 0
else:
current_speech['end'] = window_size_samples * i
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
next_start = prev_end + dur
if next_start < prev_end + cur_sample: # previously reached silence (< neg_thres) and is still not speech (< thres)
current_speech['start'] = next_start
else:
triggered = False
prev_end = next_start = temp_end = 0
possible_ends = []
else:
# Legacy max-speech cut (use_max_poss_sil_at_max_speech=False): prefer last valid silence (prev_end) if available
if prev_end:
current_speech['end'] = prev_end
speeches.append(current_speech)
current_speech = {}
if next_start < prev_end:
triggered = False
else:
current_speech['start'] = next_start
prev_end = next_start = temp_end = 0
possible_ends = []
else:
# No prev_end -> fallback to cutting at current sample
current_speech['end'] = cur_sample
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
possible_ends = []
continue
# Silence detection while in speech
if (speech_prob < neg_threshold) and triggered:
if not temp_end:
temp_end = window_size_samples * i
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence
temp_end = cur_sample
sil_dur_now = cur_sample - temp_end
if not use_max_poss_sil_at_max_speech and sil_dur_now > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
prev_end = temp_end
if (window_size_samples * i) - temp_end < min_silence_samples:
if sil_dur_now < min_silence_samples:
continue
else:
current_speech['end'] = temp_end
@@ -344,6 +418,7 @@ def get_speech_timestamps(audio: torch.Tensor,
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
possible_ends = []
continue
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
@@ -365,9 +440,10 @@ def get_speech_timestamps(audio: torch.Tensor,
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
if return_seconds:
audio_length_seconds = audio_length_samples / sampling_rate
for speech_dict in speeches:
speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1)
speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1)
speech_dict['start'] = max(round(speech_dict['start'] / sampling_rate, time_resolution), 0)
speech_dict['end'] = min(round(speech_dict['end'] / sampling_rate, time_resolution), audio_length_seconds)
elif step > 1:
for speech_dict in speeches:
speech_dict['start'] *= step
@@ -428,13 +504,16 @@ class VADIterator:
self.current_sample = 0
@torch.no_grad()
def __call__(self, x, return_seconds=False):
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
"""
x: torch.Tensor
audio chunk (see examples in repo)
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
time_resolution: int (default - 1)
time resolution of speech coordinates when requested as seconds
"""
if not torch.is_tensor(x):
@@ -453,8 +532,8 @@ class VADIterator:
if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True
speech_start = self.current_sample - self.speech_pad_samples - window_size_samples
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end:
@@ -465,24 +544,112 @@ class VADIterator:
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
self.temp_end = 0
self.triggered = False
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
return None
def collect_chunks(tss: List[dict],
wav: torch.Tensor):
chunks = []
for i in tss:
chunks.append(wav[i['start']: i['end']])
wav: torch.Tensor,
seconds: bool = False,
sampling_rate: int = None) -> torch.Tensor:
"""Collect audio chunks from a longer audio clip
This method extracts audio chunks from an audio clip, using a list of
provided coordinates, and concatenates them together. Coordinates can be
passed either as sample numbers or in seconds, in which case the audio
sampling rate is also needed.
Parameters
----------
tss: List[dict]
Coordinate list of the clips to collect from the audio.
wav: torch.Tensor, one dimensional
One dimensional float torch.Tensor, containing the audio to clip.
seconds: bool (default - False)
Whether input coordinates are passed as seconds or samples.
sampling_rate: int (default - None)
Input audio sampling rate. Required if seconds is True.
Returns
-------
torch.Tensor, one dimensional
One dimensional float torch.Tensor of the concatenated clipped audio
chunks.
Raises
------
ValueError
Raised if sampling_rate is not provided when seconds is True.
"""
if seconds and not sampling_rate:
raise ValueError('sampling_rate must be provided when seconds is True')
chunks = list()
_tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss
for i in _tss:
chunks.append(wav[i['start']:i['end']])
return torch.cat(chunks)
def drop_chunks(tss: List[dict],
wav: torch.Tensor):
chunks = []
wav: torch.Tensor,
seconds: bool = False,
sampling_rate: int = None) -> torch.Tensor:
"""Drop audio chunks from a longer audio clip
This method extracts audio chunks from an audio clip, using a list of
provided coordinates, and drops them. Coordinates can be passed either as
sample numbers or in seconds, in which case the audio sampling rate is also
needed.
Parameters
----------
tss: List[dict]
Coordinate list of the clips to drop from from the audio.
wav: torch.Tensor, one dimensional
One dimensional float torch.Tensor, containing the audio to clip.
seconds: bool (default - False)
Whether input coordinates are passed as seconds or samples.
sampling_rate: int (default - None)
Input audio sampling rate. Required if seconds is True.
Returns
-------
torch.Tensor, one dimensional
One dimensional float torch.Tensor of the input audio minus the dropped
chunks.
Raises
------
ValueError
Raised if sampling_rate is not provided when seconds is True.
"""
if seconds and not sampling_rate:
raise ValueError('sampling_rate must be provided when seconds is True')
chunks = list()
cur_start = 0
for i in tss:
_tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss
for i in _tss:
chunks.append((wav[cur_start: i['start']]))
cur_start = i['end']
chunks.append(wav[cur_start:])
return torch.cat(chunks)
def _seconds_to_samples_tss(tss: List[dict], sampling_rate: int) -> List[dict]:
"""Convert coordinates expressed in seconds to sample coordinates.
"""
return [{
'start': round(crd['start']) * sampling_rate,
'end': round(crd['end']) * sampling_rate
} for crd in tss]

BIN
tests/data/test.mp3 Normal file

Binary file not shown.

BIN
tests/data/test.opus Normal file

Binary file not shown.

BIN
tests/data/test.wav Normal file

Binary file not shown.

22
tests/test_basic.py Normal file
View File

@@ -0,0 +1,22 @@
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
import torch
torch.set_num_threads(1)
def test_jit_model():
model = load_silero_vad(onnx=False)
for path in ["tests/data/test.wav", "tests/data/test.opus", "tests/data/test.mp3"]:
audio = read_audio(path, sampling_rate=16000)
speech_timestamps = get_speech_timestamps(audio, model, visualize_probs=False, return_seconds=True)
assert speech_timestamps is not None
out = model.audio_forward(audio, sr=16000)
assert out is not None
def test_onnx_model():
model = load_silero_vad(onnx=True)
for path in ["tests/data/test.wav", "tests/data/test.opus", "tests/data/test.mp3"]:
audio = read_audio(path, sampling_rate=16000)
speech_timestamps = get_speech_timestamps(audio, model, visualize_probs=False, return_seconds=True)
assert speech_timestamps is not None
out = model.audio_forward(audio, sr=16000)
assert out is not None

74
tuning/README.md Normal file
View File

@@ -0,0 +1,74 @@
# Тюнинг Silero-VAD модели
> Код тюнинга создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный
интеллект» национальной программы «Цифровая экономика Российской Федерации».
Тюнинг используется для улучшения качества детекции речи Silero-VAD модели на кастомных данных.
## Зависимости
Следующие зависимости используются при тюнинге VAD модели:
- `torchaudio>=0.12.0`
- `omegaconf>=2.3.0`
- `sklearn>=1.2.0`
- `torch>=1.12.0`
- `pandas>=2.2.2`
- `tqdm`
## Подготовка данных
Датафреймы для тюнинга должны быть подготовлены и сохранены в формате `.feather`. Следующие колонки в `.feather` файлах тренировки и валидации являются обязательными:
- **audio_path** - абсолютный путь до аудиофайла в дисковой системе. Аудиофайлы должны представлять собой `PCM` данные, предпочтительно в форматах `.wav` или `.opus` (иные популярные форматы аудио тоже поддерживаются). Для ускорения темпа дообучения рекомендуется предварительно выполнить ресемплинг аудиофайлов (изменить частоту дискретизации) до 16000 Гц;
- **speech_ts** - разметка для соответствующего аудиофайла. Список, состоящий из словарей формата `{'start': START_SEC, 'end': 'END_SEC'}`, где `START_SEC` и `END_SEC` - время начало и конца речевого отрезка в секундах соответственно. Для качественного дообучения рекомендуется использовать разметку с точностью до 30 миллисекунд.
Чем больше данных используется на этапе дообучения, тем эффективнее показывает себя адаптированная модель на целевом домене. Длина аудио не ограничена, т.к. каждое аудио будет обрезано до `max_train_length_sec` секунд перед подачей в нейросеть. Длинные аудио лучше предварительно порезать на кусочки длины `max_train_length_sec`.
Пример `.feather` датафрейма можно посмотреть в файле `example_dataframe.feather`
## Файл конфигурации `config.yml`
Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения:
- `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
- `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
- `jit_model_path` - абсолютный путь до Silero-VAD модели в формате `.jit`. Если оставить это поле пустым, то модель будет загружена из репозитория в зависимости от значения поля `use_torchhub`
- `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`);
- `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц;
- `model_save_path` - путь сохранения добученной модели;
- `noise_loss` - коэффициент лосса, применяемый для неречевых окон аудио;
- `max_train_length_sec` - максимальная длина аудио в секундах на этапе дообучения. Более длительные аудио будут обрезаны до этого показателя;
- `aug_prob` - вероятность применения аугментаций к аудиофайлу на этапе дообучения;
- `learning_rate` - темп дообучения;
- `batch_size` - размер батча при дообучении и валидации;
- `num_workers` - количество потоков, используемых для загрузки данных;
- `num_epochs` - количество эпох дообучения. За одну эпоху прогоняются все тренировочные данные;
- `device` - `cpu` или `cuda`.
## Дообучение
Дообучение запускается командой
`python tune.py`
Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit.
## Поиск пороговых значений
Порог на вход и порог на выход можно подобрать, используя команду
`python search_thresholds`
Данный скрипт использует файл конфигурации, описанный выше. Указанная в конфигурации модель будет использована для поиска оптимальных порогов на валидационном датасете.
## Цитирование
```
@misc{Silero VAD,
author = {Silero Team},
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/snakers4/silero-vad}},
commit = {insert_some_commit_here},
email = {hello@silero.ai}
}
```

0
tuning/__init__.py Normal file
View File

17
tuning/config.yml Normal file
View File

@@ -0,0 +1,17 @@
jit_model_path: '' # путь до Silero-VAD модели в формате jit, эта модель будет использована для дообучения. Если оставить поле пустым, то модель будет загружена автоматически
use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False
tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True
train_dataset_path: 'train_dataset_path.feather' # путь до датасета в формате feather для дообучения, подробности в README
val_dataset_path: 'val_dataset_path.feather' # путь до датасета в формате feather для валидации, подробности в README
model_save_path: 'model_save_path.jit' # путь сохранения дообученной модели
noise_loss: 0.5 # коэффициент, применяемый к лоссу на неречевых окнах
max_train_length_sec: 8 # во время тюнинга аудио длиннее будут обрезаны до данного значения
aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения
learning_rate: 5e-4 # темп дообучения модели
batch_size: 128 # размер батча при дообучении и валидации
num_workers: 4 # количество потоков, используемых для даталоадеров
num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных
device: 'cuda' # cpu или cuda, на чем будет производится дообучение

Binary file not shown.

View File

@@ -0,0 +1,36 @@
from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder
from omegaconf import OmegaConf
import torch
torch.set_num_threads(1)
if __name__ == '__main__':
config = OmegaConf.load('config.yml')
loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'),
batch_size=config.batch_size,
collate_fn=SileroVadPadder,
num_workers=config.num_workers)
if config.jit_model_path:
print(f'Loading model from the local folder: {config.jit_model_path}')
model = init_jit_model(config.jit_model_path, device=config.device)
else:
if config.use_torchhub:
print('Loading model using torch.hub')
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=False,
force_reload=True)
else:
print('Loading model using silero-vad library')
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=False)
print('Model loaded')
model.to(config.device)
print('Making predicts...')
all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000)
print('Calculating thresholds...')
best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts)
print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')

65
tuning/tune.py Normal file
View File

@@ -0,0 +1,65 @@
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
from omegaconf import OmegaConf
import torch.nn as nn
import torch
if __name__ == '__main__':
config = OmegaConf.load('config.yml')
train_dataset = SileroVadDataset(config, mode='train')
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=config.batch_size,
collate_fn=SileroVadPadder,
num_workers=config.num_workers)
val_dataset = SileroVadDataset(config, mode='val')
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=config.batch_size,
collate_fn=SileroVadPadder,
num_workers=config.num_workers)
if config.jit_model_path:
print(f'Loading model from the local folder: {config.jit_model_path}')
model = init_jit_model(config.jit_model_path, device=config.device)
else:
if config.use_torchhub:
print('Loading model using torch.hub')
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=False,
force_reload=True)
else:
print('Loading model using silero-vad library')
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=False)
print('Model loaded')
model.to(config.device)
decoder = VADDecoderRNNJIT().to(config.device)
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())
decoder.train()
params = decoder.parameters()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params),
lr=config.learning_rate)
criterion = nn.BCELoss(reduction='none')
best_val_roc = 0
for i in range(config.num_epochs):
print(f'Starting epoch {i + 1}')
train_loss = train(config, train_loader, model, decoder, criterion, optimizer, config.device)
val_loss, val_roc = validate(config, val_loader, model, decoder, criterion, config.device)
print(f'Metrics after epoch {i + 1}:\n'
f'\tTrain loss: {round(train_loss, 3)}\n',
f'\tValidation loss: {round(val_loss, 3)}\n'
f'\tValidation ROC-AUC: {round(val_roc, 3)}')
if val_roc > best_val_roc:
print('New best ROC-AUC, saving model')
best_val_roc = val_roc
if config.tune_8k:
model._model_8k.decoder.load_state_dict(decoder.state_dict())
else:
model._model.decoder.load_state_dict(decoder.state_dict())
torch.jit.save(model, config.model_save_path)
print('Done')

356
tuning/utils.py Normal file
View File

@@ -0,0 +1,356 @@
from sklearn.metrics import roc_auc_score, accuracy_score
from torch.utils.data import Dataset
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
import numpy as np
import torchaudio
import warnings
import random
import torch
import gc
warnings.filterwarnings('ignore')
def read_audio(path: str,
sampling_rate: int = 16000,
normalize=False):
wav, sr = torchaudio.load(path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sampling_rate:
if sr != sampling_rate:
transform = torchaudio.transforms.Resample(orig_freq=sr,
new_freq=sampling_rate)
wav = transform(wav)
sr = sampling_rate
if normalize and wav.abs().max() != 0:
wav = wav / wav.abs().max()
return wav.squeeze(0)
def build_audiomentations_augs(p):
from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \
LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \
Aliasing, AddGaussianNoise
transforms = [Aliasing(p=1),
AddGaussianNoise(p=1),
AirAbsorption(p=1),
BandPassFilter(p=1),
BandStopFilter(p=1),
ClippingDistortion(p=1),
HighPassFilter(p=1),
HighShelfFilter(p=1),
LowPassFilter(p=1),
LowShelfFilter(p=1),
Mp3Compression(p=1),
PeakingFilter(p=1),
PitchShift(p=1),
RoomSimulator(p=1, leave_length_unchanged=True),
SevenBandParametricEQ(p=1)]
tr = SomeOf((1, 3), transforms=transforms, p=p)
return tr
class SileroVadDataset(Dataset):
def __init__(self,
config,
mode='train'):
self.num_samples = 512 # constant, do not change
self.sr = 16000 # constant, do not change
self.resample_to_8k = config.tune_8k
self.noise_loss = config.noise_loss
self.max_train_length_sec = config.max_train_length_sec
self.max_train_length_samples = config.max_train_length_sec * self.sr
assert self.max_train_length_samples % self.num_samples == 0
assert mode in ['train', 'val']
dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path
self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True)
self.index_dict = self.dataframe.to_dict('index')
self.mode = mode
print(f'DATASET SIZE : {len(self.dataframe)}')
if mode == 'train':
self.augs = build_audiomentations_augs(p=config.aug_prob)
else:
self.augs = None
def __getitem__(self, idx):
idx = None if self.mode == 'train' else idx
wav, gt, mask = self.load_speech_sample(idx)
if self.mode == 'train':
wav = self.add_augs(wav)
if len(wav) > self.max_train_length_samples:
wav = wav[:self.max_train_length_samples]
gt = gt[:int(self.max_train_length_samples / self.num_samples)]
mask = mask[:int(self.max_train_length_samples / self.num_samples)]
wav = torch.FloatTensor(wav)
if self.resample_to_8k:
transform = torchaudio.transforms.Resample(orig_freq=self.sr,
new_freq=8000)
wav = transform(wav)
return wav, torch.FloatTensor(gt), torch.from_numpy(mask)
def __len__(self):
return len(self.index_dict)
def load_speech_sample(self, idx=None):
if idx is None:
idx = random.randint(0, len(self.index_dict) - 1)
wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy()
if len(wav) % self.num_samples != 0:
pad_num = self.num_samples - (len(wav) % (self.num_samples))
wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0)
gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav))
assert len(gt) == len(wav) / self.num_samples
return wav, gt, mask
def get_ground_truth_annotated(self, annotation, audio_length_samples):
gt = np.zeros(audio_length_samples)
for i in annotation:
gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1
squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1)
squeezed_predicts = (squeezed_predicts > 0.5).astype(int)
mask = np.ones(len(squeezed_predicts))
mask[squeezed_predicts == 0] = self.noise_loss
return squeezed_predicts, mask
def add_augs(self, wav):
while True:
try:
wav_aug = self.augs(wav, self.sr)
if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()):
return wav
return wav_aug
except Exception as e:
continue
def SileroVadPadder(batch):
wavs = [batch[i][0] for i in range(len(batch))]
labels = [batch[i][1] for i in range(len(batch))]
masks = [batch[i][2] for i in range(len(batch))]
wavs = torch.nn.utils.rnn.pad_sequence(
wavs, batch_first=True, padding_value=0)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=0)
masks = torch.nn.utils.rnn.pad_sequence(
masks, batch_first=True, padding_value=0)
return wavs, labels, masks
class VADDecoderRNNJIT(nn.Module):
def __init__(self):
super(VADDecoderRNNJIT, self).__init__()
self.rnn = nn.LSTMCell(128, 128)
self.decoder = nn.Sequential(nn.Dropout(0.1),
nn.ReLU(),
nn.Conv1d(128, 1, kernel_size=1),
nn.Sigmoid())
def forward(self, x, state=torch.zeros(0)):
x = x.squeeze(-1)
if len(state):
h, c = self.rnn(x, (state[0], state[1]))
else:
h, c = self.rnn(x)
x = h.unsqueeze(-1).float()
state = torch.stack([h, c])
x = self.decoder(x)
return x, state
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def train(config,
loader,
jit_model,
decoder,
criterion,
optimizer,
device):
losses = AverageMeter()
decoder.train()
context_size = 32 if config.tune_8k else 64
num_samples = 256 if config.tune_8k else 512
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
with torch.enable_grad():
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
targets = targets.to(device)
x = x.to(device)
masks = masks.to(device)
x = torch.nn.functional.pad(x, (context_size, 0))
outs = []
state = torch.zeros(0)
for i in range(context_size, x.shape[1], num_samples):
input_ = x[:, i-context_size:i+num_samples]
out = stft_layer(input_)
out = encoder_layer(out)
out, state = decoder(out, state)
outs.append(out)
stacked = torch.cat(outs, dim=2).squeeze(1)
loss = criterion(stacked, targets)
loss = (loss * masks).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.item(), masks.numel())
torch.cuda.empty_cache()
gc.collect()
return losses.avg
def validate(config,
loader,
jit_model,
decoder,
criterion,
device):
losses = AverageMeter()
decoder.eval()
predicts = []
gts = []
context_size = 32 if config.tune_8k else 64
num_samples = 256 if config.tune_8k else 512
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
with torch.no_grad():
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
targets = targets.to(device)
x = x.to(device)
masks = masks.to(device)
x = torch.nn.functional.pad(x, (context_size, 0))
outs = []
state = torch.zeros(0)
for i in range(context_size, x.shape[1], num_samples):
input_ = x[:, i-context_size:i+num_samples]
out = stft_layer(input_)
out = encoder_layer(out)
out, state = decoder(out, state)
outs.append(out)
stacked = torch.cat(outs, dim=2).squeeze(1)
predicts.extend(stacked[masks != 0].tolist())
gts.extend(targets[masks != 0].tolist())
loss = criterion(stacked, targets)
loss = (loss * masks).mean()
losses.update(loss.item(), masks.numel())
score = roc_auc_score(gts, predicts)
torch.cuda.empty_cache()
gc.collect()
return losses.avg, round(score, 3)
def init_jit_model(model_path: str,
device=torch.device('cpu')):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
def predict(model, loader, device, sr):
with torch.no_grad():
all_predicts = []
all_gts = []
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
x = x.to(device)
out = model.audio_forward(x, sr=sr)
for i, out_chunk in enumerate(out):
predict = out_chunk[masks[i] != 0].cpu().tolist()
gt = targets[i, masks[i] != 0].cpu().tolist()
all_predicts.append(predict)
all_gts.append(gt)
return all_predicts, all_gts
def calculate_best_thresholds(all_predicts, all_gts):
best_acc = 0
for ths_enter in tqdm(np.linspace(0, 1, 20)):
for ths_exit in np.linspace(0, 1, 20):
if ths_exit >= ths_enter:
continue
accs = []
for j, predict in enumerate(all_predicts):
predict_bool = []
is_speech = False
for i in predict:
if i >= ths_enter:
is_speech = True
predict_bool.append(1)
elif i <= ths_exit:
is_speech = False
predict_bool.append(0)
else:
val = 1 if is_speech else 0
predict_bool.append(val)
score = round(accuracy_score(all_gts[j], predict_bool), 4)
accs.append(score)
mean_acc = round(np.mean(accs), 3)
if mean_acc > best_acc:
best_acc = mean_acc
best_ths_enter = round(ths_enter, 2)
best_ths_exit = round(ths_exit, 2)
return best_ths_enter, best_ths_exit, best_acc