mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
Compare commits
190 Commits
v2.0-legac
...
adamnsandl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8850d2b9b | ||
|
|
3888946c0c | ||
|
|
24f51645d0 | ||
|
|
fdbb0a3a81 | ||
|
|
60ae7abfb7 | ||
|
|
0b3d43d432 | ||
|
|
a395853982 | ||
|
|
78958b6fb6 | ||
|
|
902cfc9248 | ||
|
|
89e66a3474 | ||
|
|
a3bdebed16 | ||
|
|
4bdcf31d17 | ||
|
|
136cdcdf5b | ||
|
|
5cd2ba54db | ||
|
|
06b9e17c1e | ||
|
|
aace7e25b1 | ||
|
|
74c3f7f3fb | ||
|
|
d77d0fd42c | ||
|
|
49b421a9cd | ||
|
|
fd1f1a62b7 | ||
|
|
8145ed9a91 | ||
|
|
4392725328 | ||
|
|
8b0566682b | ||
|
|
82342b8a4c | ||
|
|
4b8ce743a8 | ||
|
|
b4b6f2ab3e | ||
|
|
5b02d84a4a | ||
|
|
60465a7e61 | ||
|
|
fcef5b3955 | ||
|
|
156436762f | ||
|
|
a27b176e45 | ||
|
|
8125483ef7 | ||
|
|
6969dcc2dc | ||
|
|
6c816a05f0 | ||
|
|
9dc344df7f | ||
|
|
41c5172dd9 | ||
|
|
894ea259f9 | ||
|
|
f56f56ffaa | ||
|
|
6c8d844710 | ||
|
|
d8cc947c73 | ||
|
|
797a88a386 | ||
|
|
48b7c742dd | ||
|
|
c5ec6bae3d | ||
|
|
af152c18f6 | ||
|
|
d391f4c302 | ||
|
|
bf18ea6b56 | ||
|
|
a65732a393 | ||
|
|
aae1e4f40d | ||
|
|
94504ece54 | ||
|
|
0b7da6e74b | ||
|
|
efb5effc8f | ||
|
|
03dc3fae5c | ||
|
|
4a6d1701a4 | ||
|
|
5e7ee10ee0 | ||
|
|
03fb810fab | ||
|
|
e30a7e32a9 | ||
|
|
bbbc657dad | ||
|
|
cb92cdd1e3 | ||
|
|
3780baf49f | ||
|
|
563106ef8c | ||
|
|
f795bc479b | ||
|
|
7e9680bc83 | ||
|
|
3b4c02dfe3 | ||
|
|
bc5a0a2dbf | ||
|
|
b03fcb2ebe | ||
|
|
026bc3d292 | ||
|
|
e755baa3c2 | ||
|
|
b88084c7ed | ||
|
|
a9d2b591de | ||
|
|
c3c67cdcb8 | ||
|
|
874c66ccbc | ||
|
|
51fbbcb32e | ||
|
|
14a0715955 | ||
|
|
a0d26769e0 | ||
|
|
e0c2015193 | ||
|
|
5872cffd78 | ||
|
|
86400b9a12 | ||
|
|
6ef43d1c5d | ||
|
|
540e092276 | ||
|
|
55c41abf46 | ||
|
|
17903cb41d | ||
|
|
c39dccc1fd | ||
|
|
a6a067de44 | ||
|
|
9865b3cb93 | ||
|
|
3d10c2d950 | ||
|
|
4f57fae3fa | ||
|
|
085d76f08e | ||
|
|
262bcb4b40 | ||
|
|
e84eca68d7 | ||
|
|
e7c4539106 | ||
|
|
a480e85aec | ||
|
|
c69cb6c9c0 | ||
|
|
11da69d88b | ||
|
|
df1d52042d | ||
|
|
d5a944b9f1 | ||
|
|
d90416e63e | ||
|
|
91f0aaecef | ||
|
|
015bfc8b21 | ||
|
|
5d56b1ea40 | ||
|
|
ff3c596cab | ||
|
|
63e1be5a22 | ||
|
|
1d8f8f38db | ||
|
|
7198087152 | ||
|
|
ad57d17f5f | ||
|
|
04e87c208a | ||
|
|
c583fd1e52 | ||
|
|
5814e548db | ||
|
|
42565d5baa | ||
|
|
ab7af9745b | ||
|
|
83e68c56ea | ||
|
|
d3882c9ebf | ||
|
|
25f04dda35 | ||
|
|
94b4c21874 | ||
|
|
324bc74a58 | ||
|
|
82d199ff22 | ||
|
|
5ba388d894 | ||
|
|
790844ba0f | ||
|
|
51b5245410 | ||
|
|
888970e77d | ||
|
|
cb6d308335 | ||
|
|
1b212c6e95 | ||
|
|
452060ad65 | ||
|
|
c7eab751b5 | ||
|
|
d1714a9ff7 | ||
|
|
94c79d899d | ||
|
|
1baf307b35 | ||
|
|
e324285cdc | ||
|
|
13dce2d067 | ||
|
|
081e6b9886 | ||
|
|
572134fdf1 | ||
|
|
a799dea837 | ||
|
|
17209e6c4f | ||
|
|
6661cc9691 | ||
|
|
7c671a75c2 | ||
|
|
622016e672 | ||
|
|
8eba346bc9 | ||
|
|
900c71a109 | ||
|
|
bf0127e016 | ||
|
|
ea7af70fe9 | ||
|
|
8cdc8d36c9 | ||
|
|
6e9fd77500 | ||
|
|
6cc08b1077 | ||
|
|
0e8e080894 | ||
|
|
af6931d1de | ||
|
|
76687cbe25 | ||
|
|
b2329fa5f2 | ||
|
|
005886e7eb | ||
|
|
f6b1294cb2 | ||
|
|
2392ea33f4 | ||
|
|
45d72863b6 | ||
|
|
f40cc128a4 | ||
|
|
0d61e4cee1 | ||
|
|
011268e492 | ||
|
|
8ebaf139c6 | ||
|
|
0a90316625 | ||
|
|
35d8969322 | ||
|
|
7c3eb8bfb5 | ||
|
|
74f759c8f8 | ||
|
|
5816eb08c4 | ||
|
|
0feae6cbbe | ||
|
|
fc0a70f42e | ||
|
|
13fd927b84 | ||
|
|
124d6564a0 | ||
|
|
56fa93a1c9 | ||
|
|
1a93276208 | ||
|
|
9fbd0c4c2d | ||
|
|
7b05a183a3 | ||
|
|
f67e68efc3 | ||
|
|
51b1365bb0 | ||
|
|
79fdb55f1c | ||
|
|
b17da75dac | ||
|
|
184e384697 | ||
|
|
adf5d6d020 | ||
|
|
41ee0f6b9f | ||
|
|
236d250a11 | ||
|
|
8794d6f835 | ||
|
|
8f16c14066 | ||
|
|
f638c47595 | ||
|
|
1fad5f4ffb | ||
|
|
7160ce99d3 | ||
|
|
8af246df49 | ||
|
|
b1142bcba4 | ||
|
|
a243bd5dc8 | ||
|
|
d4d2af5833 | ||
|
|
469ca8a2f6 | ||
|
|
8c1ae73ee7 | ||
|
|
aba7862d58 | ||
|
|
b648546a21 | ||
|
|
2e852d7d41 | ||
|
|
044278aa12 |
40
.github/workflows/python-publish.yml
vendored
Normal file
40
.github/workflows/python-publish.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# This workflow will upload a Python Package using Twine when a release is created
|
||||||
|
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
||||||
|
|
||||||
|
# This workflow uses actions that are not certified by GitHub.
|
||||||
|
# They are provided by a third-party and are governed by
|
||||||
|
# separate terms of service, privacy policy, and support
|
||||||
|
# documentation.
|
||||||
|
|
||||||
|
name: Upload Python Package
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- '*'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: '3.x'
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install build
|
||||||
|
- name: Build package
|
||||||
|
run: python -m build
|
||||||
|
- name: Publish package
|
||||||
|
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||||
|
with:
|
||||||
|
user: __token__
|
||||||
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
657
README.md
657
README.md
@@ -1,617 +1,106 @@
|
|||||||
[](mailto:hello@silero.ai) [](https://t.me/silero_speech) [](https://github.com/snakers4/silero-vad/blob/master/LICENSE)
|
[](mailto:hello@silero.ai) [](https://t.me/silero_speech) [](https://github.com/snakers4/silero-vad/blob/master/LICENSE)
|
||||||
|
|
||||||
[](https://pytorch.org/hub/snakers4_silero-vad_vad/)
|
|
||||||
|
|
||||||
[](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb)
|
[](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb)
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
- [Silero VAD](#silero-vad)
|
<br/>
|
||||||
- [TLDR](#tldr)
|
<h1 align="center">Silero VAD</h1>
|
||||||
- [Live Demonstration](#live-demonstration)
|
<br/>
|
||||||
- [Getting Started](#getting-started)
|
|
||||||
- [Pre-trained Models](#pre-trained-models)
|
**Silero VAD** - pre-trained enterprise-grade [Voice Activity Detector](https://en.wikipedia.org/wiki/Voice_activity_detection) (also see our [STT models](https://github.com/snakers4/silero-models)).
|
||||||
- [Version History](#version-history)
|
|
||||||
- [PyTorch](#pytorch)
|
<br/>
|
||||||
- [VAD](#vad)
|
|
||||||
- [Number Detector](#number-detector)
|
<p align="center">
|
||||||
- [Language Classifier](#language-classifier)
|
<img src="https://github.com/snakers4/silero-vad/assets/36505480/300bd062-4da5-4f19-9736-9c144a45d7a7" />
|
||||||
- [ONNX](#onnx)
|
</p>
|
||||||
- [VAD](#vad-1)
|
|
||||||
- [Number Detector](#number-detector-1)
|
|
||||||
- [Language Classifier](#language-classifier-1)
|
|
||||||
- [Metrics](#metrics)
|
|
||||||
- [Performance Metrics](#performance-metrics)
|
|
||||||
- [Streaming Latency](#streaming-latency)
|
|
||||||
- [Full Audio Throughput](#full-audio-throughput)
|
|
||||||
- [VAD Quality Metrics](#vad-quality-metrics)
|
|
||||||
- [FAQ](#faq)
|
|
||||||
- [VAD Parameter Fine Tuning](#vad-parameter-fine-tuning)
|
|
||||||
- [Classic way](#classic-way)
|
|
||||||
- [Adaptive way](#adaptive-way)
|
|
||||||
- [How VAD Works](#how-vad-works)
|
|
||||||
- [VAD Quality Metrics Methodology](#vad-quality-metrics-methodology)
|
|
||||||
- [How Number Detector Works](#how-number-detector-works)
|
|
||||||
- [How Language Classifier Works](#how-language-classifier-works)
|
|
||||||
- [Contact](#contact)
|
|
||||||
- [Get in Touch](#get-in-touch)
|
|
||||||
- [Commercial Inquiries](#commercial-inquiries)
|
|
||||||
- [Further reading](#further-reading)
|
|
||||||
- [Citations](#citations)
|
|
||||||
|
|
||||||
|
|
||||||
# Silero VAD
|
<details>
|
||||||

|
<summary>Real Time Example</summary>
|
||||||
|
|
||||||
## TLDR
|
https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4
|
||||||
|
|
||||||
**Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier.**
|
</details>
|
||||||
Enterprise-grade Speech Products made refreshingly simple (also see our [STT](https://github.com/snakers4/silero-models) models).
|
|
||||||
|
|
||||||
Currently, there are hardly any high quality / modern / free / public voice activity detectors except for WebRTC Voice Activity Detector ([link](https://github.com/wiseman/py-webrtcvad)). WebRTC though starts to show its age and it suffers from many false positives.
|
<br/>
|
||||||
|
<h2 align="center">Key Features</h2>
|
||||||
|
<br/>
|
||||||
|
|
||||||
Also in some cases it is crucial to be able to anonymize large-scale spoken corpora (i.e. remove personal data). Typically personal data is considered to be private / sensitive if it contains (i) a name (ii) some private ID. Name recognition is a highly subjective matter and it depends on locale and business case, but Voice Activity and Number Detection are quite general tasks.
|
- **Stellar accuracy**
|
||||||
|
|
||||||
**Key features:**
|
Silero VAD has [excellent results](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#vs-other-available-solutions) on speech detection tasks.
|
||||||
|
|
||||||
|
- **Fast**
|
||||||
|
|
||||||
- Modern, portable;
|
One audio chunk (30+ ms) [takes](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics#silero-vad-performance-metrics) less than **1ms** to be processed on a single CPU thread. Using batching or GPU can also improve performance considerably. Under certain conditions ONNX may even run up to 4-5x faster.
|
||||||
- Low memory footprint;
|
|
||||||
- Superior metrics to WebRTC;
|
|
||||||
- Trained on huge spoken corpora and noise / sound libraries;
|
|
||||||
- Slower than WebRTC, but fast enough for IOT / edge / mobile applications;
|
|
||||||
- Unlike WebRTC (which mostly tells silence from voice), our VAD can tell voice from noise / music / silence;
|
|
||||||
|
|
||||||
**Typical use cases:**
|
- **Lightweight**
|
||||||
|
|
||||||
- Spoken corpora anonymization;
|
JIT model is around two megabytes in size.
|
||||||
- Can be used together with WebRTC;
|
|
||||||
- Voice activity detection for IOT / edge / mobile use cases;
|
|
||||||
- Data cleaning and preparation, number and voice detection in general;
|
|
||||||
- PyTorch and ONNX can be used with a wide variety of deployment options and backends in mind;
|
|
||||||
|
|
||||||
### Live Demonstration
|
- **General**
|
||||||
|
|
||||||
For more information, please see [examples](https://github.com/snakers4/silero-vad/tree/master/examples).
|
Silero VAD was trained on huge corpora that include over **6000** languages and it performs well on audios from different domains with various background noise and quality levels.
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/28188499/116685087-182ff100-a9b2-11eb-927d-ed9f621226ee.mp4
|
- **Flexible sampling rate**
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/8079748/117580455-4622dd00-b0f8-11eb-858d-e6368ed4eada.mp4
|
Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate).
|
||||||
|
|
||||||
## Getting Started
|
- **Highly Portable**
|
||||||
|
|
||||||
The models are small enough to be included directly into this repository. Newer models will supersede older models directly.
|
Silero VAD reaps benefits from the rich ecosystems built around **PyTorch** and **ONNX** running everywhere where these runtimes are available.
|
||||||
|
|
||||||
### Pre-trained Models
|
- **No Strings Attached**
|
||||||
|
|
||||||
**Currently we provide the following endpoints:**
|
Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
|
||||||
|
|
||||||
| model= | Params | Model type | Streaming | Languages | PyTorch | ONNX | Colab |
|
<br/>
|
||||||
| -------------------------- | ------ | ------------------- | --------- | -------------------------- | ------------------ | ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
<h2 align="center">Fast start</h2>
|
||||||
| `'silero_vad'` | 1.1M | VAD | Yes | `ru`, `en`, `de`, `es` (*) | :heavy_check_mark: | :heavy_check_mark: | [](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
|
<br/>
|
||||||
| `'silero_vad_micro'` | 10K | VAD | Yes | `ru`, `en`, `de`, `es` (*) | :heavy_check_mark: | :heavy_check_mark: | [](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
|
|
||||||
| `'silero_vad_micro_8k'` | 10K | VAD | Yes | `ru`, `en`, `de`, `es` (*) | :heavy_check_mark: | :heavy_check_mark: | [](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
|
|
||||||
| `'silero_vad_mini'` | 100K | VAD | Yes | `ru`, `en`, `de`, `es` (*) | :heavy_check_mark: | :heavy_check_mark: | [](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
|
|
||||||
| `'silero_vad_mini_8k'` | 100K | VAD | Yes | `ru`, `en`, `de`, `es` (*) | :heavy_check_mark: | :heavy_check_mark: | [](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
|
|
||||||
| `'silero_number_detector'` | 1.1M | Number Detector | No | `ru`, `en`, `de`, `es` | :heavy_check_mark: | :heavy_check_mark: | [](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
|
|
||||||
| `'silero_lang_detector'` | 1.1M | Language Classifier | No | `ru`, `en`, `de`, `es` | :heavy_check_mark: | :heavy_check_mark: | [](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
|
|
||||||
| ~~`'silero_lang_detector_116'`~~ | ~~1.7M~~ | ~~Language Classifier~~ ||| | ||
|
|
||||||
| `'silero_lang_detector_95'` | 4.7M | Language Classifier | No | [95 languages](https://github.com/snakers4/silero-vad/blob/master/files/lang_dict_95.json) | :heavy_check_mark: | :heavy_check_mark: | [](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) |
|
|
||||||
|
|
||||||
(*) Though explicitly trained on these languages, VAD should work on any Germanic, Romance or Slavic Languages out of the box.
|
```python3
|
||||||
|
|
||||||
What models do:
|
|
||||||
|
|
||||||
- VAD - detects speech;
|
|
||||||
- Number Detector - detects spoken numbers (i.e. thirty five);
|
|
||||||
- Language Classifier - classifies utterances between language;
|
|
||||||
- Language Classifier 95 - classifies among 95 languages as well as 58 language groups (mutually intelligible languages -> same group)
|
|
||||||
|
|
||||||
### Version History
|
|
||||||
|
|
||||||
**Version history:**
|
|
||||||
|
|
||||||
| Version | Date | Comment |
|
|
||||||
| ------- | ---------- | --------------------------------------------------------------------------------------------------------------------------- |
|
|
||||||
| `v1` | 2020-12-15 | Initial release |
|
|
||||||
| `v1.1` | 2020-12-24 | better vad models compatible with chunks shorter than 250 ms |
|
|
||||||
| `v1.2` | 2020-12-30 | Number Detector added |
|
|
||||||
| `v2` | 2021-01-11 | Add Language Classifier heads (en, ru, de, es) |
|
|
||||||
| `v2.1` | 2021-02-11 | Add micro (10k params) VAD models |
|
|
||||||
| `v2.2` | 2021-03-22 | Add micro 8000 sample rate VAD models |
|
|
||||||
| `v2.3` | 2021-04-12 | Add mini (100k params) VAD models (8k and 16k sample rate) + **new** adaptive utils for full audio and single audio stream |
|
|
||||||
| `v2.4` | 2021-07-09 | Add 116 languages classifier and group classifier |
|
|
||||||
| `v2.4` | 2021-07-09 | Deleted 116 language classifier, added 95 language classifier instead (get rid of lowspoken languages for quality improvement)
|
|
||||||
|
|
|
||||||
|
|
||||||
### PyTorch
|
|
||||||
|
|
||||||
[](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb)
|
|
||||||
|
|
||||||
We are keeping the colab examples up-to-date, but you can manually manage your dependencies:
|
|
||||||
|
|
||||||
- `pytorch` >= 1.7.1 (there were breaking changes in `torch.hub` introduced in 1.7);
|
|
||||||
- `torchaudio` >= 0.7.2 (used only for IO and resampling, can be easily replaced);
|
|
||||||
- `soundfile` >= 0.10.3 (used as a default backend for torchaudio, can be replaced);
|
|
||||||
|
|
||||||
All of the dependencies except for PyTorch are superficial and for utils / example only. You can use any libraries / pipelines that read files and resample into 16 kHz.
|
|
||||||
|
|
||||||
#### VAD
|
|
||||||
|
|
||||||
[](https://pytorch.org/hub/snakers4_silero-vad_vad/)
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
import torch
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
||||||
model='silero_vad',
|
(get_speech_timestamps, _, read_audio, _, _) = utils
|
||||||
force_reload=True)
|
|
||||||
|
|
||||||
(get_speech_ts,
|
wav = read_audio('path_to_audio_file')
|
||||||
get_speech_ts_adaptive,
|
speech_timestamps = get_speech_timestamps(wav, model)
|
||||||
_, read_audio,
|
|
||||||
_, _, _) = utils
|
|
||||||
|
|
||||||
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
|
|
||||||
|
|
||||||
wav = read_audio(f'{files_dir}/en.wav')
|
|
||||||
# full audio
|
|
||||||
# get speech timestamps from full audio file
|
|
||||||
|
|
||||||
# classic way
|
|
||||||
speech_timestamps = get_speech_ts(wav, model,
|
|
||||||
num_steps=4)
|
|
||||||
pprint(speech_timestamps)
|
|
||||||
|
|
||||||
# adaptive way
|
|
||||||
speech_timestamps = get_speech_ts_adaptive(wav, model)
|
|
||||||
pprint(speech_timestamps)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Number Detector
|
<br/>
|
||||||
|
<h2 align="center">Typical Use Cases</h2>
|
||||||
[](https://pytorch.org/hub/snakers4_silero-vad_number/)
|
<br/>
|
||||||
|
|
||||||
```python
|
- Voice activity detection for IOT / edge / mobile use cases
|
||||||
import torch
|
- Data cleaning and preparation, voice detection in general
|
||||||
torch.set_num_threads(1)
|
- Telephony and call-center automation, voice bots
|
||||||
from pprint import pprint
|
- Voice interfaces
|
||||||
|
|
||||||
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
<br/>
|
||||||
model='silero_number_detector',
|
<h2 align="center">Links</h2>
|
||||||
force_reload=True)
|
<br/>
|
||||||
|
|
||||||
(get_number_ts,
|
|
||||||
_, read_audio,
|
- [Examples and Dependencies](https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies)
|
||||||
_, _) = utils
|
- [Quality Metrics](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics)
|
||||||
|
- [Performance Metrics](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics)
|
||||||
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
|
- [Versions and Available Models](https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)
|
||||||
|
- [Further reading](https://github.com/snakers4/silero-models#further-reading)
|
||||||
wav = read_audio(f'{files_dir}/en_num.wav')
|
- [FAQ](https://github.com/snakers4/silero-vad/wiki/FAQ)
|
||||||
# full audio
|
|
||||||
# get number timestamps from full audio file
|
<br/>
|
||||||
number_timestamps = get_number_ts(wav, model)
|
<h2 align="center">Get In Touch</h2>
|
||||||
|
<br/>
|
||||||
pprint(number_timestamps)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Language Classifier
|
|
||||||
##### 4 languages
|
|
||||||
|
|
||||||
[](https://pytorch.org/hub/snakers4_silero-vad_language/)
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
|
||||||
model='silero_lang_detector',
|
|
||||||
force_reload=True)
|
|
||||||
|
|
||||||
get_language, read_audio = utils
|
|
||||||
|
|
||||||
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
|
|
||||||
|
|
||||||
wav = read_audio(f'{files_dir}/de.wav')
|
|
||||||
language = get_language(wav, model)
|
|
||||||
|
|
||||||
pprint(language)
|
|
||||||
```
|
|
||||||
|
|
||||||
##### 95 languages
|
|
||||||
|
|
||||||
[](https://pytorch.org/hub/snakers4_silero-vad_language/)
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
model, lang_dict, lang_group_dict, utils = torch.hub.load(
|
|
||||||
repo_or_dir='snakers4/silero-vad',
|
|
||||||
model='silero_lang_detector_95',
|
|
||||||
force_reload=True)
|
|
||||||
|
|
||||||
get_language_and_group, read_audio = utils
|
|
||||||
|
|
||||||
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
|
|
||||||
|
|
||||||
wav = read_audio(f'{files_dir}/de.wav')
|
|
||||||
languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=2)
|
|
||||||
|
|
||||||
for i in languages:
|
|
||||||
pprint(f'Language: {i[0]} with prob {i[-1]}')
|
|
||||||
|
|
||||||
for i in language_groups:
|
|
||||||
pprint(f'Language group: {i[0]} with prob {i[-1]}')
|
|
||||||
```
|
|
||||||
|
|
||||||
### ONNX
|
|
||||||
|
|
||||||
[](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb)
|
|
||||||
|
|
||||||
You can run our models everywhere, where you can import the ONNX model or run ONNX runtime.
|
|
||||||
|
|
||||||
#### VAD
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
import onnxruntime
|
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
|
||||||
model='silero_vad',
|
|
||||||
force_reload=True)
|
|
||||||
|
|
||||||
(get_speech_ts,
|
|
||||||
get_speech_ts_adaptive,
|
|
||||||
_, read_audio,
|
|
||||||
_, _, _) = utils
|
|
||||||
|
|
||||||
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
|
|
||||||
|
|
||||||
def init_onnx_model(model_path: str):
|
|
||||||
return onnxruntime.InferenceSession(model_path)
|
|
||||||
|
|
||||||
def validate_onnx(model, inputs):
|
|
||||||
with torch.no_grad():
|
|
||||||
ort_inputs = {'input': inputs.cpu().numpy()}
|
|
||||||
outs = model.run(None, ort_inputs)
|
|
||||||
outs = [torch.Tensor(x) for x in outs]
|
|
||||||
return outs[0]
|
|
||||||
|
|
||||||
model = init_onnx_model(f'{files_dir}/model.onnx')
|
|
||||||
wav = read_audio(f'{files_dir}/en.wav')
|
|
||||||
|
|
||||||
# get speech timestamps from full audio file
|
|
||||||
|
|
||||||
# classic way
|
|
||||||
speech_timestamps = get_speech_ts(wav, model, num_steps=4, run_function=validate_onnx)
|
|
||||||
pprint(speech_timestamps)
|
|
||||||
|
|
||||||
# adaptive way
|
|
||||||
speech_timestamps = get_speech_ts(wav, model, run_function=validate_onnx)
|
|
||||||
pprint(speech_timestamps)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Number Detector
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
import onnxruntime
|
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
|
||||||
model='silero_number_detector',
|
|
||||||
force_reload=True)
|
|
||||||
|
|
||||||
(get_number_ts,
|
|
||||||
_, read_audio,
|
|
||||||
_, _) = utils
|
|
||||||
|
|
||||||
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
|
|
||||||
|
|
||||||
def init_onnx_model(model_path: str):
|
|
||||||
return onnxruntime.InferenceSession(model_path)
|
|
||||||
|
|
||||||
def validate_onnx(model, inputs):
|
|
||||||
with torch.no_grad():
|
|
||||||
ort_inputs = {'input': inputs.cpu().numpy()}
|
|
||||||
outs = model.run(None, ort_inputs)
|
|
||||||
outs = [torch.Tensor(x) for x in outs]
|
|
||||||
return outs
|
|
||||||
|
|
||||||
model = init_onnx_model(f'{files_dir}/number_detector.onnx')
|
|
||||||
wav = read_audio(f'{files_dir}/en_num.wav')
|
|
||||||
|
|
||||||
# get speech timestamps from full audio file
|
|
||||||
number_timestamps = get_number_ts(wav, model, run_function=validate_onnx)
|
|
||||||
pprint(number_timestamps)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Language Classifier
|
|
||||||
##### 4 languages
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
import onnxruntime
|
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
|
||||||
model='silero_lang_detector',
|
|
||||||
force_reload=True)
|
|
||||||
|
|
||||||
get_language, read_audio = utils
|
|
||||||
|
|
||||||
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
|
|
||||||
|
|
||||||
def init_onnx_model(model_path: str):
|
|
||||||
return onnxruntime.InferenceSession(model_path)
|
|
||||||
|
|
||||||
def validate_onnx(model, inputs):
|
|
||||||
with torch.no_grad():
|
|
||||||
ort_inputs = {'input': inputs.cpu().numpy()}
|
|
||||||
outs = model.run(None, ort_inputs)
|
|
||||||
outs = [torch.Tensor(x) for x in outs]
|
|
||||||
return outs
|
|
||||||
|
|
||||||
model = init_onnx_model(f'{files_dir}/number_detector.onnx')
|
|
||||||
wav = read_audio(f'{files_dir}/de.wav')
|
|
||||||
|
|
||||||
language = get_language(wav, model, run_function=validate_onnx)
|
|
||||||
print(language)
|
|
||||||
```
|
|
||||||
|
|
||||||
##### 95 languages
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
import onnxruntime
|
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
model, lang_dict, lang_group_dict, utils = torch.hub.load(
|
|
||||||
repo_or_dir='snakers4/silero-vad',
|
|
||||||
model='silero_lang_detector_95',
|
|
||||||
force_reload=True)
|
|
||||||
|
|
||||||
get_language_and_group, read_audio = utils
|
|
||||||
|
|
||||||
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
|
|
||||||
|
|
||||||
def init_onnx_model(model_path: str):
|
|
||||||
return onnxruntime.InferenceSession(model_path)
|
|
||||||
|
|
||||||
def validate_onnx(model, inputs):
|
|
||||||
with torch.no_grad():
|
|
||||||
ort_inputs = {'input': inputs.cpu().numpy()}
|
|
||||||
outs = model.run(None, ort_inputs)
|
|
||||||
outs = [torch.Tensor(x) for x in outs]
|
|
||||||
return outs
|
|
||||||
|
|
||||||
model = init_onnx_model(f'{files_dir}/lang_classifier_95.onnx')
|
|
||||||
wav = read_audio(f'{files_dir}/de.wav')
|
|
||||||
|
|
||||||
languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=2, run_function=validate_onnx)
|
|
||||||
|
|
||||||
for i in languages:
|
|
||||||
pprint(f'Language: {i[0]} with prob {i[-1]}')
|
|
||||||
|
|
||||||
for i in language_groups:
|
|
||||||
pprint(f'Language group: {i[0]} with prob {i[-1]}')
|
|
||||||
|
|
||||||
```
|
|
||||||
[](https://pytorch.org/hub/snakers4_silero-vad_language/)
|
|
||||||
|
|
||||||
## Metrics
|
|
||||||
|
|
||||||
### Performance Metrics
|
|
||||||
|
|
||||||
All speed test were run on AMD Ryzen Threadripper 3960X using only 1 thread:
|
|
||||||
```
|
|
||||||
torch.set_num_threads(1) # pytorch
|
|
||||||
ort_session.intra_op_num_threads = 1 # onnx
|
|
||||||
ort_session.inter_op_num_threads = 1 # onnx
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Streaming Latency
|
|
||||||
|
|
||||||
Streaming latency depends on 2 variables:
|
|
||||||
|
|
||||||
- **num_steps** - number of windows to split each audio chunk into. Our post-processing class keeps previous chunk in memory (250 ms), so new chunk (also 250 ms) is appended to it. The resulting big chunk (500 ms) is split into **num_steps** overlapping windows, each 250 ms long.
|
|
||||||
|
|
||||||
- **number of audio streams**
|
|
||||||
|
|
||||||
So **batch size** for streaming is **num_steps * number of audio streams**. Time between receiving new audio chunks and getting results is shown in picture:
|
|
||||||
|
|
||||||
| Batch size | Pytorch model time, ms | Onnx model time, ms |
|
|
||||||
| :--------: | :--------------------: | :-----------------: |
|
|
||||||
| **2** | 9 | 2 |
|
|
||||||
| **4** | 11 | 4 |
|
|
||||||
| **8** | 14 | 7 |
|
|
||||||
| **16** | 19 | 12 |
|
|
||||||
| **40** | 36 | 29 |
|
|
||||||
| **80** | 64 | 55 |
|
|
||||||
| **120** | 96 | 85 |
|
|
||||||
| **200** | 157 | 137 |
|
|
||||||
|
|
||||||
#### Full Audio Throughput
|
|
||||||
|
|
||||||
**RTS** (seconds of audio processed per second, real time speed, or 1 / RTF) for full audio processing depends on **num_steps** (see previous paragraph) and **batch size** (bigger is better).
|
|
||||||
|
|
||||||
| Batch size | num_steps | Pytorch model RTS | Onnx model RTS |
|
|
||||||
| :--------: | :-------: | :---------------: | :------------: |
|
|
||||||
| **40** | **4** | 68 | 86 |
|
|
||||||
| **40** | **8** | 34 | 43 |
|
|
||||||
| **80** | **4** | 78 | 91 |
|
|
||||||
| **80** | **8** | 39 | 45 |
|
|
||||||
| **120** | **4** | 78 | 88 |
|
|
||||||
| **120** | **8** | 39 | 44 |
|
|
||||||
| **200** | **4** | 80 | 91 |
|
|
||||||
| **200** | **8** | 40 | 46 |
|
|
||||||
|
|
||||||
### VAD Quality Metrics
|
|
||||||
|
|
||||||
We use random 250 ms audio chunks for validation. Speech to non-speech ratio among chunks is about ~50/50 (i.e. balanced). Speech chunks are sampled from real audios in four different languages (English, Russian, Spanish, German), then random background noise is added to some of them (~40%).
|
|
||||||
|
|
||||||
Since our VAD (only VAD, other networks are more flexible) was trained on chunks of the same length, model's output is just one float from 0 to 1 - **speech probability**. We use speech probabilities as thresholds for precision-recall curve. This can be extended to 100 - 150 ms. Less than 100 - 150 ms cannot be distinguished as speech with confidence.
|
|
||||||
|
|
||||||
[Webrtc](https://github.com/wiseman/py-webrtcvad) splits audio into frames, each frame has corresponding number (0 **or** 1). We use 30ms frames for webrtc, so each 250 ms chunk is split into 8 frames, their **mean** value is used as a threshold for plot.
|
|
||||||
|
|
||||||
[Auditok](https://github.com/amsehili/auditok) - logic same as Webrtc, but we use 50ms frames.
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
## FAQ
|
|
||||||
|
|
||||||
### VAD Parameter Fine Tuning
|
|
||||||
|
|
||||||
#### Classic way
|
|
||||||
|
|
||||||
**This is straightforward classic method `get_speech_ts` where thresholds (`trig_sum` and `neg_trig_sum`) are specified by users**
|
|
||||||
- Among others, we provide several [utils](https://github.com/snakers4/silero-vad/blob/8b28767292b424e3e505c55f15cd3c4b91e4804b/utils.py#L52-L59) to simplify working with VAD;
|
|
||||||
- We provide sensible basic hyper-parameters that work for us, but your case can be different;
|
|
||||||
- `trig_sum` - overlapping windows are used for each audio chunk, trig sum defines average probability among those windows for switching into triggered state (speech state);
|
|
||||||
- `neg_trig_sum` - same as `trig_sum`, but for switching from triggered to non-triggered state (non-speech)
|
|
||||||
- `num_steps` - nubmer of overlapping windows to split audio chunk into (we recommend 4 or 8)
|
|
||||||
- `num_samples_per_window` - number of samples in each window, our models were trained using `4000` samples (250 ms) per window, so this is preferable value (lesser values reduce [quality](https://github.com/snakers4/silero-vad/issues/2#issuecomment-750840434));
|
|
||||||
- `min_speech_samples` - minimum speech chunk duration in samples
|
|
||||||
- `min_silence_samples` - minimum silence duration in samples between to separate speech chunks
|
|
||||||
|
|
||||||
Optimal parameters may vary per domain, but we provided a tiny tool to learn the best parameters. You can invoke `speech_timestamps` with visualize_probs=True (`pandas` required):
|
|
||||||
|
|
||||||
```
|
|
||||||
speech_timestamps = get_speech_ts(wav, model,
|
|
||||||
num_samples_per_window=4000,
|
|
||||||
num_steps=4,
|
|
||||||
visualize_probs=True)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Adaptive way
|
|
||||||
|
|
||||||
**Adaptive algorithm (`get_speech_ts_adaptive`) automatically selects thresholds (`trig_sum` and `neg_trig_sum`) based on median speech probabilities over the whole audio, SOME ARGUMENTS VARY FROM THE CLASSIC WAY FUNCTION ARGUMENTS**
|
|
||||||
- `batch_size` - batch size to feed to silero VAD (default - `200`)
|
|
||||||
- `step` - step size in samples, (default - `500`) (`num_samples_per_window` / `num_steps` from classic method)
|
|
||||||
- `num_samples_per_window` - number of samples in each window, our models were trained using `4000` samples (250 ms) per window, so this is preferable value (lesser values reduce [quality](https://github.com/snakers4/silero-vad/issues/2#issuecomment-750840434));
|
|
||||||
- `min_speech_samples` - minimum speech chunk duration in samples (default - `10000`)
|
|
||||||
- `min_silence_samples` - minimum silence duration in samples between to separate speech chunks (default - `4000`)
|
|
||||||
- `speech_pad_samples` - widen speech by this amount of samples each side (default - `2000`)
|
|
||||||
|
|
||||||
```
|
|
||||||
speech_timestamps = get_speech_ts_adaptive(wav, model,
|
|
||||||
num_samples_per_window=4000,
|
|
||||||
step=500,
|
|
||||||
visualize_probs=True)
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
The chart should looks something like this:
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
With this particular example you can try shorter chunks (`num_samples_per_window=1600`), but this results in too much noise:
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
|
|
||||||
### How VAD Works
|
|
||||||
|
|
||||||
- Audio is split into 250 ms chunks (you can choose any chunk size, but quality with chunks shorter than 100ms will suffer and there will be more false positives and "unnatural" pauses);
|
|
||||||
- VAD keeps record of a previous chunk (or zeros at the beginning of the stream);
|
|
||||||
- Then this 500 ms audio (250 ms + 250 ms) is split into N (typically 4 or 8) windows and the model is applied to this window batch. Each window is 250 ms long (naturally, windows overlap);
|
|
||||||
- Then probability is averaged across these windows;
|
|
||||||
- Though typically pauses in speech are 300 ms+ or longer (pauses less than 200-300ms are typically not meaninful), it is hard to confidently classify speech vs noise / music on very short chunks (i.e. 30 - 50ms);
|
|
||||||
- ~~We are working on lifting this limitation, so that you can use 100 - 125ms windows~~;
|
|
||||||
|
|
||||||
### VAD Quality Metrics Methodology
|
|
||||||
|
|
||||||
Please see [Quality Metrics](#quality-metrics)
|
|
||||||
|
|
||||||
### How Number Detector Works
|
|
||||||
|
|
||||||
- It is recommended to split long audio into short ones (< 15s) and apply model on each of them;
|
|
||||||
- Number Detector can classify if the whole audio contains a number, or if each audio frame contains a number;
|
|
||||||
- Audio is splitted into frames in a certain way, so, having a per-frame output, we can restore timing bounds for a numbers with an accuracy of about 0.2s;
|
|
||||||
|
|
||||||
### How Language Classifier Works
|
|
||||||
|
|
||||||
- **99%** validation accuracy
|
|
||||||
- Language classifier was trained using audio samples in 4 languages: **Russian**, **English**, **Spanish**, **German**
|
|
||||||
- More languages TBD
|
|
||||||
- Arbitrary audio length can be used, although network was trained using audio shorter than 15 seconds
|
|
||||||
|
|
||||||
### How Language Classifier 95 Works
|
|
||||||
|
|
||||||
- **85%** validation accuracy among 95 languages, **90%** validation accuracy among [58 language groups](https://github.com/snakers4/silero-vad/blob/master/files/lang_group_dict_95.json)
|
|
||||||
- Language classifier 95 was trained using audio samples in [95 languages](https://github.com/snakers4/silero-vad/blob/master/files/lang_dict_95.json)
|
|
||||||
- Arbitrary audio length can be used, although network was trained using audio shorter than 20 seconds
|
|
||||||
|
|
||||||
## Contact
|
|
||||||
|
|
||||||
### Get in Touch
|
|
||||||
|
|
||||||
Try our models, create an [issue](https://github.com/snakers4/silero-vad/issues/new), start a [discussion](https://github.com/snakers4/silero-vad/discussions/new), join our telegram [chat](https://t.me/silero_speech), [email](mailto:hello@silero.ai) us, read our [news](https://t.me/silero_news).
|
Try our models, create an [issue](https://github.com/snakers4/silero-vad/issues/new), start a [discussion](https://github.com/snakers4/silero-vad/discussions/new), join our telegram [chat](https://t.me/silero_speech), [email](mailto:hello@silero.ai) us, read our [news](https://t.me/silero_news).
|
||||||
|
|
||||||
### Commercial Inquiries
|
Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for relevant information and [email](mailto:hello@silero.ai) us directly.
|
||||||
|
|
||||||
Please see our [wiki](https://github.com/snakers4/silero-models/wiki) and [tiers](https://github.com/snakers4/silero-models/wiki/Licensing-and-Tiers) for relevant information and [email](mailto:hello@silero.ai) us directly.
|
**Citations**
|
||||||
|
|
||||||
## Further reading
|
|
||||||
|
|
||||||
### General
|
|
||||||
|
|
||||||
- Silero-models - https://github.com/snakers4/silero-models
|
|
||||||
- Nice [thread](https://github.com/snakers4/silero-vad/discussions/16#discussioncomment-305830) in discussions
|
|
||||||
|
|
||||||
### English
|
|
||||||
|
|
||||||
- STT:
|
|
||||||
- Towards an Imagenet Moment For Speech-To-Text - [link](https://thegradient.pub/towards-an-imagenet-moment-for-speech-to-text/)
|
|
||||||
- A Speech-To-Text Practitioners Criticisms of Industry and Academia - [link](https://thegradient.pub/a-speech-to-text-practitioners-criticisms-of-industry-and-academia/)
|
|
||||||
- Modern Google-level STT Models Released - [link](https://habr.com/ru/post/519562/)
|
|
||||||
|
|
||||||
- TTS:
|
|
||||||
- High-Quality Text-to-Speech Made Accessible, Simple and Fast - [link](https://habr.com/ru/post/549482/)
|
|
||||||
|
|
||||||
- VAD:
|
|
||||||
- Modern Portable Voice Activity Detector Released - [link](https://habr.com/ru/post/537276/)
|
|
||||||
|
|
||||||
- Text Enhancement:
|
|
||||||
- We have published a model for text repunctuation and recapitalization for four languages - [link](https://habr.com/ru/post/581960/)
|
|
||||||
|
|
||||||
### Chinese
|
|
||||||
|
|
||||||
- STT:
|
|
||||||
- 迈向语音识别领域的 ImageNet 时刻 - [link](https://www.infoq.cn/article/4u58WcFCs0RdpoXev1E2)
|
|
||||||
- 语音领域学术界和工业界的七宗罪 - [link](https://www.infoq.cn/article/lEe6GCRjF1CNToVITvNw)
|
|
||||||
|
|
||||||
### Russian
|
|
||||||
|
|
||||||
- STT
|
|
||||||
- Последние обновления моделей распознавания речи из Silero Models - [link](https://habr.com/ru/post/577630/)
|
|
||||||
- Сжимаем трансформеры: простые, универсальные и прикладные способы cделать их компактными и быстрыми - [link](https://habr.com/ru/post/563778/)
|
|
||||||
- Ультимативное сравнение систем распознавания речи: Ashmanov, Google, Sber, Silero, Tinkoff, Yandex - [link](https://habr.com/ru/post/559640/)
|
|
||||||
- Мы опубликовали современные STT модели сравнимые по качеству с Google - [link](https://habr.com/ru/post/519564/)
|
|
||||||
- Понижаем барьеры на вход в распознавание речи - [link](https://habr.com/ru/post/494006/)
|
|
||||||
- Огромный открытый датасет русской речи версия 1.0 - [link](https://habr.com/ru/post/474462/)
|
|
||||||
- Насколько Быстрой Можно Сделать Систему STT? - [link](https://habr.com/ru/post/531524/)
|
|
||||||
- Наша система Speech-To-Text - [link](https://www.silero.ai/tag/our-speech-to-text/)
|
|
||||||
- Speech To Text - [link](https://www.silero.ai/tag/speech-to-text/)
|
|
||||||
|
|
||||||
- TTS:
|
|
||||||
- Мы сделали наш публичный синтез речи еще лучше - [link](https://habr.com/ru/post/563484/)
|
|
||||||
- Мы Опубликовали Качественный, Простой, Доступный и Быстрый Синтез Речи - [link](https://habr.com/ru/post/549480/)
|
|
||||||
|
|
||||||
- VAD:
|
|
||||||
- Модели для Детекции Речи, Чисел и Распознавания Языков - [link](https://www.silero.ai/vad-lang-classifier-number-detector/)
|
|
||||||
- Мы опубликовали современный Voice Activity Detector и не только -[link](https://habr.com/ru/post/537274/)
|
|
||||||
|
|
||||||
- Text Enhancement:
|
|
||||||
- Мы опубликовали модель, расставляющую знаки препинания и заглавные буквы в тексте на четырех языках - [link](https://habr.com/ru/post/581946/)
|
|
||||||
|
|
||||||
|
|
||||||
## Citations
|
|
||||||
|
|
||||||
```
|
```
|
||||||
@misc{Silero VAD,
|
@misc{Silero VAD,
|
||||||
@@ -625,3 +114,13 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) and [tiers
|
|||||||
email = {hello@silero.ai}
|
email = {hello@silero.ai}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<br/>
|
||||||
|
<h2 align="center">Examples and VAD-based Community Apps</h2>
|
||||||
|
<br/>
|
||||||
|
|
||||||
|
- Example of VAD ONNX Runtime model usage in [C++](https://github.com/snakers4/silero-vad/tree/master/examples/cpp)
|
||||||
|
|
||||||
|
- Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web
|
||||||
|
|
||||||
|
- [Rust](https://github.com/snakers4/silero-vad/tree/master/examples/rust-example), [Go](https://github.com/snakers4/silero-vad/tree/master/examples/go), [Java](https://github.com/snakers4/silero-vad/tree/master/examples/java-example) and [other](https://github.com/snakers4/silero-vad/tree/master/examples) examples
|
||||||
|
|||||||
84
datasets/README.md
Normal file
84
datasets/README.md
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# Датасет Silero-VAD
|
||||||
|
|
||||||
|
> Датасет создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный
|
||||||
|
интеллект» национальной программы «Цифровая экономика Российской Федерации».
|
||||||
|
|
||||||
|
По ссылкам ниже представлены `.feather` файлы, содержащие размеченные с помощью Silero VAD открытые наборы аудиоданных, а также короткое описание каждого набора данных с примерами загрузки. `.feather` файлы можно открыть с помощью библиотеки `pandas`:
|
||||||
|
```python3
|
||||||
|
import pandas as pd
|
||||||
|
dataframe = pd.read_feather(PATH_TO_FEATHER_FILE)
|
||||||
|
```
|
||||||
|
|
||||||
|
Каждый `.feather` файл с разметкой содержит следующие колонки:
|
||||||
|
- `speech_timings` - разметка данного аудио. Это список, содержащий словари вида `{'start': START_SECOND, 'end': END_SECOND}`, где `START_SECOND` и `END_SECOND` - время начала и конца речи в секундах. Количество данных словарей равно количеству речевых аудио отрывков, найденных в данном аудио;
|
||||||
|
- `language` - ISO код языка данного аудио.
|
||||||
|
|
||||||
|
Колонки, содержащие информацию о загрузке аудио файла различаются и описаны для каждого набора данных ниже.
|
||||||
|
|
||||||
|
**Все данные размечены при временной дискретизации в ~30 миллисекунд (`num_samples` - 512)**
|
||||||
|
|
||||||
|
| Название | Число часов | Число языков | Ссылка | Лицензия | md5sum |
|
||||||
|
|----------------------|-------------|-------------|--------|----------|----------|
|
||||||
|
| **Bible.is** | 53,138 | 1,596 | [URL](https://live.bible.is/) | [Уникальная](https://live.bible.is/terms) | ea404eeaf2cd283b8223f63002be11f9 |
|
||||||
|
| **globalrecordings.net** | 9,743 | 6,171[^1] | [URL](https://globalrecordings.net/en) | CC BY-NC-SA 4.0 | 3c5c0f31b0abd9fe94ddbe8b1e2eb326 |
|
||||||
|
| **VoxLingua107** | 6,628 | 107 | [URL](https://bark.phon.ioc.ee/voxlingua107/) | CC BY 4.0 | 5dfef33b4d091b6d399cfaf3d05f2140 |
|
||||||
|
| **Common Voice** | 30,329 | 120 | [URL](https://commonvoice.mozilla.org/en/datasets) | CC0 | 5e30a85126adf74a5fd1496e6ac8695d |
|
||||||
|
| **MLS** | 50,709 | 8 | [URL](https://www.openslr.org/94/) | CC BY 4.0 | a339d0e94bdf41bba3c003756254ac4e |
|
||||||
|
| **Итого** | **150,547** | **6,171+** | | | |
|
||||||
|
|
||||||
|
## Bible.is
|
||||||
|
|
||||||
|
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/BibleIs.feather)
|
||||||
|
|
||||||
|
- Колонка `audio_link` содержит ссылки на конкретные аудио файлы.
|
||||||
|
|
||||||
|
## globalrecordings.net
|
||||||
|
|
||||||
|
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/globalrecordings.feather)
|
||||||
|
|
||||||
|
- Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
|
||||||
|
- Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
|
||||||
|
|
||||||
|
``Количество уникальных ISO кодов данного датасета не совпадает с фактическим количеством представленных языков, т.к некоторые близкие языки могут кодироваться одним и тем же ISO кодом.``
|
||||||
|
|
||||||
|
## VoxLingua107
|
||||||
|
|
||||||
|
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/VoxLingua107.feather)
|
||||||
|
|
||||||
|
- Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
|
||||||
|
- Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
|
||||||
|
|
||||||
|
## Common Voice
|
||||||
|
|
||||||
|
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/common_voice.feather)
|
||||||
|
|
||||||
|
Этот датасет невозможно скачать по статичным ссылкам. Для загрузки необходимо перейти по [ссылке](https://commonvoice.mozilla.org/en/datasets) и, получив доступ в соответствующей форме, скачать архивы для каждого доступного языка. Внимание! Представленная разметка актуальна для версии исходного датасета `Common Voice Corpus 16.1`.
|
||||||
|
|
||||||
|
- Колонка `audio_path` содержит уникальные названия `.mp3` файлов, полученных после скачивания соответствующего датасета.
|
||||||
|
|
||||||
|
## MLS
|
||||||
|
|
||||||
|
[Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/MLS.feather)
|
||||||
|
|
||||||
|
- Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
|
||||||
|
- Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
|
||||||
|
|
||||||
|
## Лицензия
|
||||||
|
|
||||||
|
Данный датасет распространяется под [лицензией](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) `CC BY-NC-SA 4.0`.
|
||||||
|
|
||||||
|
## Цитирование
|
||||||
|
|
||||||
|
```
|
||||||
|
@misc{Silero VAD Dataset,
|
||||||
|
author = {Silero Team},
|
||||||
|
title = {Silero-VAD Dataset: a large public Internet-scale dataset for voice activity detection for 6000+ languages},
|
||||||
|
year = {2024},
|
||||||
|
publisher = {GitHub},
|
||||||
|
journal = {GitHub repository},
|
||||||
|
howpublished = {\url{https://github.com/snakers4/silero-vad/datasets/README.md}},
|
||||||
|
email = {hello@silero.ai}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
[^1]: ``Количество уникальных ISO кодов данного датасета не совпадает с фактическим количеством представленных языков, т.к некоторые близкие языки могут кодироваться одним и тем же ISO кодом.``
|
||||||
241
examples/colab_record_example.ipynb
Normal file
241
examples/colab_record_example.ipynb
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "bccAucKjnPHm"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Dependencies and inputs"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "cSih95WFmwgi"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip -q install pydub\n",
|
||||||
|
"from google.colab import output\n",
|
||||||
|
"from base64 import b64decode, b64encode\n",
|
||||||
|
"from io import BytesIO\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from pydub import AudioSegment\n",
|
||||||
|
"from IPython.display import HTML, display\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import moviepy.editor as mpe\n",
|
||||||
|
"from matplotlib.animation import FuncAnimation, FFMpegWriter\n",
|
||||||
|
"import matplotlib\n",
|
||||||
|
"matplotlib.use('Agg')\n",
|
||||||
|
"\n",
|
||||||
|
"torch.set_num_threads(1)\n",
|
||||||
|
"\n",
|
||||||
|
"model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||||
|
" model='silero_vad',\n",
|
||||||
|
" force_reload=True)\n",
|
||||||
|
"\n",
|
||||||
|
"def int2float(sound):\n",
|
||||||
|
" abs_max = np.abs(sound).max()\n",
|
||||||
|
" sound = sound.astype('float32')\n",
|
||||||
|
" if abs_max > 0:\n",
|
||||||
|
" sound *= 1/32768\n",
|
||||||
|
" sound = sound.squeeze()\n",
|
||||||
|
" return sound\n",
|
||||||
|
"\n",
|
||||||
|
"AUDIO_HTML = \"\"\"\n",
|
||||||
|
"<script>\n",
|
||||||
|
"var my_div = document.createElement(\"DIV\");\n",
|
||||||
|
"var my_p = document.createElement(\"P\");\n",
|
||||||
|
"var my_btn = document.createElement(\"BUTTON\");\n",
|
||||||
|
"var t = document.createTextNode(\"Press to start recording\");\n",
|
||||||
|
"\n",
|
||||||
|
"my_btn.appendChild(t);\n",
|
||||||
|
"//my_p.appendChild(my_btn);\n",
|
||||||
|
"my_div.appendChild(my_btn);\n",
|
||||||
|
"document.body.appendChild(my_div);\n",
|
||||||
|
"\n",
|
||||||
|
"var base64data = 0;\n",
|
||||||
|
"var reader;\n",
|
||||||
|
"var recorder, gumStream;\n",
|
||||||
|
"var recordButton = my_btn;\n",
|
||||||
|
"\n",
|
||||||
|
"var handleSuccess = function(stream) {\n",
|
||||||
|
" gumStream = stream;\n",
|
||||||
|
" var options = {\n",
|
||||||
|
" //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n",
|
||||||
|
" mimeType : 'audio/webm;codecs=opus'\n",
|
||||||
|
" //mimeType : 'audio/webm;codecs=pcm'\n",
|
||||||
|
" }; \n",
|
||||||
|
" //recorder = new MediaRecorder(stream, options);\n",
|
||||||
|
" recorder = new MediaRecorder(stream);\n",
|
||||||
|
" recorder.ondataavailable = function(e) { \n",
|
||||||
|
" var url = URL.createObjectURL(e.data);\n",
|
||||||
|
" // var preview = document.createElement('audio');\n",
|
||||||
|
" // preview.controls = true;\n",
|
||||||
|
" // preview.src = url;\n",
|
||||||
|
" // document.body.appendChild(preview);\n",
|
||||||
|
"\n",
|
||||||
|
" reader = new FileReader();\n",
|
||||||
|
" reader.readAsDataURL(e.data); \n",
|
||||||
|
" reader.onloadend = function() {\n",
|
||||||
|
" base64data = reader.result;\n",
|
||||||
|
" //console.log(\"Inside FileReader:\" + base64data);\n",
|
||||||
|
" }\n",
|
||||||
|
" };\n",
|
||||||
|
" recorder.start();\n",
|
||||||
|
" };\n",
|
||||||
|
"\n",
|
||||||
|
"recordButton.innerText = \"Recording... press to stop\";\n",
|
||||||
|
"\n",
|
||||||
|
"navigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"function toggleRecording() {\n",
|
||||||
|
" if (recorder && recorder.state == \"recording\") {\n",
|
||||||
|
" recorder.stop();\n",
|
||||||
|
" gumStream.getAudioTracks()[0].stop();\n",
|
||||||
|
" recordButton.innerText = \"Saving recording...\"\n",
|
||||||
|
" }\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"// https://stackoverflow.com/a/951057\n",
|
||||||
|
"function sleep(ms) {\n",
|
||||||
|
" return new Promise(resolve => setTimeout(resolve, ms));\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"var data = new Promise(resolve=>{\n",
|
||||||
|
"//recordButton.addEventListener(\"click\", toggleRecording);\n",
|
||||||
|
"recordButton.onclick = ()=>{\n",
|
||||||
|
"toggleRecording()\n",
|
||||||
|
"\n",
|
||||||
|
"sleep(2000).then(() => {\n",
|
||||||
|
" // wait 2000ms for the data to be available...\n",
|
||||||
|
" // ideally this should use something like await...\n",
|
||||||
|
" //console.log(\"Inside data:\" + base64data)\n",
|
||||||
|
" resolve(base64data.toString())\n",
|
||||||
|
"\n",
|
||||||
|
"});\n",
|
||||||
|
"\n",
|
||||||
|
"}\n",
|
||||||
|
"});\n",
|
||||||
|
" \n",
|
||||||
|
"</script>\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"def record(sec=10):\n",
|
||||||
|
" display(HTML(AUDIO_HTML))\n",
|
||||||
|
" s = output.eval_js(\"data\")\n",
|
||||||
|
" b = b64decode(s.split(',')[1])\n",
|
||||||
|
" audio = AudioSegment.from_file(BytesIO(b))\n",
|
||||||
|
" audio.export('test.mp3', format='mp3')\n",
|
||||||
|
" audio = audio.set_channels(1)\n",
|
||||||
|
" audio = audio.set_frame_rate(16000)\n",
|
||||||
|
" audio_float = int2float(np.array(audio.get_array_of_samples()))\n",
|
||||||
|
" audio_tens = torch.tensor(audio_float )\n",
|
||||||
|
" return audio_tens\n",
|
||||||
|
"\n",
|
||||||
|
"def make_animation(probs, audio_duration, interval=40):\n",
|
||||||
|
" fig = plt.figure(figsize=(16, 9))\n",
|
||||||
|
" ax = plt.axes(xlim=(0, audio_duration), ylim=(0, 1.02))\n",
|
||||||
|
" line, = ax.plot([], [], lw=2)\n",
|
||||||
|
" x = [i / 16000 * 512 for i in range(len(probs))]\n",
|
||||||
|
" plt.xlabel('Time, seconds', fontsize=16)\n",
|
||||||
|
" plt.ylabel('Speech Probability', fontsize=16)\n",
|
||||||
|
"\n",
|
||||||
|
" def init():\n",
|
||||||
|
" plt.fill_between(x, probs, color='#064273')\n",
|
||||||
|
" line.set_data([], [])\n",
|
||||||
|
" line.set_color('#990000')\n",
|
||||||
|
" return line,\n",
|
||||||
|
"\n",
|
||||||
|
" def animate(i):\n",
|
||||||
|
" x = i * interval / 1000 - 0.04\n",
|
||||||
|
" y = np.linspace(0, 1.02, 2)\n",
|
||||||
|
" \n",
|
||||||
|
" line.set_data(x, y)\n",
|
||||||
|
" line.set_color('#990000')\n",
|
||||||
|
" return line,\n",
|
||||||
|
"\n",
|
||||||
|
" anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=audio_duration / (interval / 1000))\n",
|
||||||
|
"\n",
|
||||||
|
" f = r\"animation.mp4\" \n",
|
||||||
|
" writervideo = FFMpegWriter(fps=1000/interval) \n",
|
||||||
|
" anim.save(f, writer=writervideo)\n",
|
||||||
|
" plt.close('all')\n",
|
||||||
|
"\n",
|
||||||
|
"def combine_audio(vidname, audname, outname, fps=25): \n",
|
||||||
|
" my_clip = mpe.VideoFileClip(vidname, verbose=False)\n",
|
||||||
|
" audio_background = mpe.AudioFileClip(audname)\n",
|
||||||
|
" final_clip = my_clip.set_audio(audio_background)\n",
|
||||||
|
" final_clip.write_videofile(outname,fps=fps,verbose=False)\n",
|
||||||
|
"\n",
|
||||||
|
"def record_make_animation():\n",
|
||||||
|
" tensor = record()\n",
|
||||||
|
"\n",
|
||||||
|
" print('Calculating probabilities...')\n",
|
||||||
|
" speech_probs = []\n",
|
||||||
|
" window_size_samples = 512\n",
|
||||||
|
" for i in range(0, len(tensor), window_size_samples):\n",
|
||||||
|
" if len(tensor[i: i+ window_size_samples]) < window_size_samples:\n",
|
||||||
|
" break\n",
|
||||||
|
" speech_prob = model(tensor[i: i+ window_size_samples], 16000).item()\n",
|
||||||
|
" speech_probs.append(speech_prob)\n",
|
||||||
|
" model.reset_states()\n",
|
||||||
|
" print('Making animation...')\n",
|
||||||
|
" make_animation(speech_probs, len(tensor) / 16000)\n",
|
||||||
|
"\n",
|
||||||
|
" print('Merging your voice with animation...')\n",
|
||||||
|
" combine_audio('animation.mp4', 'test.mp3', 'merged.mp4')\n",
|
||||||
|
" print('Done!')\n",
|
||||||
|
" mp4 = open('merged.mp4','rb').read()\n",
|
||||||
|
" data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
||||||
|
" display(HTML(\"\"\"\n",
|
||||||
|
" <video width=800 controls>\n",
|
||||||
|
" <source src=\"%s\" type=\"video/mp4\">\n",
|
||||||
|
" </video>\n",
|
||||||
|
" \"\"\" % data_url))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "IFVs3GvTnpB1"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Record example"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "5EBjrTwiqAaQ"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"record_make_animation()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"collapsed_sections": [
|
||||||
|
"bccAucKjnPHm"
|
||||||
|
],
|
||||||
|
"name": "Untitled2.ipynb",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
||||||
43
examples/cpp/README.md
Normal file
43
examples/cpp/README.md
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# Stream example in C++
|
||||||
|
|
||||||
|
Here's a simple example of the vad model in c++ onnxruntime.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
Code are tested in the environments bellow, feel free to try others.
|
||||||
|
|
||||||
|
- WSL2 + Debian-bullseye (docker)
|
||||||
|
- gcc 12.2.0
|
||||||
|
- onnxruntime-linux-x64-1.12.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
1. Install gcc 12.2.0, or just pull the docker image with `docker pull gcc:12.2.0-bullseye`
|
||||||
|
|
||||||
|
2. Install onnxruntime-linux-x64-1.12.1
|
||||||
|
|
||||||
|
- Download lib onnxruntime:
|
||||||
|
|
||||||
|
`wget https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz`
|
||||||
|
|
||||||
|
- Unzip. Assume the path is `/root/onnxruntime-linux-x64-1.12.1`
|
||||||
|
|
||||||
|
3. Modify wav path & Test configs in main function
|
||||||
|
|
||||||
|
`wav::WavReader wav_reader("${path_to_your_wav_file}");`
|
||||||
|
|
||||||
|
test sample rate, frame per ms, threshold...
|
||||||
|
|
||||||
|
4. Build with gcc and run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build
|
||||||
|
g++ silero-vad-onnx.cpp -I /root/onnxruntime-linux-x64-1.12.1/include/ -L /root/onnxruntime-linux-x64-1.12.1/lib/ -lonnxruntime -Wl,-rpath,/root/onnxruntime-linux-x64-1.12.1/lib/ -o test
|
||||||
|
|
||||||
|
# Run
|
||||||
|
./test
|
||||||
|
```
|
||||||
478
examples/cpp/silero-vad-onnx.cpp
Normal file
478
examples/cpp/silero-vad-onnx.cpp
Normal file
@@ -0,0 +1,478 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include <sstream>
|
||||||
|
#include <cstring>
|
||||||
|
#include <limits>
|
||||||
|
#include <chrono>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include "onnxruntime_cxx_api.h"
|
||||||
|
#include "wav.h"
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdarg>
|
||||||
|
#if __cplusplus < 201703L
|
||||||
|
#include <memory>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//#define __DEBUG_SPEECH_PROB___
|
||||||
|
|
||||||
|
class timestamp_t
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
int start;
|
||||||
|
int end;
|
||||||
|
|
||||||
|
// default + parameterized constructor
|
||||||
|
timestamp_t(int start = -1, int end = -1)
|
||||||
|
: start(start), end(end)
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
// assignment operator modifies object, therefore non-const
|
||||||
|
timestamp_t& operator=(const timestamp_t& a)
|
||||||
|
{
|
||||||
|
start = a.start;
|
||||||
|
end = a.end;
|
||||||
|
return *this;
|
||||||
|
};
|
||||||
|
|
||||||
|
// equality comparison. doesn't modify object. therefore const.
|
||||||
|
bool operator==(const timestamp_t& a) const
|
||||||
|
{
|
||||||
|
return (start == a.start && end == a.end);
|
||||||
|
};
|
||||||
|
std::string c_str()
|
||||||
|
{
|
||||||
|
//return std::format("timestamp {:08d}, {:08d}", start, end);
|
||||||
|
return format("{start:%08d,end:%08d}", start, end);
|
||||||
|
};
|
||||||
|
private:
|
||||||
|
|
||||||
|
std::string format(const char* fmt, ...)
|
||||||
|
{
|
||||||
|
char buf[256];
|
||||||
|
|
||||||
|
va_list args;
|
||||||
|
va_start(args, fmt);
|
||||||
|
const auto r = std::vsnprintf(buf, sizeof buf, fmt, args);
|
||||||
|
va_end(args);
|
||||||
|
|
||||||
|
if (r < 0)
|
||||||
|
// conversion failed
|
||||||
|
return {};
|
||||||
|
|
||||||
|
const size_t len = r;
|
||||||
|
if (len < sizeof buf)
|
||||||
|
// we fit in the buffer
|
||||||
|
return { buf, len };
|
||||||
|
|
||||||
|
#if __cplusplus >= 201703L
|
||||||
|
// C++17: Create a string and write to its underlying array
|
||||||
|
std::string s(len, '\0');
|
||||||
|
va_start(args, fmt);
|
||||||
|
std::vsnprintf(s.data(), len + 1, fmt, args);
|
||||||
|
va_end(args);
|
||||||
|
|
||||||
|
return s;
|
||||||
|
#else
|
||||||
|
// C++11 or C++14: We need to allocate scratch memory
|
||||||
|
auto vbuf = std::unique_ptr<char[]>(new char[len + 1]);
|
||||||
|
va_start(args, fmt);
|
||||||
|
std::vsnprintf(vbuf.get(), len + 1, fmt, args);
|
||||||
|
va_end(args);
|
||||||
|
|
||||||
|
return { vbuf.get(), len };
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class VadIterator
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
// OnnxRuntime resources
|
||||||
|
Ort::Env env;
|
||||||
|
Ort::SessionOptions session_options;
|
||||||
|
std::shared_ptr<Ort::Session> session = nullptr;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void init_engine_threads(int inter_threads, int intra_threads)
|
||||||
|
{
|
||||||
|
// The method should be called in each thread/proc in multi-thread/proc work
|
||||||
|
session_options.SetIntraOpNumThreads(intra_threads);
|
||||||
|
session_options.SetInterOpNumThreads(inter_threads);
|
||||||
|
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||||
|
};
|
||||||
|
|
||||||
|
void init_onnx_model(const std::wstring& model_path)
|
||||||
|
{
|
||||||
|
// Init threads = 1 for
|
||||||
|
init_engine_threads(1, 1);
|
||||||
|
// Load model
|
||||||
|
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
|
||||||
|
};
|
||||||
|
|
||||||
|
void reset_states()
|
||||||
|
{
|
||||||
|
// Call reset before each audio start
|
||||||
|
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
|
||||||
|
triggered = false;
|
||||||
|
temp_end = 0;
|
||||||
|
current_sample = 0;
|
||||||
|
|
||||||
|
prev_end = next_start = 0;
|
||||||
|
|
||||||
|
speeches.clear();
|
||||||
|
current_speech = timestamp_t();
|
||||||
|
};
|
||||||
|
|
||||||
|
void predict(const std::vector<float> &data)
|
||||||
|
{
|
||||||
|
// Infer
|
||||||
|
// Create ort tensors
|
||||||
|
input.assign(data.begin(), data.end());
|
||||||
|
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
|
||||||
|
memory_info, input.data(), input.size(), input_node_dims, 2);
|
||||||
|
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
|
||||||
|
memory_info, _state.data(), _state.size(), state_node_dims, 3);
|
||||||
|
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
|
||||||
|
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
|
||||||
|
|
||||||
|
// Clear and add inputs
|
||||||
|
ort_inputs.clear();
|
||||||
|
ort_inputs.emplace_back(std::move(input_ort));
|
||||||
|
ort_inputs.emplace_back(std::move(state_ort));
|
||||||
|
ort_inputs.emplace_back(std::move(sr_ort));
|
||||||
|
|
||||||
|
// Infer
|
||||||
|
ort_outputs = session->Run(
|
||||||
|
Ort::RunOptions{nullptr},
|
||||||
|
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
|
||||||
|
output_node_names.data(), output_node_names.size());
|
||||||
|
|
||||||
|
// Output probability & update h,c recursively
|
||||||
|
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
|
||||||
|
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
|
||||||
|
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
|
||||||
|
|
||||||
|
// Push forward sample index
|
||||||
|
current_sample += window_size_samples;
|
||||||
|
|
||||||
|
// Reset temp_end when > threshold
|
||||||
|
if ((speech_prob >= threshold))
|
||||||
|
{
|
||||||
|
#ifdef __DEBUG_SPEECH_PROB___
|
||||||
|
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
|
||||||
|
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample- window_size_samples);
|
||||||
|
#endif //__DEBUG_SPEECH_PROB___
|
||||||
|
if (temp_end != 0)
|
||||||
|
{
|
||||||
|
temp_end = 0;
|
||||||
|
if (next_start < prev_end)
|
||||||
|
next_start = current_sample - window_size_samples;
|
||||||
|
}
|
||||||
|
if (triggered == false)
|
||||||
|
{
|
||||||
|
triggered = true;
|
||||||
|
|
||||||
|
current_speech.start = current_sample - window_size_samples;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
(triggered == true)
|
||||||
|
&& ((current_sample - current_speech.start) > max_speech_samples)
|
||||||
|
) {
|
||||||
|
if (prev_end > 0) {
|
||||||
|
current_speech.end = prev_end;
|
||||||
|
speeches.push_back(current_speech);
|
||||||
|
current_speech = timestamp_t();
|
||||||
|
|
||||||
|
// previously reached silence(< neg_thres) and is still not speech(< thres)
|
||||||
|
if (next_start < prev_end)
|
||||||
|
triggered = false;
|
||||||
|
else{
|
||||||
|
current_speech.start = next_start;
|
||||||
|
}
|
||||||
|
prev_end = 0;
|
||||||
|
next_start = 0;
|
||||||
|
temp_end = 0;
|
||||||
|
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
current_speech.end = current_sample;
|
||||||
|
speeches.push_back(current_speech);
|
||||||
|
current_speech = timestamp_t();
|
||||||
|
prev_end = 0;
|
||||||
|
next_start = 0;
|
||||||
|
temp_end = 0;
|
||||||
|
triggered = false;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold))
|
||||||
|
{
|
||||||
|
if (triggered) {
|
||||||
|
#ifdef __DEBUG_SPEECH_PROB___
|
||||||
|
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
|
||||||
|
printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
|
||||||
|
#endif //__DEBUG_SPEECH_PROB___
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
#ifdef __DEBUG_SPEECH_PROB___
|
||||||
|
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
|
||||||
|
printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
|
||||||
|
#endif //__DEBUG_SPEECH_PROB___
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// 4) End
|
||||||
|
if ((speech_prob < (threshold - 0.15)))
|
||||||
|
{
|
||||||
|
#ifdef __DEBUG_SPEECH_PROB___
|
||||||
|
float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point.
|
||||||
|
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
|
||||||
|
#endif //__DEBUG_SPEECH_PROB___
|
||||||
|
if (triggered == true)
|
||||||
|
{
|
||||||
|
if (temp_end == 0)
|
||||||
|
{
|
||||||
|
temp_end = current_sample;
|
||||||
|
}
|
||||||
|
if (current_sample - temp_end > min_silence_samples_at_max_speech)
|
||||||
|
prev_end = temp_end;
|
||||||
|
// a. silence < min_slience_samples, continue speaking
|
||||||
|
if ((current_sample - temp_end) < min_silence_samples)
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
// b. silence >= min_slience_samples, end speaking
|
||||||
|
else
|
||||||
|
{
|
||||||
|
current_speech.end = temp_end;
|
||||||
|
if (current_speech.end - current_speech.start > min_speech_samples)
|
||||||
|
{
|
||||||
|
speeches.push_back(current_speech);
|
||||||
|
current_speech = timestamp_t();
|
||||||
|
prev_end = 0;
|
||||||
|
next_start = 0;
|
||||||
|
temp_end = 0;
|
||||||
|
triggered = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// may first windows see end state.
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
public:
|
||||||
|
void process(const std::vector<float>& input_wav)
|
||||||
|
{
|
||||||
|
reset_states();
|
||||||
|
|
||||||
|
audio_length_samples = input_wav.size();
|
||||||
|
|
||||||
|
for (int j = 0; j < audio_length_samples; j += window_size_samples)
|
||||||
|
{
|
||||||
|
if (j + window_size_samples > audio_length_samples)
|
||||||
|
break;
|
||||||
|
std::vector<float> r{ &input_wav[0] + j, &input_wav[0] + j + window_size_samples };
|
||||||
|
predict(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (current_speech.start >= 0) {
|
||||||
|
current_speech.end = audio_length_samples;
|
||||||
|
speeches.push_back(current_speech);
|
||||||
|
current_speech = timestamp_t();
|
||||||
|
prev_end = 0;
|
||||||
|
next_start = 0;
|
||||||
|
temp_end = 0;
|
||||||
|
triggered = false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void process(const std::vector<float>& input_wav, std::vector<float>& output_wav)
|
||||||
|
{
|
||||||
|
process(input_wav);
|
||||||
|
collect_chunks(input_wav, output_wav);
|
||||||
|
}
|
||||||
|
|
||||||
|
void collect_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
|
||||||
|
{
|
||||||
|
output_wav.clear();
|
||||||
|
for (int i = 0; i < speeches.size(); i++) {
|
||||||
|
#ifdef __DEBUG_SPEECH_PROB___
|
||||||
|
std::cout << speeches[i].c_str() << std::endl;
|
||||||
|
#endif //#ifdef __DEBUG_SPEECH_PROB___
|
||||||
|
std::vector<float> slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]);
|
||||||
|
output_wav.insert(output_wav.end(),slice.begin(),slice.end());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<timestamp_t> get_speech_timestamps() const
|
||||||
|
{
|
||||||
|
return speeches;
|
||||||
|
}
|
||||||
|
|
||||||
|
void drop_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
|
||||||
|
{
|
||||||
|
output_wav.clear();
|
||||||
|
int current_start = 0;
|
||||||
|
for (int i = 0; i < speeches.size(); i++) {
|
||||||
|
|
||||||
|
std::vector<float> slice(&input_wav[current_start],&input_wav[speeches[i].start]);
|
||||||
|
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
|
||||||
|
current_start = speeches[i].end;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> slice(&input_wav[current_start], &input_wav[input_wav.size()]);
|
||||||
|
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
// model config
|
||||||
|
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
|
||||||
|
int sample_rate; //Assign when init support 16000 or 8000
|
||||||
|
int sr_per_ms; // Assign when init, support 8 or 16
|
||||||
|
float threshold;
|
||||||
|
int min_silence_samples; // sr_per_ms * #ms
|
||||||
|
int min_silence_samples_at_max_speech; // sr_per_ms * #98
|
||||||
|
int min_speech_samples; // sr_per_ms * #ms
|
||||||
|
float max_speech_samples;
|
||||||
|
int speech_pad_samples; // usually a
|
||||||
|
int audio_length_samples;
|
||||||
|
|
||||||
|
// model states
|
||||||
|
bool triggered = false;
|
||||||
|
unsigned int temp_end = 0;
|
||||||
|
unsigned int current_sample = 0;
|
||||||
|
// MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
|
||||||
|
int prev_end;
|
||||||
|
int next_start = 0;
|
||||||
|
|
||||||
|
//Output timestamp
|
||||||
|
std::vector<timestamp_t> speeches;
|
||||||
|
timestamp_t current_speech;
|
||||||
|
|
||||||
|
|
||||||
|
// Onnx model
|
||||||
|
// Inputs
|
||||||
|
std::vector<Ort::Value> ort_inputs;
|
||||||
|
|
||||||
|
std::vector<const char *> input_node_names = {"input", "state", "sr"};
|
||||||
|
std::vector<float> input;
|
||||||
|
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
|
||||||
|
std::vector<float> _state;
|
||||||
|
std::vector<int64_t> sr;
|
||||||
|
|
||||||
|
int64_t input_node_dims[2] = {};
|
||||||
|
const int64_t state_node_dims[3] = {2, 1, 128};
|
||||||
|
const int64_t sr_node_dims[1] = {1};
|
||||||
|
|
||||||
|
// Outputs
|
||||||
|
std::vector<Ort::Value> ort_outputs;
|
||||||
|
std::vector<const char *> output_node_names = {"output", "stateN"};
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Construction
|
||||||
|
VadIterator(const std::wstring ModelPath,
|
||||||
|
int Sample_rate = 16000, int windows_frame_size = 32,
|
||||||
|
float Threshold = 0.5, int min_silence_duration_ms = 0,
|
||||||
|
int speech_pad_ms = 32, int min_speech_duration_ms = 32,
|
||||||
|
float max_speech_duration_s = std::numeric_limits<float>::infinity())
|
||||||
|
{
|
||||||
|
init_onnx_model(ModelPath);
|
||||||
|
threshold = Threshold;
|
||||||
|
sample_rate = Sample_rate;
|
||||||
|
sr_per_ms = sample_rate / 1000;
|
||||||
|
|
||||||
|
window_size_samples = windows_frame_size * sr_per_ms;
|
||||||
|
|
||||||
|
min_speech_samples = sr_per_ms * min_speech_duration_ms;
|
||||||
|
speech_pad_samples = sr_per_ms * speech_pad_ms;
|
||||||
|
|
||||||
|
max_speech_samples = (
|
||||||
|
sample_rate * max_speech_duration_s
|
||||||
|
- window_size_samples
|
||||||
|
- 2 * speech_pad_samples
|
||||||
|
);
|
||||||
|
|
||||||
|
min_silence_samples = sr_per_ms * min_silence_duration_ms;
|
||||||
|
min_silence_samples_at_max_speech = sr_per_ms * 98;
|
||||||
|
|
||||||
|
input.resize(window_size_samples);
|
||||||
|
input_node_dims[0] = 1;
|
||||||
|
input_node_dims[1] = window_size_samples;
|
||||||
|
|
||||||
|
_state.resize(size_state);
|
||||||
|
sr.resize(1);
|
||||||
|
sr[0] = sample_rate;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
std::vector<timestamp_t> stamps;
|
||||||
|
|
||||||
|
// Read wav
|
||||||
|
wav::WavReader wav_reader("recorder.wav"); //16000,1,32float
|
||||||
|
std::vector<float> input_wav(wav_reader.num_samples());
|
||||||
|
std::vector<float> output_wav;
|
||||||
|
|
||||||
|
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||||
|
{
|
||||||
|
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// ===== Test configs =====
|
||||||
|
std::wstring path = L"silero_vad.onnx";
|
||||||
|
VadIterator vad(path);
|
||||||
|
|
||||||
|
// ==============================================
|
||||||
|
// ==== = Example 1 of full function =====
|
||||||
|
// ==============================================
|
||||||
|
vad.process(input_wav);
|
||||||
|
|
||||||
|
// 1.a get_speech_timestamps
|
||||||
|
stamps = vad.get_speech_timestamps();
|
||||||
|
for (int i = 0; i < stamps.size(); i++) {
|
||||||
|
|
||||||
|
std::cout << stamps[i].c_str() << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1.b collect_chunks output wav
|
||||||
|
vad.collect_chunks(input_wav, output_wav);
|
||||||
|
|
||||||
|
// 1.c drop_chunks output wav
|
||||||
|
vad.drop_chunks(input_wav, output_wav);
|
||||||
|
|
||||||
|
// ==============================================
|
||||||
|
// ===== Example 2 of simple full function =====
|
||||||
|
// ==============================================
|
||||||
|
vad.process(input_wav, output_wav);
|
||||||
|
|
||||||
|
stamps = vad.get_speech_timestamps();
|
||||||
|
for (int i = 0; i < stamps.size(); i++) {
|
||||||
|
|
||||||
|
std::cout << stamps[i].c_str() << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==============================================
|
||||||
|
// ===== Example 3 of full function =====
|
||||||
|
// ==============================================
|
||||||
|
for(int i = 0; i<2; i++)
|
||||||
|
vad.process(input_wav, output_wav);
|
||||||
|
}
|
||||||
235
examples/cpp/wav.h
Normal file
235
examples/cpp/wav.h
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
// Copyright (c) 2016 Personal (Binbin Zhang)
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef FRONTEND_WAV_H_
|
||||||
|
#define FRONTEND_WAV_H_
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
// #include "utils/log.h"
|
||||||
|
|
||||||
|
namespace wav {
|
||||||
|
|
||||||
|
struct WavHeader {
|
||||||
|
char riff[4]; // "riff"
|
||||||
|
unsigned int size;
|
||||||
|
char wav[4]; // "WAVE"
|
||||||
|
char fmt[4]; // "fmt "
|
||||||
|
unsigned int fmt_size;
|
||||||
|
uint16_t format;
|
||||||
|
uint16_t channels;
|
||||||
|
unsigned int sample_rate;
|
||||||
|
unsigned int bytes_per_second;
|
||||||
|
uint16_t block_size;
|
||||||
|
uint16_t bit;
|
||||||
|
char data[4]; // "data"
|
||||||
|
unsigned int data_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WavReader {
|
||||||
|
public:
|
||||||
|
WavReader() : data_(nullptr) {}
|
||||||
|
explicit WavReader(const std::string& filename) { Open(filename); }
|
||||||
|
|
||||||
|
bool Open(const std::string& filename) {
|
||||||
|
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
|
||||||
|
if (NULL == fp) {
|
||||||
|
std::cout << "Error in read " << filename;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
WavHeader header;
|
||||||
|
fread(&header, 1, sizeof(header), fp);
|
||||||
|
if (header.fmt_size < 16) {
|
||||||
|
printf("WaveData: expect PCM format data "
|
||||||
|
"to have fmt chunk of at least size 16.\n");
|
||||||
|
return false;
|
||||||
|
} else if (header.fmt_size > 16) {
|
||||||
|
int offset = 44 - 8 + header.fmt_size - 16;
|
||||||
|
fseek(fp, offset, SEEK_SET);
|
||||||
|
fread(header.data, 8, sizeof(char), fp);
|
||||||
|
}
|
||||||
|
// check "riff" "WAVE" "fmt " "data"
|
||||||
|
|
||||||
|
// Skip any sub-chunks between "fmt" and "data". Usually there will
|
||||||
|
// be a single "fact" sub chunk, but on Windows there can also be a
|
||||||
|
// "list" sub chunk.
|
||||||
|
while (0 != strncmp(header.data, "data", 4)) {
|
||||||
|
// We will just ignore the data in these chunks.
|
||||||
|
fseek(fp, header.data_size, SEEK_CUR);
|
||||||
|
// read next sub chunk
|
||||||
|
fread(header.data, 8, sizeof(char), fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (header.data_size == 0) {
|
||||||
|
int offset = ftell(fp);
|
||||||
|
fseek(fp, 0, SEEK_END);
|
||||||
|
header.data_size = ftell(fp) - offset;
|
||||||
|
fseek(fp, offset, SEEK_SET);
|
||||||
|
}
|
||||||
|
|
||||||
|
num_channel_ = header.channels;
|
||||||
|
sample_rate_ = header.sample_rate;
|
||||||
|
bits_per_sample_ = header.bit;
|
||||||
|
int num_data = header.data_size / (bits_per_sample_ / 8);
|
||||||
|
data_ = new float[num_data]; // Create 1-dim array
|
||||||
|
num_samples_ = num_data / num_channel_;
|
||||||
|
|
||||||
|
std::cout << "num_channel_ :" << num_channel_ << std::endl;
|
||||||
|
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
|
||||||
|
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
|
||||||
|
std::cout << "num_samples :" << num_data << std::endl;
|
||||||
|
std::cout << "num_data_size :" << header.data_size << std::endl;
|
||||||
|
|
||||||
|
switch (bits_per_sample_) {
|
||||||
|
case 8: {
|
||||||
|
char sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(char), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 16: {
|
||||||
|
int16_t sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(int16_t), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 32:
|
||||||
|
{
|
||||||
|
if (header.format == 1) //S32
|
||||||
|
{
|
||||||
|
int sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(int), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (header.format == 3) // IEEE-float
|
||||||
|
{
|
||||||
|
float sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(float), fp);
|
||||||
|
data_[i] = static_cast<float>(sample);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
printf("unsupported quantization bits\n");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
printf("unsupported quantization bits\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
fclose(fp);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_channel() const { return num_channel_; }
|
||||||
|
int sample_rate() const { return sample_rate_; }
|
||||||
|
int bits_per_sample() const { return bits_per_sample_; }
|
||||||
|
int num_samples() const { return num_samples_; }
|
||||||
|
|
||||||
|
~WavReader() {
|
||||||
|
delete[] data_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float* data() const { return data_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_channel_;
|
||||||
|
int sample_rate_;
|
||||||
|
int bits_per_sample_;
|
||||||
|
int num_samples_; // sample points per channel
|
||||||
|
float* data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WavWriter {
|
||||||
|
public:
|
||||||
|
WavWriter(const float* data, int num_samples, int num_channel,
|
||||||
|
int sample_rate, int bits_per_sample)
|
||||||
|
: data_(data),
|
||||||
|
num_samples_(num_samples),
|
||||||
|
num_channel_(num_channel),
|
||||||
|
sample_rate_(sample_rate),
|
||||||
|
bits_per_sample_(bits_per_sample) {}
|
||||||
|
|
||||||
|
void Write(const std::string& filename) {
|
||||||
|
FILE* fp = fopen(filename.c_str(), "w");
|
||||||
|
// init char 'riff' 'WAVE' 'fmt ' 'data'
|
||||||
|
WavHeader header;
|
||||||
|
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
|
||||||
|
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
|
||||||
|
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
|
||||||
|
memcpy(&header, wav_header, sizeof(header));
|
||||||
|
header.channels = num_channel_;
|
||||||
|
header.bit = bits_per_sample_;
|
||||||
|
header.sample_rate = sample_rate_;
|
||||||
|
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
header.size = sizeof(header) - 8 + header.data_size;
|
||||||
|
header.bytes_per_second =
|
||||||
|
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
header.block_size = num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
|
||||||
|
fwrite(&header, 1, sizeof(header), fp);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_samples_; ++i) {
|
||||||
|
for (int j = 0; j < num_channel_; ++j) {
|
||||||
|
switch (bits_per_sample_) {
|
||||||
|
case 8: {
|
||||||
|
char sample = static_cast<char>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 16: {
|
||||||
|
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 32: {
|
||||||
|
int sample = static_cast<int>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fclose(fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const float* data_;
|
||||||
|
int num_samples_; // total float points in data_
|
||||||
|
int num_channel_;
|
||||||
|
int sample_rate_;
|
||||||
|
int bits_per_sample_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace wenet
|
||||||
|
|
||||||
|
#endif // FRONTEND_WAV_H_
|
||||||
19
examples/go/README.md
Normal file
19
examples/go/README.md
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
## Golang Example
|
||||||
|
|
||||||
|
This is a sample program of how to run speech detection using `silero-vad` from Golang (CGO + ONNX Runtime).
|
||||||
|
|
||||||
|
### Requirements
|
||||||
|
|
||||||
|
- Golang >= v1.21
|
||||||
|
- ONNX Runtime
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go run ./cmd/main.go test.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
> **_Note_**
|
||||||
|
>
|
||||||
|
> Make sure you have the ONNX Runtime library and C headers installed in your path.
|
||||||
|
|
||||||
63
examples/go/cmd/main.go
Normal file
63
examples/go/cmd/main.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/streamer45/silero-vad-go/speech"
|
||||||
|
|
||||||
|
"github.com/go-audio/wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
sd, err := speech.NewDetector(speech.DetectorConfig{
|
||||||
|
ModelPath: "../../files/silero_vad.onnx",
|
||||||
|
SampleRate: 16000,
|
||||||
|
Threshold: 0.5,
|
||||||
|
MinSilenceDurationMs: 0,
|
||||||
|
SpeechPadMs: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create speech detector: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(os.Args) != 2 {
|
||||||
|
log.Fatalf("invalid arguments provided: expecting one file path")
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Open(os.Args[1])
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to open sample audio file: %s", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
dec := wav.NewDecoder(f)
|
||||||
|
|
||||||
|
if ok := dec.IsValidFile(); !ok {
|
||||||
|
log.Fatalf("invalid WAV file")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := dec.FullPCMBuffer()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to get PCM buffer")
|
||||||
|
}
|
||||||
|
|
||||||
|
pcmBuf := buf.AsFloat32Buffer()
|
||||||
|
|
||||||
|
segments, err := sd.Detect(pcmBuf.Data)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Detect failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range segments {
|
||||||
|
log.Printf("speech starts at %0.2fs", s.SpeechStartAt)
|
||||||
|
if s.SpeechEndAt > 0 {
|
||||||
|
log.Printf("speech ends at %0.2fs", s.SpeechEndAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = sd.Destroy()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to destroy detector: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
13
examples/go/go.mod
Normal file
13
examples/go/go.mod
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
module silero
|
||||||
|
|
||||||
|
go 1.21.4
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/go-audio/wav v1.1.0
|
||||||
|
github.com/streamer45/silero-vad-go v0.2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/go-audio/audio v1.0.0 // indirect
|
||||||
|
github.com/go-audio/riff v1.0.0 // indirect
|
||||||
|
)
|
||||||
16
examples/go/go.sum
Normal file
16
examples/go/go.sum
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
||||||
|
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
|
||||||
|
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
|
||||||
|
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
|
||||||
|
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
|
||||||
|
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/streamer45/silero-vad-go v0.2.0 h1:bbRTa6cQuc7VI88y0qicx375UyWoxE6wlVOF+mUg0+g=
|
||||||
|
github.com/streamer45/silero-vad-go v0.2.0/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
|
||||||
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
30
examples/java-example/pom.xml
Normal file
30
examples/java-example/pom.xml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<groupId>org.example</groupId>
|
||||||
|
<artifactId>java-example</artifactId>
|
||||||
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
|
<name>sliero-vad-example</name>
|
||||||
|
<url>http://maven.apache.org</url>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
<version>3.8.1</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.microsoft.onnxruntime</groupId>
|
||||||
|
<artifactId>onnxruntime</artifactId>
|
||||||
|
<version>1.16.0-rc1</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
</project>
|
||||||
69
examples/java-example/src/main/java/org/example/App.java
Normal file
69
examples/java-example/src/main/java/org/example/App.java
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package org.example;
|
||||||
|
|
||||||
|
import ai.onnxruntime.OrtException;
|
||||||
|
import javax.sound.sampled.*;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class App {
|
||||||
|
|
||||||
|
private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx";
|
||||||
|
private static final int SAMPLE_RATE = 16000;
|
||||||
|
private static final float START_THRESHOLD = 0.6f;
|
||||||
|
private static final float END_THRESHOLD = 0.45f;
|
||||||
|
private static final int MIN_SILENCE_DURATION_MS = 600;
|
||||||
|
private static final int SPEECH_PAD_MS = 500;
|
||||||
|
private static final int WINDOW_SIZE_SAMPLES = 2048;
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
// Initialize the Voice Activity Detector
|
||||||
|
SlieroVadDetector vadDetector;
|
||||||
|
try {
|
||||||
|
vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
||||||
|
} catch (OrtException e) {
|
||||||
|
System.err.println("Error initializing the VAD detector: " + e.getMessage());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set audio format
|
||||||
|
AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false);
|
||||||
|
DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
|
||||||
|
|
||||||
|
// Get the target data line and open it with the specified format
|
||||||
|
TargetDataLine targetDataLine;
|
||||||
|
try {
|
||||||
|
targetDataLine = (TargetDataLine) AudioSystem.getLine(info);
|
||||||
|
targetDataLine.open(format);
|
||||||
|
targetDataLine.start();
|
||||||
|
} catch (LineUnavailableException e) {
|
||||||
|
System.err.println("Error opening target data line: " + e.getMessage());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main loop to continuously read data and apply Voice Activity Detection
|
||||||
|
while (targetDataLine.isOpen()) {
|
||||||
|
byte[] data = new byte[WINDOW_SIZE_SAMPLES];
|
||||||
|
|
||||||
|
int numBytesRead = targetDataLine.read(data, 0, data.length);
|
||||||
|
if (numBytesRead <= 0) {
|
||||||
|
System.err.println("Error reading data from target data line.");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the Voice Activity Detector to the data and get the result
|
||||||
|
Map<String, Double> detectResult;
|
||||||
|
try {
|
||||||
|
detectResult = vadDetector.apply(data, true);
|
||||||
|
} catch (Exception e) {
|
||||||
|
System.err.println("Error applying VAD detector: " + e.getMessage());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!detectResult.isEmpty()) {
|
||||||
|
System.out.println(detectResult);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the target data line to release audio resources
|
||||||
|
targetDataLine.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,145 @@
|
|||||||
|
package org.example;
|
||||||
|
|
||||||
|
import ai.onnxruntime.OrtException;
|
||||||
|
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
import java.math.RoundingMode;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
|
public class SlieroVadDetector {
|
||||||
|
// OnnxModel model used for speech processing
|
||||||
|
private final SlieroVadOnnxModel model;
|
||||||
|
// Threshold for speech start
|
||||||
|
private final float startThreshold;
|
||||||
|
// Threshold for speech end
|
||||||
|
private final float endThreshold;
|
||||||
|
// Sampling rate
|
||||||
|
private final int samplingRate;
|
||||||
|
// Minimum number of silence samples to determine the end threshold of speech
|
||||||
|
private final float minSilenceSamples;
|
||||||
|
// Additional number of samples for speech start or end to calculate speech start or end time
|
||||||
|
private final float speechPadSamples;
|
||||||
|
// Whether in the triggered state (i.e. whether speech is being detected)
|
||||||
|
private boolean triggered;
|
||||||
|
// Temporarily stored number of speech end samples
|
||||||
|
private int tempEnd;
|
||||||
|
// Number of samples currently being processed
|
||||||
|
private int currentSample;
|
||||||
|
|
||||||
|
|
||||||
|
public SlieroVadDetector(String modelPath,
|
||||||
|
float startThreshold,
|
||||||
|
float endThreshold,
|
||||||
|
int samplingRate,
|
||||||
|
int minSilenceDurationMs,
|
||||||
|
int speechPadMs) throws OrtException {
|
||||||
|
// Check if the sampling rate is 8000 or 16000, if not, throw an exception
|
||||||
|
if (samplingRate != 8000 && samplingRate != 16000) {
|
||||||
|
throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the parameters
|
||||||
|
this.model = new SlieroVadOnnxModel(modelPath);
|
||||||
|
this.startThreshold = startThreshold;
|
||||||
|
this.endThreshold = endThreshold;
|
||||||
|
this.samplingRate = samplingRate;
|
||||||
|
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
||||||
|
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
|
||||||
|
// Reset the state
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method to reset the state, including the model state, trigger state, temporary end time, and current sample count
|
||||||
|
public void reset() {
|
||||||
|
model.resetStates();
|
||||||
|
triggered = false;
|
||||||
|
tempEnd = 0;
|
||||||
|
currentSample = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply method for processing the audio array, returning possible speech start or end times
|
||||||
|
public Map<String, Double> apply(byte[] data, boolean returnSeconds) {
|
||||||
|
|
||||||
|
// Convert the byte array to a float array
|
||||||
|
float[] audioData = new float[data.length / 2];
|
||||||
|
for (int i = 0; i < audioData.length; i++) {
|
||||||
|
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the length of the audio array as the window size
|
||||||
|
int windowSizeSamples = audioData.length;
|
||||||
|
// Update the current sample count
|
||||||
|
currentSample += windowSizeSamples;
|
||||||
|
|
||||||
|
// Call the model to get the prediction probability of speech
|
||||||
|
float speechProb = 0;
|
||||||
|
try {
|
||||||
|
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
|
||||||
|
} catch (OrtException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time
|
||||||
|
// This indicates that the speech duration has exceeded expectations and needs to recalculate the end time
|
||||||
|
if (speechProb >= startThreshold && tempEnd != 0) {
|
||||||
|
tempEnd = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time
|
||||||
|
if (speechProb >= startThreshold && !triggered) {
|
||||||
|
triggered = true;
|
||||||
|
int speechStart = (int) (currentSample - speechPadSamples);
|
||||||
|
speechStart = Math.max(speechStart, 0);
|
||||||
|
Map<String, Double> result = new HashMap<>();
|
||||||
|
// Decide whether to return the result in seconds or sample count based on the returnSeconds parameter
|
||||||
|
if (returnSeconds) {
|
||||||
|
double speechStartSeconds = speechStart / (double) samplingRate;
|
||||||
|
double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
|
||||||
|
result.put("start", roundedSpeechStart);
|
||||||
|
} else {
|
||||||
|
result.put("start", (double) speechStart);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time
|
||||||
|
if (speechProb < endThreshold && triggered) {
|
||||||
|
// Initialize or update the temporary end time
|
||||||
|
if (tempEnd == 0) {
|
||||||
|
tempEnd = currentSample;
|
||||||
|
}
|
||||||
|
// If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null
|
||||||
|
// This indicates that it is not yet possible to determine whether the speech has ended
|
||||||
|
if (currentSample - tempEnd < minSilenceSamples) {
|
||||||
|
return Collections.emptyMap();
|
||||||
|
} else {
|
||||||
|
// Calculate the speech end time, reset the trigger state and temporary end time
|
||||||
|
int speechEnd = (int) (tempEnd + speechPadSamples);
|
||||||
|
tempEnd = 0;
|
||||||
|
triggered = false;
|
||||||
|
Map<String, Double> result = new HashMap<>();
|
||||||
|
|
||||||
|
if (returnSeconds) {
|
||||||
|
double speechEndSeconds = speechEnd / (double) samplingRate;
|
||||||
|
double roundedSpeechEnd = BigDecimal.valueOf(speechEndSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
|
||||||
|
result.put("end", roundedSpeechEnd);
|
||||||
|
} else {
|
||||||
|
result.put("end", (double) speechEnd);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the above conditions are not met, return null by default
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void close() throws OrtException {
|
||||||
|
reset();
|
||||||
|
model.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,180 @@
|
|||||||
|
package org.example;
|
||||||
|
|
||||||
|
import ai.onnxruntime.OnnxTensor;
|
||||||
|
import ai.onnxruntime.OrtEnvironment;
|
||||||
|
import ai.onnxruntime.OrtException;
|
||||||
|
import ai.onnxruntime.OrtSession;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class SlieroVadOnnxModel {
|
||||||
|
// Define private variable OrtSession
|
||||||
|
private final OrtSession session;
|
||||||
|
private float[][][] h;
|
||||||
|
private float[][][] c;
|
||||||
|
// Define the last sample rate
|
||||||
|
private int lastSr = 0;
|
||||||
|
// Define the last batch size
|
||||||
|
private int lastBatchSize = 0;
|
||||||
|
// Define a list of supported sample rates
|
||||||
|
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
public SlieroVadOnnxModel(String modelPath) throws OrtException {
|
||||||
|
// Get the ONNX runtime environment
|
||||||
|
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||||
|
// Create an ONNX session options object
|
||||||
|
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
||||||
|
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
|
||||||
|
opts.setInterOpNumThreads(1);
|
||||||
|
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
|
||||||
|
opts.setIntraOpNumThreads(1);
|
||||||
|
// Add a CPU device, setting to false disables CPU execution optimization
|
||||||
|
opts.addCPU(true);
|
||||||
|
// Create an ONNX session using the environment, model path, and options
|
||||||
|
session = env.createSession(modelPath, opts);
|
||||||
|
// Reset states
|
||||||
|
resetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset states
|
||||||
|
*/
|
||||||
|
void resetStates() {
|
||||||
|
h = new float[2][1][64];
|
||||||
|
c = new float[2][1][64];
|
||||||
|
lastSr = 0;
|
||||||
|
lastBatchSize = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void close() throws OrtException {
|
||||||
|
session.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define inner class ValidationResult
|
||||||
|
*/
|
||||||
|
public static class ValidationResult {
|
||||||
|
public final float[][] x;
|
||||||
|
public final int sr;
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
public ValidationResult(float[][] x, int sr) {
|
||||||
|
this.x = x;
|
||||||
|
this.sr = sr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Function to validate input data
|
||||||
|
*/
|
||||||
|
private ValidationResult validateInput(float[][] x, int sr) {
|
||||||
|
// Process the input data with dimension 1
|
||||||
|
if (x.length == 1) {
|
||||||
|
x = new float[][]{x[0]};
|
||||||
|
}
|
||||||
|
// Throw an exception when the input data dimension is greater than 2
|
||||||
|
if (x.length > 2) {
|
||||||
|
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
|
||||||
|
if (sr != 16000 && (sr % 16000 == 0)) {
|
||||||
|
int step = sr / 16000;
|
||||||
|
float[][] reducedX = new float[x.length][];
|
||||||
|
|
||||||
|
for (int i = 0; i < x.length; i++) {
|
||||||
|
float[] current = x[i];
|
||||||
|
float[] newArr = new float[(current.length + step - 1) / step];
|
||||||
|
|
||||||
|
for (int j = 0, index = 0; j < current.length; j += step, index++) {
|
||||||
|
newArr[index] = current[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
reducedX[i] = newArr;
|
||||||
|
}
|
||||||
|
|
||||||
|
x = reducedX;
|
||||||
|
sr = 16000;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the sample rate is not in the list of supported sample rates, throw an exception
|
||||||
|
if (!SAMPLE_RATES.contains(sr)) {
|
||||||
|
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the input audio block is too short, throw an exception
|
||||||
|
if (((float) sr) / x[0].length > 31.25) {
|
||||||
|
throw new IllegalArgumentException("Input audio is too short");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the validated result
|
||||||
|
return new ValidationResult(x, sr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Method to call the ONNX model
|
||||||
|
*/
|
||||||
|
public float[] call(float[][] x, int sr) throws OrtException {
|
||||||
|
ValidationResult result = validateInput(x, sr);
|
||||||
|
x = result.x;
|
||||||
|
sr = result.sr;
|
||||||
|
|
||||||
|
int batchSize = x.length;
|
||||||
|
|
||||||
|
if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) {
|
||||||
|
resetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||||
|
|
||||||
|
OnnxTensor inputTensor = null;
|
||||||
|
OnnxTensor hTensor = null;
|
||||||
|
OnnxTensor cTensor = null;
|
||||||
|
OnnxTensor srTensor = null;
|
||||||
|
OrtSession.Result ortOutputs = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Create input tensors
|
||||||
|
inputTensor = OnnxTensor.createTensor(env, x);
|
||||||
|
hTensor = OnnxTensor.createTensor(env, h);
|
||||||
|
cTensor = OnnxTensor.createTensor(env, c);
|
||||||
|
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
|
||||||
|
|
||||||
|
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||||
|
inputs.put("input", inputTensor);
|
||||||
|
inputs.put("sr", srTensor);
|
||||||
|
inputs.put("h", hTensor);
|
||||||
|
inputs.put("c", cTensor);
|
||||||
|
|
||||||
|
// Call the ONNX model for calculation
|
||||||
|
ortOutputs = session.run(inputs);
|
||||||
|
// Get the output results
|
||||||
|
float[][] output = (float[][]) ortOutputs.get(0).getValue();
|
||||||
|
h = (float[][][]) ortOutputs.get(1).getValue();
|
||||||
|
c = (float[][][]) ortOutputs.get(2).getValue();
|
||||||
|
|
||||||
|
lastSr = sr;
|
||||||
|
lastBatchSize = batchSize;
|
||||||
|
return output[0];
|
||||||
|
} finally {
|
||||||
|
if (inputTensor != null) {
|
||||||
|
inputTensor.close();
|
||||||
|
}
|
||||||
|
if (hTensor != null) {
|
||||||
|
hTensor.close();
|
||||||
|
}
|
||||||
|
if (cTensor != null) {
|
||||||
|
cTensor.close();
|
||||||
|
}
|
||||||
|
if (srTensor != null) {
|
||||||
|
srTensor.close();
|
||||||
|
}
|
||||||
|
if (ortOutputs != null) {
|
||||||
|
ortOutputs.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -186,7 +186,7 @@ if __name__ == '__main__':
|
|||||||
help="same as trig_sum, but for switching from triggered to non-triggered state (non-speech)")
|
help="same as trig_sum, but for switching from triggered to non-triggered state (non-speech)")
|
||||||
|
|
||||||
parser.add_argument('-N', '--num_steps', type=int, default=8,
|
parser.add_argument('-N', '--num_steps', type=int, default=8,
|
||||||
help="nubmer of overlapping windows to split audio chunk into (we recommend 4 or 8)")
|
help="number of overlapping windows to split audio chunk into (we recommend 4 or 8)")
|
||||||
|
|
||||||
parser.add_argument('-nspw', '--num_samples_per_window', type=int, default=4000,
|
parser.add_argument('-nspw', '--num_samples_per_window', type=int, default=4000,
|
||||||
help="number of samples in each window, our models were trained using 4000 samples (250 ms) per window, so this is preferable value (lesser values reduce quality)")
|
help="number of samples in each window, our models were trained using 4000 samples (250 ms) per window, so this is preferable value (lesser values reduce quality)")
|
||||||
@@ -198,4 +198,4 @@ if __name__ == '__main__':
|
|||||||
help=" minimum silence duration in samples between to separate speech chunks")
|
help=" minimum silence duration in samples between to separate speech chunks")
|
||||||
ARGS = parser.parse_args()
|
ARGS = parser.parse_args()
|
||||||
ARGS.rate=DEFAULT_SAMPLE_RATE
|
ARGS.rate=DEFAULT_SAMPLE_RATE
|
||||||
main(ARGS)
|
main(ARGS)
|
||||||
|
|||||||
149
examples/parallel_example.ipynb
Normal file
149
examples/parallel_example.ipynb
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Install Dependencies"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# !pip install -q torchaudio\n",
|
||||||
|
"SAMPLING_RATE = 16000\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from pprint import pprint\n",
|
||||||
|
"\n",
|
||||||
|
"torch.set_num_threads(1)\n",
|
||||||
|
"NUM_PROCESS=4 # set to the number of CPU cores in the machine\n",
|
||||||
|
"NUM_COPIES=8\n",
|
||||||
|
"# download wav files, make multiple copies\n",
|
||||||
|
"for idx in range(NUM_COPIES):\n",
|
||||||
|
" torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example{idx}.wav\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Load VAD model from torch hub"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||||
|
" model='silero_vad',\n",
|
||||||
|
" force_reload=True,\n",
|
||||||
|
" onnx=False)\n",
|
||||||
|
"\n",
|
||||||
|
"(get_speech_timestamps,\n",
|
||||||
|
"save_audio,\n",
|
||||||
|
"read_audio,\n",
|
||||||
|
"VADIterator,\n",
|
||||||
|
"collect_chunks) = utils"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Define a vad process function"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import multiprocessing\n",
|
||||||
|
"\n",
|
||||||
|
"vad_models = dict()\n",
|
||||||
|
"\n",
|
||||||
|
"def init_model(model):\n",
|
||||||
|
" pid = multiprocessing.current_process().pid\n",
|
||||||
|
" model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||||
|
" model='silero_vad',\n",
|
||||||
|
" force_reload=False,\n",
|
||||||
|
" onnx=False)\n",
|
||||||
|
" vad_models[pid] = model\n",
|
||||||
|
"\n",
|
||||||
|
"def vad_process(audio_file: str):\n",
|
||||||
|
" \n",
|
||||||
|
" pid = multiprocessing.current_process().pid\n",
|
||||||
|
" \n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n",
|
||||||
|
" return get_speech_timestamps(\n",
|
||||||
|
" wav,\n",
|
||||||
|
" vad_models[pid],\n",
|
||||||
|
" 0.46, # speech prob threshold\n",
|
||||||
|
" 16000, # sample rate\n",
|
||||||
|
" 300, # min speech duration in ms\n",
|
||||||
|
" 20, # max speech duration in seconds\n",
|
||||||
|
" 600, # min silence duration\n",
|
||||||
|
" 512, # window size\n",
|
||||||
|
" 200, # spech pad ms\n",
|
||||||
|
" )"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Parallelization"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from concurrent.futures import ProcessPoolExecutor, as_completed\n",
|
||||||
|
"\n",
|
||||||
|
"futures = []\n",
|
||||||
|
"\n",
|
||||||
|
"with ProcessPoolExecutor(max_workers=NUM_PROCESS, initializer=init_model, initargs=(model,)) as ex:\n",
|
||||||
|
" for i in range(NUM_COPIES):\n",
|
||||||
|
" futures.append(ex.submit(vad_process, f\"en_example{idx}.wav\"))\n",
|
||||||
|
"\n",
|
||||||
|
"for finished in as_completed(futures):\n",
|
||||||
|
" pprint(finished.result())"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "diarization",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.15"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
File diff suppressed because one or more lines are too long
2
examples/rust-example/.gitignore
vendored
Normal file
2
examples/rust-example/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
target/
|
||||||
|
recorder.wav
|
||||||
781
examples/rust-example/Cargo.lock
generated
Normal file
781
examples/rust-example/Cargo.lock
generated
Normal file
@@ -0,0 +1,781 @@
|
|||||||
|
# This file is automatically @generated by Cargo.
|
||||||
|
# It is not intended for manual editing.
|
||||||
|
version = 3
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "adler"
|
||||||
|
version = "1.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "autocfg"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "base64"
|
||||||
|
version = "0.22.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bitflags"
|
||||||
|
version = "1.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bitflags"
|
||||||
|
version = "2.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "block-buffer"
|
||||||
|
version = "0.10.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bumpalo"
|
||||||
|
version = "3.16.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cc"
|
||||||
|
version = "1.0.98"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cfg-if"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cpufeatures"
|
||||||
|
version = "0.2.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crc32fast"
|
||||||
|
version = "1.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crunchy"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crypto-common"
|
||||||
|
version = "0.1.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
"typenum",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "digest"
|
||||||
|
version = "0.10.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||||
|
dependencies = [
|
||||||
|
"block-buffer",
|
||||||
|
"crypto-common",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "errno"
|
||||||
|
version = "0.3.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"windows-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "filetime"
|
||||||
|
version = "0.2.23"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"redox_syscall",
|
||||||
|
"windows-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "flate2"
|
||||||
|
version = "1.0.30"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae"
|
||||||
|
dependencies = [
|
||||||
|
"crc32fast",
|
||||||
|
"miniz_oxide",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "form_urlencoded"
|
||||||
|
version = "1.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
|
||||||
|
dependencies = [
|
||||||
|
"percent-encoding",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "generic-array"
|
||||||
|
version = "0.14.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||||
|
dependencies = [
|
||||||
|
"typenum",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "getrandom"
|
||||||
|
version = "0.2.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"wasi",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "half"
|
||||||
|
version = "2.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"crunchy",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hound"
|
||||||
|
version = "3.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "idna"
|
||||||
|
version = "0.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-bidi",
|
||||||
|
"unicode-normalization",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "js-sys"
|
||||||
|
version = "0.3.69"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d"
|
||||||
|
dependencies = [
|
||||||
|
"wasm-bindgen",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libc"
|
||||||
|
version = "0.2.155"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libloading"
|
||||||
|
version = "0.8.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"windows-targets",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "linux-raw-sys"
|
||||||
|
version = "0.4.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "log"
|
||||||
|
version = "0.4.21"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matrixmultiply"
|
||||||
|
version = "0.3.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "miniz_oxide"
|
||||||
|
version = "0.7.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae"
|
||||||
|
dependencies = [
|
||||||
|
"adler",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray"
|
||||||
|
version = "0.15.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
||||||
|
dependencies = [
|
||||||
|
"matrixmultiply",
|
||||||
|
"num-complex",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.4.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-integer"
|
||||||
|
version = "0.1.46"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.2.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "once_cell"
|
||||||
|
version = "1.19.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ort"
|
||||||
|
version = "2.0.0-rc.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14"
|
||||||
|
dependencies = [
|
||||||
|
"half",
|
||||||
|
"js-sys",
|
||||||
|
"libloading",
|
||||||
|
"ndarray",
|
||||||
|
"ort-sys",
|
||||||
|
"thiserror",
|
||||||
|
"tracing",
|
||||||
|
"web-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ort-sys"
|
||||||
|
version = "2.0.0-rc.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe"
|
||||||
|
dependencies = [
|
||||||
|
"flate2",
|
||||||
|
"sha2",
|
||||||
|
"tar",
|
||||||
|
"ureq",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "percent-encoding"
|
||||||
|
version = "2.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pin-project-lite"
|
||||||
|
version = "0.2.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro2"
|
||||||
|
version = "1.0.84"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quote"
|
||||||
|
version = "1.0.36"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rawpointer"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "redox_syscall"
|
||||||
|
version = "0.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 1.3.2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ring"
|
||||||
|
version = "0.17.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"cfg-if",
|
||||||
|
"getrandom",
|
||||||
|
"libc",
|
||||||
|
"spin",
|
||||||
|
"untrusted",
|
||||||
|
"windows-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rust-example"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"hound",
|
||||||
|
"ndarray",
|
||||||
|
"ort",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustix"
|
||||||
|
version = "0.38.34"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.5.0",
|
||||||
|
"errno",
|
||||||
|
"libc",
|
||||||
|
"linux-raw-sys",
|
||||||
|
"windows-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustls"
|
||||||
|
version = "0.22.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
|
||||||
|
dependencies = [
|
||||||
|
"log",
|
||||||
|
"ring",
|
||||||
|
"rustls-pki-types",
|
||||||
|
"rustls-webpki",
|
||||||
|
"subtle",
|
||||||
|
"zeroize",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustls-pki-types"
|
||||||
|
version = "1.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustls-webpki"
|
||||||
|
version = "0.102.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e"
|
||||||
|
dependencies = [
|
||||||
|
"ring",
|
||||||
|
"rustls-pki-types",
|
||||||
|
"untrusted",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sha2"
|
||||||
|
version = "0.10.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"cpufeatures",
|
||||||
|
"digest",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "spin"
|
||||||
|
version = "0.9.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "subtle"
|
||||||
|
version = "2.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "2.0.66"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tar"
|
||||||
|
version = "0.4.40"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb"
|
||||||
|
dependencies = [
|
||||||
|
"filetime",
|
||||||
|
"libc",
|
||||||
|
"xattr",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror"
|
||||||
|
version = "1.0.61"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
|
||||||
|
dependencies = [
|
||||||
|
"thiserror-impl",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror-impl"
|
||||||
|
version = "1.0.61"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tinyvec"
|
||||||
|
version = "1.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50"
|
||||||
|
dependencies = [
|
||||||
|
"tinyvec_macros",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tinyvec_macros"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tracing"
|
||||||
|
version = "0.1.40"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
||||||
|
dependencies = [
|
||||||
|
"pin-project-lite",
|
||||||
|
"tracing-attributes",
|
||||||
|
"tracing-core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tracing-attributes"
|
||||||
|
version = "0.1.27"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tracing-core"
|
||||||
|
version = "0.1.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
|
||||||
|
dependencies = [
|
||||||
|
"once_cell",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typenum"
|
||||||
|
version = "1.17.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-bidi"
|
||||||
|
version = "0.3.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-ident"
|
||||||
|
version = "1.0.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-normalization"
|
||||||
|
version = "0.1.23"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5"
|
||||||
|
dependencies = [
|
||||||
|
"tinyvec",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "untrusted"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ureq"
|
||||||
|
version = "2.9.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd"
|
||||||
|
dependencies = [
|
||||||
|
"base64",
|
||||||
|
"log",
|
||||||
|
"once_cell",
|
||||||
|
"rustls",
|
||||||
|
"rustls-pki-types",
|
||||||
|
"rustls-webpki",
|
||||||
|
"url",
|
||||||
|
"webpki-roots",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "url"
|
||||||
|
version = "2.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633"
|
||||||
|
dependencies = [
|
||||||
|
"form_urlencoded",
|
||||||
|
"idna",
|
||||||
|
"percent-encoding",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "version_check"
|
||||||
|
version = "0.9.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasi"
|
||||||
|
version = "0.11.0+wasi-snapshot-preview1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-bindgen"
|
||||||
|
version = "0.2.92"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"wasm-bindgen-macro",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-bindgen-backend"
|
||||||
|
version = "0.2.92"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da"
|
||||||
|
dependencies = [
|
||||||
|
"bumpalo",
|
||||||
|
"log",
|
||||||
|
"once_cell",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
"wasm-bindgen-shared",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-bindgen-macro"
|
||||||
|
version = "0.2.92"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726"
|
||||||
|
dependencies = [
|
||||||
|
"quote",
|
||||||
|
"wasm-bindgen-macro-support",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-bindgen-macro-support"
|
||||||
|
version = "0.2.92"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
"wasm-bindgen-backend",
|
||||||
|
"wasm-bindgen-shared",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-bindgen-shared"
|
||||||
|
version = "0.2.92"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "web-sys"
|
||||||
|
version = "0.3.69"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef"
|
||||||
|
dependencies = [
|
||||||
|
"js-sys",
|
||||||
|
"wasm-bindgen",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "webpki-roots"
|
||||||
|
version = "0.26.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009"
|
||||||
|
dependencies = [
|
||||||
|
"rustls-pki-types",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-sys"
|
||||||
|
version = "0.52.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
||||||
|
dependencies = [
|
||||||
|
"windows-targets",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-targets"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
|
||||||
|
dependencies = [
|
||||||
|
"windows_aarch64_gnullvm",
|
||||||
|
"windows_aarch64_msvc",
|
||||||
|
"windows_i686_gnu",
|
||||||
|
"windows_i686_gnullvm",
|
||||||
|
"windows_i686_msvc",
|
||||||
|
"windows_x86_64_gnu",
|
||||||
|
"windows_x86_64_gnullvm",
|
||||||
|
"windows_x86_64_msvc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_aarch64_gnullvm"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_aarch64_msvc"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_gnu"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_gnullvm"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_msvc"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_gnu"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_gnullvm"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_msvc"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "xattr"
|
||||||
|
version = "1.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"linux-raw-sys",
|
||||||
|
"rustix",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zeroize"
|
||||||
|
version = "1.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
|
||||||
9
examples/rust-example/Cargo.toml
Normal file
9
examples/rust-example/Cargo.toml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
[package]
|
||||||
|
name = "rust-example"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
ort = { version = "2.0.0-rc.2", features = ["load-dynamic", "ndarray"] }
|
||||||
|
ndarray = "0.15"
|
||||||
|
hound = "3"
|
||||||
19
examples/rust-example/README.md
Normal file
19
examples/rust-example/README.md
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# Stream example in Rust
|
||||||
|
Made after [C++ stream example](https://github.com/snakers4/silero-vad/tree/master/examples/cpp)
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
- To build Rust crate `ort` you need `cc` installed.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
Just
|
||||||
|
```
|
||||||
|
cargo run
|
||||||
|
```
|
||||||
|
If you run example outside of this repo adjust environment variable
|
||||||
|
```
|
||||||
|
SILERO_MODEL_PATH=/path/to/silero_vad.onnx cargo run
|
||||||
|
```
|
||||||
|
If you need to test against other wav file, not `recorder.wav`, specify it as the first argument
|
||||||
|
```
|
||||||
|
cargo run -- /path/to/audio/file.wav
|
||||||
|
```
|
||||||
36
examples/rust-example/src/main.rs
Normal file
36
examples/rust-example/src/main.rs
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
mod silero;
|
||||||
|
mod utils;
|
||||||
|
mod vad_iter;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let model_path = std::env::var("SILERO_MODEL_PATH")
|
||||||
|
.unwrap_or_else(|_| String::from("../../files/silero_vad.onnx"));
|
||||||
|
let audio_path = std::env::args()
|
||||||
|
.nth(1)
|
||||||
|
.unwrap_or_else(|| String::from("recorder.wav"));
|
||||||
|
let mut wav_reader = hound::WavReader::open(audio_path).unwrap();
|
||||||
|
let sample_rate = match wav_reader.spec().sample_rate {
|
||||||
|
8000 => utils::SampleRate::EightkHz,
|
||||||
|
16000 => utils::SampleRate::SixteenkHz,
|
||||||
|
_ => panic!("Unsupported sample rate. Expect 8 kHz or 16 kHz."),
|
||||||
|
};
|
||||||
|
if wav_reader.spec().sample_format != hound::SampleFormat::Int {
|
||||||
|
panic!("Unsupported sample format. Expect Int.");
|
||||||
|
}
|
||||||
|
let content = wav_reader
|
||||||
|
.samples()
|
||||||
|
.filter_map(|x| x.ok())
|
||||||
|
.collect::<Vec<i16>>();
|
||||||
|
assert!(!content.is_empty());
|
||||||
|
let silero = silero::Silero::new(sample_rate, model_path).unwrap();
|
||||||
|
let vad_params = utils::VadParams {
|
||||||
|
sample_rate: sample_rate.into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut vad_iterator = vad_iter::VadIter::new(silero, vad_params);
|
||||||
|
vad_iterator.process(&content).unwrap();
|
||||||
|
for timestamp in vad_iterator.speeches() {
|
||||||
|
println!("{}", timestamp);
|
||||||
|
}
|
||||||
|
println!("Finished.");
|
||||||
|
}
|
||||||
59
examples/rust-example/src/silero.rs
Normal file
59
examples/rust-example/src/silero.rs
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
use crate::utils;
|
||||||
|
use ndarray::{Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Silero {
|
||||||
|
session: ort::Session,
|
||||||
|
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
||||||
|
h: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
||||||
|
c: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Silero {
|
||||||
|
pub fn new(
|
||||||
|
sample_rate: utils::SampleRate,
|
||||||
|
model_path: impl AsRef<Path>,
|
||||||
|
) -> Result<Self, ort::Error> {
|
||||||
|
let session = ort::Session::builder()?.commit_from_file(model_path)?;
|
||||||
|
let h = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
||||||
|
let c = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
||||||
|
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
|
||||||
|
Ok(Self {
|
||||||
|
session,
|
||||||
|
sample_rate,
|
||||||
|
h,
|
||||||
|
c,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.h = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
||||||
|
self.c = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
|
||||||
|
let data = audio_frame
|
||||||
|
.iter()
|
||||||
|
.map(|x| (*x as f32) / (i16::MAX as f32))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
|
||||||
|
let inps = ort::inputs![
|
||||||
|
frame,
|
||||||
|
self.sample_rate.clone(),
|
||||||
|
std::mem::take(&mut self.h),
|
||||||
|
std::mem::take(&mut self.c)
|
||||||
|
]?;
|
||||||
|
let res = self
|
||||||
|
.session
|
||||||
|
.run(ort::SessionInputs::ValueSlice::<4>(&inps))?;
|
||||||
|
self.h = res["hn"].try_extract_tensor().unwrap().to_owned();
|
||||||
|
self.c = res["cn"].try_extract_tensor().unwrap().to_owned();
|
||||||
|
Ok(*res["output"]
|
||||||
|
.try_extract_raw_tensor::<f32>()
|
||||||
|
.unwrap()
|
||||||
|
.1
|
||||||
|
.first()
|
||||||
|
.unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
60
examples/rust-example/src/utils.rs
Normal file
60
examples/rust-example/src/utils.rs
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum SampleRate {
|
||||||
|
EightkHz,
|
||||||
|
SixteenkHz,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SampleRate> for i64 {
|
||||||
|
fn from(value: SampleRate) -> Self {
|
||||||
|
match value {
|
||||||
|
SampleRate::EightkHz => 8000,
|
||||||
|
SampleRate::SixteenkHz => 16000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SampleRate> for usize {
|
||||||
|
fn from(value: SampleRate) -> Self {
|
||||||
|
match value {
|
||||||
|
SampleRate::EightkHz => 8000,
|
||||||
|
SampleRate::SixteenkHz => 16000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct VadParams {
|
||||||
|
pub frame_size: usize,
|
||||||
|
pub threshold: f32,
|
||||||
|
pub min_silence_duration_ms: usize,
|
||||||
|
pub speech_pad_ms: usize,
|
||||||
|
pub min_speech_duration_ms: usize,
|
||||||
|
pub max_speech_duration_s: f32,
|
||||||
|
pub sample_rate: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for VadParams {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
frame_size: 64,
|
||||||
|
threshold: 0.5,
|
||||||
|
min_silence_duration_ms: 0,
|
||||||
|
speech_pad_ms: 64,
|
||||||
|
min_speech_duration_ms: 64,
|
||||||
|
max_speech_duration_s: f32::INFINITY,
|
||||||
|
sample_rate: 16000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct TimeStamp {
|
||||||
|
pub start: i64,
|
||||||
|
pub end: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for TimeStamp {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "[start:{:08}, end:{:08}]", self.start, self.end)
|
||||||
|
}
|
||||||
|
}
|
||||||
223
examples/rust-example/src/vad_iter.rs
Normal file
223
examples/rust-example/src/vad_iter.rs
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
use crate::{silero, utils};
|
||||||
|
|
||||||
|
const DEBUG_SPEECH_PROB: bool = true;
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct VadIter {
|
||||||
|
silero: silero::Silero,
|
||||||
|
params: Params,
|
||||||
|
state: State,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VadIter {
|
||||||
|
pub fn new(silero: silero::Silero, params: utils::VadParams) -> Self {
|
||||||
|
Self {
|
||||||
|
silero,
|
||||||
|
params: Params::from(params),
|
||||||
|
state: State::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> {
|
||||||
|
self.reset_states();
|
||||||
|
for audio_frame in samples.chunks_exact(self.params.frame_size_samples) {
|
||||||
|
let speech_prob = self.silero.calc_level(audio_frame)?;
|
||||||
|
self.state.update(&self.params, speech_prob);
|
||||||
|
}
|
||||||
|
self.state.check_for_last_speech(samples.len());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn speeches(&self) -> &[utils::TimeStamp] {
|
||||||
|
&self.state.speeches
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VadIter {
|
||||||
|
fn reset_states(&mut self) {
|
||||||
|
self.silero.reset();
|
||||||
|
self.state = State::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Params {
|
||||||
|
frame_size: usize,
|
||||||
|
threshold: f32,
|
||||||
|
min_silence_duration_ms: usize,
|
||||||
|
speech_pad_ms: usize,
|
||||||
|
min_speech_duration_ms: usize,
|
||||||
|
max_speech_duration_s: f32,
|
||||||
|
sample_rate: usize,
|
||||||
|
sr_per_ms: usize,
|
||||||
|
frame_size_samples: usize,
|
||||||
|
min_speech_samples: usize,
|
||||||
|
speech_pad_samples: usize,
|
||||||
|
max_speech_samples: f32,
|
||||||
|
min_silence_samples: usize,
|
||||||
|
min_silence_samples_at_max_speech: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<utils::VadParams> for Params {
|
||||||
|
fn from(value: utils::VadParams) -> Self {
|
||||||
|
let frame_size = value.frame_size;
|
||||||
|
let threshold = value.threshold;
|
||||||
|
let min_silence_duration_ms = value.min_silence_duration_ms;
|
||||||
|
let speech_pad_ms = value.speech_pad_ms;
|
||||||
|
let min_speech_duration_ms = value.min_speech_duration_ms;
|
||||||
|
let max_speech_duration_s = value.max_speech_duration_s;
|
||||||
|
let sample_rate = value.sample_rate;
|
||||||
|
let sr_per_ms = sample_rate / 1000;
|
||||||
|
let frame_size_samples = frame_size * sr_per_ms;
|
||||||
|
let min_speech_samples = sr_per_ms * min_speech_duration_ms;
|
||||||
|
let speech_pad_samples = sr_per_ms * speech_pad_ms;
|
||||||
|
let max_speech_samples = sample_rate as f32 * max_speech_duration_s
|
||||||
|
- frame_size_samples as f32
|
||||||
|
- 2.0 * speech_pad_samples as f32;
|
||||||
|
let min_silence_samples = sr_per_ms * min_silence_duration_ms;
|
||||||
|
let min_silence_samples_at_max_speech = sr_per_ms * 98;
|
||||||
|
Self {
|
||||||
|
frame_size,
|
||||||
|
threshold,
|
||||||
|
min_silence_duration_ms,
|
||||||
|
speech_pad_ms,
|
||||||
|
min_speech_duration_ms,
|
||||||
|
max_speech_duration_s,
|
||||||
|
sample_rate,
|
||||||
|
sr_per_ms,
|
||||||
|
frame_size_samples,
|
||||||
|
min_speech_samples,
|
||||||
|
speech_pad_samples,
|
||||||
|
max_speech_samples,
|
||||||
|
min_silence_samples,
|
||||||
|
min_silence_samples_at_max_speech,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct State {
|
||||||
|
current_sample: usize,
|
||||||
|
temp_end: usize,
|
||||||
|
next_start: usize,
|
||||||
|
prev_end: usize,
|
||||||
|
triggered: bool,
|
||||||
|
current_speech: utils::TimeStamp,
|
||||||
|
speeches: Vec<utils::TimeStamp>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State {
|
||||||
|
fn new() -> Self {
|
||||||
|
Default::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&mut self, params: &Params, speech_prob: f32) {
|
||||||
|
self.current_sample += params.frame_size_samples;
|
||||||
|
if speech_prob > params.threshold {
|
||||||
|
if self.temp_end != 0 {
|
||||||
|
self.temp_end = 0;
|
||||||
|
if self.next_start < self.prev_end {
|
||||||
|
self.next_start = self
|
||||||
|
.current_sample
|
||||||
|
.saturating_sub(params.frame_size_samples)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !self.triggered {
|
||||||
|
self.debug(speech_prob, params, "start");
|
||||||
|
self.triggered = true;
|
||||||
|
self.current_speech.start =
|
||||||
|
self.current_sample as i64 - params.frame_size_samples as i64;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if self.triggered
|
||||||
|
&& (self.current_sample as i64 - self.current_speech.start) as f32
|
||||||
|
> params.max_speech_samples
|
||||||
|
{
|
||||||
|
if self.prev_end > 0 {
|
||||||
|
self.current_speech.end = self.prev_end as _;
|
||||||
|
self.take_speech();
|
||||||
|
if self.next_start < self.prev_end {
|
||||||
|
self.triggered = false
|
||||||
|
} else {
|
||||||
|
self.current_speech.start = self.next_start as _;
|
||||||
|
}
|
||||||
|
self.prev_end = 0;
|
||||||
|
self.next_start = 0;
|
||||||
|
self.temp_end = 0;
|
||||||
|
} else {
|
||||||
|
self.current_speech.end = self.current_sample as _;
|
||||||
|
self.take_speech();
|
||||||
|
self.prev_end = 0;
|
||||||
|
self.next_start = 0;
|
||||||
|
self.temp_end = 0;
|
||||||
|
self.triggered = false;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if speech_prob >= (params.threshold - 0.15) && (speech_prob < params.threshold) {
|
||||||
|
if self.triggered {
|
||||||
|
self.debug(speech_prob, params, "speaking")
|
||||||
|
} else {
|
||||||
|
self.debug(speech_prob, params, "silence")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self.triggered && speech_prob < (params.threshold - 0.15) {
|
||||||
|
self.debug(speech_prob, params, "end");
|
||||||
|
if self.temp_end == 0 {
|
||||||
|
self.temp_end = self.current_sample;
|
||||||
|
}
|
||||||
|
if self.current_sample.saturating_sub(self.temp_end)
|
||||||
|
> params.min_silence_samples_at_max_speech
|
||||||
|
{
|
||||||
|
self.prev_end = self.temp_end;
|
||||||
|
}
|
||||||
|
if self.current_sample.saturating_sub(self.temp_end) >= params.min_silence_samples {
|
||||||
|
self.current_speech.end = self.temp_end as _;
|
||||||
|
if self.current_speech.end - self.current_speech.start
|
||||||
|
> params.min_speech_samples as _
|
||||||
|
{
|
||||||
|
self.take_speech();
|
||||||
|
self.prev_end = 0;
|
||||||
|
self.next_start = 0;
|
||||||
|
self.temp_end = 0;
|
||||||
|
self.triggered = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn take_speech(&mut self) {
|
||||||
|
self.speeches.push(std::mem::take(&mut self.current_speech)); // current speech becomes TimeStamp::default() due to take()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_for_last_speech(&mut self, last_sample: usize) {
|
||||||
|
if self.current_speech.start > 0 {
|
||||||
|
self.current_speech.end = last_sample as _;
|
||||||
|
self.take_speech();
|
||||||
|
self.prev_end = 0;
|
||||||
|
self.next_start = 0;
|
||||||
|
self.temp_end = 0;
|
||||||
|
self.triggered = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn debug(&self, speech_prob: f32, params: &Params, title: &str) {
|
||||||
|
if DEBUG_SPEECH_PROB {
|
||||||
|
let speech = self.current_sample as f32
|
||||||
|
- params.frame_size_samples as f32
|
||||||
|
- if title == "end" {
|
||||||
|
params.speech_pad_samples
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
} as f32; // minus window_size_samples to get precise start time point.
|
||||||
|
println!(
|
||||||
|
"[{:10}: {:.3} s ({:.3}) {:8}]",
|
||||||
|
title,
|
||||||
|
speech / params.sample_rate as f32,
|
||||||
|
speech_prob,
|
||||||
|
self.current_sample - params.frame_size_samples,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
BIN
files/de.wav
BIN
files/de.wav
Binary file not shown.
BIN
files/en.wav
BIN
files/en.wav
Binary file not shown.
BIN
files/en_num.wav
BIN
files/en_num.wav
Binary file not shown.
BIN
files/es.wav
BIN
files/es.wav
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
files/model.jit
BIN
files/model.jit
Binary file not shown.
BIN
files/model.onnx
BIN
files/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
files/ru.wav
BIN
files/ru.wav
Binary file not shown.
BIN
files/ru_num.wav
BIN
files/ru_num.wav
Binary file not shown.
BIN
files/silero_vad.jit
Normal file
BIN
files/silero_vad.jit
Normal file
Binary file not shown.
BIN
files/silero_vad.onnx
Normal file
BIN
files/silero_vad.onnx
Normal file
Binary file not shown.
168
hubconf.py
168
hubconf.py
@@ -1,154 +1,50 @@
|
|||||||
dependencies = ['torch', 'torchaudio']
|
dependencies = ['torch', 'torchaudio']
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from utils_vad import (init_jit_model,
|
from utils_vad import (init_jit_model,
|
||||||
get_speech_ts,
|
get_speech_timestamps,
|
||||||
get_speech_ts_adaptive,
|
|
||||||
get_number_ts,
|
|
||||||
get_language,
|
|
||||||
get_language_and_group,
|
|
||||||
save_audio,
|
save_audio,
|
||||||
read_audio,
|
read_audio,
|
||||||
state_generator,
|
VADIterator,
|
||||||
single_audio_stream,
|
|
||||||
collect_chunks,
|
collect_chunks,
|
||||||
drop_chunks)
|
drop_chunks,
|
||||||
|
Validator,
|
||||||
|
OnnxWrapper)
|
||||||
|
|
||||||
|
|
||||||
def silero_vad(**kwargs):
|
def versiontuple(v):
|
||||||
|
splitted = v.split('+')[0].split(".")
|
||||||
|
version_list = []
|
||||||
|
for i in splitted:
|
||||||
|
try:
|
||||||
|
version_list.append(int(i))
|
||||||
|
except:
|
||||||
|
version_list.append(0)
|
||||||
|
return tuple(version_list)
|
||||||
|
|
||||||
|
|
||||||
|
def silero_vad(onnx=False, force_onnx_cpu=False):
|
||||||
"""Silero Voice Activity Detector
|
"""Silero Voice Activity Detector
|
||||||
Returns a model with a set of utils
|
Returns a model with a set of utils
|
||||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
Please see https://github.com/snakers4/silero-vad for usage examples
|
||||||
"""
|
"""
|
||||||
hub_dir = torch.hub.get_dir()
|
|
||||||
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/model.jit')
|
if not onnx:
|
||||||
utils = (get_speech_ts,
|
installed_version = torch.__version__
|
||||||
get_speech_ts_adaptive,
|
supported_version = '1.12.0'
|
||||||
|
if versiontuple(installed_version) < versiontuple(supported_version):
|
||||||
|
raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)')
|
||||||
|
|
||||||
|
model_dir = os.path.join(os.path.dirname(__file__), 'files')
|
||||||
|
if onnx:
|
||||||
|
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
|
||||||
|
else:
|
||||||
|
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
|
||||||
|
utils = (get_speech_timestamps,
|
||||||
save_audio,
|
save_audio,
|
||||||
read_audio,
|
read_audio,
|
||||||
state_generator,
|
VADIterator,
|
||||||
single_audio_stream,
|
|
||||||
collect_chunks)
|
collect_chunks)
|
||||||
|
|
||||||
return model, utils
|
return model, utils
|
||||||
|
|
||||||
|
|
||||||
def silero_vad_micro(**kwargs):
|
|
||||||
"""Silero Voice Activity Detector
|
|
||||||
Returns a model with a set of utils
|
|
||||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
|
||||||
"""
|
|
||||||
hub_dir = torch.hub.get_dir()
|
|
||||||
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/model_micro.jit')
|
|
||||||
utils = (get_speech_ts,
|
|
||||||
get_speech_ts_adaptive,
|
|
||||||
save_audio,
|
|
||||||
read_audio,
|
|
||||||
state_generator,
|
|
||||||
single_audio_stream,
|
|
||||||
collect_chunks)
|
|
||||||
|
|
||||||
return model, utils
|
|
||||||
|
|
||||||
|
|
||||||
def silero_vad_micro_8k(**kwargs):
|
|
||||||
"""Silero Voice Activity Detector
|
|
||||||
Returns a model with a set of utils
|
|
||||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
|
||||||
"""
|
|
||||||
hub_dir = torch.hub.get_dir()
|
|
||||||
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/model_micro_8k.jit')
|
|
||||||
utils = (get_speech_ts,
|
|
||||||
get_speech_ts_adaptive,
|
|
||||||
save_audio,
|
|
||||||
read_audio,
|
|
||||||
state_generator,
|
|
||||||
single_audio_stream,
|
|
||||||
collect_chunks)
|
|
||||||
|
|
||||||
return model, utils
|
|
||||||
|
|
||||||
|
|
||||||
def silero_vad_mini(**kwargs):
|
|
||||||
"""Silero Voice Activity Detector
|
|
||||||
Returns a model with a set of utils
|
|
||||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
|
||||||
"""
|
|
||||||
hub_dir = torch.hub.get_dir()
|
|
||||||
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/model_mini.jit')
|
|
||||||
utils = (get_speech_ts,
|
|
||||||
get_speech_ts_adaptive,
|
|
||||||
save_audio,
|
|
||||||
read_audio,
|
|
||||||
state_generator,
|
|
||||||
single_audio_stream,
|
|
||||||
collect_chunks)
|
|
||||||
|
|
||||||
return model, utils
|
|
||||||
|
|
||||||
|
|
||||||
def silero_vad_mini_8k(**kwargs):
|
|
||||||
"""Silero Voice Activity Detector
|
|
||||||
Returns a model with a set of utils
|
|
||||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
|
||||||
"""
|
|
||||||
hub_dir = torch.hub.get_dir()
|
|
||||||
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/model_mini_8k.jit')
|
|
||||||
utils = (get_speech_ts,
|
|
||||||
get_speech_ts_adaptive,
|
|
||||||
save_audio,
|
|
||||||
read_audio,
|
|
||||||
state_generator,
|
|
||||||
single_audio_stream,
|
|
||||||
collect_chunks)
|
|
||||||
|
|
||||||
return model, utils
|
|
||||||
|
|
||||||
|
|
||||||
def silero_number_detector(**kwargs):
|
|
||||||
"""Silero Number Detector
|
|
||||||
Returns a model with a set of utils
|
|
||||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
|
||||||
"""
|
|
||||||
hub_dir = torch.hub.get_dir()
|
|
||||||
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/number_detector.jit')
|
|
||||||
utils = (get_number_ts,
|
|
||||||
save_audio,
|
|
||||||
read_audio,
|
|
||||||
collect_chunks,
|
|
||||||
drop_chunks)
|
|
||||||
|
|
||||||
return model, utils
|
|
||||||
|
|
||||||
|
|
||||||
def silero_lang_detector(**kwargs):
|
|
||||||
"""Silero Language Classifier
|
|
||||||
Returns a model with a set of utils
|
|
||||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
|
||||||
"""
|
|
||||||
hub_dir = torch.hub.get_dir()
|
|
||||||
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/number_detector.jit')
|
|
||||||
utils = (get_language,
|
|
||||||
read_audio)
|
|
||||||
|
|
||||||
return model, utils
|
|
||||||
|
|
||||||
|
|
||||||
def silero_lang_detector_95(**kwargs):
|
|
||||||
"""Silero Language Classifier (95 languages)
|
|
||||||
Returns a model with a set of utils
|
|
||||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
|
||||||
"""
|
|
||||||
|
|
||||||
hub_dir = torch.hub.get_dir()
|
|
||||||
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/lang_classifier_95.jit')
|
|
||||||
|
|
||||||
with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_dict_95.json', 'r') as f:
|
|
||||||
lang_dict = json.load(f)
|
|
||||||
|
|
||||||
with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_group_dict_95.json', 'r') as f:
|
|
||||||
lang_group_dict = json.load(f)
|
|
||||||
|
|
||||||
utils = (get_language_and_group, read_audio)
|
|
||||||
|
|
||||||
return model, lang_dict, lang_group_dict, utils
|
|
||||||
895
silero-vad.ipynb
895
silero-vad.ipynb
@@ -1,23 +1,5 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "sVNOuHQQjsrp"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"# PyTorch Examples"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "FpMplOCA2Fwp"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## VAD"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@@ -25,7 +7,7 @@
|
|||||||
"id": "62A6F_072Fwq"
|
"id": "62A6F_072Fwq"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"### Install Dependencies"
|
"## Install Dependencies"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -40,28 +22,41 @@
|
|||||||
"#@title Install and Import Dependencies\n",
|
"#@title Install and Import Dependencies\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# this assumes that you have a relevant version of PyTorch installed\n",
|
"# this assumes that you have a relevant version of PyTorch installed\n",
|
||||||
"!pip install -q torchaudio soundfile\n",
|
"!pip install -q torchaudio\n",
|
||||||
|
"\n",
|
||||||
|
"SAMPLING_RATE = 16000\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import glob\n",
|
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"torch.set_num_threads(1)\n",
|
"torch.set_num_threads(1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from IPython.display import Audio\n",
|
"from IPython.display import Audio\n",
|
||||||
"from pprint import pprint\n",
|
"from pprint import pprint\n",
|
||||||
|
"# download example\n",
|
||||||
|
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "pSifus5IilRp"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"USE_ONNX = False # change this to True if you want to test onnx model\n",
|
||||||
|
"if USE_ONNX:\n",
|
||||||
|
" !pip install -q onnxruntime\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||||
" model='silero_vad',\n",
|
" model='silero_vad',\n",
|
||||||
" force_reload=True)\n",
|
" force_reload=True,\n",
|
||||||
|
" onnx=USE_ONNX)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"(get_speech_ts,\n",
|
"(get_speech_timestamps,\n",
|
||||||
" get_speech_ts_adaptive,\n",
|
|
||||||
" save_audio,\n",
|
" save_audio,\n",
|
||||||
" read_audio,\n",
|
" read_audio,\n",
|
||||||
" state_generator,\n",
|
" VADIterator,\n",
|
||||||
" single_audio_stream,\n",
|
" collect_chunks) = utils"
|
||||||
" collect_chunks) = utils\n",
|
|
||||||
"\n",
|
|
||||||
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -70,16 +65,7 @@
|
|||||||
"id": "fXbbaUO3jsrw"
|
"id": "fXbbaUO3jsrw"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"### Full Audio"
|
"## Speech timestapms from full audio"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "dY2Us3_Q2Fws"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"**Classic way of getting speech chunks, you may need to select the thresholds yourself**"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -90,10 +76,9 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"wav = read_audio(f'{files_dir}/en.wav')\n",
|
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||||
"# get speech timestamps from full audio file\n",
|
"# get speech timestamps from full audio file\n",
|
||||||
"speech_timestamps = get_speech_ts(wav, model,\n",
|
"speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLING_RATE)\n",
|
||||||
" num_steps=4)\n",
|
|
||||||
"pprint(speech_timestamps)"
|
"pprint(speech_timestamps)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -107,45 +92,31 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# merge all speech chunks to one audio\n",
|
"# merge all speech chunks to one audio\n",
|
||||||
"save_audio('only_speech.wav',\n",
|
"save_audio('only_speech.wav',\n",
|
||||||
" collect_chunks(speech_timestamps, wav), 16000) \n",
|
" collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE)\n",
|
||||||
"Audio('only_speech.wav')"
|
"Audio('only_speech.wav')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "n8plzbJU2Fws"
|
"id": "zeO1xCqxUC6w"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"**Experimental Adaptive method, algorithm selects thresholds itself (see readme for more information)**"
|
"## Entire audio inference"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "SQOtu2Vl2Fwt"
|
"id": "LjZBcsaTT7Mk"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"wav = read_audio(f'{files_dir}/en.wav')\n",
|
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||||
"# get speech timestamps from full audio file\n",
|
"# audio is being splitted into 31.25 ms long pieces\n",
|
||||||
"speech_timestamps = get_speech_ts_adaptive(wav, model, step=500, num_samples_per_window=4000)\n",
|
"# so output length equals ceil(input_length * 31.25 / SAMPLING_RATE)\n",
|
||||||
"pprint(speech_timestamps)"
|
"predicts = model.audio_forward(wav, sr=SAMPLING_RATE)"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "Lr6zCGXh2Fwt"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# merge all speech chunks to one audio\n",
|
|
||||||
"save_audio('only_speech.wav',\n",
|
|
||||||
" collect_chunks(speech_timestamps, wav), 16000) \n",
|
|
||||||
"Audio('only_speech.wav')"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -154,16 +125,7 @@
|
|||||||
"id": "iDKQbVr8jsry"
|
"id": "iDKQbVr8jsry"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"### Single Audio Stream"
|
"## Stream imitation example"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "xCM-HrUR2Fwu"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"**Classic way of getting speech chunks, you may need to select the thresholds yourself**"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -174,20 +136,20 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"wav = f'{files_dir}/en.wav'\n",
|
"## using VADIterator class\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for batch in single_audio_stream(model, wav):\n",
|
"vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE)\n",
|
||||||
" if batch:\n",
|
"wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||||
" print(batch)"
|
"\n",
|
||||||
]
|
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
|
||||||
},
|
"for i in range(0, len(wav), window_size_samples):\n",
|
||||||
{
|
" chunk = wav[i: i+ window_size_samples]\n",
|
||||||
"cell_type": "markdown",
|
" if len(chunk) < window_size_samples:\n",
|
||||||
"metadata": {
|
" break\n",
|
||||||
"id": "t8TXtnvk2Fwv"
|
" speech_dict = vad_iterator(chunk, return_seconds=True)\n",
|
||||||
},
|
" if speech_dict:\n",
|
||||||
"source": [
|
" print(speech_dict, end=' ')\n",
|
||||||
"**Experimental Adaptive method, algorithm selects thresholds itself (see readme for more information)**"
|
"vad_iterator.reset_states() # reset model states after each audio"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -198,755 +160,20 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"wav = f'{files_dir}/en.wav'\n",
|
"## just probabilities\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for batch in single_audio_stream(model, wav, iterator_type='adaptive'):\n",
|
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||||
" if batch:\n",
|
"speech_probs = []\n",
|
||||||
" print(batch)"
|
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
|
||||||
]
|
"for i in range(0, len(wav), window_size_samples):\n",
|
||||||
},
|
" chunk = wav[i: i+ window_size_samples]\n",
|
||||||
{
|
" if len(chunk) < window_size_samples:\n",
|
||||||
"cell_type": "markdown",
|
" break\n",
|
||||||
"metadata": {
|
" speech_prob = model(chunk, SAMPLING_RATE).item()\n",
|
||||||
"heading_collapsed": true,
|
" speech_probs.append(speech_prob)\n",
|
||||||
"id": "KBDVybJCjsrz"
|
"vad_iterator.reset_states() # reset model states after each audio\n",
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Multiple Audio Streams"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "BK4tGfWgjsrz"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"audios_for_stream = glob.glob(f'{files_dir}/*.wav')\n",
|
|
||||||
"len(audios_for_stream) # total 4 audios"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "v1l8sam1jsrz"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"for batch in state_generator(model, audios_for_stream, audios_in_stream=2): # 2 audio stream\n",
|
|
||||||
" if batch:\n",
|
|
||||||
" pprint(batch)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"id": "36jY0niD2Fww"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Number detector"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"hidden": true,
|
|
||||||
"id": "scd1DlS42Fwx"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Install Dependencies"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "Kq5gQuYq2Fwx"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"#@title Install and Import Dependencies\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# this assumes that you have a relevant version of PyTorch installed\n",
|
"print(speech_probs[:10]) # first 10 chunks predicts"
|
||||||
"!pip install -q torchaudio soundfile\n",
|
|
||||||
"\n",
|
|
||||||
"import glob\n",
|
|
||||||
"import torch\n",
|
|
||||||
"torch.set_num_threads(1)\n",
|
|
||||||
"\n",
|
|
||||||
"from IPython.display import Audio\n",
|
|
||||||
"from pprint import pprint\n",
|
|
||||||
"\n",
|
|
||||||
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
|
||||||
" model='silero_number_detector',\n",
|
|
||||||
" force_reload=True)\n",
|
|
||||||
"\n",
|
|
||||||
"(get_number_ts,\n",
|
|
||||||
" save_audio,\n",
|
|
||||||
" read_audio,\n",
|
|
||||||
" collect_chunks,\n",
|
|
||||||
" drop_chunks) = utils\n",
|
|
||||||
"\n",
|
|
||||||
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"hidden": true,
|
|
||||||
"id": "qhPa30ij2Fwy"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Full audio"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "EXpau6xq2Fwy"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"wav = read_audio(f'{files_dir}/en_num.wav')\n",
|
|
||||||
"# get number timestamps from full audio file\n",
|
|
||||||
"number_timestamps = get_number_ts(wav, model)\n",
|
|
||||||
"pprint(number_timestamps)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "u-KfXRhZ2Fwy"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"sample_rate = 16000\n",
|
|
||||||
"# convert ms in timestamps to samples\n",
|
|
||||||
"for timestamp in number_timestamps:\n",
|
|
||||||
" timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n",
|
|
||||||
" timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "iwYEC4aZ2Fwy"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# merge all number chunks to one audio\n",
|
|
||||||
"save_audio('only_numbers.wav',\n",
|
|
||||||
" collect_chunks(number_timestamps, wav), sample_rate) \n",
|
|
||||||
"Audio('only_numbers.wav')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "fHaYejX12Fwy"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# drop all number chunks from audio\n",
|
|
||||||
"save_audio('no_numbers.wav',\n",
|
|
||||||
" drop_chunks(number_timestamps, wav), sample_rate) \n",
|
|
||||||
"Audio('no_numbers.wav')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"id": "PnKtJKbq2Fwz"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Language detector"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"hidden": true,
|
|
||||||
"id": "F5cAmMbP2Fwz"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Install Dependencies"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "Zu9D0t6n2Fwz"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"#@title Install and Import Dependencies\n",
|
|
||||||
"\n",
|
|
||||||
"# this assumes that you have a relevant version of PyTorch installed\n",
|
|
||||||
"!pip install -q torchaudio soundfile\n",
|
|
||||||
"\n",
|
|
||||||
"import glob\n",
|
|
||||||
"import torch\n",
|
|
||||||
"torch.set_num_threads(1)\n",
|
|
||||||
"\n",
|
|
||||||
"from IPython.display import Audio\n",
|
|
||||||
"from pprint import pprint\n",
|
|
||||||
"\n",
|
|
||||||
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
|
||||||
" model='silero_lang_detector',\n",
|
|
||||||
" force_reload=True)\n",
|
|
||||||
"\n",
|
|
||||||
"(get_language,\n",
|
|
||||||
" read_audio) = utils\n",
|
|
||||||
"\n",
|
|
||||||
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"hidden": true,
|
|
||||||
"id": "iC696eMX2Fwz"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Full audio"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "c8UYnYBF2Fw0"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"wav = read_audio(f'{files_dir}/en.wav')\n",
|
|
||||||
"lang = get_language(wav, model)\n",
|
|
||||||
"print(lang)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "57avIBd6jsrz"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"# ONNX Example"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "hEhnfORV2Fw0"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## VAD"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"id": "bL4kn4KJrlyL"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Install Dependencies"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"hidden": true,
|
|
||||||
"id": "Q4QIfSpprnkI"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"#@title Install and Import Dependencies\n",
|
|
||||||
"\n",
|
|
||||||
"# this assumes that you have a relevant version of PyTorch installed\n",
|
|
||||||
"!pip install -q torchaudio soundfile onnxruntime\n",
|
|
||||||
"\n",
|
|
||||||
"import glob\n",
|
|
||||||
"import onnxruntime\n",
|
|
||||||
"from pprint import pprint\n",
|
|
||||||
"\n",
|
|
||||||
"from IPython.display import Audio\n",
|
|
||||||
"\n",
|
|
||||||
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
|
||||||
" model='silero_vad',\n",
|
|
||||||
" force_reload=True)\n",
|
|
||||||
"\n",
|
|
||||||
"(get_speech_ts,\n",
|
|
||||||
" get_speech_ts_adaptive,\n",
|
|
||||||
" save_audio,\n",
|
|
||||||
" read_audio,\n",
|
|
||||||
" state_generator,\n",
|
|
||||||
" single_audio_stream,\n",
|
|
||||||
" collect_speeches) = utils\n",
|
|
||||||
"\n",
|
|
||||||
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
|
|
||||||
"\n",
|
|
||||||
"def init_onnx_model(model_path: str):\n",
|
|
||||||
" return onnxruntime.InferenceSession(model_path)\n",
|
|
||||||
"\n",
|
|
||||||
"def validate_onnx(model, inputs):\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
|
|
||||||
" outs = model.run(None, ort_inputs)\n",
|
|
||||||
" outs = [torch.Tensor(x) for x in outs]\n",
|
|
||||||
" return outs[0]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "5JHErdB7jsr0"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Full Audio"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "TNEtK5zi2Fw2"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"**Classic way of getting speech chunks, you may need to select the thresholds yourself**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "krnGoA6Kjsr0"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = init_onnx_model(f'{files_dir}/model.onnx')\n",
|
|
||||||
"wav = read_audio(f'{files_dir}/en.wav')\n",
|
|
||||||
"\n",
|
|
||||||
"# get speech timestamps from full audio file\n",
|
|
||||||
"speech_timestamps = get_speech_ts(wav, model, num_steps=4, run_function=validate_onnx) \n",
|
|
||||||
"pprint(speech_timestamps)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "B176Lzfnjsr1"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# merge all speech chunks to one audio\n",
|
|
||||||
"save_audio('only_speech.wav', collect_chunks(speech_timestamps, wav), 16000)\n",
|
|
||||||
"Audio('only_speech.wav')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "21RE8KEC2Fw2"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"**Experimental Adaptive method, algorithm selects thresholds itself (see readme for more information)**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "uIVs56rb2Fw2"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = init_onnx_model(f'{files_dir}/model.onnx')\n",
|
|
||||||
"wav = read_audio(f'{files_dir}/en.wav')\n",
|
|
||||||
"\n",
|
|
||||||
"# get speech timestamps from full audio file\n",
|
|
||||||
"speech_timestamps = get_speech_ts_adaptive(wav, model, run_function=validate_onnx) \n",
|
|
||||||
"pprint(speech_timestamps)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "cox6oumC2Fw3"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# merge all speech chunks to one audio\n",
|
|
||||||
"save_audio('only_speech.wav', collect_chunks(speech_timestamps, wav), 16000)\n",
|
|
||||||
"Audio('only_speech.wav')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "Rio9W50gjsr1"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Single Audio Stream"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "i8EZwtaA2Fw3"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"**Classic way of getting speech chunks, you may need to select the thresholds yourself**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "IPkl8Yy1jsr1"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = init_onnx_model(f'{files_dir}/model.onnx')\n",
|
|
||||||
"wav = f'{files_dir}/en.wav'"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "NC6Jim0hjsr1"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"for batch in single_audio_stream(model, wav, run_function=validate_onnx):\n",
|
|
||||||
" if batch:\n",
|
|
||||||
" pprint(batch)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "0pSKslpz2Fw3"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"**Experimental Adaptive method, algorithm selects thresholds itself (see readme for more information)**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "RZwc-Khk2Fw4"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = init_onnx_model(f'{files_dir}/model.onnx')\n",
|
|
||||||
"wav = f'{files_dir}/en.wav'"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "Z4lzFPs02Fw4"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"for batch in single_audio_stream(model, wav, iterator_type='adaptive', run_function=validate_onnx):\n",
|
|
||||||
" if batch:\n",
|
|
||||||
" pprint(batch)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"id": "WNZ42u0ajsr1"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Multiple Audio Streams"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "XjhGQGppjsr1"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = init_onnx_model(f'{files_dir}/model.onnx')\n",
|
|
||||||
"audios_for_stream = glob.glob(f'{files_dir}/*.wav')\n",
|
|
||||||
"pprint(len(audios_for_stream)) # total 4 audios"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "QI7-arlqjsr2"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"for batch in state_generator(model, audios_for_stream, audios_in_stream=2, run_function=validate_onnx): # 2 audio stream\n",
|
|
||||||
" if batch:\n",
|
|
||||||
" pprint(batch)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"id": "7QMvUvpg2Fw4"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Number detector"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"hidden": true,
|
|
||||||
"id": "tBPDkpHr2Fw4"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Install Dependencies"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"hidden": true,
|
|
||||||
"id": "PdjGd56R2Fw5"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"#@title Install and Import Dependencies\n",
|
|
||||||
"\n",
|
|
||||||
"# this assumes that you have a relevant version of PyTorch installed\n",
|
|
||||||
"!pip install -q torchaudio soundfile onnxruntime\n",
|
|
||||||
"\n",
|
|
||||||
"import glob\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import onnxruntime\n",
|
|
||||||
"from pprint import pprint\n",
|
|
||||||
"\n",
|
|
||||||
"from IPython.display import Audio\n",
|
|
||||||
"\n",
|
|
||||||
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
|
||||||
" model='silero_number_detector',\n",
|
|
||||||
" force_reload=True)\n",
|
|
||||||
"\n",
|
|
||||||
"(get_number_ts,\n",
|
|
||||||
" save_audio,\n",
|
|
||||||
" read_audio,\n",
|
|
||||||
" collect_chunks,\n",
|
|
||||||
" drop_chunks) = utils\n",
|
|
||||||
"\n",
|
|
||||||
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
|
|
||||||
"\n",
|
|
||||||
"def init_onnx_model(model_path: str):\n",
|
|
||||||
" return onnxruntime.InferenceSession(model_path)\n",
|
|
||||||
"\n",
|
|
||||||
"def validate_onnx(model, inputs):\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
|
|
||||||
" outs = model.run(None, ort_inputs)\n",
|
|
||||||
" outs = [torch.Tensor(x) for x in outs]\n",
|
|
||||||
" return outs"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"hidden": true,
|
|
||||||
"id": "I9QWSFZh2Fw5"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Full Audio"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "_r6QZiwu2Fw5"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = init_onnx_model(f'{files_dir}/number_detector.onnx')\n",
|
|
||||||
"wav = read_audio(f'{files_dir}/en_num.wav')\n",
|
|
||||||
"\n",
|
|
||||||
"# get number timestamps from full audio file\n",
|
|
||||||
"number_timestamps = get_number_ts(wav, model, run_function=validate_onnx)\n",
|
|
||||||
"pprint(number_timestamps)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "FN4aDwLV2Fw5"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"sample_rate = 16000\n",
|
|
||||||
"# convert ms in timestamps to samples\n",
|
|
||||||
"for timestamp in number_timestamps:\n",
|
|
||||||
" timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n",
|
|
||||||
" timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "JnvS6WTK2Fw5"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# merge all number chunks to one audio\n",
|
|
||||||
"save_audio('only_numbers.wav',\n",
|
|
||||||
" collect_chunks(number_timestamps, wav), 16000) \n",
|
|
||||||
"Audio('only_numbers.wav')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "yUxOcOFG2Fw6"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# drop all number chunks from audio\n",
|
|
||||||
"save_audio('no_numbers.wav',\n",
|
|
||||||
" drop_chunks(number_timestamps, wav), 16000) \n",
|
|
||||||
"Audio('no_numbers.wav')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"id": "SR8Bgcd52Fw6"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Language detector"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"heading_collapsed": true,
|
|
||||||
"hidden": true,
|
|
||||||
"id": "PBnXPtKo2Fw6"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Install Dependencies"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"hidden": true,
|
|
||||||
"id": "iNkDWJ3H2Fw6"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"#@title Install and Import Dependencies\n",
|
|
||||||
"\n",
|
|
||||||
"# this assumes that you have a relevant version of PyTorch installed\n",
|
|
||||||
"!pip install -q torchaudio soundfile onnxruntime\n",
|
|
||||||
"\n",
|
|
||||||
"import glob\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import onnxruntime\n",
|
|
||||||
"from pprint import pprint\n",
|
|
||||||
"\n",
|
|
||||||
"from IPython.display import Audio\n",
|
|
||||||
"\n",
|
|
||||||
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
|
||||||
" model='silero_lang_detector',\n",
|
|
||||||
" force_reload=True)\n",
|
|
||||||
"\n",
|
|
||||||
"(get_language,\n",
|
|
||||||
" read_audio) = utils\n",
|
|
||||||
"\n",
|
|
||||||
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
|
|
||||||
"\n",
|
|
||||||
"def init_onnx_model(model_path: str):\n",
|
|
||||||
" return onnxruntime.InferenceSession(model_path)\n",
|
|
||||||
"\n",
|
|
||||||
"def validate_onnx(model, inputs):\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
|
|
||||||
" outs = model.run(None, ort_inputs)\n",
|
|
||||||
" outs = [torch.Tensor(x) for x in outs]\n",
|
|
||||||
" return outs"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "G8N8oP4q2Fw6"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"### Full Audio"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"hidden": true,
|
|
||||||
"id": "WHXnh9IV2Fw6"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = init_onnx_model(f'{files_dir}/number_detector.onnx')\n",
|
|
||||||
"wav = read_audio(f'{files_dir}/en.wav')\n",
|
|
||||||
"\n",
|
|
||||||
"lang = get_language(wav, model, run_function=validate_onnx)\n",
|
|
||||||
"print(lang)"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
886
utils_vad.py
886
utils_vad.py
@@ -1,617 +1,473 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from typing import List
|
from typing import Callable, List
|
||||||
from itertools import repeat
|
import warnings
|
||||||
from collections import deque
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
torchaudio.set_audio_backend("soundfile") # switch backend
|
|
||||||
|
|
||||||
|
|
||||||
languages = ['ru', 'en', 'de', 'es']
|
languages = ['ru', 'en', 'de', 'es']
|
||||||
|
|
||||||
|
|
||||||
class IterativeMedianMeter():
|
class OnnxWrapper():
|
||||||
def __init__(self):
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
def __init__(self, path, force_onnx_cpu=False):
|
||||||
self.median = 0
|
import numpy as np
|
||||||
self.counts = {}
|
global np
|
||||||
for i in range(0, 101, 1):
|
import onnxruntime
|
||||||
self.counts[i / 100] = 0
|
|
||||||
self.total_values = 0
|
|
||||||
|
|
||||||
def __call__(self, val):
|
opts = onnxruntime.SessionOptions()
|
||||||
self.total_values += 1
|
opts.inter_op_num_threads = 1
|
||||||
rounded = round(abs(val), 2)
|
opts.intra_op_num_threads = 1
|
||||||
self.counts[rounded] += 1
|
|
||||||
bin_sum = 0
|
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||||
for j in self.counts:
|
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
||||||
bin_sum += self.counts[j]
|
else:
|
||||||
if bin_sum >= self.total_values / 2:
|
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||||
self.median = j
|
|
||||||
break
|
self.reset_states()
|
||||||
return self.median
|
self.sample_rates = [8000, 16000]
|
||||||
|
|
||||||
|
def _validate_input(self, x, sr: int):
|
||||||
|
if x.dim() == 1:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
if x.dim() > 2:
|
||||||
|
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
||||||
|
|
||||||
|
if sr != 16000 and (sr % 16000 == 0):
|
||||||
|
step = sr // 16000
|
||||||
|
x = x[:,::step]
|
||||||
|
sr = 16000
|
||||||
|
|
||||||
|
if sr not in self.sample_rates:
|
||||||
|
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
||||||
|
if sr / x.shape[1] > 31.25:
|
||||||
|
raise ValueError("Input audio chunk is too short")
|
||||||
|
|
||||||
|
return x, sr
|
||||||
|
|
||||||
|
def reset_states(self, batch_size=1):
|
||||||
|
self._state = torch.zeros((2, batch_size, 128)).float()
|
||||||
|
self._context = torch.zeros(0)
|
||||||
|
self._last_sr = 0
|
||||||
|
self._last_batch_size = 0
|
||||||
|
|
||||||
|
def __call__(self, x, sr: int):
|
||||||
|
|
||||||
|
x, sr = self._validate_input(x, sr)
|
||||||
|
num_samples = 512 if sr == 16000 else 256
|
||||||
|
|
||||||
|
if x.shape[-1] != num_samples:
|
||||||
|
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||||
|
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
context_size = 64 if sr == 16000 else 32
|
||||||
|
|
||||||
|
if not self._last_batch_size:
|
||||||
|
self.reset_states(batch_size)
|
||||||
|
if (self._last_sr) and (self._last_sr != sr):
|
||||||
|
self.reset_states(batch_size)
|
||||||
|
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
||||||
|
self.reset_states(batch_size)
|
||||||
|
|
||||||
|
if not len(self._context):
|
||||||
|
self._context = torch.zeros(batch_size, context_size)
|
||||||
|
|
||||||
|
x = torch.cat([self._context, x], dim=1)
|
||||||
|
if sr in [8000, 16000]:
|
||||||
|
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
|
||||||
|
ort_outs = self.session.run(None, ort_inputs)
|
||||||
|
out, state = ort_outs
|
||||||
|
self._state = torch.from_numpy(state)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
self._context = x[..., -context_size:]
|
||||||
|
self._last_sr = sr
|
||||||
|
self._last_batch_size = batch_size
|
||||||
|
|
||||||
|
out = torch.from_numpy(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def audio_forward(self, x, sr: int):
|
||||||
|
outs = []
|
||||||
|
x, sr = self._validate_input(x, sr)
|
||||||
|
self.reset_states()
|
||||||
|
num_samples = 512 if sr == 16000 else 256
|
||||||
|
|
||||||
|
if x.shape[1] % num_samples:
|
||||||
|
pad_num = num_samples - (x.shape[1] % num_samples)
|
||||||
|
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
|
||||||
|
|
||||||
|
for i in range(0, x.shape[1], num_samples):
|
||||||
|
wavs_batch = x[:, i:i+num_samples]
|
||||||
|
out_chunk = self.__call__(wavs_batch, sr)
|
||||||
|
outs.append(out_chunk)
|
||||||
|
|
||||||
|
stacked = torch.cat(outs, dim=1)
|
||||||
|
return stacked.cpu()
|
||||||
|
|
||||||
|
|
||||||
def validate(model,
|
class Validator():
|
||||||
inputs: torch.Tensor):
|
def __init__(self, url, force_onnx_cpu):
|
||||||
with torch.no_grad():
|
self.onnx = True if url.endswith('.onnx') else False
|
||||||
outs = model(inputs)
|
torch.hub.download_url_to_file(url, 'inf.model')
|
||||||
return outs
|
if self.onnx:
|
||||||
|
import onnxruntime
|
||||||
|
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||||
|
self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider'])
|
||||||
|
else:
|
||||||
|
self.model = onnxruntime.InferenceSession('inf.model')
|
||||||
|
else:
|
||||||
|
self.model = init_jit_model(model_path='inf.model')
|
||||||
|
|
||||||
|
def __call__(self, inputs: torch.Tensor):
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.onnx:
|
||||||
|
ort_inputs = {'input': inputs.cpu().numpy()}
|
||||||
|
outs = self.model.run(None, ort_inputs)
|
||||||
|
outs = [torch.Tensor(x) for x in outs]
|
||||||
|
else:
|
||||||
|
outs = self.model(inputs)
|
||||||
|
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
def read_audio(path: str,
|
def read_audio(path: str,
|
||||||
target_sr: int = 16000):
|
sampling_rate: int = 16000):
|
||||||
|
|
||||||
assert torchaudio.get_audio_backend() == 'soundfile'
|
sox_backends = set(['sox', 'sox_io'])
|
||||||
wav, sr = torchaudio.load(path)
|
audio_backends = torchaudio.list_audio_backends()
|
||||||
|
|
||||||
if wav.size(0) > 1:
|
if len(sox_backends.intersection(audio_backends)) > 0:
|
||||||
wav = wav.mean(dim=0, keepdim=True)
|
effects = [
|
||||||
|
['channels', '1'],
|
||||||
|
['rate', str(sampling_rate)]
|
||||||
|
]
|
||||||
|
|
||||||
if sr != target_sr:
|
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
|
||||||
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
else:
|
||||||
new_freq=target_sr)
|
wav, sr = torchaudio.load(path)
|
||||||
wav = transform(wav)
|
|
||||||
sr = target_sr
|
|
||||||
|
|
||||||
assert sr == target_sr
|
if wav.size(0) > 1:
|
||||||
|
wav = wav.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
if sr != sampling_rate:
|
||||||
|
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
||||||
|
new_freq=sampling_rate)
|
||||||
|
wav = transform(wav)
|
||||||
|
sr = sampling_rate
|
||||||
|
|
||||||
|
assert sr == sampling_rate
|
||||||
return wav.squeeze(0)
|
return wav.squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
def save_audio(path: str,
|
def save_audio(path: str,
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
sr: int = 16000):
|
sampling_rate: int = 16000):
|
||||||
torchaudio.save(path, tensor.unsqueeze(0), sr)
|
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16)
|
||||||
|
|
||||||
|
|
||||||
def init_jit_model(model_path: str,
|
def init_jit_model(model_path: str,
|
||||||
device=torch.device('cpu')):
|
device=torch.device('cpu')):
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
model = torch.jit.load(model_path, map_location=device)
|
model = torch.jit.load(model_path, map_location=device)
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_speech_ts(wav: torch.Tensor,
|
def make_visualization(probs, step):
|
||||||
model,
|
import pandas as pd
|
||||||
trig_sum: float = 0.25,
|
pd.DataFrame({'probs': probs},
|
||||||
neg_trig_sum: float = 0.07,
|
index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8),
|
||||||
num_steps: int = 8,
|
kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step],
|
||||||
batch_size: int = 200,
|
xlabel='seconds',
|
||||||
num_samples_per_window: int = 4000,
|
ylabel='speech probability',
|
||||||
min_speech_samples: int = 10000, #samples
|
colormap='tab20')
|
||||||
min_silence_samples: int = 500,
|
|
||||||
run_function=validate,
|
|
||||||
visualize_probs=False,
|
|
||||||
smoothed_prob_func='mean',
|
|
||||||
device='cpu'):
|
|
||||||
|
|
||||||
assert smoothed_prob_func in ['mean', 'max'], 'smoothed_prob_func not in ["max", "mean"]'
|
|
||||||
num_samples = num_samples_per_window
|
|
||||||
assert num_samples % num_steps == 0
|
|
||||||
step = int(num_samples / num_steps) # stride / hop
|
|
||||||
outs = []
|
|
||||||
to_concat = []
|
|
||||||
for i in range(0, len(wav), step):
|
|
||||||
chunk = wav[i: i+num_samples]
|
|
||||||
if len(chunk) < num_samples:
|
|
||||||
chunk = F.pad(chunk, (0, num_samples - len(chunk)))
|
|
||||||
to_concat.append(chunk.unsqueeze(0))
|
|
||||||
if len(to_concat) >= batch_size:
|
|
||||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
|
|
||||||
out = run_function(model, chunks)
|
|
||||||
outs.append(out)
|
|
||||||
to_concat = []
|
|
||||||
|
|
||||||
if to_concat:
|
|
||||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
|
|
||||||
out = run_function(model, chunks)
|
|
||||||
outs.append(out)
|
|
||||||
|
|
||||||
outs = torch.cat(outs, dim=0)
|
|
||||||
|
|
||||||
buffer = deque(maxlen=num_steps) # maxlen reached => first element dropped
|
|
||||||
triggered = False
|
|
||||||
speeches = []
|
|
||||||
current_speech = {}
|
|
||||||
if visualize_probs:
|
|
||||||
import pandas as pd
|
|
||||||
smoothed_probs = []
|
|
||||||
|
|
||||||
speech_probs = outs[:, 1] # this is very misleading
|
|
||||||
temp_end = 0
|
|
||||||
for i, predict in enumerate(speech_probs): # add name
|
|
||||||
buffer.append(predict)
|
|
||||||
if smoothed_prob_func == 'mean':
|
|
||||||
smoothed_prob = (sum(buffer) / len(buffer))
|
|
||||||
elif smoothed_prob_func == 'max':
|
|
||||||
smoothed_prob = max(buffer)
|
|
||||||
|
|
||||||
if visualize_probs:
|
|
||||||
smoothed_probs.append(float(smoothed_prob))
|
|
||||||
if (smoothed_prob >= trig_sum) and temp_end:
|
|
||||||
temp_end=0
|
|
||||||
if (smoothed_prob >= trig_sum) and not triggered:
|
|
||||||
triggered = True
|
|
||||||
current_speech['start'] = step * max(0, i-num_steps)
|
|
||||||
continue
|
|
||||||
if (smoothed_prob < neg_trig_sum) and triggered:
|
|
||||||
if not temp_end:
|
|
||||||
temp_end = step * i
|
|
||||||
if step * i - temp_end < min_silence_samples:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
current_speech['end'] = temp_end
|
|
||||||
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
|
|
||||||
speeches.append(current_speech)
|
|
||||||
temp_end = 0
|
|
||||||
current_speech = {}
|
|
||||||
triggered = False
|
|
||||||
continue
|
|
||||||
if current_speech:
|
|
||||||
current_speech['end'] = len(wav)
|
|
||||||
speeches.append(current_speech)
|
|
||||||
|
|
||||||
if visualize_probs:
|
|
||||||
pd.DataFrame({'probs':smoothed_probs}).plot(figsize=(16,8))
|
|
||||||
return speeches
|
|
||||||
|
|
||||||
|
|
||||||
def get_speech_ts_adaptive(wav: torch.Tensor,
|
@torch.no_grad()
|
||||||
model,
|
def get_speech_timestamps(audio: torch.Tensor,
|
||||||
batch_size: int = 200,
|
model,
|
||||||
step: int = 500,
|
threshold: float = 0.5,
|
||||||
num_samples_per_window: int = 4000, # Number of samples per audio chunk to feed to NN (4000 for 16k SR, 2000 for 8k SR is optimal)
|
sampling_rate: int = 16000,
|
||||||
min_speech_samples: int = 10000, # samples
|
min_speech_duration_ms: int = 250,
|
||||||
min_silence_samples: int = 4000,
|
max_speech_duration_s: float = float('inf'),
|
||||||
speech_pad_samples: int = 2000,
|
min_silence_duration_ms: int = 100,
|
||||||
run_function=validate,
|
speech_pad_ms: int = 30,
|
||||||
visualize_probs=False,
|
return_seconds: bool = False,
|
||||||
device='cpu'):
|
visualize_probs: bool = False,
|
||||||
|
progress_tracking_callback: Callable[[float], None] = None,
|
||||||
|
window_size_samples: int = 512,):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This function is used for splitting long audios into speech chunks using silero VAD
|
This method is used for splitting long audios into speech chunks using silero VAD
|
||||||
Attention! All default sample rate values are optimal for 16000 sample rate model, if you are using 8000 sample rate model optimal values are half as much!
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
batch_size: int
|
audio: torch.Tensor, one dimensional
|
||||||
batch size to feed to silero VAD (default - 200)
|
One dimensional float torch.Tensor, other types are casted to torch if possible
|
||||||
|
|
||||||
step: int
|
model: preloaded .jit/.onnx silero VAD model
|
||||||
step size in samples, (default - 500)
|
|
||||||
|
|
||||||
num_samples_per_window: int
|
threshold: float (default - 0.5)
|
||||||
window size in samples (chunk length in samples to feed to NN, default - 4000)
|
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||||
|
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||||
|
|
||||||
min_speech_samples: int
|
sampling_rate: int (default - 16000)
|
||||||
if speech duration is shorter than this value, do not consider it speech (default - 10000)
|
Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates
|
||||||
|
|
||||||
min_silence_samples: int
|
min_speech_duration_ms: int (default - 250 milliseconds)
|
||||||
number of samples to wait before considering as the end of speech (default - 4000)
|
Final speech chunks shorter min_speech_duration_ms are thrown out
|
||||||
|
|
||||||
speech_pad_samples: int
|
max_speech_duration_s: int (default - inf)
|
||||||
widen speech by this amount of samples each side (default - 2000)
|
Maximum duration of speech chunks in seconds
|
||||||
|
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting.
|
||||||
|
Otherwise, they will be split aggressively just before max_speech_duration_s.
|
||||||
|
|
||||||
run_function: function
|
min_silence_duration_ms: int (default - 100 milliseconds)
|
||||||
function to use for the model call
|
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
||||||
|
|
||||||
visualize_probs: bool
|
speech_pad_ms: int (default - 30 milliseconds)
|
||||||
whether draw prob hist or not (default: False)
|
Final speech chunks are padded by speech_pad_ms each side
|
||||||
|
|
||||||
device: string
|
return_seconds: bool (default - False)
|
||||||
torch device to use for the model call (default - "cpu")
|
whether return timestamps in seconds (default - samples)
|
||||||
|
|
||||||
|
visualize_probs: bool (default - False)
|
||||||
|
whether draw prob hist or not
|
||||||
|
|
||||||
|
progress_tracking_callback: Callable[[float], None] (default - None)
|
||||||
|
callback function taking progress in percents as an argument
|
||||||
|
|
||||||
|
window_size_samples: int (default - 512 samples)
|
||||||
|
!!! DEPRECATED, DOES NOTHING !!!
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
speeches: list
|
speeches: list of dicts
|
||||||
list containing ends and beginnings of speech chunks (in samples)
|
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
|
||||||
"""
|
"""
|
||||||
if visualize_probs:
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
num_samples = num_samples_per_window
|
if not torch.is_tensor(audio):
|
||||||
num_steps = int(num_samples / step)
|
try:
|
||||||
assert min_silence_samples >= step
|
audio = torch.Tensor(audio)
|
||||||
outs = []
|
except:
|
||||||
to_concat = []
|
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
|
||||||
for i in range(0, len(wav), step):
|
|
||||||
chunk = wav[i: i+num_samples]
|
|
||||||
if len(chunk) < num_samples:
|
|
||||||
chunk = F.pad(chunk, (0, num_samples - len(chunk)))
|
|
||||||
to_concat.append(chunk.unsqueeze(0))
|
|
||||||
if len(to_concat) >= batch_size:
|
|
||||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
|
|
||||||
out = run_function(model, chunks)
|
|
||||||
outs.append(out)
|
|
||||||
to_concat = []
|
|
||||||
|
|
||||||
if to_concat:
|
if len(audio.shape) > 1:
|
||||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
|
for i in range(len(audio.shape)): # trying to squeeze empty dimensions
|
||||||
out = run_function(model, chunks)
|
audio = audio.squeeze(0)
|
||||||
outs.append(out)
|
if len(audio.shape) > 1:
|
||||||
|
raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?")
|
||||||
|
|
||||||
outs = torch.cat(outs, dim=0).cpu()
|
if sampling_rate > 16000 and (sampling_rate % 16000 == 0):
|
||||||
|
step = sampling_rate // 16000
|
||||||
|
sampling_rate = 16000
|
||||||
|
audio = audio[::step]
|
||||||
|
warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!')
|
||||||
|
else:
|
||||||
|
step = 1
|
||||||
|
|
||||||
|
if sampling_rate not in [8000, 16000]:
|
||||||
|
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
|
||||||
|
|
||||||
|
window_size_samples = 512 if sampling_rate == 16000 else 256
|
||||||
|
|
||||||
|
model.reset_states()
|
||||||
|
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
||||||
|
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||||
|
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
|
||||||
|
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||||
|
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
|
||||||
|
|
||||||
|
audio_length_samples = len(audio)
|
||||||
|
|
||||||
|
speech_probs = []
|
||||||
|
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
||||||
|
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
|
||||||
|
if len(chunk) < window_size_samples:
|
||||||
|
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||||
|
speech_prob = model(chunk, sampling_rate).item()
|
||||||
|
speech_probs.append(speech_prob)
|
||||||
|
# caculate progress and seng it to callback function
|
||||||
|
progress = current_start_sample + window_size_samples
|
||||||
|
if progress > audio_length_samples:
|
||||||
|
progress = audio_length_samples
|
||||||
|
progress_percent = (progress / audio_length_samples) * 100
|
||||||
|
if progress_tracking_callback:
|
||||||
|
progress_tracking_callback(progress_percent)
|
||||||
|
|
||||||
buffer = deque(maxlen=num_steps)
|
|
||||||
triggered = False
|
triggered = False
|
||||||
speeches = []
|
speeches = []
|
||||||
smoothed_probs = []
|
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
speech_probs = outs[:, 1] # 0 index for silence probs, 1 index for speech probs
|
neg_threshold = threshold - 0.15
|
||||||
median_probs = speech_probs.median()
|
temp_end = 0 # to save potential segment end (and tolerate some silence)
|
||||||
|
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
|
||||||
|
|
||||||
trig_sum = 0.89 * median_probs + 0.08 # 0.08 when median is zero, 0.97 when median is 1
|
for i, speech_prob in enumerate(speech_probs):
|
||||||
|
if (speech_prob >= threshold) and temp_end:
|
||||||
temp_end = 0
|
|
||||||
for i, predict in enumerate(speech_probs):
|
|
||||||
buffer.append(predict)
|
|
||||||
smoothed_prob = max(buffer)
|
|
||||||
if visualize_probs:
|
|
||||||
smoothed_probs.append(float(smoothed_prob))
|
|
||||||
if (smoothed_prob >= trig_sum) and temp_end:
|
|
||||||
temp_end = 0
|
temp_end = 0
|
||||||
if (smoothed_prob >= trig_sum) and not triggered:
|
if next_start < prev_end:
|
||||||
|
next_start = window_size_samples * i
|
||||||
|
|
||||||
|
if (speech_prob >= threshold) and not triggered:
|
||||||
triggered = True
|
triggered = True
|
||||||
current_speech['start'] = step * max(0, i-num_steps)
|
current_speech['start'] = window_size_samples * i
|
||||||
continue
|
continue
|
||||||
if (smoothed_prob < trig_sum) and triggered:
|
|
||||||
|
if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples:
|
||||||
|
if prev_end:
|
||||||
|
current_speech['end'] = prev_end
|
||||||
|
speeches.append(current_speech)
|
||||||
|
current_speech = {}
|
||||||
|
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
|
||||||
|
triggered = False
|
||||||
|
else:
|
||||||
|
current_speech['start'] = next_start
|
||||||
|
prev_end = next_start = temp_end = 0
|
||||||
|
else:
|
||||||
|
current_speech['end'] = window_size_samples * i
|
||||||
|
speeches.append(current_speech)
|
||||||
|
current_speech = {}
|
||||||
|
prev_end = next_start = temp_end = 0
|
||||||
|
triggered = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (speech_prob < neg_threshold) and triggered:
|
||||||
if not temp_end:
|
if not temp_end:
|
||||||
temp_end = step * i
|
temp_end = window_size_samples * i
|
||||||
if step * i - temp_end < min_silence_samples:
|
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence
|
||||||
|
prev_end = temp_end
|
||||||
|
if (window_size_samples * i) - temp_end < min_silence_samples:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
current_speech['end'] = temp_end
|
current_speech['end'] = temp_end
|
||||||
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
|
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
|
||||||
speeches.append(current_speech)
|
speeches.append(current_speech)
|
||||||
temp_end = 0
|
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
|
prev_end = next_start = temp_end = 0
|
||||||
triggered = False
|
triggered = False
|
||||||
continue
|
continue
|
||||||
if current_speech:
|
|
||||||
current_speech['end'] = len(wav)
|
|
||||||
speeches.append(current_speech)
|
|
||||||
if visualize_probs:
|
|
||||||
pd.DataFrame({'probs': smoothed_probs}).plot(figsize=(16, 8))
|
|
||||||
|
|
||||||
for i, ts in enumerate(speeches):
|
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
|
||||||
|
current_speech['end'] = audio_length_samples
|
||||||
|
speeches.append(current_speech)
|
||||||
|
|
||||||
|
for i, speech in enumerate(speeches):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
ts['start'] = max(0, ts['start'] - speech_pad_samples)
|
speech['start'] = int(max(0, speech['start'] - speech_pad_samples))
|
||||||
if i != len(speeches) - 1:
|
if i != len(speeches) - 1:
|
||||||
silence_duration = speeches[i+1]['start'] - ts['end']
|
silence_duration = speeches[i+1]['start'] - speech['end']
|
||||||
if silence_duration < 2 * speech_pad_samples:
|
if silence_duration < 2 * speech_pad_samples:
|
||||||
ts['end'] += silence_duration // 2
|
speech['end'] += int(silence_duration // 2)
|
||||||
speeches[i+1]['start'] = max(0, speeches[i+1]['start'] - silence_duration // 2)
|
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2))
|
||||||
else:
|
else:
|
||||||
ts['end'] += speech_pad_samples
|
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
|
||||||
|
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples))
|
||||||
else:
|
else:
|
||||||
ts['end'] = min(len(wav), ts['end'] + speech_pad_samples)
|
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
|
||||||
|
|
||||||
|
if return_seconds:
|
||||||
|
for speech_dict in speeches:
|
||||||
|
speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1)
|
||||||
|
speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1)
|
||||||
|
elif step > 1:
|
||||||
|
for speech_dict in speeches:
|
||||||
|
speech_dict['start'] *= step
|
||||||
|
speech_dict['end'] *= step
|
||||||
|
|
||||||
|
if visualize_probs:
|
||||||
|
make_visualization(speech_probs, window_size_samples / sampling_rate)
|
||||||
|
|
||||||
return speeches
|
return speeches
|
||||||
|
|
||||||
|
|
||||||
def get_number_ts(wav: torch.Tensor,
|
class VADIterator:
|
||||||
model,
|
def __init__(self,
|
||||||
model_stride=8,
|
|
||||||
hop_length=160,
|
|
||||||
sample_rate=16000,
|
|
||||||
run_function=validate):
|
|
||||||
wav = torch.unsqueeze(wav, dim=0)
|
|
||||||
perframe_logits = run_function(model, wav)[0]
|
|
||||||
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
|
|
||||||
extended_preds = []
|
|
||||||
for i in perframe_preds:
|
|
||||||
extended_preds.extend([i.item()] * model_stride)
|
|
||||||
# len(extended_preds) is *num_frames_real*; for each frame of audio we know if it has a number in it.
|
|
||||||
triggered = False
|
|
||||||
timings = []
|
|
||||||
cur_timing = {}
|
|
||||||
for i, pred in enumerate(extended_preds):
|
|
||||||
if pred == 1:
|
|
||||||
if not triggered:
|
|
||||||
cur_timing['start'] = int((i * hop_length) / (sample_rate / 1000))
|
|
||||||
triggered = True
|
|
||||||
elif pred == 0:
|
|
||||||
if triggered:
|
|
||||||
cur_timing['end'] = int((i * hop_length) / (sample_rate / 1000))
|
|
||||||
timings.append(cur_timing)
|
|
||||||
cur_timing = {}
|
|
||||||
triggered = False
|
|
||||||
if cur_timing:
|
|
||||||
cur_timing['end'] = int(len(wav) / (sample_rate / 1000))
|
|
||||||
timings.append(cur_timing)
|
|
||||||
return timings
|
|
||||||
|
|
||||||
|
|
||||||
def get_language(wav: torch.Tensor,
|
|
||||||
model,
|
model,
|
||||||
run_function=validate):
|
threshold: float = 0.5,
|
||||||
wav = torch.unsqueeze(wav, dim=0)
|
sampling_rate: int = 16000,
|
||||||
lang_logits = run_function(model, wav)[2]
|
min_silence_duration_ms: int = 100,
|
||||||
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
|
speech_pad_ms: int = 30
|
||||||
assert lang_pred < len(languages)
|
):
|
||||||
return languages[lang_pred]
|
|
||||||
|
|
||||||
|
|
||||||
def get_language_and_group(wav: torch.Tensor,
|
|
||||||
model,
|
|
||||||
lang_dict: dict,
|
|
||||||
lang_group_dict: dict,
|
|
||||||
top_n=1,
|
|
||||||
run_function=validate):
|
|
||||||
wav = torch.unsqueeze(wav, dim=0)
|
|
||||||
lang_logits, lang_group_logits = run_function(model, wav)
|
|
||||||
|
|
||||||
softm = torch.softmax(lang_logits, dim=1).squeeze()
|
|
||||||
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
|
|
||||||
|
|
||||||
srtd = torch.argsort(softm, descending=True)
|
|
||||||
srtd_group = torch.argsort(softm_group, descending=True)
|
|
||||||
|
|
||||||
outs = []
|
|
||||||
outs_group = []
|
|
||||||
for i in range(top_n):
|
|
||||||
prob = round(softm[srtd[i]].item(), 2)
|
|
||||||
prob_group = round(softm_group[srtd_group[i]].item(), 2)
|
|
||||||
outs.append((lang_dict[str(srtd[i].item())], prob))
|
|
||||||
outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group))
|
|
||||||
|
|
||||||
return outs, outs_group
|
|
||||||
|
|
||||||
|
|
||||||
class VADiterator:
|
|
||||||
def __init__(self,
|
|
||||||
trig_sum: float = 0.26,
|
|
||||||
neg_trig_sum: float = 0.07,
|
|
||||||
num_steps: int = 8,
|
|
||||||
num_samples_per_window: int = 4000):
|
|
||||||
self.num_samples = num_samples_per_window
|
|
||||||
self.num_steps = num_steps
|
|
||||||
assert self.num_samples % num_steps == 0
|
|
||||||
self.step = int(self.num_samples / num_steps) # 500 samples is good enough
|
|
||||||
self.prev = torch.zeros(self.num_samples)
|
|
||||||
self.last = False
|
|
||||||
self.triggered = False
|
|
||||||
self.buffer = deque(maxlen=num_steps)
|
|
||||||
self.num_frames = 0
|
|
||||||
self.trig_sum = trig_sum
|
|
||||||
self.neg_trig_sum = neg_trig_sum
|
|
||||||
self.current_name = ''
|
|
||||||
|
|
||||||
def refresh(self):
|
|
||||||
self.prev = torch.zeros(self.num_samples)
|
|
||||||
self.last = False
|
|
||||||
self.triggered = False
|
|
||||||
self.buffer = deque(maxlen=self.num_steps)
|
|
||||||
self.num_frames = 0
|
|
||||||
|
|
||||||
def prepare_batch(self, wav_chunk, name=None):
|
|
||||||
if (name is not None) and (name != self.current_name):
|
|
||||||
self.refresh()
|
|
||||||
self.current_name = name
|
|
||||||
assert len(wav_chunk) <= self.num_samples
|
|
||||||
self.num_frames += len(wav_chunk)
|
|
||||||
if len(wav_chunk) < self.num_samples:
|
|
||||||
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # short chunk => eof audio
|
|
||||||
self.last = True
|
|
||||||
|
|
||||||
stacked = torch.cat([self.prev, wav_chunk])
|
|
||||||
self.prev = wav_chunk
|
|
||||||
|
|
||||||
overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0)
|
|
||||||
for i in range(self.step, self.num_samples+1, self.step)]
|
|
||||||
return torch.cat(overlap_chunks, dim=0)
|
|
||||||
|
|
||||||
def state(self, model_out):
|
|
||||||
current_speech = {}
|
|
||||||
speech_probs = model_out[:, 1] # this is very misleading
|
|
||||||
for i, predict in enumerate(speech_probs):
|
|
||||||
self.buffer.append(predict)
|
|
||||||
if ((sum(self.buffer) / len(self.buffer)) >= self.trig_sum) and not self.triggered:
|
|
||||||
self.triggered = True
|
|
||||||
current_speech[self.num_frames - (self.num_steps-i) * self.step] = 'start'
|
|
||||||
if ((sum(self.buffer) / len(self.buffer)) < self.neg_trig_sum) and self.triggered:
|
|
||||||
current_speech[self.num_frames - (self.num_steps-i) * self.step] = 'end'
|
|
||||||
self.triggered = False
|
|
||||||
if self.triggered and self.last:
|
|
||||||
current_speech[self.num_frames] = 'end'
|
|
||||||
if self.last:
|
|
||||||
self.refresh()
|
|
||||||
return current_speech, self.current_name
|
|
||||||
|
|
||||||
|
|
||||||
class VADiteratorAdaptive:
|
|
||||||
def __init__(self,
|
|
||||||
trig_sum: float = 0.26,
|
|
||||||
neg_trig_sum: float = 0.06,
|
|
||||||
step: int = 500,
|
|
||||||
num_samples_per_window: int = 4000,
|
|
||||||
speech_pad_samples: int = 1000,
|
|
||||||
accum_period: int = 50):
|
|
||||||
"""
|
"""
|
||||||
This class is used for streaming silero VAD usage
|
Class for stream imitation
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
trig_sum: float
|
model: preloaded .jit/.onnx silero VAD model
|
||||||
trigger value for speech probability, probs above this value are considered speech, switch to TRIGGERED state (default - 0.26)
|
|
||||||
|
|
||||||
neg_trig_sum: float
|
threshold: float (default - 0.5)
|
||||||
in triggered state probabilites below this value are considered nonspeech, switch to NONTRIGGERED state (default - 0.06)
|
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||||
|
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||||
|
|
||||||
step: int
|
sampling_rate: int (default - 16000)
|
||||||
step size in samples, (default - 500)
|
Currently silero VAD models support 8000 and 16000 sample rates
|
||||||
|
|
||||||
num_samples_per_window: int
|
min_silence_duration_ms: int (default - 100 milliseconds)
|
||||||
window size in samples (chunk length in samples to feed to NN, default - 4000)
|
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
||||||
|
|
||||||
speech_pad_samples: int
|
|
||||||
widen speech by this amount of samples each side (default - 1000)
|
|
||||||
|
|
||||||
accum_period: int
|
|
||||||
number of chunks / iterations to wait before switching from constant (initial) trig and neg_trig coeffs to adaptive median coeffs (default - 50)
|
|
||||||
|
|
||||||
|
speech_pad_ms: int (default - 30 milliseconds)
|
||||||
|
Final speech chunks are padded by speech_pad_ms each side
|
||||||
"""
|
"""
|
||||||
self.num_samples = num_samples_per_window
|
|
||||||
self.num_steps = int(num_samples_per_window / step)
|
self.model = model
|
||||||
self.step = step
|
self.threshold = threshold
|
||||||
self.prev = torch.zeros(self.num_samples)
|
self.sampling_rate = sampling_rate
|
||||||
self.last = False
|
|
||||||
|
if sampling_rate not in [8000, 16000]:
|
||||||
|
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
||||||
|
|
||||||
|
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||||
|
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||||
|
self.reset_states()
|
||||||
|
|
||||||
|
def reset_states(self):
|
||||||
|
|
||||||
|
self.model.reset_states()
|
||||||
self.triggered = False
|
self.triggered = False
|
||||||
self.buffer = deque(maxlen=self.num_steps)
|
self.temp_end = 0
|
||||||
self.num_frames = 0
|
self.current_sample = 0
|
||||||
self.trig_sum = trig_sum
|
|
||||||
self.neg_trig_sum = neg_trig_sum
|
|
||||||
self.current_name = ''
|
|
||||||
self.median_meter = IterativeMedianMeter()
|
|
||||||
self.median = 0
|
|
||||||
self.total_steps = 0
|
|
||||||
self.accum_period = accum_period
|
|
||||||
self.speech_pad_samples = speech_pad_samples
|
|
||||||
|
|
||||||
def refresh(self):
|
@torch.no_grad()
|
||||||
self.prev = torch.zeros(self.num_samples)
|
def __call__(self, x, return_seconds=False):
|
||||||
self.last = False
|
"""
|
||||||
self.triggered = False
|
x: torch.Tensor
|
||||||
self.buffer = deque(maxlen=self.num_steps)
|
audio chunk (see examples in repo)
|
||||||
self.num_frames = 0
|
|
||||||
self.median_meter.reset()
|
|
||||||
self.median = 0
|
|
||||||
self.total_steps = 0
|
|
||||||
|
|
||||||
def prepare_batch(self, wav_chunk, name=None):
|
return_seconds: bool (default - False)
|
||||||
if (name is not None) and (name != self.current_name):
|
whether return timestamps in seconds (default - samples)
|
||||||
self.refresh()
|
"""
|
||||||
self.current_name = name
|
|
||||||
assert len(wav_chunk) <= self.num_samples
|
|
||||||
self.num_frames += len(wav_chunk)
|
|
||||||
if len(wav_chunk) < self.num_samples:
|
|
||||||
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # short chunk => eof audio
|
|
||||||
self.last = True
|
|
||||||
|
|
||||||
stacked = torch.cat([self.prev, wav_chunk])
|
if not torch.is_tensor(x):
|
||||||
self.prev = wav_chunk
|
|
||||||
|
|
||||||
overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0)
|
|
||||||
for i in range(self.step, self.num_samples+1, self.step)]
|
|
||||||
return torch.cat(overlap_chunks, dim=0)
|
|
||||||
|
|
||||||
def state(self, model_out):
|
|
||||||
current_speech = {}
|
|
||||||
speech_probs = model_out[:, 1] # 0 index for silence probs, 1 index for speech probs
|
|
||||||
for i, predict in enumerate(speech_probs):
|
|
||||||
self.median = self.median_meter(predict.item())
|
|
||||||
if self.total_steps < self.accum_period:
|
|
||||||
trig_sum = self.trig_sum
|
|
||||||
neg_trig_sum = self.neg_trig_sum
|
|
||||||
else:
|
|
||||||
trig_sum = 0.89 * self.median + 0.08 # 0.08 when median is zero, 0.97 when median is 1
|
|
||||||
neg_trig_sum = 0.6 * self.median
|
|
||||||
self.total_steps += 1
|
|
||||||
self.buffer.append(predict)
|
|
||||||
smoothed_prob = max(self.buffer)
|
|
||||||
if (smoothed_prob >= trig_sum) and not self.triggered:
|
|
||||||
self.triggered = True
|
|
||||||
current_speech[max(0, self.num_frames - (self.num_steps-i) * self.step - self.speech_pad_samples)] = 'start'
|
|
||||||
if (smoothed_prob < neg_trig_sum) and self.triggered:
|
|
||||||
current_speech[self.num_frames - (self.num_steps-i) * self.step + self.speech_pad_samples] = 'end'
|
|
||||||
self.triggered = False
|
|
||||||
if self.triggered and self.last:
|
|
||||||
current_speech[self.num_frames] = 'end'
|
|
||||||
if self.last:
|
|
||||||
self.refresh()
|
|
||||||
return current_speech, self.current_name
|
|
||||||
|
|
||||||
|
|
||||||
def state_generator(model,
|
|
||||||
audios: List[str],
|
|
||||||
onnx: bool = False,
|
|
||||||
trig_sum: float = 0.26,
|
|
||||||
neg_trig_sum: float = 0.07,
|
|
||||||
num_steps: int = 8,
|
|
||||||
num_samples_per_window: int = 4000,
|
|
||||||
audios_in_stream: int = 2,
|
|
||||||
run_function=validate):
|
|
||||||
VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps, num_samples_per_window) for i in range(audios_in_stream)]
|
|
||||||
for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream, num_samples_per_window)):
|
|
||||||
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
|
|
||||||
batch = torch.cat(for_batch)
|
|
||||||
|
|
||||||
outs = run_function(model, batch)
|
|
||||||
vad_outs = torch.split(outs, num_steps)
|
|
||||||
|
|
||||||
states = []
|
|
||||||
for x, y in zip(VADiters, vad_outs):
|
|
||||||
cur_st = x.state(y)
|
|
||||||
if cur_st[0]:
|
|
||||||
states.append(cur_st)
|
|
||||||
yield states
|
|
||||||
|
|
||||||
|
|
||||||
def stream_imitator(audios: List[str],
|
|
||||||
audios_in_stream: int,
|
|
||||||
num_samples_per_window: int = 4000):
|
|
||||||
audio_iter = iter(audios)
|
|
||||||
iterators = []
|
|
||||||
num_samples = num_samples_per_window
|
|
||||||
# initial wavs
|
|
||||||
for i in range(audios_in_stream):
|
|
||||||
next_wav = next(audio_iter)
|
|
||||||
wav = read_audio(next_wav)
|
|
||||||
wav_chunks = iter([(wav[i:i+num_samples], next_wav) for i in range(0, len(wav), num_samples)])
|
|
||||||
iterators.append(wav_chunks)
|
|
||||||
print('Done initial Loading')
|
|
||||||
good_iters = audios_in_stream
|
|
||||||
while True:
|
|
||||||
values = []
|
|
||||||
for i, it in enumerate(iterators):
|
|
||||||
try:
|
try:
|
||||||
out, wav_name = next(it)
|
x = torch.Tensor(x)
|
||||||
except StopIteration:
|
except:
|
||||||
try:
|
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
|
||||||
next_wav = next(audio_iter)
|
|
||||||
print('Loading next wav: ', next_wav)
|
|
||||||
wav = read_audio(next_wav)
|
|
||||||
iterators[i] = iter([(wav[i:i+num_samples], next_wav) for i in range(0, len(wav), num_samples)])
|
|
||||||
out, wav_name = next(iterators[i])
|
|
||||||
except StopIteration:
|
|
||||||
good_iters -= 1
|
|
||||||
iterators[i] = repeat((torch.zeros(num_samples), 'junk'))
|
|
||||||
out, wav_name = next(iterators[i])
|
|
||||||
if good_iters == 0:
|
|
||||||
return
|
|
||||||
values.append((out, wav_name))
|
|
||||||
yield values
|
|
||||||
|
|
||||||
|
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
||||||
|
self.current_sample += window_size_samples
|
||||||
|
|
||||||
def single_audio_stream(model,
|
speech_prob = self.model(x, self.sampling_rate).item()
|
||||||
audio: torch.Tensor,
|
|
||||||
num_samples_per_window:int = 4000,
|
|
||||||
run_function=validate,
|
|
||||||
iterator_type='basic',
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
num_samples = num_samples_per_window
|
|
||||||
if iterator_type == 'basic':
|
|
||||||
VADiter = VADiterator(num_samples_per_window=num_samples_per_window, **kwargs)
|
|
||||||
elif iterator_type == 'adaptive':
|
|
||||||
VADiter = VADiteratorAdaptive(num_samples_per_window=num_samples_per_window, **kwargs)
|
|
||||||
|
|
||||||
wav = read_audio(audio)
|
|
||||||
wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)])
|
|
||||||
for chunk in wav_chunks:
|
|
||||||
batch = VADiter.prepare_batch(chunk)
|
|
||||||
|
|
||||||
outs = run_function(model, batch)
|
if (speech_prob >= self.threshold) and self.temp_end:
|
||||||
|
self.temp_end = 0
|
||||||
|
|
||||||
states = []
|
if (speech_prob >= self.threshold) and not self.triggered:
|
||||||
state = VADiter.state(outs)
|
self.triggered = True
|
||||||
if state[0]:
|
speech_start = self.current_sample - self.speech_pad_samples - window_size_samples
|
||||||
states.append(state[0])
|
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
|
||||||
yield states
|
|
||||||
|
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||||
|
if not self.temp_end:
|
||||||
|
self.temp_end = self.current_sample
|
||||||
|
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
||||||
|
self.temp_end = 0
|
||||||
|
self.triggered = False
|
||||||
|
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def collect_chunks(tss: List[dict],
|
def collect_chunks(tss: List[dict],
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
from utils_vad import *
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
sys.path.append('/home/keras/notebook/nvme_raid/adamnsandle/silero_mono/pipelines/align/bin/')
|
|
||||||
from align_utils import load_audio_norm
|
|
||||||
import torch
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
sys.path.append('/home/keras/notebook/nvme_raid/adamnsandle/silero_mono/utils/')
|
|
||||||
from open_stt import soundfile_opus as sf
|
|
||||||
|
|
||||||
def split_save_audio_chunks(audio_path, model_path, save_path=None, device='cpu', absolute=True, max_duration=10, adaptive=False, **kwargs):
|
|
||||||
|
|
||||||
if not save_path:
|
|
||||||
save_path = str(Path(audio_path).with_name('after_vad'))
|
|
||||||
print(f'No save path specified! Using {save_path} to save audio chunks!')
|
|
||||||
|
|
||||||
SAMPLE_RATE = 16000
|
|
||||||
if type(model_path) == str:
|
|
||||||
#print('Loading model...')
|
|
||||||
model = init_jit_model(model_path, device)
|
|
||||||
else:
|
|
||||||
#print('Using loaded model')
|
|
||||||
model = model_path
|
|
||||||
save_name = Path(audio_path).stem
|
|
||||||
audio, sr = load_audio_norm(audio_path)
|
|
||||||
wav = torch.tensor(audio)
|
|
||||||
if adaptive:
|
|
||||||
speech_timestamps = get_speech_ts_adaptive(wav, model, device=device, **kwargs)
|
|
||||||
else:
|
|
||||||
speech_timestamps = get_speech_ts(wav, model, device=device, **kwargs)
|
|
||||||
|
|
||||||
full_save_path = Path(save_path, save_name)
|
|
||||||
if not os.path.exists(full_save_path):
|
|
||||||
os.makedirs(full_save_path, exist_ok=True)
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
if not speech_timestamps:
|
|
||||||
return pd.DataFrame()
|
|
||||||
for ts in speech_timestamps:
|
|
||||||
start_ts = int(ts['start'])
|
|
||||||
end_ts = int(ts['end'])
|
|
||||||
|
|
||||||
for i in range(start_ts, end_ts, max_duration * SAMPLE_RATE):
|
|
||||||
new_start = i
|
|
||||||
new_end = min(end_ts, i + max_duration * SAMPLE_RATE)
|
|
||||||
duration = round((new_end - new_start) / SAMPLE_RATE, 2)
|
|
||||||
chunk_path = Path(full_save_path, f'{save_name}_{new_start}-{new_end}.opus')
|
|
||||||
chunk_path = chunk_path.absolute() if absolute else chunk_path
|
|
||||||
sf.write(str(chunk_path), audio[new_start: new_end], 16000, format='OGG', subtype='OPUS')
|
|
||||||
chunks.append({'audio_path': chunk_path,
|
|
||||||
'text': '',
|
|
||||||
'duration': duration,
|
|
||||||
'domain': ''})
|
|
||||||
return pd.DataFrame(chunks)
|
|
||||||
Reference in New Issue
Block a user