mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
Compare commits
150 Commits
v4.0stable
...
adamnsandl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c2c90aa1c | ||
|
|
1d48167271 | ||
|
|
d0139d94d9 | ||
|
|
46f94b7d60 | ||
|
|
3de3ee3abe | ||
|
|
e680ea6633 | ||
|
|
199de226e5 | ||
|
|
4109b107c1 | ||
|
|
36854a90db | ||
|
|
827e86e685 | ||
|
|
e706ec6fee | ||
|
|
88df0ce1dd | ||
|
|
d18b91e037 | ||
|
|
1e3f343767 | ||
|
|
6a8ee81ee0 | ||
|
|
cb25c0c047 | ||
|
|
7af8628a27 | ||
|
|
3682cb189c | ||
|
|
57c0b51f9b | ||
|
|
dd0b143803 | ||
|
|
181cdf92b6 | ||
|
|
a7bd2dd38f | ||
|
|
df7de797a5 | ||
|
|
87ed11b508 | ||
|
|
84768cefdf | ||
|
|
6de3660f25 | ||
|
|
d9a6941852 | ||
|
|
dfdc9a484e | ||
|
|
f2e3a23d96 | ||
|
|
2b97f61160 | ||
|
|
e8850d2b9b | ||
|
|
657dac8736 | ||
|
|
412a478e29 | ||
|
|
9adf6d2192 | ||
|
|
8a2a73c14f | ||
|
|
3e0305559d | ||
|
|
f0d880d79c | ||
|
|
3888946c0c | ||
|
|
24f51645d0 | ||
|
|
fdbb0a3a81 | ||
|
|
60ae7abfb7 | ||
|
|
0b3d43d432 | ||
|
|
a395853982 | ||
|
|
78958b6fb6 | ||
|
|
902cfc9248 | ||
|
|
89e66a3474 | ||
|
|
a3bdebed16 | ||
|
|
4bdcf31d17 | ||
|
|
136cdcdf5b | ||
|
|
5cd2ba54db | ||
|
|
06b9e17c1e | ||
|
|
aace7e25b1 | ||
|
|
74c3f7f3fb | ||
|
|
d77d0fd42c | ||
|
|
49b421a9cd | ||
|
|
fd1f1a62b7 | ||
|
|
8145ed9a91 | ||
|
|
4392725328 | ||
|
|
8b0566682b | ||
|
|
82342b8a4c | ||
|
|
4b8ce743a8 | ||
|
|
b4b6f2ab3e | ||
|
|
5b02d84a4a | ||
|
|
60465a7e61 | ||
|
|
fcef5b3955 | ||
|
|
156436762f | ||
|
|
a27b176e45 | ||
|
|
8125483ef7 | ||
|
|
6969dcc2dc | ||
|
|
6c816a05f0 | ||
|
|
9dc344df7f | ||
|
|
41c5172dd9 | ||
|
|
894ea259f9 | ||
|
|
f56f56ffaa | ||
|
|
6c8d844710 | ||
|
|
d8cc947c73 | ||
|
|
797a88a386 | ||
|
|
48b7c742dd | ||
|
|
c5ec6bae3d | ||
|
|
af152c18f6 | ||
|
|
d391f4c302 | ||
|
|
bf18ea6b56 | ||
|
|
a65732a393 | ||
|
|
aae1e4f40d | ||
|
|
94504ece54 | ||
|
|
0b7da6e74b | ||
|
|
efb5effc8f | ||
|
|
03dc3fae5c | ||
|
|
4a6d1701a4 | ||
|
|
5e7ee10ee0 | ||
|
|
03fb810fab | ||
|
|
e30a7e32a9 | ||
|
|
bbbc657dad | ||
|
|
cb92cdd1e3 | ||
|
|
3780baf49f | ||
|
|
563106ef8c | ||
|
|
f795bc479b | ||
|
|
7e9680bc83 | ||
|
|
3b4c02dfe3 | ||
|
|
bc5a0a2dbf | ||
|
|
b03fcb2ebe | ||
|
|
026bc3d292 | ||
|
|
e755baa3c2 | ||
|
|
b88084c7ed | ||
|
|
a9d2b591de | ||
|
|
c3c67cdcb8 | ||
|
|
874c66ccbc | ||
|
|
51fbbcb32e | ||
|
|
14a0715955 | ||
|
|
a0d26769e0 | ||
|
|
e0c2015193 | ||
|
|
5872cffd78 | ||
|
|
86400b9a12 | ||
|
|
6ef43d1c5d | ||
|
|
540e092276 | ||
|
|
55c41abf46 | ||
|
|
17903cb41d | ||
|
|
c39dccc1fd | ||
|
|
a6a067de44 | ||
|
|
9865b3cb93 | ||
|
|
3d10c2d950 | ||
|
|
4f57fae3fa | ||
|
|
085d76f08e | ||
|
|
262bcb4b40 | ||
|
|
e84eca68d7 | ||
|
|
e7c4539106 | ||
|
|
a480e85aec | ||
|
|
c69cb6c9c0 | ||
|
|
11da69d88b | ||
|
|
df1d52042d | ||
|
|
d5a944b9f1 | ||
|
|
d90416e63e | ||
|
|
91f0aaecef | ||
|
|
015bfc8b21 | ||
|
|
5d56b1ea40 | ||
|
|
ff3c596cab | ||
|
|
63e1be5a22 | ||
|
|
1d8f8f38db | ||
|
|
7198087152 | ||
|
|
ad57d17f5f | ||
|
|
04e87c208a | ||
|
|
c583fd1e52 | ||
|
|
5814e548db | ||
|
|
42565d5baa | ||
|
|
ab7af9745b | ||
|
|
83e68c56ea | ||
|
|
d3882c9ebf | ||
|
|
25f04dda35 | ||
|
|
94b4c21874 | ||
|
|
324bc74a58 |
40
.github/workflows/python-publish.yml
vendored
Normal file
40
.github/workflows/python-publish.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
76
README.md
76
README.md
@@ -10,22 +10,66 @@
|
||||
|
||||
**Silero VAD** - pre-trained enterprise-grade [Voice Activity Detector](https://en.wikipedia.org/wiki/Voice_activity_detection) (also see our [STT models](https://github.com/snakers4/silero-models)).
|
||||
|
||||
This repository also includes Number Detector and Language classifier [models](https://github.com/snakers4/silero-vad/wiki/Other-Models)
|
||||
|
||||
<br/>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://user-images.githubusercontent.com/36505480/198026365-8da383e0-5398-4a12-b7f8-22c2c0059512.png" />
|
||||
<img src="https://github.com/snakers4/silero-vad/assets/36505480/300bd062-4da5-4f19-9736-9c144a45d7a7" />
|
||||
</p>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Real Time Example</summary>
|
||||
|
||||
|
||||
https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<br/>
|
||||
|
||||
<h2 align="center">Fast start</h2>
|
||||
<br/>
|
||||
|
||||
<details>
|
||||
<summary>Dependencies</summary>
|
||||
|
||||
**Silero VAD uses torchaudio library for audio file I/O functionalities, which are torchaudio.info, torchaudio.load, and torchaudio.save, so a proper audio backend is required:**
|
||||
|
||||
- Option №1 - [**FFmpeg**](https://www.ffmpeg.org/) backend. `conda install -c conda-forge 'ffmpeg<7'`
|
||||
- Option №2 - [**sox_io**](https://pypi.org/project/sox/) backend. `apt-get install sox`, TorchAudio is tested on libsox 14.4.2.
|
||||
- Option №3 - [**soundfile**](https://pypi.org/project/soundfile/) backend. `pip install soundfile`
|
||||
|
||||
**Additional dependencies:**
|
||||
|
||||
- **torch>=1.12.0**
|
||||
- **torchaudio>=0.12.0** (for I/O functionalities only)
|
||||
- **onnxruntime>=1.16.1** (for ONNX model usage)
|
||||
|
||||
</details>
|
||||
|
||||
**Using pip**:
|
||||
`pip install silero-vad`
|
||||
|
||||
```python3
|
||||
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
|
||||
model = load_silero_vad()
|
||||
wav = read_audio('path_to_audio_file')
|
||||
speech_timestamps = get_speech_timestamps(wav, model)
|
||||
```
|
||||
|
||||
**Using torch.hub**:
|
||||
```python3
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
||||
(get_speech_timestamps, _, read_audio, _, _) = utils
|
||||
|
||||
wav = read_audio('path_to_audio_file')
|
||||
speech_timestamps = get_speech_timestamps(wav, model)
|
||||
```
|
||||
|
||||
<br/>
|
||||
|
||||
<h2 align="center">Key Features</h2>
|
||||
<br/>
|
||||
|
||||
@@ -39,20 +83,16 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
|
||||
|
||||
- **Lightweight**
|
||||
|
||||
JIT model is around one megabyte in size.
|
||||
JIT model is around two megabytes in size.
|
||||
|
||||
- **General**
|
||||
|
||||
Silero VAD was trained on huge corpora that include over **100** languages and it performs well on audios from different domains with various background noise and quality levels.
|
||||
Silero VAD was trained on huge corpora that include over **6000** languages and it performs well on audios from different domains with various background noise and quality levels.
|
||||
|
||||
- **Flexible sampling rate**
|
||||
|
||||
Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate).
|
||||
|
||||
- **Flexible chunk size**
|
||||
|
||||
Model was trained on **30 ms**. Longer chunks are supported directly, others may work as well.
|
||||
|
||||
- **Highly Portable**
|
||||
|
||||
Silero VAD reaps benefits from the rich ecosystems built around **PyTorch** and **ONNX** running everywhere where these runtimes are available.
|
||||
@@ -62,6 +102,7 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
|
||||
Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
|
||||
|
||||
<br/>
|
||||
|
||||
<h2 align="center">Typical Use Cases</h2>
|
||||
<br/>
|
||||
|
||||
@@ -78,7 +119,6 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
|
||||
- [Examples and Dependencies](https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies)
|
||||
- [Quality Metrics](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics)
|
||||
- [Performance Metrics](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics)
|
||||
- [Number Detector and Language classifier models](https://github.com/snakers4/silero-vad/wiki/Other-Models)
|
||||
- [Versions and Available Models](https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)
|
||||
- [Further reading](https://github.com/snakers4/silero-models#further-reading)
|
||||
- [FAQ](https://github.com/snakers4/silero-vad/wiki/FAQ)
|
||||
@@ -89,7 +129,7 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
|
||||
|
||||
Try our models, create an [issue](https://github.com/snakers4/silero-vad/issues/new), start a [discussion](https://github.com/snakers4/silero-vad/discussions/new), join our telegram [chat](https://t.me/silero_speech), [email](mailto:hello@silero.ai) us, read our [news](https://t.me/silero_news).
|
||||
|
||||
Please see our [wiki](https://github.com/snakers4/silero-models/wiki) and [tiers](https://github.com/snakers4/silero-models/wiki/Licensing-and-Tiers) for relevant information and [email](mailto:hello@silero.ai) us directly.
|
||||
Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for relevant information and [email](mailto:hello@silero.ai) us directly.
|
||||
|
||||
**Citations**
|
||||
|
||||
@@ -97,7 +137,7 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) and [tiers
|
||||
@misc{Silero VAD,
|
||||
author = {Silero Team},
|
||||
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
|
||||
year = {2021},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/snakers4/silero-vad}},
|
||||
@@ -107,7 +147,11 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) and [tiers
|
||||
```
|
||||
|
||||
<br/>
|
||||
<h2 align="center">VAD-based Community Apps</h2>
|
||||
<h2 align="center">Examples and VAD-based Community Apps</h2>
|
||||
<br/>
|
||||
|
||||
- Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web
|
||||
- Example of VAD ONNX Runtime model usage in [C++](https://github.com/snakers4/silero-vad/tree/master/examples/cpp)
|
||||
|
||||
- Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web
|
||||
|
||||
- [Rust](https://github.com/snakers4/silero-vad/tree/master/examples/rust-example), [Go](https://github.com/snakers4/silero-vad/tree/master/examples/go), [Java](https://github.com/snakers4/silero-vad/tree/master/examples/java-example) and [other](https://github.com/snakers4/silero-vad/tree/master/examples) examples
|
||||
|
||||
84
datasets/README.md
Normal file
84
datasets/README.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# Датасет Silero-VAD
|
||||
|
||||
> Датасет создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный
|
||||
интеллект» национальной программы «Цифровая экономика Российской Федерации».
|
||||
|
||||
По ссылкам ниже представлены `.feather` файлы, содержащие размеченные с помощью Silero VAD открытые наборы аудиоданных, а также короткое описание каждого набора данных с примерами загрузки. `.feather` файлы можно открыть с помощью библиотеки `pandas`:
|
||||
```python3
|
||||
import pandas as pd
|
||||
dataframe = pd.read_feather(PATH_TO_FEATHER_FILE)
|
||||
```
|
||||
|
||||
Каждый `.feather` файл с разметкой содержит следующие колонки:
|
||||
- `speech_timings` - разметка данного аудио. Это список, содержащий словари вида `{'start': START_SECOND, 'end': END_SECOND}`, где `START_SECOND` и `END_SECOND` - время начала и конца речи в секундах. Количество данных словарей равно количеству речевых аудио отрывков, найденных в данном аудио;
|
||||
- `language` - ISO код языка данного аудио.
|
||||
|
||||
Колонки, содержащие информацию о загрузке аудио файла различаются и описаны для каждого набора данных ниже.
|
||||
|
||||
**Все данные размечены при временной дискретизации в ~30 миллисекунд (`num_samples` - 512)**
|
||||
|
||||
| Название | Число часов | Число языков | Ссылка | Лицензия | md5sum |
|
||||
|----------------------|-------------|-------------|--------|----------|----------|
|
||||
| **Bible.is** | 53,138 | 1,596 | [URL](https://live.bible.is/) | [Уникальная](https://live.bible.is/terms) | ea404eeaf2cd283b8223f63002be11f9 |
|
||||
| **globalrecordings.net** | 9,743 | 6,171[^1] | [URL](https://globalrecordings.net/en) | CC BY-NC-SA 4.0 | 3c5c0f31b0abd9fe94ddbe8b1e2eb326 |
|
||||
| **VoxLingua107** | 6,628 | 107 | [URL](https://bark.phon.ioc.ee/voxlingua107/) | CC BY 4.0 | 5dfef33b4d091b6d399cfaf3d05f2140 |
|
||||
| **Common Voice** | 30,329 | 120 | [URL](https://commonvoice.mozilla.org/en/datasets) | CC0 | 5e30a85126adf74a5fd1496e6ac8695d |
|
||||
| **MLS** | 50,709 | 8 | [URL](https://www.openslr.org/94/) | CC BY 4.0 | a339d0e94bdf41bba3c003756254ac4e |
|
||||
| **Итого** | **150,547** | **6,171+** | | | |
|
||||
|
||||
## Bible.is
|
||||
|
||||
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/BibleIs.feather)
|
||||
|
||||
- Колонка `audio_link` содержит ссылки на конкретные аудио файлы.
|
||||
|
||||
## globalrecordings.net
|
||||
|
||||
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/globalrecordings.feather)
|
||||
|
||||
- Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
|
||||
- Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
|
||||
|
||||
``Количество уникальных ISO кодов данного датасета не совпадает с фактическим количеством представленных языков, т.к некоторые близкие языки могут кодироваться одним и тем же ISO кодом.``
|
||||
|
||||
## VoxLingua107
|
||||
|
||||
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/VoxLingua107.feather)
|
||||
|
||||
- Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
|
||||
- Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
|
||||
|
||||
## Common Voice
|
||||
|
||||
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/common_voice.feather)
|
||||
|
||||
Этот датасет невозможно скачать по статичным ссылкам. Для загрузки необходимо перейти по [ссылке](https://commonvoice.mozilla.org/en/datasets) и, получив доступ в соответствующей форме, скачать архивы для каждого доступного языка. Внимание! Представленная разметка актуальна для версии исходного датасета `Common Voice Corpus 16.1`.
|
||||
|
||||
- Колонка `audio_path` содержит уникальные названия `.mp3` файлов, полученных после скачивания соответствующего датасета.
|
||||
|
||||
## MLS
|
||||
|
||||
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/MLS.feather)
|
||||
|
||||
- Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
|
||||
- Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
|
||||
|
||||
## Лицензия
|
||||
|
||||
Данный датасет распространяется под [лицензией](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) `CC BY-NC-SA 4.0`.
|
||||
|
||||
## Цитирование
|
||||
|
||||
```
|
||||
@misc{Silero VAD Dataset,
|
||||
author = {Silero Team},
|
||||
title = {Silero-VAD Dataset: a large public Internet-scale dataset for voice activity detection for 6000+ languages},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/snakers4/silero-vad/datasets/README.md}},
|
||||
email = {hello@silero.ai}
|
||||
}
|
||||
```
|
||||
|
||||
[^1]: ``Количество уникальных ISO кодов данного датасета не совпадает с фактическим количеством представленных языков, т.к некоторые близкие языки могут кодироваться одним и тем же ISO кодом.``
|
||||
@@ -41,7 +41,7 @@
|
||||
" abs_max = np.abs(sound).max()\n",
|
||||
" sound = sound.astype('float32')\n",
|
||||
" if abs_max > 0:\n",
|
||||
" sound *= 1/abs_max\n",
|
||||
" sound *= 1/32768\n",
|
||||
" sound = sound.squeeze()\n",
|
||||
" return sound\n",
|
||||
"\n",
|
||||
|
||||
43
examples/cpp/README.md
Normal file
43
examples/cpp/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Stream example in C++
|
||||
|
||||
Here's a simple example of the vad model in c++ onnxruntime.
|
||||
|
||||
|
||||
|
||||
## Requirements
|
||||
|
||||
Code are tested in the environments bellow, feel free to try others.
|
||||
|
||||
- WSL2 + Debian-bullseye (docker)
|
||||
- gcc 12.2.0
|
||||
- onnxruntime-linux-x64-1.12.1
|
||||
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
1. Install gcc 12.2.0, or just pull the docker image with `docker pull gcc:12.2.0-bullseye`
|
||||
|
||||
2. Install onnxruntime-linux-x64-1.12.1
|
||||
|
||||
- Download lib onnxruntime:
|
||||
|
||||
`wget https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz`
|
||||
|
||||
- Unzip. Assume the path is `/root/onnxruntime-linux-x64-1.12.1`
|
||||
|
||||
3. Modify wav path & Test configs in main function
|
||||
|
||||
`wav::WavReader wav_reader("${path_to_your_wav_file}");`
|
||||
|
||||
test sample rate, frame per ms, threshold...
|
||||
|
||||
4. Build with gcc and run
|
||||
|
||||
```bash
|
||||
# Build
|
||||
g++ silero-vad-onnx.cpp -I /root/onnxruntime-linux-x64-1.12.1/include/ -L /root/onnxruntime-linux-x64-1.12.1/lib/ -lonnxruntime -Wl,-rpath,/root/onnxruntime-linux-x64-1.12.1/lib/ -o test
|
||||
|
||||
# Run
|
||||
./test
|
||||
```
|
||||
478
examples/cpp/silero-vad-onnx.cpp
Normal file
478
examples/cpp/silero-vad-onnx.cpp
Normal file
@@ -0,0 +1,478 @@
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#include "wav.h"
|
||||
#include <cstdio>
|
||||
#include <cstdarg>
|
||||
#if __cplusplus < 201703L
|
||||
#include <memory>
|
||||
#endif
|
||||
|
||||
//#define __DEBUG_SPEECH_PROB___
|
||||
|
||||
class timestamp_t
|
||||
{
|
||||
public:
|
||||
int start;
|
||||
int end;
|
||||
|
||||
// default + parameterized constructor
|
||||
timestamp_t(int start = -1, int end = -1)
|
||||
: start(start), end(end)
|
||||
{
|
||||
};
|
||||
|
||||
// assignment operator modifies object, therefore non-const
|
||||
timestamp_t& operator=(const timestamp_t& a)
|
||||
{
|
||||
start = a.start;
|
||||
end = a.end;
|
||||
return *this;
|
||||
};
|
||||
|
||||
// equality comparison. doesn't modify object. therefore const.
|
||||
bool operator==(const timestamp_t& a) const
|
||||
{
|
||||
return (start == a.start && end == a.end);
|
||||
};
|
||||
std::string c_str()
|
||||
{
|
||||
//return std::format("timestamp {:08d}, {:08d}", start, end);
|
||||
return format("{start:%08d,end:%08d}", start, end);
|
||||
};
|
||||
private:
|
||||
|
||||
std::string format(const char* fmt, ...)
|
||||
{
|
||||
char buf[256];
|
||||
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
const auto r = std::vsnprintf(buf, sizeof buf, fmt, args);
|
||||
va_end(args);
|
||||
|
||||
if (r < 0)
|
||||
// conversion failed
|
||||
return {};
|
||||
|
||||
const size_t len = r;
|
||||
if (len < sizeof buf)
|
||||
// we fit in the buffer
|
||||
return { buf, len };
|
||||
|
||||
#if __cplusplus >= 201703L
|
||||
// C++17: Create a string and write to its underlying array
|
||||
std::string s(len, '\0');
|
||||
va_start(args, fmt);
|
||||
std::vsnprintf(s.data(), len + 1, fmt, args);
|
||||
va_end(args);
|
||||
|
||||
return s;
|
||||
#else
|
||||
// C++11 or C++14: We need to allocate scratch memory
|
||||
auto vbuf = std::unique_ptr<char[]>(new char[len + 1]);
|
||||
va_start(args, fmt);
|
||||
std::vsnprintf(vbuf.get(), len + 1, fmt, args);
|
||||
va_end(args);
|
||||
|
||||
return { vbuf.get(), len };
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
class VadIterator
|
||||
{
|
||||
private:
|
||||
// OnnxRuntime resources
|
||||
Ort::Env env;
|
||||
Ort::SessionOptions session_options;
|
||||
std::shared_ptr<Ort::Session> session = nullptr;
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
|
||||
|
||||
private:
|
||||
void init_engine_threads(int inter_threads, int intra_threads)
|
||||
{
|
||||
// The method should be called in each thread/proc in multi-thread/proc work
|
||||
session_options.SetIntraOpNumThreads(intra_threads);
|
||||
session_options.SetInterOpNumThreads(inter_threads);
|
||||
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||
};
|
||||
|
||||
void init_onnx_model(const std::wstring& model_path)
|
||||
{
|
||||
// Init threads = 1 for
|
||||
init_engine_threads(1, 1);
|
||||
// Load model
|
||||
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
|
||||
};
|
||||
|
||||
void reset_states()
|
||||
{
|
||||
// Call reset before each audio start
|
||||
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
|
||||
triggered = false;
|
||||
temp_end = 0;
|
||||
current_sample = 0;
|
||||
|
||||
prev_end = next_start = 0;
|
||||
|
||||
speeches.clear();
|
||||
current_speech = timestamp_t();
|
||||
};
|
||||
|
||||
void predict(const std::vector<float> &data)
|
||||
{
|
||||
// Infer
|
||||
// Create ort tensors
|
||||
input.assign(data.begin(), data.end());
|
||||
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, input.data(), input.size(), input_node_dims, 2);
|
||||
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, _state.data(), _state.size(), state_node_dims, 3);
|
||||
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
|
||||
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
|
||||
|
||||
// Clear and add inputs
|
||||
ort_inputs.clear();
|
||||
ort_inputs.emplace_back(std::move(input_ort));
|
||||
ort_inputs.emplace_back(std::move(state_ort));
|
||||
ort_inputs.emplace_back(std::move(sr_ort));
|
||||
|
||||
// Infer
|
||||
ort_outputs = session->Run(
|
||||
Ort::RunOptions{nullptr},
|
||||
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
|
||||
output_node_names.data(), output_node_names.size());
|
||||
|
||||
// Output probability & update h,c recursively
|
||||
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
|
||||
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
|
||||
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
|
||||
|
||||
// Push forward sample index
|
||||
current_sample += window_size_samples;
|
||||
|
||||
// Reset temp_end when > threshold
|
||||
if ((speech_prob >= threshold))
|
||||
{
|
||||
#ifdef __DEBUG_SPEECH_PROB___
|
||||
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
|
||||
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample- window_size_samples);
|
||||
#endif //__DEBUG_SPEECH_PROB___
|
||||
if (temp_end != 0)
|
||||
{
|
||||
temp_end = 0;
|
||||
if (next_start < prev_end)
|
||||
next_start = current_sample - window_size_samples;
|
||||
}
|
||||
if (triggered == false)
|
||||
{
|
||||
triggered = true;
|
||||
|
||||
current_speech.start = current_sample - window_size_samples;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
(triggered == true)
|
||||
&& ((current_sample - current_speech.start) > max_speech_samples)
|
||||
) {
|
||||
if (prev_end > 0) {
|
||||
current_speech.end = prev_end;
|
||||
speeches.push_back(current_speech);
|
||||
current_speech = timestamp_t();
|
||||
|
||||
// previously reached silence(< neg_thres) and is still not speech(< thres)
|
||||
if (next_start < prev_end)
|
||||
triggered = false;
|
||||
else{
|
||||
current_speech.start = next_start;
|
||||
}
|
||||
prev_end = 0;
|
||||
next_start = 0;
|
||||
temp_end = 0;
|
||||
|
||||
}
|
||||
else{
|
||||
current_speech.end = current_sample;
|
||||
speeches.push_back(current_speech);
|
||||
current_speech = timestamp_t();
|
||||
prev_end = 0;
|
||||
next_start = 0;
|
||||
temp_end = 0;
|
||||
triggered = false;
|
||||
}
|
||||
return;
|
||||
|
||||
}
|
||||
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold))
|
||||
{
|
||||
if (triggered) {
|
||||
#ifdef __DEBUG_SPEECH_PROB___
|
||||
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
|
||||
printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
|
||||
#endif //__DEBUG_SPEECH_PROB___
|
||||
}
|
||||
else {
|
||||
#ifdef __DEBUG_SPEECH_PROB___
|
||||
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
|
||||
printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
|
||||
#endif //__DEBUG_SPEECH_PROB___
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// 4) End
|
||||
if ((speech_prob < (threshold - 0.15)))
|
||||
{
|
||||
#ifdef __DEBUG_SPEECH_PROB___
|
||||
float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point.
|
||||
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
|
||||
#endif //__DEBUG_SPEECH_PROB___
|
||||
if (triggered == true)
|
||||
{
|
||||
if (temp_end == 0)
|
||||
{
|
||||
temp_end = current_sample;
|
||||
}
|
||||
if (current_sample - temp_end > min_silence_samples_at_max_speech)
|
||||
prev_end = temp_end;
|
||||
// a. silence < min_slience_samples, continue speaking
|
||||
if ((current_sample - temp_end) < min_silence_samples)
|
||||
{
|
||||
|
||||
}
|
||||
// b. silence >= min_slience_samples, end speaking
|
||||
else
|
||||
{
|
||||
current_speech.end = temp_end;
|
||||
if (current_speech.end - current_speech.start > min_speech_samples)
|
||||
{
|
||||
speeches.push_back(current_speech);
|
||||
current_speech = timestamp_t();
|
||||
prev_end = 0;
|
||||
next_start = 0;
|
||||
temp_end = 0;
|
||||
triggered = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// may first windows see end state.
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
public:
|
||||
void process(const std::vector<float>& input_wav)
|
||||
{
|
||||
reset_states();
|
||||
|
||||
audio_length_samples = input_wav.size();
|
||||
|
||||
for (int j = 0; j < audio_length_samples; j += window_size_samples)
|
||||
{
|
||||
if (j + window_size_samples > audio_length_samples)
|
||||
break;
|
||||
std::vector<float> r{ &input_wav[0] + j, &input_wav[0] + j + window_size_samples };
|
||||
predict(r);
|
||||
}
|
||||
|
||||
if (current_speech.start >= 0) {
|
||||
current_speech.end = audio_length_samples;
|
||||
speeches.push_back(current_speech);
|
||||
current_speech = timestamp_t();
|
||||
prev_end = 0;
|
||||
next_start = 0;
|
||||
temp_end = 0;
|
||||
triggered = false;
|
||||
}
|
||||
};
|
||||
|
||||
void process(const std::vector<float>& input_wav, std::vector<float>& output_wav)
|
||||
{
|
||||
process(input_wav);
|
||||
collect_chunks(input_wav, output_wav);
|
||||
}
|
||||
|
||||
void collect_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
|
||||
{
|
||||
output_wav.clear();
|
||||
for (int i = 0; i < speeches.size(); i++) {
|
||||
#ifdef __DEBUG_SPEECH_PROB___
|
||||
std::cout << speeches[i].c_str() << std::endl;
|
||||
#endif //#ifdef __DEBUG_SPEECH_PROB___
|
||||
std::vector<float> slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]);
|
||||
output_wav.insert(output_wav.end(),slice.begin(),slice.end());
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<timestamp_t> get_speech_timestamps() const
|
||||
{
|
||||
return speeches;
|
||||
}
|
||||
|
||||
void drop_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
|
||||
{
|
||||
output_wav.clear();
|
||||
int current_start = 0;
|
||||
for (int i = 0; i < speeches.size(); i++) {
|
||||
|
||||
std::vector<float> slice(&input_wav[current_start],&input_wav[speeches[i].start]);
|
||||
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
|
||||
current_start = speeches[i].end;
|
||||
}
|
||||
|
||||
std::vector<float> slice(&input_wav[current_start], &input_wav[input_wav.size()]);
|
||||
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
|
||||
};
|
||||
|
||||
private:
|
||||
// model config
|
||||
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
|
||||
int sample_rate; //Assign when init support 16000 or 8000
|
||||
int sr_per_ms; // Assign when init, support 8 or 16
|
||||
float threshold;
|
||||
int min_silence_samples; // sr_per_ms * #ms
|
||||
int min_silence_samples_at_max_speech; // sr_per_ms * #98
|
||||
int min_speech_samples; // sr_per_ms * #ms
|
||||
float max_speech_samples;
|
||||
int speech_pad_samples; // usually a
|
||||
int audio_length_samples;
|
||||
|
||||
// model states
|
||||
bool triggered = false;
|
||||
unsigned int temp_end = 0;
|
||||
unsigned int current_sample = 0;
|
||||
// MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
|
||||
int prev_end;
|
||||
int next_start = 0;
|
||||
|
||||
//Output timestamp
|
||||
std::vector<timestamp_t> speeches;
|
||||
timestamp_t current_speech;
|
||||
|
||||
|
||||
// Onnx model
|
||||
// Inputs
|
||||
std::vector<Ort::Value> ort_inputs;
|
||||
|
||||
std::vector<const char *> input_node_names = {"input", "state", "sr"};
|
||||
std::vector<float> input;
|
||||
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
|
||||
std::vector<float> _state;
|
||||
std::vector<int64_t> sr;
|
||||
|
||||
int64_t input_node_dims[2] = {};
|
||||
const int64_t state_node_dims[3] = {2, 1, 128};
|
||||
const int64_t sr_node_dims[1] = {1};
|
||||
|
||||
// Outputs
|
||||
std::vector<Ort::Value> ort_outputs;
|
||||
std::vector<const char *> output_node_names = {"output", "stateN"};
|
||||
|
||||
public:
|
||||
// Construction
|
||||
VadIterator(const std::wstring ModelPath,
|
||||
int Sample_rate = 16000, int windows_frame_size = 32,
|
||||
float Threshold = 0.5, int min_silence_duration_ms = 0,
|
||||
int speech_pad_ms = 32, int min_speech_duration_ms = 32,
|
||||
float max_speech_duration_s = std::numeric_limits<float>::infinity())
|
||||
{
|
||||
init_onnx_model(ModelPath);
|
||||
threshold = Threshold;
|
||||
sample_rate = Sample_rate;
|
||||
sr_per_ms = sample_rate / 1000;
|
||||
|
||||
window_size_samples = windows_frame_size * sr_per_ms;
|
||||
|
||||
min_speech_samples = sr_per_ms * min_speech_duration_ms;
|
||||
speech_pad_samples = sr_per_ms * speech_pad_ms;
|
||||
|
||||
max_speech_samples = (
|
||||
sample_rate * max_speech_duration_s
|
||||
- window_size_samples
|
||||
- 2 * speech_pad_samples
|
||||
);
|
||||
|
||||
min_silence_samples = sr_per_ms * min_silence_duration_ms;
|
||||
min_silence_samples_at_max_speech = sr_per_ms * 98;
|
||||
|
||||
input.resize(window_size_samples);
|
||||
input_node_dims[0] = 1;
|
||||
input_node_dims[1] = window_size_samples;
|
||||
|
||||
_state.resize(size_state);
|
||||
sr.resize(1);
|
||||
sr[0] = sample_rate;
|
||||
};
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
std::vector<timestamp_t> stamps;
|
||||
|
||||
// Read wav
|
||||
wav::WavReader wav_reader("recorder.wav"); //16000,1,32float
|
||||
std::vector<float> input_wav(wav_reader.num_samples());
|
||||
std::vector<float> output_wav;
|
||||
|
||||
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||
{
|
||||
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
|
||||
}
|
||||
|
||||
|
||||
|
||||
// ===== Test configs =====
|
||||
std::wstring path = L"silero_vad.onnx";
|
||||
VadIterator vad(path);
|
||||
|
||||
// ==============================================
|
||||
// ==== = Example 1 of full function =====
|
||||
// ==============================================
|
||||
vad.process(input_wav);
|
||||
|
||||
// 1.a get_speech_timestamps
|
||||
stamps = vad.get_speech_timestamps();
|
||||
for (int i = 0; i < stamps.size(); i++) {
|
||||
|
||||
std::cout << stamps[i].c_str() << std::endl;
|
||||
}
|
||||
|
||||
// 1.b collect_chunks output wav
|
||||
vad.collect_chunks(input_wav, output_wav);
|
||||
|
||||
// 1.c drop_chunks output wav
|
||||
vad.drop_chunks(input_wav, output_wav);
|
||||
|
||||
// ==============================================
|
||||
// ===== Example 2 of simple full function =====
|
||||
// ==============================================
|
||||
vad.process(input_wav, output_wav);
|
||||
|
||||
stamps = vad.get_speech_timestamps();
|
||||
for (int i = 0; i < stamps.size(); i++) {
|
||||
|
||||
std::cout << stamps[i].c_str() << std::endl;
|
||||
}
|
||||
|
||||
// ==============================================
|
||||
// ===== Example 3 of full function =====
|
||||
// ==============================================
|
||||
for(int i = 0; i<2; i++)
|
||||
vad.process(input_wav, output_wav);
|
||||
}
|
||||
235
examples/cpp/wav.h
Normal file
235
examples/cpp/wav.h
Normal file
@@ -0,0 +1,235 @@
|
||||
// Copyright (c) 2016 Personal (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef FRONTEND_WAV_H_
|
||||
#define FRONTEND_WAV_H_
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
// #include "utils/log.h"
|
||||
|
||||
namespace wav {
|
||||
|
||||
struct WavHeader {
|
||||
char riff[4]; // "riff"
|
||||
unsigned int size;
|
||||
char wav[4]; // "WAVE"
|
||||
char fmt[4]; // "fmt "
|
||||
unsigned int fmt_size;
|
||||
uint16_t format;
|
||||
uint16_t channels;
|
||||
unsigned int sample_rate;
|
||||
unsigned int bytes_per_second;
|
||||
uint16_t block_size;
|
||||
uint16_t bit;
|
||||
char data[4]; // "data"
|
||||
unsigned int data_size;
|
||||
};
|
||||
|
||||
class WavReader {
|
||||
public:
|
||||
WavReader() : data_(nullptr) {}
|
||||
explicit WavReader(const std::string& filename) { Open(filename); }
|
||||
|
||||
bool Open(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
|
||||
if (NULL == fp) {
|
||||
std::cout << "Error in read " << filename;
|
||||
return false;
|
||||
}
|
||||
|
||||
WavHeader header;
|
||||
fread(&header, 1, sizeof(header), fp);
|
||||
if (header.fmt_size < 16) {
|
||||
printf("WaveData: expect PCM format data "
|
||||
"to have fmt chunk of at least size 16.\n");
|
||||
return false;
|
||||
} else if (header.fmt_size > 16) {
|
||||
int offset = 44 - 8 + header.fmt_size - 16;
|
||||
fseek(fp, offset, SEEK_SET);
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
// check "riff" "WAVE" "fmt " "data"
|
||||
|
||||
// Skip any sub-chunks between "fmt" and "data". Usually there will
|
||||
// be a single "fact" sub chunk, but on Windows there can also be a
|
||||
// "list" sub chunk.
|
||||
while (0 != strncmp(header.data, "data", 4)) {
|
||||
// We will just ignore the data in these chunks.
|
||||
fseek(fp, header.data_size, SEEK_CUR);
|
||||
// read next sub chunk
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
|
||||
if (header.data_size == 0) {
|
||||
int offset = ftell(fp);
|
||||
fseek(fp, 0, SEEK_END);
|
||||
header.data_size = ftell(fp) - offset;
|
||||
fseek(fp, offset, SEEK_SET);
|
||||
}
|
||||
|
||||
num_channel_ = header.channels;
|
||||
sample_rate_ = header.sample_rate;
|
||||
bits_per_sample_ = header.bit;
|
||||
int num_data = header.data_size / (bits_per_sample_ / 8);
|
||||
data_ = new float[num_data]; // Create 1-dim array
|
||||
num_samples_ = num_data / num_channel_;
|
||||
|
||||
std::cout << "num_channel_ :" << num_channel_ << std::endl;
|
||||
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
|
||||
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
|
||||
std::cout << "num_samples :" << num_data << std::endl;
|
||||
std::cout << "num_data_size :" << header.data_size << std::endl;
|
||||
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(char), fp);
|
||||
data_[i] = static_cast<float>(sample) / 32768;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(int16_t), fp);
|
||||
data_[i] = static_cast<float>(sample) / 32768;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 32:
|
||||
{
|
||||
if (header.format == 1) //S32
|
||||
{
|
||||
int sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(int), fp);
|
||||
data_[i] = static_cast<float>(sample) / 32768;
|
||||
}
|
||||
}
|
||||
else if (header.format == 3) // IEEE-float
|
||||
{
|
||||
float sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(float), fp);
|
||||
data_[i] = static_cast<float>(sample);
|
||||
}
|
||||
}
|
||||
else {
|
||||
printf("unsupported quantization bits\n");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
printf("unsupported quantization bits\n");
|
||||
break;
|
||||
}
|
||||
|
||||
fclose(fp);
|
||||
return true;
|
||||
}
|
||||
|
||||
int num_channel() const { return num_channel_; }
|
||||
int sample_rate() const { return sample_rate_; }
|
||||
int bits_per_sample() const { return bits_per_sample_; }
|
||||
int num_samples() const { return num_samples_; }
|
||||
|
||||
~WavReader() {
|
||||
delete[] data_;
|
||||
}
|
||||
|
||||
const float* data() const { return data_; }
|
||||
|
||||
private:
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
int num_samples_; // sample points per channel
|
||||
float* data_;
|
||||
};
|
||||
|
||||
class WavWriter {
|
||||
public:
|
||||
WavWriter(const float* data, int num_samples, int num_channel,
|
||||
int sample_rate, int bits_per_sample)
|
||||
: data_(data),
|
||||
num_samples_(num_samples),
|
||||
num_channel_(num_channel),
|
||||
sample_rate_(sample_rate),
|
||||
bits_per_sample_(bits_per_sample) {}
|
||||
|
||||
void Write(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "w");
|
||||
// init char 'riff' 'WAVE' 'fmt ' 'data'
|
||||
WavHeader header;
|
||||
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
|
||||
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
|
||||
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
|
||||
memcpy(&header, wav_header, sizeof(header));
|
||||
header.channels = num_channel_;
|
||||
header.bit = bits_per_sample_;
|
||||
header.sample_rate = sample_rate_;
|
||||
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.size = sizeof(header) - 8 + header.data_size;
|
||||
header.bytes_per_second =
|
||||
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.block_size = num_channel_ * (bits_per_sample_ / 8);
|
||||
|
||||
fwrite(&header, 1, sizeof(header), fp);
|
||||
|
||||
for (int i = 0; i < num_samples_; ++i) {
|
||||
for (int j = 0; j < num_channel_; ++j) {
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample = static_cast<char>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
int sample = static_cast<int>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fclose(fp);
|
||||
}
|
||||
|
||||
private:
|
||||
const float* data_;
|
||||
int num_samples_; // total float points in data_
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
};
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // FRONTEND_WAV_H_
|
||||
35
examples/csharp/Program.cs
Normal file
35
examples/csharp/Program.cs
Normal file
@@ -0,0 +1,35 @@
|
||||
using System.Text;
|
||||
|
||||
namespace VadDotNet;
|
||||
|
||||
|
||||
class Program
|
||||
{
|
||||
private const string MODEL_PATH = "./resources/silero_vad.onnx";
|
||||
private const string EXAMPLE_WAV_FILE = "./resources/example.wav";
|
||||
private const int SAMPLE_RATE = 16000;
|
||||
private const float THRESHOLD = 0.5f;
|
||||
private const int MIN_SPEECH_DURATION_MS = 250;
|
||||
private const float MAX_SPEECH_DURATION_SECONDS = float.PositiveInfinity;
|
||||
private const int MIN_SILENCE_DURATION_MS = 100;
|
||||
private const int SPEECH_PAD_MS = 30;
|
||||
|
||||
public static void Main(string[] args)
|
||||
{
|
||||
|
||||
var vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
|
||||
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
||||
List<SileroSpeechSegment> speechTimeList = vadDetector.GetSpeechSegmentList(new FileInfo(EXAMPLE_WAV_FILE));
|
||||
//Console.WriteLine(speechTimeList.ToJson());
|
||||
StringBuilder sb = new StringBuilder();
|
||||
foreach (var speechSegment in speechTimeList)
|
||||
{
|
||||
sb.Append($"start second: {speechSegment.StartSecond}, end second: {speechSegment.EndSecond}\n");
|
||||
|
||||
}
|
||||
Console.WriteLine(sb.ToString());
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
21
examples/csharp/SileroSpeechSegment.cs
Normal file
21
examples/csharp/SileroSpeechSegment.cs
Normal file
@@ -0,0 +1,21 @@
|
||||
namespace VadDotNet;
|
||||
|
||||
public class SileroSpeechSegment
|
||||
{
|
||||
public int? StartOffset { get; set; }
|
||||
public int? EndOffset { get; set; }
|
||||
public float? StartSecond { get; set; }
|
||||
public float? EndSecond { get; set; }
|
||||
|
||||
public SileroSpeechSegment()
|
||||
{
|
||||
}
|
||||
|
||||
public SileroSpeechSegment(int startOffset, int? endOffset, float? startSecond, float? endSecond)
|
||||
{
|
||||
StartOffset = startOffset;
|
||||
EndOffset = endOffset;
|
||||
StartSecond = startSecond;
|
||||
EndSecond = endSecond;
|
||||
}
|
||||
}
|
||||
250
examples/csharp/SileroVadDetector.cs
Normal file
250
examples/csharp/SileroVadDetector.cs
Normal file
@@ -0,0 +1,250 @@
|
||||
using NAudio.Wave;
|
||||
using VADdotnet;
|
||||
|
||||
namespace VadDotNet;
|
||||
|
||||
public class SileroVadDetector
|
||||
{
|
||||
private readonly SileroVadOnnxModel _model;
|
||||
private readonly float _threshold;
|
||||
private readonly float _negThreshold;
|
||||
private readonly int _samplingRate;
|
||||
private readonly int _windowSizeSample;
|
||||
private readonly float _minSpeechSamples;
|
||||
private readonly float _speechPadSamples;
|
||||
private readonly float _maxSpeechSamples;
|
||||
private readonly float _minSilenceSamples;
|
||||
private readonly float _minSilenceSamplesAtMaxSpeech;
|
||||
private int _audioLengthSamples;
|
||||
private const float THRESHOLD_GAP = 0.15f;
|
||||
// ReSharper disable once InconsistentNaming
|
||||
private const int SAMPLING_RATE_8K = 8000;
|
||||
// ReSharper disable once InconsistentNaming
|
||||
private const int SAMPLING_RATE_16K = 16000;
|
||||
|
||||
public SileroVadDetector(string onnxModelPath, float threshold, int samplingRate,
|
||||
int minSpeechDurationMs, float maxSpeechDurationSeconds,
|
||||
int minSilenceDurationMs, int speechPadMs)
|
||||
{
|
||||
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K)
|
||||
{
|
||||
throw new ArgumentException("Sampling rate not support, only available for [8000, 16000]");
|
||||
}
|
||||
|
||||
this._model = new SileroVadOnnxModel(onnxModelPath);
|
||||
this._samplingRate = samplingRate;
|
||||
this._threshold = threshold;
|
||||
this._negThreshold = threshold - THRESHOLD_GAP;
|
||||
this._windowSizeSample = samplingRate == SAMPLING_RATE_16K ? 512 : 256;
|
||||
this._minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
|
||||
this._speechPadSamples = samplingRate * speechPadMs / 1000f;
|
||||
this._maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - _windowSizeSample - 2 * _speechPadSamples;
|
||||
this._minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
||||
this._minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
|
||||
this.Reset();
|
||||
}
|
||||
|
||||
public void Reset()
|
||||
{
|
||||
_model.ResetStates();
|
||||
}
|
||||
|
||||
public List<SileroSpeechSegment> GetSpeechSegmentList(FileInfo wavFile)
|
||||
{
|
||||
Reset();
|
||||
|
||||
using (var audioFile = new AudioFileReader(wavFile.FullName))
|
||||
{
|
||||
List<float> speechProbList = new List<float>();
|
||||
this._audioLengthSamples = (int)(audioFile.Length / 2);
|
||||
float[] buffer = new float[this._windowSizeSample];
|
||||
|
||||
while (audioFile.Read(buffer, 0, buffer.Length) > 0)
|
||||
{
|
||||
float speechProb = _model.Call(new[] { buffer }, _samplingRate)[0];
|
||||
speechProbList.Add(speechProb);
|
||||
}
|
||||
|
||||
return CalculateProb(speechProbList);
|
||||
}
|
||||
}
|
||||
|
||||
private List<SileroSpeechSegment> CalculateProb(List<float> speechProbList)
|
||||
{
|
||||
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
|
||||
bool triggered = false;
|
||||
int tempEnd = 0, prevEnd = 0, nextStart = 0;
|
||||
SileroSpeechSegment segment = new SileroSpeechSegment();
|
||||
|
||||
for (int i = 0; i < speechProbList.Count; i++)
|
||||
{
|
||||
float speechProb = speechProbList[i];
|
||||
if (speechProb >= _threshold && (tempEnd != 0))
|
||||
{
|
||||
tempEnd = 0;
|
||||
if (nextStart < prevEnd)
|
||||
{
|
||||
nextStart = _windowSizeSample * i;
|
||||
}
|
||||
}
|
||||
|
||||
if (speechProb >= _threshold && !triggered)
|
||||
{
|
||||
triggered = true;
|
||||
segment.StartOffset = _windowSizeSample * i;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (triggered && (_windowSizeSample * i) - segment.StartOffset > _maxSpeechSamples)
|
||||
{
|
||||
if (prevEnd != 0)
|
||||
{
|
||||
segment.EndOffset = prevEnd;
|
||||
result.Add(segment);
|
||||
segment = new SileroSpeechSegment();
|
||||
if (nextStart < prevEnd)
|
||||
{
|
||||
triggered = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
segment.StartOffset = nextStart;
|
||||
}
|
||||
|
||||
prevEnd = 0;
|
||||
nextStart = 0;
|
||||
tempEnd = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
segment.EndOffset = _windowSizeSample * i;
|
||||
result.Add(segment);
|
||||
segment = new SileroSpeechSegment();
|
||||
prevEnd = 0;
|
||||
nextStart = 0;
|
||||
tempEnd = 0;
|
||||
triggered = false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (speechProb < _negThreshold && triggered)
|
||||
{
|
||||
if (tempEnd == 0)
|
||||
{
|
||||
tempEnd = _windowSizeSample * i;
|
||||
}
|
||||
|
||||
if (((_windowSizeSample * i) - tempEnd) > _minSilenceSamplesAtMaxSpeech)
|
||||
{
|
||||
prevEnd = tempEnd;
|
||||
}
|
||||
|
||||
if ((_windowSizeSample * i) - tempEnd < _minSilenceSamples)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
else
|
||||
{
|
||||
segment.EndOffset = tempEnd;
|
||||
if ((segment.EndOffset - segment.StartOffset) > _minSpeechSamples)
|
||||
{
|
||||
result.Add(segment);
|
||||
}
|
||||
|
||||
segment = new SileroSpeechSegment();
|
||||
prevEnd = 0;
|
||||
nextStart = 0;
|
||||
tempEnd = 0;
|
||||
triggered = false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (segment.StartOffset != null && (_audioLengthSamples - segment.StartOffset) > _minSpeechSamples)
|
||||
{
|
||||
segment.EndOffset = _audioLengthSamples;
|
||||
result.Add(segment);
|
||||
}
|
||||
|
||||
for (int i = 0; i < result.Count; i++)
|
||||
{
|
||||
SileroSpeechSegment item = result[i];
|
||||
if (i == 0)
|
||||
{
|
||||
item.StartOffset = (int)Math.Max(0, item.StartOffset.Value - _speechPadSamples);
|
||||
}
|
||||
|
||||
if (i != result.Count - 1)
|
||||
{
|
||||
SileroSpeechSegment nextItem = result[i + 1];
|
||||
int silenceDuration = nextItem.StartOffset.Value - item.EndOffset.Value;
|
||||
if (silenceDuration < 2 * _speechPadSamples)
|
||||
{
|
||||
item.EndOffset = item.EndOffset + (silenceDuration / 2);
|
||||
nextItem.StartOffset = Math.Max(0, nextItem.StartOffset.Value - (silenceDuration / 2));
|
||||
}
|
||||
else
|
||||
{
|
||||
item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
|
||||
nextItem.StartOffset = (int)Math.Max(0, nextItem.StartOffset.Value - _speechPadSamples);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
|
||||
}
|
||||
}
|
||||
|
||||
return MergeListAndCalculateSecond(result, _samplingRate);
|
||||
}
|
||||
|
||||
private List<SileroSpeechSegment> MergeListAndCalculateSecond(List<SileroSpeechSegment> original, int samplingRate)
|
||||
{
|
||||
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
|
||||
if (original == null || original.Count == 0)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
int left = original[0].StartOffset.Value;
|
||||
int right = original[0].EndOffset.Value;
|
||||
if (original.Count > 1)
|
||||
{
|
||||
original.Sort((a, b) => a.StartOffset.Value.CompareTo(b.StartOffset.Value));
|
||||
for (int i = 1; i < original.Count; i++)
|
||||
{
|
||||
SileroSpeechSegment segment = original[i];
|
||||
|
||||
if (segment.StartOffset > right)
|
||||
{
|
||||
result.Add(new SileroSpeechSegment(left, right,
|
||||
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
|
||||
left = segment.StartOffset.Value;
|
||||
right = segment.EndOffset.Value;
|
||||
}
|
||||
else
|
||||
{
|
||||
right = Math.Max(right, segment.EndOffset.Value);
|
||||
}
|
||||
}
|
||||
|
||||
result.Add(new SileroSpeechSegment(left, right,
|
||||
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
|
||||
}
|
||||
else
|
||||
{
|
||||
result.Add(new SileroSpeechSegment(left, right,
|
||||
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private float CalculateSecondByOffset(int offset, int samplingRate)
|
||||
{
|
||||
float secondValue = offset * 1.0f / samplingRate;
|
||||
return (float)Math.Floor(secondValue * 1000.0f) / 1000.0f;
|
||||
}
|
||||
}
|
||||
220
examples/csharp/SileroVadOnnxModel.cs
Normal file
220
examples/csharp/SileroVadOnnxModel.cs
Normal file
@@ -0,0 +1,220 @@
|
||||
using Microsoft.ML.OnnxRuntime;
|
||||
using Microsoft.ML.OnnxRuntime.Tensors;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace VADdotnet;
|
||||
|
||||
|
||||
public class SileroVadOnnxModel : IDisposable
|
||||
{
|
||||
private readonly InferenceSession session;
|
||||
private float[][][] state;
|
||||
private float[][] context;
|
||||
private int lastSr = 0;
|
||||
private int lastBatchSize = 0;
|
||||
private static readonly List<int> SAMPLE_RATES = new List<int> { 8000, 16000 };
|
||||
|
||||
public SileroVadOnnxModel(string modelPath)
|
||||
{
|
||||
var sessionOptions = new SessionOptions();
|
||||
sessionOptions.InterOpNumThreads = 1;
|
||||
sessionOptions.IntraOpNumThreads = 1;
|
||||
sessionOptions.EnableCpuMemArena = true;
|
||||
|
||||
session = new InferenceSession(modelPath, sessionOptions);
|
||||
ResetStates();
|
||||
}
|
||||
|
||||
public void ResetStates()
|
||||
{
|
||||
state = new float[2][][];
|
||||
state[0] = new float[1][];
|
||||
state[1] = new float[1][];
|
||||
state[0][0] = new float[128];
|
||||
state[1][0] = new float[128];
|
||||
context = Array.Empty<float[]>();
|
||||
lastSr = 0;
|
||||
lastBatchSize = 0;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
session?.Dispose();
|
||||
}
|
||||
|
||||
public class ValidationResult
|
||||
{
|
||||
public float[][] X { get; }
|
||||
public int Sr { get; }
|
||||
|
||||
public ValidationResult(float[][] x, int sr)
|
||||
{
|
||||
X = x;
|
||||
Sr = sr;
|
||||
}
|
||||
}
|
||||
|
||||
private ValidationResult ValidateInput(float[][] x, int sr)
|
||||
{
|
||||
if (x.Length == 1)
|
||||
{
|
||||
x = new float[][] { x[0] };
|
||||
}
|
||||
if (x.Length > 2)
|
||||
{
|
||||
throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}");
|
||||
}
|
||||
|
||||
if (sr != 16000 && (sr % 16000 == 0))
|
||||
{
|
||||
int step = sr / 16000;
|
||||
float[][] reducedX = new float[x.Length][];
|
||||
|
||||
for (int i = 0; i < x.Length; i++)
|
||||
{
|
||||
float[] current = x[i];
|
||||
float[] newArr = new float[(current.Length + step - 1) / step];
|
||||
|
||||
for (int j = 0, index = 0; j < current.Length; j += step, index++)
|
||||
{
|
||||
newArr[index] = current[j];
|
||||
}
|
||||
|
||||
reducedX[i] = newArr;
|
||||
}
|
||||
|
||||
x = reducedX;
|
||||
sr = 16000;
|
||||
}
|
||||
|
||||
if (!SAMPLE_RATES.Contains(sr))
|
||||
{
|
||||
throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)");
|
||||
}
|
||||
|
||||
if (((float)sr) / x[0].Length > 31.25)
|
||||
{
|
||||
throw new ArgumentException("Input audio is too short");
|
||||
}
|
||||
|
||||
return new ValidationResult(x, sr);
|
||||
}
|
||||
|
||||
private static float[][] Concatenate(float[][] a, float[][] b)
|
||||
{
|
||||
if (a.Length != b.Length)
|
||||
{
|
||||
throw new ArgumentException("The number of rows in both arrays must be the same.");
|
||||
}
|
||||
|
||||
int rows = a.Length;
|
||||
int colsA = a[0].Length;
|
||||
int colsB = b[0].Length;
|
||||
float[][] result = new float[rows][];
|
||||
|
||||
for (int i = 0; i < rows; i++)
|
||||
{
|
||||
result[i] = new float[colsA + colsB];
|
||||
Array.Copy(a[i], 0, result[i], 0, colsA);
|
||||
Array.Copy(b[i], 0, result[i], colsA, colsB);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private static float[][] GetLastColumns(float[][] array, int contextSize)
|
||||
{
|
||||
int rows = array.Length;
|
||||
int cols = array[0].Length;
|
||||
|
||||
if (contextSize > cols)
|
||||
{
|
||||
throw new ArgumentException("contextSize cannot be greater than the number of columns in the array.");
|
||||
}
|
||||
|
||||
float[][] result = new float[rows][];
|
||||
|
||||
for (int i = 0; i < rows; i++)
|
||||
{
|
||||
result[i] = new float[contextSize];
|
||||
Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public float[] Call(float[][] x, int sr)
|
||||
{
|
||||
var result = ValidateInput(x, sr);
|
||||
x = result.X;
|
||||
sr = result.Sr;
|
||||
int numberSamples = sr == 16000 ? 512 : 256;
|
||||
|
||||
if (x[0].Length != numberSamples)
|
||||
{
|
||||
throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)");
|
||||
}
|
||||
|
||||
int batchSize = x.Length;
|
||||
int contextSize = sr == 16000 ? 64 : 32;
|
||||
|
||||
if (lastBatchSize == 0)
|
||||
{
|
||||
ResetStates();
|
||||
}
|
||||
if (lastSr != 0 && lastSr != sr)
|
||||
{
|
||||
ResetStates();
|
||||
}
|
||||
if (lastBatchSize != 0 && lastBatchSize != batchSize)
|
||||
{
|
||||
ResetStates();
|
||||
}
|
||||
|
||||
if (context.Length == 0)
|
||||
{
|
||||
context = new float[batchSize][];
|
||||
for (int i = 0; i < batchSize; i++)
|
||||
{
|
||||
context[i] = new float[contextSize];
|
||||
}
|
||||
}
|
||||
|
||||
x = Concatenate(context, x);
|
||||
|
||||
var inputs = new List<NamedOnnxValue>
|
||||
{
|
||||
NamedOnnxValue.CreateFromTensor("input", new DenseTensor<float>(x.SelectMany(a => a).ToArray(), new[] { x.Length, x[0].Length })),
|
||||
NamedOnnxValue.CreateFromTensor("sr", new DenseTensor<long>(new[] { (long)sr }, new[] { 1 })),
|
||||
NamedOnnxValue.CreateFromTensor("state", new DenseTensor<float>(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), new[] { state.Length, state[0].Length, state[0][0].Length }))
|
||||
};
|
||||
|
||||
using (var outputs = session.Run(inputs))
|
||||
{
|
||||
var output = outputs.First(o => o.Name == "output").AsTensor<float>();
|
||||
var newState = outputs.First(o => o.Name == "stateN").AsTensor<float>();
|
||||
|
||||
context = GetLastColumns(x, contextSize);
|
||||
lastSr = sr;
|
||||
lastBatchSize = batchSize;
|
||||
|
||||
state = new float[newState.Dimensions[0]][][];
|
||||
for (int i = 0; i < newState.Dimensions[0]; i++)
|
||||
{
|
||||
state[i] = new float[newState.Dimensions[1]][];
|
||||
for (int j = 0; j < newState.Dimensions[1]; j++)
|
||||
{
|
||||
state[i][j] = new float[newState.Dimensions[2]];
|
||||
for (int k = 0; k < newState.Dimensions[2]; k++)
|
||||
{
|
||||
state[i][j][k] = newState[i, j, k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return output.ToArray();
|
||||
}
|
||||
}
|
||||
}
|
||||
25
examples/csharp/VadDotNet.csproj
Normal file
25
examples/csharp/VadDotNet.csproj
Normal file
@@ -0,0 +1,25 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
<TargetFramework>net8.0</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.18.1" />
|
||||
<PackageReference Include="NAudio" Version="2.2.1" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Folder Include="resources\" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Content Include="resources\**">
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
</Content>
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
1
examples/csharp/resources/put_model_here.txt
Normal file
1
examples/csharp/resources/put_model_here.txt
Normal file
@@ -0,0 +1 @@
|
||||
place onnx model file and example.wav file in this folder
|
||||
19
examples/go/README.md
Normal file
19
examples/go/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
## Golang Example
|
||||
|
||||
This is a sample program of how to run speech detection using `silero-vad` from Golang (CGO + ONNX Runtime).
|
||||
|
||||
### Requirements
|
||||
|
||||
- Golang >= v1.21
|
||||
- ONNX Runtime
|
||||
|
||||
### Usage
|
||||
|
||||
```sh
|
||||
go run ./cmd/main.go test.wav
|
||||
```
|
||||
|
||||
> **_Note_**
|
||||
>
|
||||
> Make sure you have the ONNX Runtime library and C headers installed in your path.
|
||||
|
||||
63
examples/go/cmd/main.go
Normal file
63
examples/go/cmd/main.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/streamer45/silero-vad-go/speech"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sd, err := speech.NewDetector(speech.DetectorConfig{
|
||||
ModelPath: "../../src/silero_vad/data/silero_vad.onnx",
|
||||
SampleRate: 16000,
|
||||
Threshold: 0.5,
|
||||
MinSilenceDurationMs: 100,
|
||||
SpeechPadMs: 30,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create speech detector: %s", err)
|
||||
}
|
||||
|
||||
if len(os.Args) != 2 {
|
||||
log.Fatalf("invalid arguments provided: expecting one file path")
|
||||
}
|
||||
|
||||
f, err := os.Open(os.Args[1])
|
||||
if err != nil {
|
||||
log.Fatalf("failed to open sample audio file: %s", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
dec := wav.NewDecoder(f)
|
||||
|
||||
if ok := dec.IsValidFile(); !ok {
|
||||
log.Fatalf("invalid WAV file")
|
||||
}
|
||||
|
||||
buf, err := dec.FullPCMBuffer()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get PCM buffer")
|
||||
}
|
||||
|
||||
pcmBuf := buf.AsFloat32Buffer()
|
||||
|
||||
segments, err := sd.Detect(pcmBuf.Data)
|
||||
if err != nil {
|
||||
log.Fatalf("Detect failed: %s", err)
|
||||
}
|
||||
|
||||
for _, s := range segments {
|
||||
log.Printf("speech starts at %0.2fs", s.SpeechStartAt)
|
||||
if s.SpeechEndAt > 0 {
|
||||
log.Printf("speech ends at %0.2fs", s.SpeechEndAt)
|
||||
}
|
||||
}
|
||||
|
||||
err = sd.Destroy()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to destroy detector: %s", err)
|
||||
}
|
||||
}
|
||||
13
examples/go/go.mod
Normal file
13
examples/go/go.mod
Normal file
@@ -0,0 +1,13 @@
|
||||
module silero
|
||||
|
||||
go 1.21.4
|
||||
|
||||
require (
|
||||
github.com/go-audio/wav v1.1.0
|
||||
github.com/streamer45/silero-vad-go v0.2.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/go-audio/audio v1.0.0 // indirect
|
||||
github.com/go-audio/riff v1.0.0 // indirect
|
||||
)
|
||||
18
examples/go/go.sum
Normal file
18
examples/go/go.sum
Normal file
@@ -0,0 +1,18 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
||||
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
|
||||
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
|
||||
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
|
||||
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
|
||||
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/streamer45/silero-vad-go v0.2.0 h1:bbRTa6cQuc7VI88y0qicx375UyWoxE6wlVOF+mUg0+g=
|
||||
github.com/streamer45/silero-vad-go v0.2.0/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
|
||||
github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk=
|
||||
github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
30
examples/java-example/pom.xml
Normal file
30
examples/java-example/pom.xml
Normal file
@@ -0,0 +1,30 @@
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>org.example</groupId>
|
||||
<artifactId>java-example</artifactId>
|
||||
<version>1.0-SNAPSHOT</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>sliero-vad-example</name>
|
||||
<url>http://maven.apache.org</url>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>3.8.1</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.microsoft.onnxruntime</groupId>
|
||||
<artifactId>onnxruntime</artifactId>
|
||||
<version>1.16.0-rc1</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
69
examples/java-example/src/main/java/org/example/App.java
Normal file
69
examples/java-example/src/main/java/org/example/App.java
Normal file
@@ -0,0 +1,69 @@
|
||||
package org.example;
|
||||
|
||||
import ai.onnxruntime.OrtException;
|
||||
import javax.sound.sampled.*;
|
||||
import java.util.Map;
|
||||
|
||||
public class App {
|
||||
|
||||
private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx";
|
||||
private static final int SAMPLE_RATE = 16000;
|
||||
private static final float START_THRESHOLD = 0.6f;
|
||||
private static final float END_THRESHOLD = 0.45f;
|
||||
private static final int MIN_SILENCE_DURATION_MS = 600;
|
||||
private static final int SPEECH_PAD_MS = 500;
|
||||
private static final int WINDOW_SIZE_SAMPLES = 2048;
|
||||
|
||||
public static void main(String[] args) {
|
||||
// Initialize the Voice Activity Detector
|
||||
SlieroVadDetector vadDetector;
|
||||
try {
|
||||
vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
||||
} catch (OrtException e) {
|
||||
System.err.println("Error initializing the VAD detector: " + e.getMessage());
|
||||
return;
|
||||
}
|
||||
|
||||
// Set audio format
|
||||
AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false);
|
||||
DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
|
||||
|
||||
// Get the target data line and open it with the specified format
|
||||
TargetDataLine targetDataLine;
|
||||
try {
|
||||
targetDataLine = (TargetDataLine) AudioSystem.getLine(info);
|
||||
targetDataLine.open(format);
|
||||
targetDataLine.start();
|
||||
} catch (LineUnavailableException e) {
|
||||
System.err.println("Error opening target data line: " + e.getMessage());
|
||||
return;
|
||||
}
|
||||
|
||||
// Main loop to continuously read data and apply Voice Activity Detection
|
||||
while (targetDataLine.isOpen()) {
|
||||
byte[] data = new byte[WINDOW_SIZE_SAMPLES];
|
||||
|
||||
int numBytesRead = targetDataLine.read(data, 0, data.length);
|
||||
if (numBytesRead <= 0) {
|
||||
System.err.println("Error reading data from target data line.");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Apply the Voice Activity Detector to the data and get the result
|
||||
Map<String, Double> detectResult;
|
||||
try {
|
||||
detectResult = vadDetector.apply(data, true);
|
||||
} catch (Exception e) {
|
||||
System.err.println("Error applying VAD detector: " + e.getMessage());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!detectResult.isEmpty()) {
|
||||
System.out.println(detectResult);
|
||||
}
|
||||
}
|
||||
|
||||
// Close the target data line to release audio resources
|
||||
targetDataLine.close();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package org.example;
|
||||
|
||||
import ai.onnxruntime.OrtException;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.RoundingMode;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
public class SlieroVadDetector {
|
||||
// OnnxModel model used for speech processing
|
||||
private final SlieroVadOnnxModel model;
|
||||
// Threshold for speech start
|
||||
private final float startThreshold;
|
||||
// Threshold for speech end
|
||||
private final float endThreshold;
|
||||
// Sampling rate
|
||||
private final int samplingRate;
|
||||
// Minimum number of silence samples to determine the end threshold of speech
|
||||
private final float minSilenceSamples;
|
||||
// Additional number of samples for speech start or end to calculate speech start or end time
|
||||
private final float speechPadSamples;
|
||||
// Whether in the triggered state (i.e. whether speech is being detected)
|
||||
private boolean triggered;
|
||||
// Temporarily stored number of speech end samples
|
||||
private int tempEnd;
|
||||
// Number of samples currently being processed
|
||||
private int currentSample;
|
||||
|
||||
|
||||
public SlieroVadDetector(String modelPath,
|
||||
float startThreshold,
|
||||
float endThreshold,
|
||||
int samplingRate,
|
||||
int minSilenceDurationMs,
|
||||
int speechPadMs) throws OrtException {
|
||||
// Check if the sampling rate is 8000 or 16000, if not, throw an exception
|
||||
if (samplingRate != 8000 && samplingRate != 16000) {
|
||||
throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]");
|
||||
}
|
||||
|
||||
// Initialize the parameters
|
||||
this.model = new SlieroVadOnnxModel(modelPath);
|
||||
this.startThreshold = startThreshold;
|
||||
this.endThreshold = endThreshold;
|
||||
this.samplingRate = samplingRate;
|
||||
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
||||
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
|
||||
// Reset the state
|
||||
reset();
|
||||
}
|
||||
|
||||
// Method to reset the state, including the model state, trigger state, temporary end time, and current sample count
|
||||
public void reset() {
|
||||
model.resetStates();
|
||||
triggered = false;
|
||||
tempEnd = 0;
|
||||
currentSample = 0;
|
||||
}
|
||||
|
||||
// apply method for processing the audio array, returning possible speech start or end times
|
||||
public Map<String, Double> apply(byte[] data, boolean returnSeconds) {
|
||||
|
||||
// Convert the byte array to a float array
|
||||
float[] audioData = new float[data.length / 2];
|
||||
for (int i = 0; i < audioData.length; i++) {
|
||||
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
|
||||
}
|
||||
|
||||
// Get the length of the audio array as the window size
|
||||
int windowSizeSamples = audioData.length;
|
||||
// Update the current sample count
|
||||
currentSample += windowSizeSamples;
|
||||
|
||||
// Call the model to get the prediction probability of speech
|
||||
float speechProb = 0;
|
||||
try {
|
||||
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
// If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time
|
||||
// This indicates that the speech duration has exceeded expectations and needs to recalculate the end time
|
||||
if (speechProb >= startThreshold && tempEnd != 0) {
|
||||
tempEnd = 0;
|
||||
}
|
||||
|
||||
// If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time
|
||||
if (speechProb >= startThreshold && !triggered) {
|
||||
triggered = true;
|
||||
int speechStart = (int) (currentSample - speechPadSamples);
|
||||
speechStart = Math.max(speechStart, 0);
|
||||
Map<String, Double> result = new HashMap<>();
|
||||
// Decide whether to return the result in seconds or sample count based on the returnSeconds parameter
|
||||
if (returnSeconds) {
|
||||
double speechStartSeconds = speechStart / (double) samplingRate;
|
||||
double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
|
||||
result.put("start", roundedSpeechStart);
|
||||
} else {
|
||||
result.put("start", (double) speechStart);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time
|
||||
if (speechProb < endThreshold && triggered) {
|
||||
// Initialize or update the temporary end time
|
||||
if (tempEnd == 0) {
|
||||
tempEnd = currentSample;
|
||||
}
|
||||
// If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null
|
||||
// This indicates that it is not yet possible to determine whether the speech has ended
|
||||
if (currentSample - tempEnd < minSilenceSamples) {
|
||||
return Collections.emptyMap();
|
||||
} else {
|
||||
// Calculate the speech end time, reset the trigger state and temporary end time
|
||||
int speechEnd = (int) (tempEnd + speechPadSamples);
|
||||
tempEnd = 0;
|
||||
triggered = false;
|
||||
Map<String, Double> result = new HashMap<>();
|
||||
|
||||
if (returnSeconds) {
|
||||
double speechEndSeconds = speechEnd / (double) samplingRate;
|
||||
double roundedSpeechEnd = BigDecimal.valueOf(speechEndSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
|
||||
result.put("end", roundedSpeechEnd);
|
||||
} else {
|
||||
result.put("end", (double) speechEnd);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// If the above conditions are not met, return null by default
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
public void close() throws OrtException {
|
||||
reset();
|
||||
model.close();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
package org.example;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SlieroVadOnnxModel {
|
||||
// Define private variable OrtSession
|
||||
private final OrtSession session;
|
||||
private float[][][] h;
|
||||
private float[][][] c;
|
||||
// Define the last sample rate
|
||||
private int lastSr = 0;
|
||||
// Define the last batch size
|
||||
private int lastBatchSize = 0;
|
||||
// Define a list of supported sample rates
|
||||
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
|
||||
|
||||
// Constructor
|
||||
public SlieroVadOnnxModel(String modelPath) throws OrtException {
|
||||
// Get the ONNX runtime environment
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
// Create an ONNX session options object
|
||||
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
||||
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
|
||||
opts.setInterOpNumThreads(1);
|
||||
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
|
||||
opts.setIntraOpNumThreads(1);
|
||||
// Add a CPU device, setting to false disables CPU execution optimization
|
||||
opts.addCPU(true);
|
||||
// Create an ONNX session using the environment, model path, and options
|
||||
session = env.createSession(modelPath, opts);
|
||||
// Reset states
|
||||
resetStates();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset states
|
||||
*/
|
||||
void resetStates() {
|
||||
h = new float[2][1][64];
|
||||
c = new float[2][1][64];
|
||||
lastSr = 0;
|
||||
lastBatchSize = 0;
|
||||
}
|
||||
|
||||
public void close() throws OrtException {
|
||||
session.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Define inner class ValidationResult
|
||||
*/
|
||||
public static class ValidationResult {
|
||||
public final float[][] x;
|
||||
public final int sr;
|
||||
|
||||
// Constructor
|
||||
public ValidationResult(float[][] x, int sr) {
|
||||
this.x = x;
|
||||
this.sr = sr;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Function to validate input data
|
||||
*/
|
||||
private ValidationResult validateInput(float[][] x, int sr) {
|
||||
// Process the input data with dimension 1
|
||||
if (x.length == 1) {
|
||||
x = new float[][]{x[0]};
|
||||
}
|
||||
// Throw an exception when the input data dimension is greater than 2
|
||||
if (x.length > 2) {
|
||||
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
|
||||
}
|
||||
|
||||
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
|
||||
if (sr != 16000 && (sr % 16000 == 0)) {
|
||||
int step = sr / 16000;
|
||||
float[][] reducedX = new float[x.length][];
|
||||
|
||||
for (int i = 0; i < x.length; i++) {
|
||||
float[] current = x[i];
|
||||
float[] newArr = new float[(current.length + step - 1) / step];
|
||||
|
||||
for (int j = 0, index = 0; j < current.length; j += step, index++) {
|
||||
newArr[index] = current[j];
|
||||
}
|
||||
|
||||
reducedX[i] = newArr;
|
||||
}
|
||||
|
||||
x = reducedX;
|
||||
sr = 16000;
|
||||
}
|
||||
|
||||
// If the sample rate is not in the list of supported sample rates, throw an exception
|
||||
if (!SAMPLE_RATES.contains(sr)) {
|
||||
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
|
||||
}
|
||||
|
||||
// If the input audio block is too short, throw an exception
|
||||
if (((float) sr) / x[0].length > 31.25) {
|
||||
throw new IllegalArgumentException("Input audio is too short");
|
||||
}
|
||||
|
||||
// Return the validated result
|
||||
return new ValidationResult(x, sr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to call the ONNX model
|
||||
*/
|
||||
public float[] call(float[][] x, int sr) throws OrtException {
|
||||
ValidationResult result = validateInput(x, sr);
|
||||
x = result.x;
|
||||
sr = result.sr;
|
||||
|
||||
int batchSize = x.length;
|
||||
|
||||
if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) {
|
||||
resetStates();
|
||||
}
|
||||
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
||||
OnnxTensor inputTensor = null;
|
||||
OnnxTensor hTensor = null;
|
||||
OnnxTensor cTensor = null;
|
||||
OnnxTensor srTensor = null;
|
||||
OrtSession.Result ortOutputs = null;
|
||||
|
||||
try {
|
||||
// Create input tensors
|
||||
inputTensor = OnnxTensor.createTensor(env, x);
|
||||
hTensor = OnnxTensor.createTensor(env, h);
|
||||
cTensor = OnnxTensor.createTensor(env, c);
|
||||
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
|
||||
|
||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||
inputs.put("input", inputTensor);
|
||||
inputs.put("sr", srTensor);
|
||||
inputs.put("h", hTensor);
|
||||
inputs.put("c", cTensor);
|
||||
|
||||
// Call the ONNX model for calculation
|
||||
ortOutputs = session.run(inputs);
|
||||
// Get the output results
|
||||
float[][] output = (float[][]) ortOutputs.get(0).getValue();
|
||||
h = (float[][][]) ortOutputs.get(1).getValue();
|
||||
c = (float[][][]) ortOutputs.get(2).getValue();
|
||||
|
||||
lastSr = sr;
|
||||
lastBatchSize = batchSize;
|
||||
return output[0];
|
||||
} finally {
|
||||
if (inputTensor != null) {
|
||||
inputTensor.close();
|
||||
}
|
||||
if (hTensor != null) {
|
||||
hTensor.close();
|
||||
}
|
||||
if (cTensor != null) {
|
||||
cTensor.close();
|
||||
}
|
||||
if (srTensor != null) {
|
||||
srTensor.close();
|
||||
}
|
||||
if (ortOutputs != null) {
|
||||
ortOutputs.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package org.example;
|
||||
|
||||
import ai.onnxruntime.OrtException;
|
||||
import java.io.File;
|
||||
import java.util.List;
|
||||
|
||||
public class App {
|
||||
|
||||
private static final String MODEL_PATH = "/path/silero_vad.onnx";
|
||||
private static final String EXAMPLE_WAV_FILE = "/path/example.wav";
|
||||
private static final int SAMPLE_RATE = 16000;
|
||||
private static final float THRESHOLD = 0.5f;
|
||||
private static final int MIN_SPEECH_DURATION_MS = 250;
|
||||
private static final float MAX_SPEECH_DURATION_SECONDS = Float.POSITIVE_INFINITY;
|
||||
private static final int MIN_SILENCE_DURATION_MS = 100;
|
||||
private static final int SPEECH_PAD_MS = 30;
|
||||
|
||||
public static void main(String[] args) {
|
||||
// Initialize the Voice Activity Detector
|
||||
SileroVadDetector vadDetector;
|
||||
try {
|
||||
vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
|
||||
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
||||
fromWavFile(vadDetector, new File(EXAMPLE_WAV_FILE));
|
||||
} catch (OrtException e) {
|
||||
System.err.println("Error initializing the VAD detector: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public static void fromWavFile(SileroVadDetector vadDetector, File wavFile) {
|
||||
List<SileroSpeechSegment> speechTimeList = vadDetector.getSpeechSegmentList(wavFile);
|
||||
for (SileroSpeechSegment speechSegment : speechTimeList) {
|
||||
System.out.println(String.format("start second: %f, end second: %f",
|
||||
speechSegment.getStartSecond(), speechSegment.getEndSecond()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package org.example;
|
||||
|
||||
|
||||
public class SileroSpeechSegment {
|
||||
private Integer startOffset;
|
||||
private Integer endOffset;
|
||||
private Float startSecond;
|
||||
private Float endSecond;
|
||||
|
||||
public SileroSpeechSegment() {
|
||||
}
|
||||
|
||||
public SileroSpeechSegment(Integer startOffset, Integer endOffset, Float startSecond, Float endSecond) {
|
||||
this.startOffset = startOffset;
|
||||
this.endOffset = endOffset;
|
||||
this.startSecond = startSecond;
|
||||
this.endSecond = endSecond;
|
||||
}
|
||||
|
||||
public Integer getStartOffset() {
|
||||
return startOffset;
|
||||
}
|
||||
|
||||
public Integer getEndOffset() {
|
||||
return endOffset;
|
||||
}
|
||||
|
||||
public Float getStartSecond() {
|
||||
return startSecond;
|
||||
}
|
||||
|
||||
public Float getEndSecond() {
|
||||
return endSecond;
|
||||
}
|
||||
|
||||
public void setStartOffset(Integer startOffset) {
|
||||
this.startOffset = startOffset;
|
||||
}
|
||||
|
||||
public void setEndOffset(Integer endOffset) {
|
||||
this.endOffset = endOffset;
|
||||
}
|
||||
|
||||
public void setStartSecond(Float startSecond) {
|
||||
this.startSecond = startSecond;
|
||||
}
|
||||
|
||||
public void setEndSecond(Float endSecond) {
|
||||
this.endSecond = endSecond;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package org.example;
|
||||
|
||||
|
||||
import ai.onnxruntime.OrtException;
|
||||
|
||||
import javax.sound.sampled.AudioInputStream;
|
||||
import javax.sound.sampled.AudioSystem;
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
public class SileroVadDetector {
|
||||
private final SileroVadOnnxModel model;
|
||||
private final float threshold;
|
||||
private final float negThreshold;
|
||||
private final int samplingRate;
|
||||
private final int windowSizeSample;
|
||||
private final float minSpeechSamples;
|
||||
private final float speechPadSamples;
|
||||
private final float maxSpeechSamples;
|
||||
private final float minSilenceSamples;
|
||||
private final float minSilenceSamplesAtMaxSpeech;
|
||||
private int audioLengthSamples;
|
||||
private static final float THRESHOLD_GAP = 0.15f;
|
||||
private static final Integer SAMPLING_RATE_8K = 8000;
|
||||
private static final Integer SAMPLING_RATE_16K = 16000;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
* @param onnxModelPath the path of silero-vad onnx model
|
||||
* @param threshold threshold for speech start
|
||||
* @param samplingRate audio sampling rate, only available for [8k, 16k]
|
||||
* @param minSpeechDurationMs Minimum speech length in millis, any speech duration that smaller than this value would not be considered as speech
|
||||
* @param maxSpeechDurationSeconds Maximum speech length in millis, recommend to be set as Float.POSITIVE_INFINITY
|
||||
* @param minSilenceDurationMs Minimum silence length in millis, any silence duration that smaller than this value would not be considered as silence
|
||||
* @param speechPadMs Additional pad millis for speech start and end
|
||||
* @throws OrtException
|
||||
*/
|
||||
public SileroVadDetector(String onnxModelPath, float threshold, int samplingRate,
|
||||
int minSpeechDurationMs, float maxSpeechDurationSeconds,
|
||||
int minSilenceDurationMs, int speechPadMs) throws OrtException {
|
||||
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K) {
|
||||
throw new IllegalArgumentException("Sampling rate not support, only available for [8000, 16000]");
|
||||
}
|
||||
this.model = new SileroVadOnnxModel(onnxModelPath);
|
||||
this.samplingRate = samplingRate;
|
||||
this.threshold = threshold;
|
||||
this.negThreshold = threshold - THRESHOLD_GAP;
|
||||
if (samplingRate == SAMPLING_RATE_16K) {
|
||||
this.windowSizeSample = 512;
|
||||
} else {
|
||||
this.windowSizeSample = 256;
|
||||
}
|
||||
this.minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
|
||||
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
|
||||
this.maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - windowSizeSample - 2 * speechPadSamples;
|
||||
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
||||
this.minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
|
||||
this.reset();
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to reset the state
|
||||
*/
|
||||
public void reset() {
|
||||
model.resetStates();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get speech segment list by given wav-format file
|
||||
* @param wavFile wav file
|
||||
* @return list of speech segment
|
||||
*/
|
||||
public List<SileroSpeechSegment> getSpeechSegmentList(File wavFile) {
|
||||
reset();
|
||||
try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(wavFile)){
|
||||
List<Float> speechProbList = new ArrayList<>();
|
||||
this.audioLengthSamples = audioInputStream.available() / 2;
|
||||
byte[] data = new byte[this.windowSizeSample * 2];
|
||||
int numBytesRead = 0;
|
||||
|
||||
while ((numBytesRead = audioInputStream.read(data)) != -1) {
|
||||
if (numBytesRead <= 0) {
|
||||
break;
|
||||
}
|
||||
// Convert the byte array to a float array
|
||||
float[] audioData = new float[data.length / 2];
|
||||
for (int i = 0; i < audioData.length; i++) {
|
||||
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
|
||||
}
|
||||
|
||||
float speechProb = 0;
|
||||
try {
|
||||
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
|
||||
speechProbList.add(speechProb);
|
||||
} catch (OrtException e) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
return calculateProb(speechProbList);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("SileroVadDetector getSpeechTimeList with error", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate speech segement by probability
|
||||
* @param speechProbList speech probability list
|
||||
* @return list of speech segment
|
||||
*/
|
||||
private List<SileroSpeechSegment> calculateProb(List<Float> speechProbList) {
|
||||
List<SileroSpeechSegment> result = new ArrayList<>();
|
||||
boolean triggered = false;
|
||||
int tempEnd = 0, prevEnd = 0, nextStart = 0;
|
||||
SileroSpeechSegment segment = new SileroSpeechSegment();
|
||||
|
||||
for (int i = 0; i < speechProbList.size(); i++) {
|
||||
Float speechProb = speechProbList.get(i);
|
||||
if (speechProb >= threshold && (tempEnd != 0)) {
|
||||
tempEnd = 0;
|
||||
if (nextStart < prevEnd) {
|
||||
nextStart = windowSizeSample * i;
|
||||
}
|
||||
}
|
||||
|
||||
if (speechProb >= threshold && !triggered) {
|
||||
triggered = true;
|
||||
segment.setStartOffset(windowSizeSample * i);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (triggered && (windowSizeSample * i) - segment.getStartOffset() > maxSpeechSamples) {
|
||||
if (prevEnd != 0) {
|
||||
segment.setEndOffset(prevEnd);
|
||||
result.add(segment);
|
||||
segment = new SileroSpeechSegment();
|
||||
if (nextStart < prevEnd) {
|
||||
triggered = false;
|
||||
}else {
|
||||
segment.setStartOffset(nextStart);
|
||||
}
|
||||
prevEnd = 0;
|
||||
nextStart = 0;
|
||||
tempEnd = 0;
|
||||
}else {
|
||||
segment.setEndOffset(windowSizeSample * i);
|
||||
result.add(segment);
|
||||
segment = new SileroSpeechSegment();
|
||||
prevEnd = 0;
|
||||
nextStart = 0;
|
||||
tempEnd = 0;
|
||||
triggered = false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (speechProb < negThreshold && triggered) {
|
||||
if (tempEnd == 0) {
|
||||
tempEnd = windowSizeSample * i;
|
||||
}
|
||||
if (((windowSizeSample * i) - tempEnd) > minSilenceSamplesAtMaxSpeech) {
|
||||
prevEnd = tempEnd;
|
||||
}
|
||||
if ((windowSizeSample * i) - tempEnd < minSilenceSamples) {
|
||||
continue;
|
||||
}else {
|
||||
segment.setEndOffset(tempEnd);
|
||||
if ((segment.getEndOffset() - segment.getStartOffset()) > minSpeechSamples) {
|
||||
result.add(segment);
|
||||
}
|
||||
segment = new SileroSpeechSegment();
|
||||
prevEnd = 0;
|
||||
nextStart = 0;
|
||||
tempEnd = 0;
|
||||
triggered = false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (segment.getStartOffset() != null && (audioLengthSamples - segment.getStartOffset()) > minSpeechSamples) {
|
||||
segment.setEndOffset(audioLengthSamples);
|
||||
result.add(segment);
|
||||
}
|
||||
|
||||
for (int i = 0; i < result.size(); i++) {
|
||||
SileroSpeechSegment item = result.get(i);
|
||||
if (i == 0) {
|
||||
item.setStartOffset((int)(Math.max(0,item.getStartOffset() - speechPadSamples)));
|
||||
}
|
||||
if (i != result.size() - 1) {
|
||||
SileroSpeechSegment nextItem = result.get(i + 1);
|
||||
Integer silenceDuration = nextItem.getStartOffset() - item.getEndOffset();
|
||||
if(silenceDuration < 2 * speechPadSamples){
|
||||
item.setEndOffset(item.getEndOffset() + (silenceDuration / 2 ));
|
||||
nextItem.setStartOffset(Math.max(0, nextItem.getStartOffset() - (silenceDuration / 2)));
|
||||
} else {
|
||||
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
|
||||
nextItem.setStartOffset((int)(Math.max(0,nextItem.getStartOffset() - speechPadSamples)));
|
||||
}
|
||||
}else {
|
||||
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
|
||||
}
|
||||
}
|
||||
|
||||
return mergeListAndCalculateSecond(result, samplingRate);
|
||||
}
|
||||
|
||||
private List<SileroSpeechSegment> mergeListAndCalculateSecond(List<SileroSpeechSegment> original, Integer samplingRate) {
|
||||
List<SileroSpeechSegment> result = new ArrayList<>();
|
||||
if (original == null || original.size() == 0) {
|
||||
return result;
|
||||
}
|
||||
Integer left = original.get(0).getStartOffset();
|
||||
Integer right = original.get(0).getEndOffset();
|
||||
if (original.size() > 1) {
|
||||
original.sort(Comparator.comparingLong(SileroSpeechSegment::getStartOffset));
|
||||
for (int i = 1; i < original.size(); i++) {
|
||||
SileroSpeechSegment segment = original.get(i);
|
||||
|
||||
if (segment.getStartOffset() > right) {
|
||||
result.add(new SileroSpeechSegment(left, right,
|
||||
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||
left = segment.getStartOffset();
|
||||
right = segment.getEndOffset();
|
||||
} else {
|
||||
right = Math.max(right, segment.getEndOffset());
|
||||
}
|
||||
}
|
||||
result.add(new SileroSpeechSegment(left, right,
|
||||
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||
}else {
|
||||
result.add(new SileroSpeechSegment(left, right,
|
||||
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private Float calculateSecondByOffset(Integer offset, Integer samplingRate) {
|
||||
float secondValue = offset * 1.0f / samplingRate;
|
||||
return (float) Math.floor(secondValue * 1000.0f) / 1000.0f;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
package org.example;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SileroVadOnnxModel {
|
||||
// Define private variable OrtSession
|
||||
private final OrtSession session;
|
||||
private float[][][] state;
|
||||
private float[][] context;
|
||||
// Define the last sample rate
|
||||
private int lastSr = 0;
|
||||
// Define the last batch size
|
||||
private int lastBatchSize = 0;
|
||||
// Define a list of supported sample rates
|
||||
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
|
||||
|
||||
// Constructor
|
||||
public SileroVadOnnxModel(String modelPath) throws OrtException {
|
||||
// Get the ONNX runtime environment
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
// Create an ONNX session options object
|
||||
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
||||
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
|
||||
opts.setInterOpNumThreads(1);
|
||||
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
|
||||
opts.setIntraOpNumThreads(1);
|
||||
// Add a CPU device, setting to false disables CPU execution optimization
|
||||
opts.addCPU(true);
|
||||
// Create an ONNX session using the environment, model path, and options
|
||||
session = env.createSession(modelPath, opts);
|
||||
// Reset states
|
||||
resetStates();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset states
|
||||
*/
|
||||
void resetStates() {
|
||||
state = new float[2][1][128];
|
||||
context = new float[0][];
|
||||
lastSr = 0;
|
||||
lastBatchSize = 0;
|
||||
}
|
||||
|
||||
public void close() throws OrtException {
|
||||
session.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Define inner class ValidationResult
|
||||
*/
|
||||
public static class ValidationResult {
|
||||
public final float[][] x;
|
||||
public final int sr;
|
||||
|
||||
// Constructor
|
||||
public ValidationResult(float[][] x, int sr) {
|
||||
this.x = x;
|
||||
this.sr = sr;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Function to validate input data
|
||||
*/
|
||||
private ValidationResult validateInput(float[][] x, int sr) {
|
||||
// Process the input data with dimension 1
|
||||
if (x.length == 1) {
|
||||
x = new float[][]{x[0]};
|
||||
}
|
||||
// Throw an exception when the input data dimension is greater than 2
|
||||
if (x.length > 2) {
|
||||
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
|
||||
}
|
||||
|
||||
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
|
||||
if (sr != 16000 && (sr % 16000 == 0)) {
|
||||
int step = sr / 16000;
|
||||
float[][] reducedX = new float[x.length][];
|
||||
|
||||
for (int i = 0; i < x.length; i++) {
|
||||
float[] current = x[i];
|
||||
float[] newArr = new float[(current.length + step - 1) / step];
|
||||
|
||||
for (int j = 0, index = 0; j < current.length; j += step, index++) {
|
||||
newArr[index] = current[j];
|
||||
}
|
||||
|
||||
reducedX[i] = newArr;
|
||||
}
|
||||
|
||||
x = reducedX;
|
||||
sr = 16000;
|
||||
}
|
||||
|
||||
// If the sample rate is not in the list of supported sample rates, throw an exception
|
||||
if (!SAMPLE_RATES.contains(sr)) {
|
||||
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
|
||||
}
|
||||
|
||||
// If the input audio block is too short, throw an exception
|
||||
if (((float) sr) / x[0].length > 31.25) {
|
||||
throw new IllegalArgumentException("Input audio is too short");
|
||||
}
|
||||
|
||||
// Return the validated result
|
||||
return new ValidationResult(x, sr);
|
||||
}
|
||||
|
||||
private static float[][] concatenate(float[][] a, float[][] b) {
|
||||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("The number of rows in both arrays must be the same.");
|
||||
}
|
||||
|
||||
int rows = a.length;
|
||||
int colsA = a[0].length;
|
||||
int colsB = b[0].length;
|
||||
float[][] result = new float[rows][colsA + colsB];
|
||||
|
||||
for (int i = 0; i < rows; i++) {
|
||||
System.arraycopy(a[i], 0, result[i], 0, colsA);
|
||||
System.arraycopy(b[i], 0, result[i], colsA, colsB);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private static float[][] getLastColumns(float[][] array, int contextSize) {
|
||||
int rows = array.length;
|
||||
int cols = array[0].length;
|
||||
|
||||
if (contextSize > cols) {
|
||||
throw new IllegalArgumentException("contextSize cannot be greater than the number of columns in the array.");
|
||||
}
|
||||
|
||||
float[][] result = new float[rows][contextSize];
|
||||
|
||||
for (int i = 0; i < rows; i++) {
|
||||
System.arraycopy(array[i], cols - contextSize, result[i], 0, contextSize);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to call the ONNX model
|
||||
*/
|
||||
public float[] call(float[][] x, int sr) throws OrtException {
|
||||
ValidationResult result = validateInput(x, sr);
|
||||
x = result.x;
|
||||
sr = result.sr;
|
||||
int numberSamples = 256;
|
||||
if (sr == 16000) {
|
||||
numberSamples = 512;
|
||||
}
|
||||
|
||||
if (x[0].length != numberSamples) {
|
||||
throw new IllegalArgumentException("Provided number of samples is " + x[0].length + " (Supported values: 256 for 8000 sample rate, 512 for 16000)");
|
||||
}
|
||||
|
||||
int batchSize = x.length;
|
||||
|
||||
int contextSize = 32;
|
||||
if (sr == 16000) {
|
||||
contextSize = 64;
|
||||
}
|
||||
|
||||
if (lastBatchSize == 0) {
|
||||
resetStates();
|
||||
}
|
||||
if (lastSr != 0 && lastSr != sr) {
|
||||
resetStates();
|
||||
}
|
||||
if (lastBatchSize != 0 && lastBatchSize != batchSize) {
|
||||
resetStates();
|
||||
}
|
||||
|
||||
if (context.length == 0) {
|
||||
context = new float[batchSize][contextSize];
|
||||
}
|
||||
|
||||
x = concatenate(context, x);
|
||||
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
||||
OnnxTensor inputTensor = null;
|
||||
OnnxTensor stateTensor = null;
|
||||
OnnxTensor srTensor = null;
|
||||
OrtSession.Result ortOutputs = null;
|
||||
|
||||
try {
|
||||
// Create input tensors
|
||||
inputTensor = OnnxTensor.createTensor(env, x);
|
||||
stateTensor = OnnxTensor.createTensor(env, state);
|
||||
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
|
||||
|
||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||
inputs.put("input", inputTensor);
|
||||
inputs.put("sr", srTensor);
|
||||
inputs.put("state", stateTensor);
|
||||
|
||||
// Call the ONNX model for calculation
|
||||
ortOutputs = session.run(inputs);
|
||||
// Get the output results
|
||||
float[][] output = (float[][]) ortOutputs.get(0).getValue();
|
||||
state = (float[][][]) ortOutputs.get(1).getValue();
|
||||
|
||||
context = getLastColumns(x, contextSize);
|
||||
lastSr = sr;
|
||||
lastBatchSize = batchSize;
|
||||
return output[0];
|
||||
} finally {
|
||||
if (inputTensor != null) {
|
||||
inputTensor.close();
|
||||
}
|
||||
if (stateTensor != null) {
|
||||
stateTensor.close();
|
||||
}
|
||||
if (srTensor != null) {
|
||||
srTensor.close();
|
||||
}
|
||||
if (ortOutputs != null) {
|
||||
ortOutputs.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -186,7 +186,7 @@ if __name__ == '__main__':
|
||||
help="same as trig_sum, but for switching from triggered to non-triggered state (non-speech)")
|
||||
|
||||
parser.add_argument('-N', '--num_steps', type=int, default=8,
|
||||
help="nubmer of overlapping windows to split audio chunk into (we recommend 4 or 8)")
|
||||
help="number of overlapping windows to split audio chunk into (we recommend 4 or 8)")
|
||||
|
||||
parser.add_argument('-nspw', '--num_samples_per_window', type=int, default=4000,
|
||||
help="number of samples in each window, our models were trained using 4000 samples (250 ms) per window, so this is preferable value (lesser values reduce quality)")
|
||||
@@ -198,4 +198,4 @@ if __name__ == '__main__':
|
||||
help=" minimum silence duration in samples between to separate speech chunks")
|
||||
ARGS = parser.parse_args()
|
||||
ARGS.rate=DEFAULT_SAMPLE_RATE
|
||||
main(ARGS)
|
||||
main(ARGS)
|
||||
|
||||
149
examples/parallel_example.ipynb
Normal file
149
examples/parallel_example.ipynb
Normal file
@@ -0,0 +1,149 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Install Dependencies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip install -q torchaudio\n",
|
||||
"SAMPLING_RATE = 16000\n",
|
||||
"import torch\n",
|
||||
"from pprint import pprint\n",
|
||||
"\n",
|
||||
"torch.set_num_threads(1)\n",
|
||||
"NUM_PROCESS=4 # set to the number of CPU cores in the machine\n",
|
||||
"NUM_COPIES=8\n",
|
||||
"# download wav files, make multiple copies\n",
|
||||
"for idx in range(NUM_COPIES):\n",
|
||||
" torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example{idx}.wav\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load VAD model from torch hub"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||
" model='silero_vad',\n",
|
||||
" force_reload=True,\n",
|
||||
" onnx=False)\n",
|
||||
"\n",
|
||||
"(get_speech_timestamps,\n",
|
||||
"save_audio,\n",
|
||||
"read_audio,\n",
|
||||
"VADIterator,\n",
|
||||
"collect_chunks) = utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Define a vad process function"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import multiprocessing\n",
|
||||
"\n",
|
||||
"vad_models = dict()\n",
|
||||
"\n",
|
||||
"def init_model(model):\n",
|
||||
" pid = multiprocessing.current_process().pid\n",
|
||||
" model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||
" model='silero_vad',\n",
|
||||
" force_reload=False,\n",
|
||||
" onnx=False)\n",
|
||||
" vad_models[pid] = model\n",
|
||||
"\n",
|
||||
"def vad_process(audio_file: str):\n",
|
||||
" \n",
|
||||
" pid = multiprocessing.current_process().pid\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n",
|
||||
" return get_speech_timestamps(\n",
|
||||
" wav,\n",
|
||||
" vad_models[pid],\n",
|
||||
" 0.46, # speech prob threshold\n",
|
||||
" 16000, # sample rate\n",
|
||||
" 300, # min speech duration in ms\n",
|
||||
" 20, # max speech duration in seconds\n",
|
||||
" 600, # min silence duration\n",
|
||||
" 512, # window size\n",
|
||||
" 200, # spech pad ms\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Parallelization"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from concurrent.futures import ProcessPoolExecutor, as_completed\n",
|
||||
"\n",
|
||||
"futures = []\n",
|
||||
"\n",
|
||||
"with ProcessPoolExecutor(max_workers=NUM_PROCESS, initializer=init_model, initargs=(model,)) as ex:\n",
|
||||
" for i in range(NUM_COPIES):\n",
|
||||
" futures.append(ex.submit(vad_process, f\"en_example{idx}.wav\"))\n",
|
||||
"\n",
|
||||
"for finished in as_completed(futures):\n",
|
||||
" pprint(finished.result())"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "diarization",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -31,11 +31,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install numpy==1.20.2\n",
|
||||
"#!pip install torch==1.9.0\n",
|
||||
"#!pip install matplotlib==3.4.2\n",
|
||||
"#!pip install torchaudio==0.9.0\n",
|
||||
"#!pip install soundfile==0.10.3.post1\n",
|
||||
"#!pip install numpy==2.0.2\n",
|
||||
"#!pip install torch==2.4.1\n",
|
||||
"#!pip install matplotlib==3.9.2\n",
|
||||
"#!pip install torchaudio==2.4.1\n",
|
||||
"#!pip install soundfile==0.12.1\n",
|
||||
"#!pip install pyaudio==0.2.11"
|
||||
]
|
||||
},
|
||||
@@ -61,7 +61,6 @@
|
||||
"import torchaudio\n",
|
||||
"import matplotlib\n",
|
||||
"import matplotlib.pylab as plt\n",
|
||||
"torchaudio.set_audio_backend(\"soundfile\")\n",
|
||||
"import pyaudio"
|
||||
]
|
||||
},
|
||||
@@ -118,7 +117,7 @@
|
||||
" abs_max = np.abs(sound).max()\n",
|
||||
" sound = sound.astype('float32')\n",
|
||||
" if abs_max > 0:\n",
|
||||
" sound *= 1/abs_max\n",
|
||||
" sound *= 1/32768\n",
|
||||
" sound = sound.squeeze() # depends on the use case\n",
|
||||
" return sound"
|
||||
]
|
||||
@@ -162,7 +161,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"num_samples = 1536"
|
||||
"num_samples = 512"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -180,6 +179,8 @@
|
||||
"data = []\n",
|
||||
"voiced_confidences = []\n",
|
||||
"\n",
|
||||
"frames_to_record = 50\n",
|
||||
"\n",
|
||||
"print(\"Started Recording\")\n",
|
||||
"for i in range(0, frames_to_record):\n",
|
||||
" \n",
|
||||
@@ -296,7 +297,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -310,7 +311,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.10"
|
||||
"version": "3.9.10"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
|
||||
2
examples/rust-example/.gitignore
vendored
Normal file
2
examples/rust-example/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
target/
|
||||
recorder.wav
|
||||
781
examples/rust-example/Cargo.lock
generated
Normal file
781
examples/rust-example/Cargo.lock
generated
Normal file
@@ -0,0 +1,781 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "adler"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1"
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.98"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crunchy"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "filetime"
|
||||
version = "0.2.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.0.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "form_urlencoded"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||
dependencies = [
|
||||
"typenum",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hound"
|
||||
version = "3.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f"
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
|
||||
dependencies = [
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d"
|
||||
dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.155"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae"
|
||||
dependencies = [
|
||||
"adler",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndarray"
|
||||
version = "0.15.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
||||
dependencies = [
|
||||
"matrixmultiply",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
||||
|
||||
[[package]]
|
||||
name = "ort"
|
||||
version = "2.0.0-rc.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14"
|
||||
dependencies = [
|
||||
"half",
|
||||
"js-sys",
|
||||
"libloading",
|
||||
"ndarray",
|
||||
"ort-sys",
|
||||
"thiserror",
|
||||
"tracing",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ort-sys"
|
||||
version = "2.0.0-rc.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe"
|
||||
dependencies = [
|
||||
"flate2",
|
||||
"sha2",
|
||||
"tar",
|
||||
"ureq",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.84"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.17.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"getrandom",
|
||||
"libc",
|
||||
"spin",
|
||||
"untrusted",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-example"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"hound",
|
||||
"ndarray",
|
||||
"ort",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.34"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f"
|
||||
dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.22.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.102.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.66"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tar"
|
||||
version = "0.4.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb"
|
||||
dependencies = [
|
||||
"filetime",
|
||||
"libc",
|
||||
"xattr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.61"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.61"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50"
|
||||
dependencies = [
|
||||
"tinyvec_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec_macros"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
|
||||
[[package]]
|
||||
name = "tracing"
|
||||
version = "0.1.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
||||
dependencies = [
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-attributes"
|
||||
version = "0.1.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-core"
|
||||
version = "0.1.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-bidi"
|
||||
version = "0.3.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-normalization"
|
||||
version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5"
|
||||
dependencies = [
|
||||
"tinyvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "2.9.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki",
|
||||
"url",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"idna",
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.92"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"wasm-bindgen-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-backend"
|
||||
version = "0.2.92"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da"
|
||||
dependencies = [
|
||||
"bumpalo",
|
||||
"log",
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro"
|
||||
version = "0.2.92"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"wasm-bindgen-macro-support",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro-support"
|
||||
version = "0.2.92"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-shared"
|
||||
version = "0.2.92"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96"
|
||||
|
||||
[[package]]
|
||||
name = "web-sys"
|
||||
version = "0.3.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.26.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_gnullvm",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_gnullvm",
|
||||
"windows_x86_64_msvc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.52.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
|
||||
|
||||
[[package]]
|
||||
name = "xattr"
|
||||
version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeroize"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
|
||||
9
examples/rust-example/Cargo.toml
Normal file
9
examples/rust-example/Cargo.toml
Normal file
@@ -0,0 +1,9 @@
|
||||
[package]
|
||||
name = "rust-example"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
ort = { version = "2.0.0-rc.2", features = ["load-dynamic", "ndarray"] }
|
||||
ndarray = "0.15"
|
||||
hound = "3"
|
||||
19
examples/rust-example/README.md
Normal file
19
examples/rust-example/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Stream example in Rust
|
||||
Made after [C++ stream example](https://github.com/snakers4/silero-vad/tree/master/examples/cpp)
|
||||
|
||||
## Dependencies
|
||||
- To build Rust crate `ort` you need `cc` installed.
|
||||
|
||||
## Usage
|
||||
Just
|
||||
```
|
||||
cargo run
|
||||
```
|
||||
If you run example outside of this repo adjust environment variable
|
||||
```
|
||||
SILERO_MODEL_PATH=/path/to/silero_vad.onnx cargo run
|
||||
```
|
||||
If you need to test against other wav file, not `recorder.wav`, specify it as the first argument
|
||||
```
|
||||
cargo run -- /path/to/audio/file.wav
|
||||
```
|
||||
36
examples/rust-example/src/main.rs
Normal file
36
examples/rust-example/src/main.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
mod silero;
|
||||
mod utils;
|
||||
mod vad_iter;
|
||||
|
||||
fn main() {
|
||||
let model_path = std::env::var("SILERO_MODEL_PATH")
|
||||
.unwrap_or_else(|_| String::from("../../files/silero_vad.onnx"));
|
||||
let audio_path = std::env::args()
|
||||
.nth(1)
|
||||
.unwrap_or_else(|| String::from("recorder.wav"));
|
||||
let mut wav_reader = hound::WavReader::open(audio_path).unwrap();
|
||||
let sample_rate = match wav_reader.spec().sample_rate {
|
||||
8000 => utils::SampleRate::EightkHz,
|
||||
16000 => utils::SampleRate::SixteenkHz,
|
||||
_ => panic!("Unsupported sample rate. Expect 8 kHz or 16 kHz."),
|
||||
};
|
||||
if wav_reader.spec().sample_format != hound::SampleFormat::Int {
|
||||
panic!("Unsupported sample format. Expect Int.");
|
||||
}
|
||||
let content = wav_reader
|
||||
.samples()
|
||||
.filter_map(|x| x.ok())
|
||||
.collect::<Vec<i16>>();
|
||||
assert!(!content.is_empty());
|
||||
let silero = silero::Silero::new(sample_rate, model_path).unwrap();
|
||||
let vad_params = utils::VadParams {
|
||||
sample_rate: sample_rate.into(),
|
||||
..Default::default()
|
||||
};
|
||||
let mut vad_iterator = vad_iter::VadIter::new(silero, vad_params);
|
||||
vad_iterator.process(&content).unwrap();
|
||||
for timestamp in vad_iterator.speeches() {
|
||||
println!("{}", timestamp);
|
||||
}
|
||||
println!("Finished.");
|
||||
}
|
||||
54
examples/rust-example/src/silero.rs
Normal file
54
examples/rust-example/src/silero.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use crate::utils;
|
||||
use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Silero {
|
||||
session: ort::Session,
|
||||
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
||||
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
||||
}
|
||||
|
||||
impl Silero {
|
||||
pub fn new(
|
||||
sample_rate: utils::SampleRate,
|
||||
model_path: impl AsRef<Path>,
|
||||
) -> Result<Self, ort::Error> {
|
||||
let session = ort::Session::builder()?.commit_from_file(model_path)?;
|
||||
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
|
||||
Ok(Self {
|
||||
session,
|
||||
sample_rate,
|
||||
state,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||
}
|
||||
|
||||
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
|
||||
let data = audio_frame
|
||||
.iter()
|
||||
.map(|x| (*x as f32) / (i16::MAX as f32))
|
||||
.collect::<Vec<_>>();
|
||||
let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
|
||||
frame = frame.slice(s![.., ..480]).to_owned();
|
||||
let inps = ort::inputs![
|
||||
frame,
|
||||
std::mem::take(&mut self.state),
|
||||
self.sample_rate.clone(),
|
||||
]?;
|
||||
let res = self
|
||||
.session
|
||||
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
|
||||
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
|
||||
Ok(*res["output"]
|
||||
.try_extract_raw_tensor::<f32>()
|
||||
.unwrap()
|
||||
.1
|
||||
.first()
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
60
examples/rust-example/src/utils.rs
Normal file
60
examples/rust-example/src/utils.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum SampleRate {
|
||||
EightkHz,
|
||||
SixteenkHz,
|
||||
}
|
||||
|
||||
impl From<SampleRate> for i64 {
|
||||
fn from(value: SampleRate) -> Self {
|
||||
match value {
|
||||
SampleRate::EightkHz => 8000,
|
||||
SampleRate::SixteenkHz => 16000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SampleRate> for usize {
|
||||
fn from(value: SampleRate) -> Self {
|
||||
match value {
|
||||
SampleRate::EightkHz => 8000,
|
||||
SampleRate::SixteenkHz => 16000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct VadParams {
|
||||
pub frame_size: usize,
|
||||
pub threshold: f32,
|
||||
pub min_silence_duration_ms: usize,
|
||||
pub speech_pad_ms: usize,
|
||||
pub min_speech_duration_ms: usize,
|
||||
pub max_speech_duration_s: f32,
|
||||
pub sample_rate: usize,
|
||||
}
|
||||
|
||||
impl Default for VadParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
frame_size: 64,
|
||||
threshold: 0.5,
|
||||
min_silence_duration_ms: 0,
|
||||
speech_pad_ms: 64,
|
||||
min_speech_duration_ms: 64,
|
||||
max_speech_duration_s: f32::INFINITY,
|
||||
sample_rate: 16000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct TimeStamp {
|
||||
pub start: i64,
|
||||
pub end: i64,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TimeStamp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[start:{:08}, end:{:08}]", self.start, self.end)
|
||||
}
|
||||
}
|
||||
223
examples/rust-example/src/vad_iter.rs
Normal file
223
examples/rust-example/src/vad_iter.rs
Normal file
@@ -0,0 +1,223 @@
|
||||
use crate::{silero, utils};
|
||||
|
||||
const DEBUG_SPEECH_PROB: bool = true;
|
||||
#[derive(Debug)]
|
||||
pub struct VadIter {
|
||||
silero: silero::Silero,
|
||||
params: Params,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl VadIter {
|
||||
pub fn new(silero: silero::Silero, params: utils::VadParams) -> Self {
|
||||
Self {
|
||||
silero,
|
||||
params: Params::from(params),
|
||||
state: State::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> {
|
||||
self.reset_states();
|
||||
for audio_frame in samples.chunks_exact(self.params.frame_size_samples) {
|
||||
let speech_prob: f32 = self.silero.calc_level(audio_frame)?;
|
||||
self.state.update(&self.params, speech_prob);
|
||||
}
|
||||
self.state.check_for_last_speech(samples.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn speeches(&self) -> &[utils::TimeStamp] {
|
||||
&self.state.speeches
|
||||
}
|
||||
}
|
||||
|
||||
impl VadIter {
|
||||
fn reset_states(&mut self) {
|
||||
self.silero.reset();
|
||||
self.state = State::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug)]
|
||||
struct Params {
|
||||
frame_size: usize,
|
||||
threshold: f32,
|
||||
min_silence_duration_ms: usize,
|
||||
speech_pad_ms: usize,
|
||||
min_speech_duration_ms: usize,
|
||||
max_speech_duration_s: f32,
|
||||
sample_rate: usize,
|
||||
sr_per_ms: usize,
|
||||
frame_size_samples: usize,
|
||||
min_speech_samples: usize,
|
||||
speech_pad_samples: usize,
|
||||
max_speech_samples: f32,
|
||||
min_silence_samples: usize,
|
||||
min_silence_samples_at_max_speech: usize,
|
||||
}
|
||||
|
||||
impl From<utils::VadParams> for Params {
|
||||
fn from(value: utils::VadParams) -> Self {
|
||||
let frame_size = value.frame_size;
|
||||
let threshold = value.threshold;
|
||||
let min_silence_duration_ms = value.min_silence_duration_ms;
|
||||
let speech_pad_ms = value.speech_pad_ms;
|
||||
let min_speech_duration_ms = value.min_speech_duration_ms;
|
||||
let max_speech_duration_s = value.max_speech_duration_s;
|
||||
let sample_rate = value.sample_rate;
|
||||
let sr_per_ms = sample_rate / 1000;
|
||||
let frame_size_samples = frame_size * sr_per_ms;
|
||||
let min_speech_samples = sr_per_ms * min_speech_duration_ms;
|
||||
let speech_pad_samples = sr_per_ms * speech_pad_ms;
|
||||
let max_speech_samples = sample_rate as f32 * max_speech_duration_s
|
||||
- frame_size_samples as f32
|
||||
- 2.0 * speech_pad_samples as f32;
|
||||
let min_silence_samples = sr_per_ms * min_silence_duration_ms;
|
||||
let min_silence_samples_at_max_speech = sr_per_ms * 98;
|
||||
Self {
|
||||
frame_size,
|
||||
threshold,
|
||||
min_silence_duration_ms,
|
||||
speech_pad_ms,
|
||||
min_speech_duration_ms,
|
||||
max_speech_duration_s,
|
||||
sample_rate,
|
||||
sr_per_ms,
|
||||
frame_size_samples,
|
||||
min_speech_samples,
|
||||
speech_pad_samples,
|
||||
max_speech_samples,
|
||||
min_silence_samples,
|
||||
min_silence_samples_at_max_speech,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct State {
|
||||
current_sample: usize,
|
||||
temp_end: usize,
|
||||
next_start: usize,
|
||||
prev_end: usize,
|
||||
triggered: bool,
|
||||
current_speech: utils::TimeStamp,
|
||||
speeches: Vec<utils::TimeStamp>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn update(&mut self, params: &Params, speech_prob: f32) {
|
||||
self.current_sample += params.frame_size_samples;
|
||||
if speech_prob > params.threshold {
|
||||
if self.temp_end != 0 {
|
||||
self.temp_end = 0;
|
||||
if self.next_start < self.prev_end {
|
||||
self.next_start = self
|
||||
.current_sample
|
||||
.saturating_sub(params.frame_size_samples)
|
||||
}
|
||||
}
|
||||
if !self.triggered {
|
||||
self.debug(speech_prob, params, "start");
|
||||
self.triggered = true;
|
||||
self.current_speech.start =
|
||||
self.current_sample as i64 - params.frame_size_samples as i64;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if self.triggered
|
||||
&& (self.current_sample as i64 - self.current_speech.start) as f32
|
||||
> params.max_speech_samples
|
||||
{
|
||||
if self.prev_end > 0 {
|
||||
self.current_speech.end = self.prev_end as _;
|
||||
self.take_speech();
|
||||
if self.next_start < self.prev_end {
|
||||
self.triggered = false
|
||||
} else {
|
||||
self.current_speech.start = self.next_start as _;
|
||||
}
|
||||
self.prev_end = 0;
|
||||
self.next_start = 0;
|
||||
self.temp_end = 0;
|
||||
} else {
|
||||
self.current_speech.end = self.current_sample as _;
|
||||
self.take_speech();
|
||||
self.prev_end = 0;
|
||||
self.next_start = 0;
|
||||
self.temp_end = 0;
|
||||
self.triggered = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if speech_prob >= (params.threshold - 0.15) && (speech_prob < params.threshold) {
|
||||
if self.triggered {
|
||||
self.debug(speech_prob, params, "speaking")
|
||||
} else {
|
||||
self.debug(speech_prob, params, "silence")
|
||||
}
|
||||
}
|
||||
if self.triggered && speech_prob < (params.threshold - 0.15) {
|
||||
self.debug(speech_prob, params, "end");
|
||||
if self.temp_end == 0 {
|
||||
self.temp_end = self.current_sample;
|
||||
}
|
||||
if self.current_sample.saturating_sub(self.temp_end)
|
||||
> params.min_silence_samples_at_max_speech
|
||||
{
|
||||
self.prev_end = self.temp_end;
|
||||
}
|
||||
if self.current_sample.saturating_sub(self.temp_end) >= params.min_silence_samples {
|
||||
self.current_speech.end = self.temp_end as _;
|
||||
if self.current_speech.end - self.current_speech.start
|
||||
> params.min_speech_samples as _
|
||||
{
|
||||
self.take_speech();
|
||||
self.prev_end = 0;
|
||||
self.next_start = 0;
|
||||
self.temp_end = 0;
|
||||
self.triggered = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn take_speech(&mut self) {
|
||||
self.speeches.push(std::mem::take(&mut self.current_speech)); // current speech becomes TimeStamp::default() due to take()
|
||||
}
|
||||
|
||||
fn check_for_last_speech(&mut self, last_sample: usize) {
|
||||
if self.current_speech.start > 0 {
|
||||
self.current_speech.end = last_sample as _;
|
||||
self.take_speech();
|
||||
self.prev_end = 0;
|
||||
self.next_start = 0;
|
||||
self.temp_end = 0;
|
||||
self.triggered = false;
|
||||
}
|
||||
}
|
||||
|
||||
fn debug(&self, speech_prob: f32, params: &Params, title: &str) {
|
||||
if DEBUG_SPEECH_PROB {
|
||||
let speech = self.current_sample as f32
|
||||
- params.frame_size_samples as f32
|
||||
- if title == "end" {
|
||||
params.speech_pad_samples
|
||||
} else {
|
||||
0
|
||||
} as f32; // minus window_size_samples to get precise start time point.
|
||||
println!(
|
||||
"[{:10}: {:.3} s ({:.3}) {:8}]",
|
||||
title,
|
||||
speech / params.sample_rate as f32,
|
||||
speech_prob,
|
||||
self.current_sample - params.frame_size_samples,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
{"59": "mg, Malagasy", "76": "tk, Turkmen", "20": "lb, Luxembourgish, Letzeburgesch", "62": "or, Oriya", "30": "en, English", "26": "oc, Occitan", "69": "no, Norwegian", "77": "sr, Serbian", "90": "bs, Bosnian", "71": "el, Greek, Modern (1453\u2013)", "15": "az, Azerbaijani", "12": "lo, Lao", "85": "zh-HK, Chinese", "79": "cs, Czech", "43": "sv, Swedish", "37": "mn, Mongolian", "32": "fi, Finnish", "51": "tg, Tajik", "46": "am, Amharic", "17": "nn, Norwegian Nynorsk", "40": "ja, Japanese", "8": "it, Italian", "21": "ha, Hausa", "11": "as, Assamese", "29": "fa, Persian", "82": "bn, Bengali", "54": "mk, Macedonian", "31": "sw, Swahili", "45": "vi, Vietnamese", "41": "ur, Urdu", "74": "bo, Tibetan", "4": "hi, Hindi", "86": "mr, Marathi", "3": "fy-NL, Western Frisian", "65": "sk, Slovak", "2": "ln, Lingala", "92": "gl, Galician", "53": "sn, Shona", "87": "su, Sundanese", "35": "tt, Tatar", "93": "kn, Kannada", "6": "yo, Yoruba", "27": "ps, Pashto, Pushto", "34": "hy, Armenian", "25": "pa-IN, Punjabi, Panjabi", "23": "nl, Dutch, Flemish", "48": "th, Thai", "73": "mt, Maltese", "55": "ar, Arabic", "89": "ba, Bashkir", "78": "bg, Bulgarian", "42": "yi, Yiddish", "5": "ru, Russian", "84": "sv-SE, Swedish", "80": "tr, Turkish", "33": "sq, Albanian", "38": "kk, Kazakh", "50": "pl, Polish", "9": "hr, Croatian", "66": "ky, Kirghiz, Kyrgyz", "49": "hu, Hungarian", "10": "si, Sinhala, Sinhalese", "56": "la, Latin", "75": "de, German", "14": "ko, Korean", "22": "id, Indonesian", "47": "sl, Slovenian", "57": "be, Belarusian", "36": "ta, Tamil", "7": "da, Danish", "91": "sd, Sindhi", "28": "et, Estonian", "63": "pt, Portuguese", "60": "ne, Nepali", "94": "zh-TW, Chinese", "18": "zh-CN, Chinese", "88": "rw, Kinyarwanda", "19": "es, Spanish, Castilian", "39": "ht, Haitian, Haitian Creole", "64": "tl, Tagalog", "83": "ms, Malay", "70": "ro, Romanian, Moldavian, Moldovan", "68": "pa, Punjabi, Panjabi", "52": "uz, Uzbek", "58": "km, Central Khmer", "67": "my, Burmese", "0": "fr, French", "24": "af, Afrikaans", "16": "gu, Gujarati", "81": "so, Somali", "13": "uk, Ukrainian", "44": "ca, Catalan, Valencian", "72": "ml, Malayalam", "61": "te, Telugu", "1": "zh, Chinese"}
|
||||
@@ -1 +0,0 @@
|
||||
{"0": ["Afrikaans", "Dutch, Flemish", "Western Frisian"], "1": ["Turkish", "Azerbaijani"], "2": ["Russian", "Slovak", "Ukrainian", "Czech", "Polish", "Belarusian"], "3": ["Bulgarian", "Macedonian", "Serbian", "Croatian", "Bosnian", "Slovenian"], "4": ["Norwegian Nynorsk", "Swedish", "Danish", "Norwegian"], "5": ["English"], "6": ["Finnish", "Estonian"], "7": ["Yiddish", "Luxembourgish, Letzeburgesch", "German"], "8": ["Spanish", "Occitan", "Portuguese", "Catalan, Valencian", "Galician", "Spanish, Castilian", "Italian"], "9": ["Maltese", "Arabic"], "10": ["Marathi"], "11": ["Hindi", "Urdu"], "12": ["Lao", "Thai"], "13": ["Malay", "Indonesian"], "14": ["Romanian, Moldavian, Moldovan"], "15": ["Tagalog"], "16": ["Tajik", "Persian"], "17": ["Kazakh", "Uzbek", "Kirghiz, Kyrgyz"], "18": ["Kinyarwanda"], "19": ["Tatar", "Bashkir"], "20": ["French"], "21": ["Chinese"], "22": ["Lingala"], "23": ["Yoruba"], "24": ["Sinhala, Sinhalese"], "25": ["Assamese"], "26": ["Korean"], "27": ["Gujarati"], "28": ["Hausa"], "29": ["Punjabi, Panjabi"], "30": ["Pashto, Pushto"], "31": ["Swahili"], "32": ["Albanian"], "33": ["Armenian"], "34": ["Mongolian"], "35": ["Tamil"], "36": ["Haitian, Haitian Creole"], "37": ["Japanese"], "38": ["Vietnamese"], "39": ["Amharic"], "40": ["Hungarian"], "41": ["Shona"], "42": ["Latin"], "43": ["Central Khmer"], "44": ["Malagasy"], "45": ["Nepali"], "46": ["Telugu"], "47": ["Oriya"], "48": ["Burmese"], "49": ["Greek, Modern (1453\u2013)"], "50": ["Malayalam"], "51": ["Tibetan"], "52": ["Turkmen"], "53": ["Somali"], "54": ["Bengali"], "55": ["Sundanese"], "56": ["Sindhi"], "57": ["Kannada"]}
|
||||
Binary file not shown.
Binary file not shown.
92
hubconf.py
92
hubconf.py
@@ -1,23 +1,26 @@
|
||||
dependencies = ['torch', 'torchaudio']
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from utils_vad import (init_jit_model,
|
||||
get_speech_timestamps,
|
||||
get_number_ts,
|
||||
get_language,
|
||||
get_language_and_group,
|
||||
save_audio,
|
||||
read_audio,
|
||||
VADIterator,
|
||||
collect_chunks,
|
||||
drop_chunks,
|
||||
Validator,
|
||||
OnnxWrapper)
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||
from silero_vad.utils_vad import (init_jit_model,
|
||||
get_speech_timestamps,
|
||||
save_audio,
|
||||
read_audio,
|
||||
VADIterator,
|
||||
collect_chunks,
|
||||
OnnxWrapper)
|
||||
|
||||
|
||||
def versiontuple(v):
|
||||
return tuple(map(int, (v.split('+')[0].split("."))))
|
||||
splitted = v.split('+')[0].split(".")
|
||||
version_list = []
|
||||
for i in splitted:
|
||||
try:
|
||||
version_list.append(int(i))
|
||||
except:
|
||||
version_list.append(0)
|
||||
return tuple(version_list)
|
||||
|
||||
|
||||
def silero_vad(onnx=False, force_onnx_cpu=False):
|
||||
@@ -32,7 +35,7 @@ def silero_vad(onnx=False, force_onnx_cpu=False):
|
||||
if versiontuple(installed_version) < versiontuple(supported_version):
|
||||
raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)')
|
||||
|
||||
model_dir = os.path.join(os.path.dirname(__file__), 'files')
|
||||
model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data')
|
||||
if onnx:
|
||||
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
|
||||
else:
|
||||
@@ -44,62 +47,3 @@ def silero_vad(onnx=False, force_onnx_cpu=False):
|
||||
collect_chunks)
|
||||
|
||||
return model, utils
|
||||
|
||||
|
||||
def silero_number_detector(onnx=False, force_onnx_cpu=False):
|
||||
"""Silero Number Detector
|
||||
Returns a model with a set of utils
|
||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
||||
"""
|
||||
if onnx:
|
||||
url = 'https://models.silero.ai/vad_models/number_detector.onnx'
|
||||
else:
|
||||
url = 'https://models.silero.ai/vad_models/number_detector.jit'
|
||||
model = Validator(url, force_onnx_cpu)
|
||||
utils = (get_number_ts,
|
||||
save_audio,
|
||||
read_audio,
|
||||
collect_chunks,
|
||||
drop_chunks)
|
||||
|
||||
return model, utils
|
||||
|
||||
|
||||
def silero_lang_detector(onnx=False, force_onnx_cpu=False):
|
||||
"""Silero Language Classifier
|
||||
Returns a model with a set of utils
|
||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
||||
"""
|
||||
if onnx:
|
||||
url = 'https://models.silero.ai/vad_models/number_detector.onnx'
|
||||
else:
|
||||
url = 'https://models.silero.ai/vad_models/number_detector.jit'
|
||||
model = Validator(url, force_onnx_cpu)
|
||||
utils = (get_language,
|
||||
read_audio)
|
||||
|
||||
return model, utils
|
||||
|
||||
|
||||
def silero_lang_detector_95(onnx=False, force_onnx_cpu=False):
|
||||
"""Silero Language Classifier (95 languages)
|
||||
Returns a model with a set of utils
|
||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
||||
"""
|
||||
|
||||
if onnx:
|
||||
url = 'https://models.silero.ai/vad_models/lang_classifier_95.onnx'
|
||||
else:
|
||||
url = 'https://models.silero.ai/vad_models/lang_classifier_95.jit'
|
||||
model = Validator(url, force_onnx_cpu)
|
||||
|
||||
model_dir = os.path.join(os.path.dirname(__file__), 'files')
|
||||
with open(os.path.join(model_dir, 'lang_dict_95.json'), 'r') as f:
|
||||
lang_dict = json.load(f)
|
||||
|
||||
with open(os.path.join(model_dir, 'lang_group_dict_95.json'), 'r') as f:
|
||||
lang_group_dict = json.load(f)
|
||||
|
||||
utils = (get_language_and_group, read_audio)
|
||||
|
||||
return model, lang_dict, lang_group_dict, utils
|
||||
|
||||
35
pyproject.toml
Normal file
35
pyproject.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
[project]
|
||||
name = "silero-vad"
|
||||
version = "5.1"
|
||||
authors = [
|
||||
{name="Silero Team", email="hello@silero.ai"},
|
||||
]
|
||||
description = "Voice Activity Detector (VAD) by Silero"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
classifiers = [
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Scientific/Engineering",
|
||||
]
|
||||
dependencies = [
|
||||
"torch>=1.12.0",
|
||||
"torchaudio>=0.12.0",
|
||||
"onnxruntime>=1.16.1",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/snakers4/silero-vad"
|
||||
Issues = "https://github.com/snakers4/silero-vad/issues"
|
||||
313
silero-vad.ipynb
313
silero-vad.ipynb
@@ -1,14 +1,5 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "FpMplOCA2Fwp"
|
||||
},
|
||||
"source": [
|
||||
"#VAD"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
@@ -52,20 +43,30 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"USE_PIP = True # download model using pip package or torch.hub\n",
|
||||
"USE_ONNX = False # change this to True if you want to test onnx model\n",
|
||||
"if USE_ONNX:\n",
|
||||
" !pip install -q onnxruntime\n",
|
||||
" \n",
|
||||
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||
" model='silero_vad',\n",
|
||||
" force_reload=True,\n",
|
||||
" onnx=USE_ONNX)\n",
|
||||
"if USE_PIP:\n",
|
||||
" !pip install -q silero-vad\n",
|
||||
" from silero_vad import (load_silero_vad,\n",
|
||||
" read_audio,\n",
|
||||
" get_speech_timestamps,\n",
|
||||
" save_audio,\n",
|
||||
" VADIterator,\n",
|
||||
" collect_chunks)\n",
|
||||
" model = load_silero_vad(onnx=USE_ONNX)\n",
|
||||
"else:\n",
|
||||
" model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||
" model='silero_vad',\n",
|
||||
" force_reload=True,\n",
|
||||
" onnx=USE_ONNX)\n",
|
||||
"\n",
|
||||
"(get_speech_timestamps,\n",
|
||||
" save_audio,\n",
|
||||
" read_audio,\n",
|
||||
" VADIterator,\n",
|
||||
" collect_chunks) = utils"
|
||||
" (get_speech_timestamps,\n",
|
||||
" save_audio,\n",
|
||||
" read_audio,\n",
|
||||
" VADIterator,\n",
|
||||
" collect_chunks) = utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -74,16 +75,7 @@
|
||||
"id": "fXbbaUO3jsrw"
|
||||
},
|
||||
"source": [
|
||||
"## Full Audio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "RAfJPb_a-Auj"
|
||||
},
|
||||
"source": [
|
||||
"**Speech timestapms from full audio**"
|
||||
"## Speech timestapms from full audio"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -110,10 +102,33 @@
|
||||
"source": [
|
||||
"# merge all speech chunks to one audio\n",
|
||||
"save_audio('only_speech.wav',\n",
|
||||
" collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE) \n",
|
||||
" collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE)\n",
|
||||
"Audio('only_speech.wav')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "zeO1xCqxUC6w"
|
||||
},
|
||||
"source": [
|
||||
"## Entire audio inference"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "LjZBcsaTT7Mk"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||
"# audio is being splitted into 31.25 ms long pieces\n",
|
||||
"# so output length equals ceil(input_length * 31.25 / SAMPLING_RATE)\n",
|
||||
"predicts = model.audio_forward(wav, sr=SAMPLING_RATE)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
@@ -133,10 +148,10 @@
|
||||
"source": [
|
||||
"## using VADIterator class\n",
|
||||
"\n",
|
||||
"vad_iterator = VADIterator(model)\n",
|
||||
"vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE)\n",
|
||||
"wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||
"\n",
|
||||
"window_size_samples = 1536 # number of samples in a single audio chunk\n",
|
||||
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
|
||||
"for i in range(0, len(wav), window_size_samples):\n",
|
||||
" chunk = wav[i: i+ window_size_samples]\n",
|
||||
" if len(chunk) < window_size_samples:\n",
|
||||
@@ -159,7 +174,7 @@
|
||||
"\n",
|
||||
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||
"speech_probs = []\n",
|
||||
"window_size_samples = 1536\n",
|
||||
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
|
||||
"for i in range(0, len(wav), window_size_samples):\n",
|
||||
" chunk = wav[i: i+ window_size_samples]\n",
|
||||
" if len(chunk) < window_size_samples:\n",
|
||||
@@ -170,238 +185,6 @@
|
||||
"\n",
|
||||
"print(speech_probs[:10]) # first 10 chunks predicts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"heading_collapsed": true,
|
||||
"id": "36jY0niD2Fww"
|
||||
},
|
||||
"source": [
|
||||
"# Number detector"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"heading_collapsed": true,
|
||||
"hidden": true,
|
||||
"id": "scd1DlS42Fwx"
|
||||
},
|
||||
"source": [
|
||||
"## Install Dependencies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"hidden": true,
|
||||
"id": "Kq5gQuYq2Fwx"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Install and Import Dependencies\n",
|
||||
"\n",
|
||||
"# this assumes that you have a relevant version of PyTorch installed\n",
|
||||
"!pip install -q torchaudio\n",
|
||||
"\n",
|
||||
"SAMPLING_RATE = 16000\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"torch.set_num_threads(1)\n",
|
||||
"\n",
|
||||
"from IPython.display import Audio\n",
|
||||
"from pprint import pprint\n",
|
||||
"# download example\n",
|
||||
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en_num.wav', 'en_number_example.wav')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "dPwCFHmFycUF"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"USE_ONNX = False # change this to True if you want to test onnx model\n",
|
||||
"if USE_ONNX:\n",
|
||||
" !pip install -q onnxruntime\n",
|
||||
" \n",
|
||||
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||
" model='silero_number_detector',\n",
|
||||
" force_reload=True,\n",
|
||||
" onnx=USE_ONNX)\n",
|
||||
"\n",
|
||||
"(get_number_ts,\n",
|
||||
" save_audio,\n",
|
||||
" read_audio,\n",
|
||||
" collect_chunks,\n",
|
||||
" drop_chunks) = utils\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"heading_collapsed": true,
|
||||
"hidden": true,
|
||||
"id": "qhPa30ij2Fwy"
|
||||
},
|
||||
"source": [
|
||||
"## Full audio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"hidden": true,
|
||||
"id": "EXpau6xq2Fwy"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"wav = read_audio('en_number_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||
"# get number timestamps from full audio file\n",
|
||||
"number_timestamps = get_number_ts(wav, model)\n",
|
||||
"pprint(number_timestamps)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"hidden": true,
|
||||
"id": "u-KfXRhZ2Fwy"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# convert ms in timestamps to samples\n",
|
||||
"for timestamp in number_timestamps:\n",
|
||||
" timestamp['start'] = int(timestamp['start'] * SAMPLING_RATE / 1000)\n",
|
||||
" timestamp['end'] = int(timestamp['end'] * SAMPLING_RATE / 1000)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"hidden": true,
|
||||
"id": "iwYEC4aZ2Fwy"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# merge all number chunks to one audio\n",
|
||||
"save_audio('only_numbers.wav',\n",
|
||||
" collect_chunks(number_timestamps, wav), SAMPLING_RATE) \n",
|
||||
"Audio('only_numbers.wav')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"hidden": true,
|
||||
"id": "fHaYejX12Fwy"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# drop all number chunks from audio\n",
|
||||
"save_audio('no_numbers.wav',\n",
|
||||
" drop_chunks(number_timestamps, wav), SAMPLING_RATE) \n",
|
||||
"Audio('no_numbers.wav')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"heading_collapsed": true,
|
||||
"id": "PnKtJKbq2Fwz"
|
||||
},
|
||||
"source": [
|
||||
"# Language detector"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"heading_collapsed": true,
|
||||
"hidden": true,
|
||||
"id": "F5cAmMbP2Fwz"
|
||||
},
|
||||
"source": [
|
||||
"## Install Dependencies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"hidden": true,
|
||||
"id": "Zu9D0t6n2Fwz"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Install and Import Dependencies\n",
|
||||
"\n",
|
||||
"# this assumes that you have a relevant version of PyTorch installed\n",
|
||||
"!pip install -q torchaudio\n",
|
||||
"\n",
|
||||
"SAMPLING_RATE = 16000\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"torch.set_num_threads(1)\n",
|
||||
"\n",
|
||||
"from IPython.display import Audio\n",
|
||||
"from pprint import pprint\n",
|
||||
"# download example\n",
|
||||
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "JfRKDZiRztFe"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"USE_ONNX = False # change this to True if you want to test onnx model\n",
|
||||
"if USE_ONNX:\n",
|
||||
" !pip install -q onnxruntime\n",
|
||||
" \n",
|
||||
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||
" model='silero_lang_detector',\n",
|
||||
" force_reload=True,\n",
|
||||
" onnx=USE_ONNX)\n",
|
||||
"\n",
|
||||
"get_language, read_audio = utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"heading_collapsed": true,
|
||||
"hidden": true,
|
||||
"id": "iC696eMX2Fwz"
|
||||
},
|
||||
"source": [
|
||||
"## Full audio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"hidden": true,
|
||||
"id": "c8UYnYBF2Fw0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||
"lang = get_language(wav, model)\n",
|
||||
"print(lang)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
12
src/silero_vad/__init__.py
Normal file
12
src/silero_vad/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from importlib.metadata import version
|
||||
try:
|
||||
__version__ = version(__name__)
|
||||
except:
|
||||
pass
|
||||
|
||||
from silero_vad.model import load_silero_vad
|
||||
from silero_vad.utils_vad import (get_speech_timestamps,
|
||||
save_audio,
|
||||
read_audio,
|
||||
VADIterator,
|
||||
collect_chunks)
|
||||
0
src/silero_vad/data/__init__.py
Normal file
0
src/silero_vad/data/__init__.py
Normal file
BIN
src/silero_vad/data/silero_vad.jit
Normal file
BIN
src/silero_vad/data/silero_vad.jit
Normal file
Binary file not shown.
BIN
src/silero_vad/data/silero_vad.onnx
Normal file
BIN
src/silero_vad/data/silero_vad.onnx
Normal file
Binary file not shown.
BIN
src/silero_vad/data/silero_vad_half.onnx
Normal file
BIN
src/silero_vad/data/silero_vad_half.onnx
Normal file
Binary file not shown.
25
src/silero_vad/model.py
Normal file
25
src/silero_vad/model.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from .utils_vad import init_jit_model, OnnxWrapper
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
def load_silero_vad(onnx=False):
|
||||
model_name = 'silero_vad.onnx' if onnx else 'silero_vad.jit'
|
||||
package_path = "silero_vad.data"
|
||||
|
||||
try:
|
||||
import importlib_resources as impresources
|
||||
model_file_path = str(impresources.files(package_path).joinpath(model_name))
|
||||
except:
|
||||
from importlib import resources as impresources
|
||||
try:
|
||||
with impresources.path(package_path, model_name) as f:
|
||||
model_file_path = f
|
||||
except:
|
||||
model_file_path = str(impresources.files(package_path).joinpath(model_name))
|
||||
|
||||
if onnx:
|
||||
model = OnnxWrapper(model_file_path, force_onnx_cpu=True)
|
||||
else:
|
||||
model = init_jit_model(model_file_path)
|
||||
|
||||
return model
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from typing import List
|
||||
import torch.nn.functional as F
|
||||
from typing import Callable, List
|
||||
import warnings
|
||||
|
||||
languages = ['ru', 'en', 'de', 'es']
|
||||
@@ -13,12 +12,15 @@ class OnnxWrapper():
|
||||
import numpy as np
|
||||
global np
|
||||
import onnxruntime
|
||||
|
||||
opts = onnxruntime.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 1
|
||||
|
||||
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])
|
||||
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
||||
else:
|
||||
self.session = onnxruntime.InferenceSession(path)
|
||||
self.session.intra_op_num_threads = 1
|
||||
self.session.inter_op_num_threads = 1
|
||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||
|
||||
self.reset_states()
|
||||
self.sample_rates = [8000, 16000]
|
||||
@@ -31,27 +33,32 @@ class OnnxWrapper():
|
||||
|
||||
if sr != 16000 and (sr % 16000 == 0):
|
||||
step = sr // 16000
|
||||
x = x[::step]
|
||||
x = x[:,::step]
|
||||
sr = 16000
|
||||
|
||||
if sr not in self.sample_rates:
|
||||
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
||||
|
||||
if sr / x.shape[1] > 31.25:
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
return x, sr
|
||||
|
||||
def reset_states(self, batch_size=1):
|
||||
self._h = np.zeros((2, batch_size, 64)).astype('float32')
|
||||
self._c = np.zeros((2, batch_size, 64)).astype('float32')
|
||||
self._state = torch.zeros((2, batch_size, 128)).float()
|
||||
self._context = torch.zeros(0)
|
||||
self._last_sr = 0
|
||||
self._last_batch_size = 0
|
||||
|
||||
def __call__(self, x, sr: int):
|
||||
|
||||
x, sr = self._validate_input(x, sr)
|
||||
num_samples = 512 if sr == 16000 else 256
|
||||
|
||||
if x.shape[-1] != num_samples:
|
||||
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||
|
||||
batch_size = x.shape[0]
|
||||
context_size = 64 if sr == 16000 else 32
|
||||
|
||||
if not self._last_batch_size:
|
||||
self.reset_states(batch_size)
|
||||
@@ -60,28 +67,35 @@ class OnnxWrapper():
|
||||
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
||||
self.reset_states(batch_size)
|
||||
|
||||
if not len(self._context):
|
||||
self._context = torch.zeros(batch_size, context_size)
|
||||
|
||||
x = torch.cat([self._context, x], dim=1)
|
||||
if sr in [8000, 16000]:
|
||||
ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr)}
|
||||
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
|
||||
ort_outs = self.session.run(None, ort_inputs)
|
||||
out, self._h, self._c = ort_outs
|
||||
out, state = ort_outs
|
||||
self._state = torch.from_numpy(state)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self._context = x[..., -context_size:]
|
||||
self._last_sr = sr
|
||||
self._last_batch_size = batch_size
|
||||
|
||||
out = torch.tensor(out)
|
||||
out = torch.from_numpy(out)
|
||||
return out
|
||||
|
||||
def audio_forward(self, x, sr: int, num_samples: int = 512):
|
||||
def audio_forward(self, x, sr: int):
|
||||
outs = []
|
||||
x, sr = self._validate_input(x, sr)
|
||||
self.reset_states()
|
||||
num_samples = 512 if sr == 16000 else 256
|
||||
|
||||
if x.shape[1] % num_samples:
|
||||
pad_num = num_samples - (x.shape[1] % num_samples)
|
||||
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
|
||||
|
||||
self.reset_states(x.shape[0])
|
||||
for i in range(0, x.shape[1], num_samples):
|
||||
wavs_batch = x[:, i:i+num_samples]
|
||||
out_chunk = self.__call__(wavs_batch, sr)
|
||||
@@ -118,17 +132,29 @@ class Validator():
|
||||
|
||||
def read_audio(path: str,
|
||||
sampling_rate: int = 16000):
|
||||
list_backends = torchaudio.list_audio_backends()
|
||||
|
||||
wav, sr = torchaudio.load(path)
|
||||
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
|
||||
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
|
||||
|
||||
if wav.size(0) > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
try:
|
||||
effects = [
|
||||
['channels', '1'],
|
||||
['rate', str(sampling_rate)]
|
||||
]
|
||||
|
||||
if sr != sampling_rate:
|
||||
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
||||
new_freq=sampling_rate)
|
||||
wav = transform(wav)
|
||||
sr = sampling_rate
|
||||
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
|
||||
except:
|
||||
wav, sr = torchaudio.load(path)
|
||||
|
||||
if wav.size(0) > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != sampling_rate:
|
||||
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
||||
new_freq=sampling_rate)
|
||||
wav = transform(wav)
|
||||
sr = sampling_rate
|
||||
|
||||
assert sr == sampling_rate
|
||||
return wav.squeeze(0)
|
||||
@@ -137,12 +163,11 @@ def read_audio(path: str,
|
||||
def save_audio(path: str,
|
||||
tensor: torch.Tensor,
|
||||
sampling_rate: int = 16000):
|
||||
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate)
|
||||
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16)
|
||||
|
||||
|
||||
def init_jit_model(model_path: str,
|
||||
device=torch.device('cpu')):
|
||||
torch.set_grad_enabled(False)
|
||||
model = torch.jit.load(model_path, map_location=device)
|
||||
model.eval()
|
||||
return model
|
||||
@@ -158,16 +183,20 @@ def make_visualization(probs, step):
|
||||
colormap='tab20')
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_speech_timestamps(audio: torch.Tensor,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sampling_rate: int = 16000,
|
||||
min_speech_duration_ms: int = 250,
|
||||
max_speech_duration_s: float = float('inf'),
|
||||
min_silence_duration_ms: int = 100,
|
||||
window_size_samples: int = 512,
|
||||
speech_pad_ms: int = 30,
|
||||
return_seconds: bool = False,
|
||||
visualize_probs: bool = False):
|
||||
visualize_probs: bool = False,
|
||||
progress_tracking_callback: Callable[[float], None] = None,
|
||||
neg_threshold: float = None,
|
||||
window_size_samples: int = 512,):
|
||||
|
||||
"""
|
||||
This method is used for splitting long audios into speech chunks using silero VAD
|
||||
@@ -177,26 +206,26 @@ def get_speech_timestamps(audio: torch.Tensor,
|
||||
audio: torch.Tensor, one dimensional
|
||||
One dimensional float torch.Tensor, other types are casted to torch if possible
|
||||
|
||||
model: preloaded .jit silero VAD model
|
||||
model: preloaded .jit/.onnx silero VAD model
|
||||
|
||||
threshold: float (default - 0.5)
|
||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||
|
||||
sampling_rate: int (default - 16000)
|
||||
Currently silero VAD models support 8000 and 16000 sample rates
|
||||
Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates
|
||||
|
||||
min_speech_duration_ms: int (default - 250 milliseconds)
|
||||
Final speech chunks shorter min_speech_duration_ms are thrown out
|
||||
|
||||
max_speech_duration_s: int (default - inf)
|
||||
Maximum duration of speech chunks in seconds
|
||||
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting.
|
||||
Otherwise, they will be split aggressively just before max_speech_duration_s.
|
||||
|
||||
min_silence_duration_ms: int (default - 100 milliseconds)
|
||||
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
||||
|
||||
window_size_samples: int (default - 1536 samples)
|
||||
Audio chunks of window_size_samples size are fed to the silero VAD model.
|
||||
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples for 8000 sample rate.
|
||||
Values other than these may affect model perfomance!!
|
||||
|
||||
speech_pad_ms: int (default - 30 milliseconds)
|
||||
Final speech chunks are padded by speech_pad_ms each side
|
||||
|
||||
@@ -206,6 +235,15 @@ def get_speech_timestamps(audio: torch.Tensor,
|
||||
visualize_probs: bool (default - False)
|
||||
whether draw prob hist or not
|
||||
|
||||
progress_tracking_callback: Callable[[float], None] (default - None)
|
||||
callback function taking progress in percents as an argument
|
||||
|
||||
neg_threshold: float (default = threshold - 0.15)
|
||||
Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH.
|
||||
|
||||
window_size_samples: int (default - 512 samples)
|
||||
!!! DEPRECATED, DOES NOTHING !!!
|
||||
|
||||
Returns
|
||||
----------
|
||||
speeches: list of dicts
|
||||
@@ -232,15 +270,17 @@ def get_speech_timestamps(audio: torch.Tensor,
|
||||
else:
|
||||
step = 1
|
||||
|
||||
if sampling_rate == 8000 and window_size_samples > 768:
|
||||
warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!')
|
||||
if window_size_samples not in [256, 512, 768, 1024, 1536]:
|
||||
warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate')
|
||||
if sampling_rate not in [8000, 16000]:
|
||||
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
|
||||
|
||||
window_size_samples = 512 if sampling_rate == 16000 else 256
|
||||
|
||||
model.reset_states()
|
||||
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
||||
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
|
||||
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
|
||||
|
||||
audio_length_samples = len(audio)
|
||||
|
||||
@@ -251,33 +291,65 @@ def get_speech_timestamps(audio: torch.Tensor,
|
||||
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||
speech_prob = model(chunk, sampling_rate).item()
|
||||
speech_probs.append(speech_prob)
|
||||
# caculate progress and seng it to callback function
|
||||
progress = current_start_sample + window_size_samples
|
||||
if progress > audio_length_samples:
|
||||
progress = audio_length_samples
|
||||
progress_percent = (progress / audio_length_samples) * 100
|
||||
if progress_tracking_callback:
|
||||
progress_tracking_callback(progress_percent)
|
||||
|
||||
triggered = False
|
||||
speeches = []
|
||||
current_speech = {}
|
||||
neg_threshold = threshold - 0.15
|
||||
temp_end = 0
|
||||
|
||||
if neg_threshold is None:
|
||||
neg_threshold = threshold - 0.15
|
||||
temp_end = 0 # to save potential segment end (and tolerate some silence)
|
||||
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
|
||||
|
||||
for i, speech_prob in enumerate(speech_probs):
|
||||
if (speech_prob >= threshold) and temp_end:
|
||||
temp_end = 0
|
||||
if next_start < prev_end:
|
||||
next_start = window_size_samples * i
|
||||
|
||||
if (speech_prob >= threshold) and not triggered:
|
||||
triggered = True
|
||||
current_speech['start'] = window_size_samples * i
|
||||
continue
|
||||
|
||||
if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples:
|
||||
if prev_end:
|
||||
current_speech['end'] = prev_end
|
||||
speeches.append(current_speech)
|
||||
current_speech = {}
|
||||
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
|
||||
triggered = False
|
||||
else:
|
||||
current_speech['start'] = next_start
|
||||
prev_end = next_start = temp_end = 0
|
||||
else:
|
||||
current_speech['end'] = window_size_samples * i
|
||||
speeches.append(current_speech)
|
||||
current_speech = {}
|
||||
prev_end = next_start = temp_end = 0
|
||||
triggered = False
|
||||
continue
|
||||
|
||||
if (speech_prob < neg_threshold) and triggered:
|
||||
if not temp_end:
|
||||
temp_end = window_size_samples * i
|
||||
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
|
||||
prev_end = temp_end
|
||||
if (window_size_samples * i) - temp_end < min_silence_samples:
|
||||
continue
|
||||
else:
|
||||
current_speech['end'] = temp_end
|
||||
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
|
||||
speeches.append(current_speech)
|
||||
temp_end = 0
|
||||
current_speech = {}
|
||||
prev_end = next_start = temp_end = 0
|
||||
triggered = False
|
||||
continue
|
||||
|
||||
@@ -314,72 +386,6 @@ def get_speech_timestamps(audio: torch.Tensor,
|
||||
return speeches
|
||||
|
||||
|
||||
def get_number_ts(wav: torch.Tensor,
|
||||
model,
|
||||
model_stride=8,
|
||||
hop_length=160,
|
||||
sample_rate=16000):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
perframe_logits = model(wav)[0]
|
||||
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
|
||||
extended_preds = []
|
||||
for i in perframe_preds:
|
||||
extended_preds.extend([i.item()] * model_stride)
|
||||
# len(extended_preds) is *num_frames_real*; for each frame of audio we know if it has a number in it.
|
||||
triggered = False
|
||||
timings = []
|
||||
cur_timing = {}
|
||||
for i, pred in enumerate(extended_preds):
|
||||
if pred == 1:
|
||||
if not triggered:
|
||||
cur_timing['start'] = int((i * hop_length) / (sample_rate / 1000))
|
||||
triggered = True
|
||||
elif pred == 0:
|
||||
if triggered:
|
||||
cur_timing['end'] = int((i * hop_length) / (sample_rate / 1000))
|
||||
timings.append(cur_timing)
|
||||
cur_timing = {}
|
||||
triggered = False
|
||||
if cur_timing:
|
||||
cur_timing['end'] = int(len(wav) / (sample_rate / 1000))
|
||||
timings.append(cur_timing)
|
||||
return timings
|
||||
|
||||
|
||||
def get_language(wav: torch.Tensor,
|
||||
model):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
lang_logits = model(wav)[2]
|
||||
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
|
||||
assert lang_pred < len(languages)
|
||||
return languages[lang_pred]
|
||||
|
||||
|
||||
def get_language_and_group(wav: torch.Tensor,
|
||||
model,
|
||||
lang_dict: dict,
|
||||
lang_group_dict: dict,
|
||||
top_n=1):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
lang_logits, lang_group_logits = model(wav)
|
||||
|
||||
softm = torch.softmax(lang_logits, dim=1).squeeze()
|
||||
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
|
||||
|
||||
srtd = torch.argsort(softm, descending=True)
|
||||
srtd_group = torch.argsort(softm_group, descending=True)
|
||||
|
||||
outs = []
|
||||
outs_group = []
|
||||
for i in range(top_n):
|
||||
prob = round(softm[srtd[i]].item(), 2)
|
||||
prob_group = round(softm_group[srtd_group[i]].item(), 2)
|
||||
outs.append((lang_dict[str(srtd[i].item())], prob))
|
||||
outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group))
|
||||
|
||||
return outs, outs_group
|
||||
|
||||
|
||||
class VADIterator:
|
||||
def __init__(self,
|
||||
model,
|
||||
@@ -394,7 +400,7 @@ class VADIterator:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: preloaded .jit silero VAD model
|
||||
model: preloaded .jit/.onnx silero VAD model
|
||||
|
||||
threshold: float (default - 0.5)
|
||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||
@@ -428,6 +434,7 @@ class VADIterator:
|
||||
self.temp_end = 0
|
||||
self.current_sample = 0
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, x, return_seconds=False):
|
||||
"""
|
||||
x: torch.Tensor
|
||||
@@ -453,7 +460,7 @@ class VADIterator:
|
||||
|
||||
if (speech_prob >= self.threshold) and not self.triggered:
|
||||
self.triggered = True
|
||||
speech_start = self.current_sample - self.speech_pad_samples
|
||||
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
||||
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
|
||||
|
||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||
@@ -462,7 +469,7 @@ class VADIterator:
|
||||
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||
return None
|
||||
else:
|
||||
speech_end = self.temp_end + self.speech_pad_samples
|
||||
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
||||
self.temp_end = 0
|
||||
self.triggered = False
|
||||
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
|
||||
74
tuning/README.md
Normal file
74
tuning/README.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# Тюнинг Silero-VAD модели
|
||||
|
||||
> Код тюнинга создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный
|
||||
интеллект» национальной программы «Цифровая экономика Российской Федерации».
|
||||
|
||||
Тюнинг используется для улучшения качества детекции речи Silero-VAD модели на кастомных данных.
|
||||
|
||||
## Зависимости
|
||||
Следующие зависимости используются при тюнинге VAD модели:
|
||||
- `torchaudio>=0.12.0`
|
||||
- `omegaconf>=2.3.0`
|
||||
- `sklearn>=1.2.0`
|
||||
- `torch>=1.12.0`
|
||||
- `pandas>=2.2.2`
|
||||
- `tqdm`
|
||||
|
||||
## Подготовка данных
|
||||
|
||||
Датафреймы для тюнинга должны быть подготовлены и сохранены в формате `.feather`. Следующие колонки в `.feather` файлах тренировки и валидации являются обязательными:
|
||||
- **audio_path** - абсолютный путь до аудиофайла в дисковой системе. Аудиофайлы должны представлять собой `PCM` данные, предпочтительно в форматах `.wav` или `.opus` (иные популярные форматы аудио тоже поддерживаются). Для ускорения темпа дообучения рекомендуется предварительно выполнить ресемплинг аудиофайлов (изменить частоту дискретизации) до 16000 Гц;
|
||||
- **speech_ts** - разметка для соответствующего аудиофайла. Список, состоящий из словарей формата `{'start': START_SEC, 'end': 'END_SEC'}`, где `START_SEC` и `END_SEC` - время начало и конца речевого отрезка в секундах соответственно. Для качественного дообучения рекомендуется использовать разметку с точностью до 30 миллисекунд.
|
||||
|
||||
Чем больше данных используется на этапе дообучения, тем эффективнее показывает себя адаптированная модель на целевом домене. Длина аудио не ограничена, т.к. каждое аудио будет обрезано до `max_train_length_sec` секунд перед подачей в нейросеть. Длинные аудио лучше предварительно порезать на кусочки длины `max_train_length_sec`.
|
||||
|
||||
Пример `.feather` датафрейма можно посмотреть в файле `example_dataframe.feather`
|
||||
|
||||
## Файл конфигурации `config.yml`
|
||||
|
||||
Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения:
|
||||
- `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
|
||||
- `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
|
||||
- `jit_model_path` - абсолютный путь до Silero-VAD модели в формате `.jit`. Если оставить это поле пустым, то модель будет загружена из репозитория в зависимости от значения поля `use_torchhub`
|
||||
- `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`);
|
||||
- `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц;
|
||||
- `model_save_path` - путь сохранения добученной модели;
|
||||
- `noise_loss` - коэффициент лосса, применяемый для неречевых окон аудио;
|
||||
- `max_train_length_sec` - максимальная длина аудио в секундах на этапе дообучения. Более длительные аудио будут обрезаны до этого показателя;
|
||||
- `aug_prob` - вероятность применения аугментаций к аудиофайлу на этапе дообучения;
|
||||
- `learning_rate` - темп дообучения;
|
||||
- `batch_size` - размер батча при дообучении и валидации;
|
||||
- `num_workers` - количество потоков, используемых для загрузки данных;
|
||||
- `num_epochs` - количество эпох дообучения. За одну эпоху прогоняются все тренировочные данные;
|
||||
- `device` - `cpu` или `cuda`.
|
||||
|
||||
## Дообучение
|
||||
|
||||
Дообучение запускается командой
|
||||
|
||||
`python tune.py`
|
||||
|
||||
Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit.
|
||||
|
||||
## Поиск пороговых значений
|
||||
|
||||
Порог на вход и порог на выход можно подобрать, используя команду
|
||||
|
||||
`python search_thresholds`
|
||||
|
||||
Данный скрипт использует файл конфигурации, описанный выше. Указанная в конфигурации модель будет использована для поиска оптимальных порогов на валидационном датасете.
|
||||
|
||||
## Цитирование
|
||||
|
||||
```
|
||||
@misc{Silero VAD,
|
||||
author = {Silero Team},
|
||||
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/snakers4/silero-vad}},
|
||||
commit = {insert_some_commit_here},
|
||||
email = {hello@silero.ai}
|
||||
}
|
||||
```
|
||||
0
tuning/__init__.py
Normal file
0
tuning/__init__.py
Normal file
17
tuning/config.yml
Normal file
17
tuning/config.yml
Normal file
@@ -0,0 +1,17 @@
|
||||
jit_model_path: '' # путь до Silero-VAD модели в формате jit, эта модель будет использована для дообучения. Если оставить поле пустым, то модель будет загружена автоматически
|
||||
use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False
|
||||
|
||||
tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True
|
||||
train_dataset_path: 'train_dataset_path.feather' # путь до датасета в формате feather для дообучения, подробности в README
|
||||
val_dataset_path: 'val_dataset_path.feather' # путь до датасета в формате feather для валидации, подробности в README
|
||||
model_save_path: 'model_save_path.jit' # путь сохранения дообученной модели
|
||||
|
||||
noise_loss: 0.5 # коэффициент, применяемый к лоссу на неречевых окнах
|
||||
max_train_length_sec: 8 # во время тюнинга аудио длиннее будут обрезаны до данного значения
|
||||
aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения
|
||||
|
||||
learning_rate: 5e-4 # темп дообучения модели
|
||||
batch_size: 128 # размер батча при дообучении и валидации
|
||||
num_workers: 4 # количество потоков, используемых для даталоадеров
|
||||
num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных
|
||||
device: 'cuda' # cpu или cuda, на чем будет производится дообучение
|
||||
BIN
tuning/example_dataframe.feather
Normal file
BIN
tuning/example_dataframe.feather
Normal file
Binary file not shown.
36
tuning/search_thresholds.py
Normal file
36
tuning/search_thresholds.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = OmegaConf.load('config.yml')
|
||||
|
||||
loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'),
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=SileroVadPadder,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
if config.jit_model_path:
|
||||
print(f'Loading model from the local folder: {config.jit_model_path}')
|
||||
model = init_jit_model(config.jit_model_path, device=config.device)
|
||||
else:
|
||||
if config.use_torchhub:
|
||||
print('Loading model using torch.hub')
|
||||
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
onnx=False,
|
||||
force_reload=True)
|
||||
else:
|
||||
print('Loading model using silero-vad library')
|
||||
from silero_vad import load_silero_vad
|
||||
model = load_silero_vad(onnx=False)
|
||||
|
||||
print('Model loaded')
|
||||
model.to(config.device)
|
||||
|
||||
print('Making predicts...')
|
||||
all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000)
|
||||
print('Calculating thresholds...')
|
||||
best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts)
|
||||
print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')
|
||||
65
tuning/tune.py
Normal file
65
tuning/tune.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
|
||||
from omegaconf import OmegaConf
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = OmegaConf.load('config.yml')
|
||||
|
||||
train_dataset = SileroVadDataset(config, mode='train')
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=SileroVadPadder,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
val_dataset = SileroVadDataset(config, mode='val')
|
||||
val_loader = torch.utils.data.DataLoader(val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=SileroVadPadder,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
if config.jit_model_path:
|
||||
print(f'Loading model from the local folder: {config.jit_model_path}')
|
||||
model = init_jit_model(config.jit_model_path, device=config.device)
|
||||
else:
|
||||
if config.use_torchhub:
|
||||
print('Loading model using torch.hub')
|
||||
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
onnx=False,
|
||||
force_reload=True)
|
||||
else:
|
||||
print('Loading model using silero-vad library')
|
||||
from silero_vad import load_silero_vad
|
||||
model = load_silero_vad(onnx=False)
|
||||
|
||||
print('Model loaded')
|
||||
model.to(config.device)
|
||||
decoder = VADDecoderRNNJIT().to(config.device)
|
||||
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())
|
||||
decoder.train()
|
||||
params = decoder.parameters()
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params),
|
||||
lr=config.learning_rate)
|
||||
criterion = nn.BCELoss(reduction='none')
|
||||
|
||||
best_val_roc = 0
|
||||
for i in range(config.num_epochs):
|
||||
print(f'Starting epoch {i + 1}')
|
||||
train_loss = train(config, train_loader, model, decoder, criterion, optimizer, config.device)
|
||||
val_loss, val_roc = validate(config, val_loader, model, decoder, criterion, config.device)
|
||||
print(f'Metrics after epoch {i + 1}:\n'
|
||||
f'\tTrain loss: {round(train_loss, 3)}\n',
|
||||
f'\tValidation loss: {round(val_loss, 3)}\n'
|
||||
f'\tValidation ROC-AUC: {round(val_roc, 3)}')
|
||||
|
||||
if val_roc > best_val_roc:
|
||||
print('New best ROC-AUC, saving model')
|
||||
best_val_roc = val_roc
|
||||
if config.tune_8k:
|
||||
model._model_8k.decoder.load_state_dict(decoder.state_dict())
|
||||
else:
|
||||
model._model.decoder.load_state_dict(decoder.state_dict())
|
||||
torch.jit.save(model, config.model_save_path)
|
||||
print('Done')
|
||||
357
tuning/utils.py
Normal file
357
tuning/utils.py
Normal file
@@ -0,0 +1,357 @@
|
||||
from sklearn.metrics import roc_auc_score, accuracy_score
|
||||
from torch.utils.data import Dataset
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
import warnings
|
||||
import random
|
||||
import torch
|
||||
import gc
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def read_audio(path: str,
|
||||
sampling_rate: int = 16000,
|
||||
normalize=False):
|
||||
|
||||
wav, sr = torchaudio.load(path)
|
||||
|
||||
if wav.size(0) > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
if sampling_rate:
|
||||
if sr != sampling_rate:
|
||||
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
||||
new_freq=sampling_rate)
|
||||
wav = transform(wav)
|
||||
sr = sampling_rate
|
||||
|
||||
if normalize and wav.abs().max() != 0:
|
||||
wav = wav / wav.abs().max()
|
||||
|
||||
return wav.squeeze(0)
|
||||
|
||||
|
||||
def build_audiomentations_augs(p):
|
||||
from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \
|
||||
LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \
|
||||
Aliasing, AddGaussianNoise
|
||||
transforms = [Aliasing(p=1),
|
||||
AddGaussianNoise(p=1),
|
||||
AirAbsorption(p=1),
|
||||
BandPassFilter(p=1),
|
||||
BandStopFilter(p=1),
|
||||
ClippingDistortion(p=1),
|
||||
HighPassFilter(p=1),
|
||||
HighShelfFilter(p=1),
|
||||
LowPassFilter(p=1),
|
||||
LowShelfFilter(p=1),
|
||||
Mp3Compression(p=1),
|
||||
PeakingFilter(p=1),
|
||||
PitchShift(p=1),
|
||||
RoomSimulator(p=1, leave_length_unchanged=True),
|
||||
SevenBandParametricEQ(p=1)]
|
||||
tr = SomeOf((1, 3), transforms=transforms, p=p)
|
||||
return tr
|
||||
|
||||
|
||||
class SileroVadDataset(Dataset):
|
||||
def __init__(self,
|
||||
config,
|
||||
mode='train'):
|
||||
|
||||
self.num_samples = 512 # constant, do not change
|
||||
self.sr = 16000 # constant, do not change
|
||||
|
||||
self.resample_to_8k = config.tune_8k
|
||||
self.noise_loss = config.noise_loss
|
||||
self.max_train_length_sec = config.max_train_length_sec
|
||||
self.max_train_length_samples = config.max_train_length_sec * self.sr
|
||||
|
||||
assert self.max_train_length_samples % self.num_samples == 0
|
||||
assert mode in ['train', 'val']
|
||||
|
||||
dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path
|
||||
self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True)
|
||||
self.index_dict = self.dataframe.to_dict('index')
|
||||
self.mode = mode
|
||||
print(f'DATASET SIZE : {len(self.dataframe)}')
|
||||
|
||||
if mode == 'train':
|
||||
self.augs = build_audiomentations_augs(p=config.aug_prob)
|
||||
else:
|
||||
self.augs = None
|
||||
|
||||
def __getitem__(self, idx):
|
||||
idx = None if self.mode == 'train' else idx
|
||||
wav, gt, mask = self.load_speech_sample(idx)
|
||||
|
||||
if self.mode == 'train':
|
||||
wav = self.add_augs(wav)
|
||||
if len(wav) > self.max_train_length_samples:
|
||||
wav = wav[:self.max_train_length_samples]
|
||||
gt = gt[:int(self.max_train_length_samples / self.num_samples)]
|
||||
mask = mask[:int(self.max_train_length_samples / self.num_samples)]
|
||||
|
||||
wav = torch.FloatTensor(wav)
|
||||
if self.resample_to_8k:
|
||||
transform = torchaudio.transforms.Resample(orig_freq=self.sr,
|
||||
new_freq=8000)
|
||||
wav = transform(wav)
|
||||
return wav, torch.FloatTensor(gt), torch.from_numpy(mask)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index_dict)
|
||||
|
||||
def load_speech_sample(self, idx=None):
|
||||
if idx is None:
|
||||
idx = random.randint(0, len(self.index_dict) - 1)
|
||||
wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy()
|
||||
|
||||
if len(wav) % self.num_samples != 0:
|
||||
pad_num = self.num_samples - (len(wav) % (self.num_samples))
|
||||
wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0)
|
||||
|
||||
gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav))
|
||||
|
||||
assert len(gt) == len(wav) / self.num_samples
|
||||
|
||||
mask[gt == 0]
|
||||
|
||||
return wav, gt, mask
|
||||
|
||||
def get_ground_truth_annotated(self, annotation, audio_length_samples):
|
||||
gt = np.zeros(audio_length_samples)
|
||||
|
||||
for i in annotation:
|
||||
gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1
|
||||
|
||||
squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1)
|
||||
squeezed_predicts = (squeezed_predicts > 0.5).astype(int)
|
||||
mask = np.ones(len(squeezed_predicts))
|
||||
mask[squeezed_predicts == 0] = self.noise_loss
|
||||
return squeezed_predicts, mask
|
||||
|
||||
def add_augs(self, wav):
|
||||
while True:
|
||||
try:
|
||||
wav_aug = self.augs(wav, self.sr)
|
||||
if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()):
|
||||
return wav
|
||||
return wav_aug
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
|
||||
def SileroVadPadder(batch):
|
||||
wavs = [batch[i][0] for i in range(len(batch))]
|
||||
labels = [batch[i][1] for i in range(len(batch))]
|
||||
masks = [batch[i][2] for i in range(len(batch))]
|
||||
|
||||
wavs = torch.nn.utils.rnn.pad_sequence(
|
||||
wavs, batch_first=True, padding_value=0)
|
||||
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
labels, batch_first=True, padding_value=0)
|
||||
|
||||
masks = torch.nn.utils.rnn.pad_sequence(
|
||||
masks, batch_first=True, padding_value=0)
|
||||
|
||||
return wavs, labels, masks
|
||||
|
||||
|
||||
class VADDecoderRNNJIT(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(VADDecoderRNNJIT, self).__init__()
|
||||
|
||||
self.rnn = nn.LSTMCell(128, 128)
|
||||
self.decoder = nn.Sequential(nn.Dropout(0.1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(128, 1, kernel_size=1),
|
||||
nn.Sigmoid())
|
||||
|
||||
def forward(self, x, state=torch.zeros(0)):
|
||||
x = x.squeeze(-1)
|
||||
if len(state):
|
||||
h, c = self.rnn(x, (state[0], state[1]))
|
||||
else:
|
||||
h, c = self.rnn(x)
|
||||
|
||||
x = h.unsqueeze(-1).float()
|
||||
state = torch.stack([h, c])
|
||||
x = self.decoder(x)
|
||||
return x, state
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def train(config,
|
||||
loader,
|
||||
jit_model,
|
||||
decoder,
|
||||
criterion,
|
||||
optimizer,
|
||||
device):
|
||||
|
||||
losses = AverageMeter()
|
||||
decoder.train()
|
||||
|
||||
context_size = 32 if config.tune_8k else 64
|
||||
num_samples = 256 if config.tune_8k else 512
|
||||
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
|
||||
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
|
||||
|
||||
with torch.enable_grad():
|
||||
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||
targets = targets.to(device)
|
||||
x = x.to(device)
|
||||
masks = masks.to(device)
|
||||
x = torch.nn.functional.pad(x, (context_size, 0))
|
||||
|
||||
outs = []
|
||||
state = torch.zeros(0)
|
||||
for i in range(context_size, x.shape[1], num_samples):
|
||||
input_ = x[:, i-context_size:i+num_samples]
|
||||
out = stft_layer(input_)
|
||||
out = encoder_layer(out)
|
||||
out, state = decoder(out, state)
|
||||
outs.append(out)
|
||||
stacked = torch.cat(outs, dim=2).squeeze(1)
|
||||
|
||||
loss = criterion(stacked, targets)
|
||||
loss = (loss * masks).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
losses.update(loss.item(), masks.numel())
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return losses.avg
|
||||
|
||||
|
||||
def validate(config,
|
||||
loader,
|
||||
jit_model,
|
||||
decoder,
|
||||
criterion,
|
||||
device):
|
||||
|
||||
losses = AverageMeter()
|
||||
decoder.eval()
|
||||
|
||||
predicts = []
|
||||
gts = []
|
||||
|
||||
context_size = 32 if config.tune_8k else 64
|
||||
num_samples = 256 if config.tune_8k else 512
|
||||
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
|
||||
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
|
||||
|
||||
with torch.no_grad():
|
||||
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||
targets = targets.to(device)
|
||||
x = x.to(device)
|
||||
masks = masks.to(device)
|
||||
x = torch.nn.functional.pad(x, (context_size, 0))
|
||||
|
||||
outs = []
|
||||
state = torch.zeros(0)
|
||||
for i in range(context_size, x.shape[1], num_samples):
|
||||
input_ = x[:, i-context_size:i+num_samples]
|
||||
out = stft_layer(input_)
|
||||
out = encoder_layer(out)
|
||||
out, state = decoder(out, state)
|
||||
outs.append(out)
|
||||
stacked = torch.cat(outs, dim=2).squeeze(1)
|
||||
|
||||
predicts.extend(stacked[masks != 0].tolist())
|
||||
gts.extend(targets[masks != 0].tolist())
|
||||
|
||||
loss = criterion(stacked, targets)
|
||||
loss = (loss * masks).mean()
|
||||
losses.update(loss.item(), masks.numel())
|
||||
score = roc_auc_score(gts, predicts)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return losses.avg, round(score, 3)
|
||||
|
||||
|
||||
def init_jit_model(model_path: str,
|
||||
device=torch.device('cpu')):
|
||||
torch.set_grad_enabled(False)
|
||||
model = torch.jit.load(model_path, map_location=device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def predict(model, loader, device, sr):
|
||||
with torch.no_grad():
|
||||
all_predicts = []
|
||||
all_gts = []
|
||||
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||
x = x.to(device)
|
||||
out = model.audio_forward(x, sr=sr)
|
||||
|
||||
for i, out_chunk in enumerate(out):
|
||||
predict = out_chunk[masks[i] != 0].cpu().tolist()
|
||||
gt = targets[i, masks[i] != 0].cpu().tolist()
|
||||
|
||||
all_predicts.append(predict)
|
||||
all_gts.append(gt)
|
||||
return all_predicts, all_gts
|
||||
|
||||
|
||||
def calculate_best_thresholds(all_predicts, all_gts):
|
||||
best_acc = 0
|
||||
for ths_enter in tqdm(np.linspace(0, 1, 20)):
|
||||
for ths_exit in np.linspace(0, 1, 20):
|
||||
if ths_exit >= ths_enter:
|
||||
continue
|
||||
|
||||
accs = []
|
||||
for j, predict in enumerate(all_predicts):
|
||||
predict_bool = []
|
||||
is_speech = False
|
||||
for i in predict:
|
||||
if i >= ths_enter:
|
||||
is_speech = True
|
||||
predict_bool.append(1)
|
||||
elif i <= ths_exit:
|
||||
is_speech = False
|
||||
predict_bool.append(0)
|
||||
else:
|
||||
val = 1 if is_speech else 0
|
||||
predict_bool.append(val)
|
||||
|
||||
score = round(accuracy_score(all_gts[j], predict_bool), 4)
|
||||
accs.append(score)
|
||||
|
||||
mean_acc = round(np.mean(accs), 3)
|
||||
if mean_acc > best_acc:
|
||||
best_acc = mean_acc
|
||||
best_ths_enter = round(ths_enter, 2)
|
||||
best_ths_exit = round(ths_exit, 2)
|
||||
return best_ths_enter, best_ths_exit, best_acc
|
||||
Reference in New Issue
Block a user