mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 01:49:22 +08:00
Compare commits
107 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fba061dc55 | ||
|
|
11631356a2 | ||
|
|
34dea51680 | ||
|
|
51fd43130a | ||
|
|
3080062489 | ||
|
|
f974f2d6bc | ||
|
|
f1886d9088 | ||
|
|
4c00cd14be | ||
|
|
5d70880844 | ||
|
|
a16f3ed079 | ||
|
|
b0fbf4bec6 | ||
|
|
ab02267584 | ||
|
|
485a7d91b0 | ||
|
|
1da76acfc3 | ||
|
|
3c70b587e8 | ||
|
|
7aff370d68 | ||
|
|
931eddfdab | ||
|
|
6143b9a5d9 | ||
|
|
8ca8cf7d9b | ||
|
|
ad0fdbe4ac | ||
|
|
06806eb70b | ||
|
|
c90e1603c5 | ||
|
|
023d3a36f0 | ||
|
|
aa2a66cf46 | ||
|
|
b1cd34aae2 | ||
|
|
50be3744fe | ||
|
|
fce776f872 | ||
|
|
fbddc91a5d | ||
|
|
bbf22a0064 | ||
|
|
94811cbe12 | ||
|
|
22a2362b4c | ||
|
|
0dd45f0bcd | ||
|
|
feba8cd5c4 | ||
|
|
6622e562e4 | ||
|
|
d5625d5c38 | ||
|
|
cd92290a15 | ||
|
|
33a9d190fe | ||
|
|
7440bc4689 | ||
|
|
10e7e8a8bc | ||
|
|
5a5b662496 | ||
|
|
9060f664f2 | ||
|
|
94271e9096 | ||
|
|
3f9fffc261 | ||
|
|
eaf633ec9d | ||
|
|
cff5eb2980 | ||
|
|
f356a8081a | ||
|
|
782e30d28f | ||
|
|
caee535cf6 | ||
|
|
8ab5be005f | ||
|
|
9f67a54e87 | ||
|
|
c8df1dee3f | ||
|
|
0189ebd8af | ||
|
|
05e380c1de | ||
|
|
93b9782f28 | ||
|
|
d2ab7c254e | ||
|
|
6217b08bbb | ||
|
|
d53ba1ea11 | ||
|
|
102e6d0962 | ||
|
|
e531cd3462 | ||
|
|
fd41da0b15 | ||
|
|
9db72c35bd | ||
|
|
867a067bee | ||
|
|
2c43391b17 | ||
|
|
6478567951 | ||
|
|
add6e3028e | ||
|
|
e7025ed8c5 | ||
|
|
35d601adc6 | ||
|
|
032ca21a70 | ||
|
|
001d57d6ff | ||
|
|
6e6da04e7a | ||
|
|
9c1eff9169 | ||
|
|
36b759d053 | ||
|
|
1a7499607a | ||
|
|
87451b059f | ||
|
|
becc7770c7 | ||
|
|
3f2eff0303 | ||
|
|
3a25110cf9 | ||
|
|
d23867da10 | ||
|
|
2043282182 | ||
|
|
fa8036ae1c | ||
|
|
2fff4b8ce8 | ||
|
|
64b863d2ff | ||
|
|
8a3600665b | ||
|
|
9c2c90aa1c | ||
|
|
1d48167271 | ||
|
|
d0139d94d9 | ||
|
|
46f94b7d60 | ||
|
|
3de3ee3abe | ||
|
|
e680ea6633 | ||
|
|
199de226e5 | ||
|
|
4109b107c1 | ||
|
|
36854a90db | ||
|
|
827e86e685 | ||
|
|
e706ec6fee | ||
|
|
88df0ce1dd | ||
|
|
d18b91e037 | ||
|
|
1e3f343767 | ||
|
|
6a8ee81ee0 | ||
|
|
cb25c0c047 | ||
|
|
7af8628a27 | ||
|
|
3682cb189c | ||
|
|
57c0b51f9b | ||
|
|
dd0b143803 | ||
|
|
181cdf92b6 | ||
|
|
a7bd2dd38f | ||
|
|
df7de797a5 | ||
|
|
87ed11b508 |
39
.github/workflows/test.yml
vendored
Normal file
39
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
name: Test Package
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: # запуск вручную
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||||
|
python-version: ["3.8","3.9","3.10","3.11","3.12","3.13"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install build hatchling pytest soundfile
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: python -m build --wheel --outdir dist
|
||||||
|
|
||||||
|
- name: Install package
|
||||||
|
run: |
|
||||||
|
import glob, subprocess, sys
|
||||||
|
whl = glob.glob("dist/*.whl")[0]
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", whl])
|
||||||
|
shell: python
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: pytest tests
|
||||||
20
CITATION.cff
Normal file
20
CITATION.cff
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
message: "If you use this software, please cite it as below."
|
||||||
|
title: "Silero VAD"
|
||||||
|
authors:
|
||||||
|
- family-names: "Silero Team"
|
||||||
|
email: "hello@silero.ai"
|
||||||
|
type: software
|
||||||
|
repository-code: "https://github.com/snakers4/silero-vad"
|
||||||
|
license: MIT
|
||||||
|
abstract: "Pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier"
|
||||||
|
preferred-citation:
|
||||||
|
type: software
|
||||||
|
authors:
|
||||||
|
- family-names: "Silero Team"
|
||||||
|
email: "hello@silero.ai"
|
||||||
|
title: "Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier"
|
||||||
|
year: 2024
|
||||||
|
publisher: "GitHub"
|
||||||
|
journal: "GitHub repository"
|
||||||
|
howpublished: "https://github.com/snakers4/silero-vad"
|
||||||
56
README.md
56
README.md
@@ -1,6 +1,6 @@
|
|||||||
[](mailto:hello@silero.ai) [](https://t.me/silero_speech) [](https://github.com/snakers4/silero-vad/blob/master/LICENSE)
|
[](mailto:hello@silero.ai) [](https://t.me/silero_speech) [](https://github.com/snakers4/silero-vad/blob/master/LICENSE) [](https://pypi.org/project/silero-vad/)
|
||||||
|
|
||||||
[](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb)
|
[](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) [](https://github.com/snakers4/silero-vad/actions/workflows/test.yml)
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://github.com/snakers4/silero-vad/assets/36505480/300bd062-4da5-4f19-9736-9c144a45d7a7" />
|
<img src="https://github.com/user-attachments/assets/f2940867-0a51-4bdb-8c14-1129d3c44e64" />
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|
||||||
@@ -22,6 +22,8 @@
|
|||||||
|
|
||||||
https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4
|
https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4
|
||||||
|
|
||||||
|
Please note, that video loads only if you are logged in your GitHub account.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<br/>
|
<br/>
|
||||||
@@ -29,14 +31,46 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
|
|||||||
<h2 align="center">Fast start</h2>
|
<h2 align="center">Fast start</h2>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Dependencies</summary>
|
||||||
|
|
||||||
|
System requirements to run python examples on `x86-64` systems:
|
||||||
|
|
||||||
|
- `python 3.8+`;
|
||||||
|
- 1G+ RAM;
|
||||||
|
- A modern CPU with AVX, AVX2, AVX-512 or AMX instruction sets.
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
|
||||||
|
- `torch>=1.12.0`;
|
||||||
|
- `torchaudio>=0.12.0` (for I/O only);
|
||||||
|
- `onnxruntime>=1.16.1` (for ONNX model usage).
|
||||||
|
|
||||||
|
Silero VAD uses torchaudio library for audio I/O (`torchaudio.info`, `torchaudio.load`, and `torchaudio.save`), so a proper audio backend is required:
|
||||||
|
|
||||||
|
- Option №1 - [**FFmpeg**](https://www.ffmpeg.org/) backend. `conda install -c conda-forge 'ffmpeg<7'`;
|
||||||
|
- Option №2 - [**sox_io**](https://pypi.org/project/sox/) backend. `apt-get install sox`, TorchAudio is tested on libsox 14.4.2;
|
||||||
|
- Option №3 - [**soundfile**](https://pypi.org/project/soundfile/) backend. `pip install soundfile`.
|
||||||
|
|
||||||
|
If you are planning to run the VAD using solely the `onnx-runtime`, it will run on any other system architectures where onnx-runtume is [supported](https://onnxruntime.ai/getting-started). In this case please note that:
|
||||||
|
|
||||||
|
- You will have to implement the I/O;
|
||||||
|
- You will have to adapt the existing wrappers / examples / post-processing for your use-case.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
**Using pip**:
|
**Using pip**:
|
||||||
`pip install silero-vad`
|
`pip install silero-vad`
|
||||||
|
|
||||||
```python3
|
```python3
|
||||||
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
|
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
|
||||||
model = load_silero_vad()
|
model = load_silero_vad()
|
||||||
wav = read_audio('path_to_audio_file') # backend (sox, soundfile, or ffmpeg) required!
|
wav = read_audio('path_to_audio_file')
|
||||||
speech_timestamps = get_speech_timestamps(wav, model)
|
speech_timestamps = get_speech_timestamps(
|
||||||
|
wav,
|
||||||
|
model,
|
||||||
|
return_seconds=True, # Return speech timestamps in seconds (default is samples)
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Using torch.hub**:
|
**Using torch.hub**:
|
||||||
@@ -47,8 +81,12 @@ torch.set_num_threads(1)
|
|||||||
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
||||||
(get_speech_timestamps, _, read_audio, _, _) = utils
|
(get_speech_timestamps, _, read_audio, _, _) = utils
|
||||||
|
|
||||||
wav = read_audio('path_to_audio_file') # backend (sox, soundfile, or ffmpeg) required!
|
wav = read_audio('path_to_audio_file')
|
||||||
speech_timestamps = get_speech_timestamps(wav, model)
|
speech_timestamps = get_speech_timestamps(
|
||||||
|
wav,
|
||||||
|
model,
|
||||||
|
return_seconds=True, # Return speech timestamps in seconds (default is samples)
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
<br/>
|
<br/>
|
||||||
@@ -120,7 +158,7 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for releva
|
|||||||
@misc{Silero VAD,
|
@misc{Silero VAD,
|
||||||
author = {Silero Team},
|
author = {Silero Team},
|
||||||
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
|
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
|
||||||
year = {2021},
|
year = {2024},
|
||||||
publisher = {GitHub},
|
publisher = {GitHub},
|
||||||
journal = {GitHub repository},
|
journal = {GitHub repository},
|
||||||
howpublished = {\url{https://github.com/snakers4/silero-vad}},
|
howpublished = {\url{https://github.com/snakers4/silero-vad}},
|
||||||
@@ -137,4 +175,4 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for releva
|
|||||||
|
|
||||||
- Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web
|
- Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web
|
||||||
|
|
||||||
- [Rust](https://github.com/snakers4/silero-vad/tree/master/examples/rust-example), [Go](https://github.com/snakers4/silero-vad/tree/master/examples/go), [Java](https://github.com/snakers4/silero-vad/tree/master/examples/java-example) and [other](https://github.com/snakers4/silero-vad/tree/master/examples) examples
|
- [Rust](https://github.com/snakers4/silero-vad/tree/master/examples/rust-example), [Go](https://github.com/snakers4/silero-vad/tree/master/examples/go), [Java](https://github.com/snakers4/silero-vad/tree/master/examples/java-example), [C++](https://github.com/snakers4/silero-vad/tree/master/examples/cpp), [C#](https://github.com/snakers4/silero-vad/tree/master/examples/csharp) and [other](https://github.com/snakers4/silero-vad/tree/master/examples) community examples
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"#!apt install ffmpeg\n",
|
||||||
"!pip -q install pydub\n",
|
"!pip -q install pydub\n",
|
||||||
"from google.colab import output\n",
|
"from google.colab import output\n",
|
||||||
"from base64 import b64decode, b64encode\n",
|
"from base64 import b64decode, b64encode\n",
|
||||||
@@ -37,13 +38,12 @@
|
|||||||
" model='silero_vad',\n",
|
" model='silero_vad',\n",
|
||||||
" force_reload=True)\n",
|
" force_reload=True)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def int2float(sound):\n",
|
"def int2float(audio):\n",
|
||||||
" abs_max = np.abs(sound).max()\n",
|
" samples = audio.get_array_of_samples()\n",
|
||||||
" sound = sound.astype('float32')\n",
|
" new_sound = audio._spawn(samples)\n",
|
||||||
" if abs_max > 0:\n",
|
" arr = np.array(samples).astype(np.float32)\n",
|
||||||
" sound *= 1/32768\n",
|
" arr = arr / np.abs(arr).max()\n",
|
||||||
" sound = sound.squeeze()\n",
|
" return arr\n",
|
||||||
" return sound\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"AUDIO_HTML = \"\"\"\n",
|
"AUDIO_HTML = \"\"\"\n",
|
||||||
"<script>\n",
|
"<script>\n",
|
||||||
@@ -68,10 +68,10 @@
|
|||||||
" //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n",
|
" //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n",
|
||||||
" mimeType : 'audio/webm;codecs=opus'\n",
|
" mimeType : 'audio/webm;codecs=opus'\n",
|
||||||
" //mimeType : 'audio/webm;codecs=pcm'\n",
|
" //mimeType : 'audio/webm;codecs=pcm'\n",
|
||||||
" }; \n",
|
" };\n",
|
||||||
" //recorder = new MediaRecorder(stream, options);\n",
|
" //recorder = new MediaRecorder(stream, options);\n",
|
||||||
" recorder = new MediaRecorder(stream);\n",
|
" recorder = new MediaRecorder(stream);\n",
|
||||||
" recorder.ondataavailable = function(e) { \n",
|
" recorder.ondataavailable = function(e) {\n",
|
||||||
" var url = URL.createObjectURL(e.data);\n",
|
" var url = URL.createObjectURL(e.data);\n",
|
||||||
" // var preview = document.createElement('audio');\n",
|
" // var preview = document.createElement('audio');\n",
|
||||||
" // preview.controls = true;\n",
|
" // preview.controls = true;\n",
|
||||||
@@ -79,7 +79,7 @@
|
|||||||
" // document.body.appendChild(preview);\n",
|
" // document.body.appendChild(preview);\n",
|
||||||
"\n",
|
"\n",
|
||||||
" reader = new FileReader();\n",
|
" reader = new FileReader();\n",
|
||||||
" reader.readAsDataURL(e.data); \n",
|
" reader.readAsDataURL(e.data);\n",
|
||||||
" reader.onloadend = function() {\n",
|
" reader.onloadend = function() {\n",
|
||||||
" base64data = reader.result;\n",
|
" base64data = reader.result;\n",
|
||||||
" //console.log(\"Inside FileReader:\" + base64data);\n",
|
" //console.log(\"Inside FileReader:\" + base64data);\n",
|
||||||
@@ -121,7 +121,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"});\n",
|
"});\n",
|
||||||
" \n",
|
"\n",
|
||||||
"</script>\n",
|
"</script>\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -133,8 +133,8 @@
|
|||||||
" audio.export('test.mp3', format='mp3')\n",
|
" audio.export('test.mp3', format='mp3')\n",
|
||||||
" audio = audio.set_channels(1)\n",
|
" audio = audio.set_channels(1)\n",
|
||||||
" audio = audio.set_frame_rate(16000)\n",
|
" audio = audio.set_frame_rate(16000)\n",
|
||||||
" audio_float = int2float(np.array(audio.get_array_of_samples()))\n",
|
" audio_float = int2float(audio)\n",
|
||||||
" audio_tens = torch.tensor(audio_float )\n",
|
" audio_tens = torch.tensor(audio_float)\n",
|
||||||
" return audio_tens\n",
|
" return audio_tens\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def make_animation(probs, audio_duration, interval=40):\n",
|
"def make_animation(probs, audio_duration, interval=40):\n",
|
||||||
@@ -154,19 +154,18 @@
|
|||||||
" def animate(i):\n",
|
" def animate(i):\n",
|
||||||
" x = i * interval / 1000 - 0.04\n",
|
" x = i * interval / 1000 - 0.04\n",
|
||||||
" y = np.linspace(0, 1.02, 2)\n",
|
" y = np.linspace(0, 1.02, 2)\n",
|
||||||
" \n",
|
"\n",
|
||||||
" line.set_data(x, y)\n",
|
" line.set_data(x, y)\n",
|
||||||
" line.set_color('#990000')\n",
|
" line.set_color('#990000')\n",
|
||||||
" return line,\n",
|
" return line,\n",
|
||||||
|
" anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=int(audio_duration / (interval / 1000)))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=audio_duration / (interval / 1000))\n",
|
" f = r\"animation.mp4\"\n",
|
||||||
"\n",
|
" writervideo = FFMpegWriter(fps=1000/interval)\n",
|
||||||
" f = r\"animation.mp4\" \n",
|
|
||||||
" writervideo = FFMpegWriter(fps=1000/interval) \n",
|
|
||||||
" anim.save(f, writer=writervideo)\n",
|
" anim.save(f, writer=writervideo)\n",
|
||||||
" plt.close('all')\n",
|
" plt.close('all')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def combine_audio(vidname, audname, outname, fps=25): \n",
|
"def combine_audio(vidname, audname, outname, fps=25):\n",
|
||||||
" my_clip = mpe.VideoFileClip(vidname, verbose=False)\n",
|
" my_clip = mpe.VideoFileClip(vidname, verbose=False)\n",
|
||||||
" audio_background = mpe.AudioFileClip(audname)\n",
|
" audio_background = mpe.AudioFileClip(audname)\n",
|
||||||
" final_clip = my_clip.set_audio(audio_background)\n",
|
" final_clip = my_clip.set_audio(audio_background)\n",
|
||||||
@@ -174,15 +173,10 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"def record_make_animation():\n",
|
"def record_make_animation():\n",
|
||||||
" tensor = record()\n",
|
" tensor = record()\n",
|
||||||
"\n",
|
|
||||||
" print('Calculating probabilities...')\n",
|
" print('Calculating probabilities...')\n",
|
||||||
" speech_probs = []\n",
|
" speech_probs = []\n",
|
||||||
" window_size_samples = 512\n",
|
" window_size_samples = 512\n",
|
||||||
" for i in range(0, len(tensor), window_size_samples):\n",
|
" speech_probs = model.audio_forward(tensor, sr=16000)[0].tolist()\n",
|
||||||
" if len(tensor[i: i+ window_size_samples]) < window_size_samples:\n",
|
|
||||||
" break\n",
|
|
||||||
" speech_prob = model(tensor[i: i+ window_size_samples], 16000).item()\n",
|
|
||||||
" speech_probs.append(speech_prob)\n",
|
|
||||||
" model.reset_states()\n",
|
" model.reset_states()\n",
|
||||||
" print('Making animation...')\n",
|
" print('Making animation...')\n",
|
||||||
" make_animation(speech_probs, len(tensor) / 16000)\n",
|
" make_animation(speech_probs, len(tensor) / 16000)\n",
|
||||||
@@ -196,7 +190,9 @@
|
|||||||
" <video width=800 controls>\n",
|
" <video width=800 controls>\n",
|
||||||
" <source src=\"%s\" type=\"video/mp4\">\n",
|
" <source src=\"%s\" type=\"video/mp4\">\n",
|
||||||
" </video>\n",
|
" </video>\n",
|
||||||
" \"\"\" % data_url))"
|
" \"\"\" % data_url))\n",
|
||||||
|
"\n",
|
||||||
|
" return speech_probs"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -216,7 +212,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"record_make_animation()"
|
"speech_probs = record_make_animation()"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,211 +1,227 @@
|
|||||||
|
#ifndef _CRT_SECURE_NO_WARNINGS
|
||||||
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
#include <iomanip>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <iostream>
|
|
||||||
#include <string>
|
|
||||||
#include "onnxruntime_cxx_api.h"
|
|
||||||
#include "wav.h"
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
|
#include <cmath> // for std::rint
|
||||||
#if __cplusplus < 201703L
|
#if __cplusplus < 201703L
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
//#define __DEBUG_SPEECH_PROB___
|
//#define __DEBUG_SPEECH_PROB___
|
||||||
|
|
||||||
class timestamp_t
|
#include "onnxruntime_cxx_api.h"
|
||||||
{
|
#include "wav.h" // For reading WAV files
|
||||||
|
|
||||||
|
// timestamp_t class: stores the start and end (in samples) of a speech segment.
|
||||||
|
class timestamp_t {
|
||||||
public:
|
public:
|
||||||
int start;
|
int start;
|
||||||
int end;
|
int end;
|
||||||
|
|
||||||
// default + parameterized constructor
|
|
||||||
timestamp_t(int start = -1, int end = -1)
|
timestamp_t(int start = -1, int end = -1)
|
||||||
: start(start), end(end)
|
: start(start), end(end) { }
|
||||||
{
|
|
||||||
};
|
|
||||||
|
|
||||||
// assignment operator modifies object, therefore non-const
|
timestamp_t& operator=(const timestamp_t& a) {
|
||||||
timestamp_t& operator=(const timestamp_t& a)
|
|
||||||
{
|
|
||||||
start = a.start;
|
start = a.start;
|
||||||
end = a.end;
|
end = a.end;
|
||||||
return *this;
|
return *this;
|
||||||
};
|
}
|
||||||
|
|
||||||
// equality comparison. doesn't modify object. therefore const.
|
bool operator==(const timestamp_t& a) const {
|
||||||
bool operator==(const timestamp_t& a) const
|
|
||||||
{
|
|
||||||
return (start == a.start && end == a.end);
|
return (start == a.start && end == a.end);
|
||||||
};
|
}
|
||||||
std::string c_str()
|
|
||||||
{
|
// Returns a formatted string of the timestamp.
|
||||||
//return std::format("timestamp {:08d}, {:08d}", start, end);
|
std::string c_str() const {
|
||||||
return format("{start:%08d,end:%08d}", start, end);
|
return format("{start:%08d, end:%08d}", start, end);
|
||||||
};
|
}
|
||||||
private:
|
private:
|
||||||
|
// Helper function for formatting.
|
||||||
std::string format(const char* fmt, ...)
|
std::string format(const char* fmt, ...) const {
|
||||||
{
|
|
||||||
char buf[256];
|
char buf[256];
|
||||||
|
|
||||||
va_list args;
|
va_list args;
|
||||||
va_start(args, fmt);
|
va_start(args, fmt);
|
||||||
const auto r = std::vsnprintf(buf, sizeof buf, fmt, args);
|
const auto r = std::vsnprintf(buf, sizeof(buf), fmt, args);
|
||||||
va_end(args);
|
va_end(args);
|
||||||
|
|
||||||
if (r < 0)
|
if (r < 0)
|
||||||
// conversion failed
|
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
const size_t len = r;
|
const size_t len = r;
|
||||||
if (len < sizeof buf)
|
if (len < sizeof(buf))
|
||||||
// we fit in the buffer
|
return std::string(buf, len);
|
||||||
return { buf, len };
|
|
||||||
|
|
||||||
#if __cplusplus >= 201703L
|
#if __cplusplus >= 201703L
|
||||||
// C++17: Create a string and write to its underlying array
|
|
||||||
std::string s(len, '\0');
|
std::string s(len, '\0');
|
||||||
va_start(args, fmt);
|
va_start(args, fmt);
|
||||||
std::vsnprintf(s.data(), len + 1, fmt, args);
|
std::vsnprintf(s.data(), len + 1, fmt, args);
|
||||||
va_end(args);
|
va_end(args);
|
||||||
|
|
||||||
return s;
|
return s;
|
||||||
#else
|
#else
|
||||||
// C++11 or C++14: We need to allocate scratch memory
|
|
||||||
auto vbuf = std::unique_ptr<char[]>(new char[len + 1]);
|
auto vbuf = std::unique_ptr<char[]>(new char[len + 1]);
|
||||||
va_start(args, fmt);
|
va_start(args, fmt);
|
||||||
std::vsnprintf(vbuf.get(), len + 1, fmt, args);
|
std::vsnprintf(vbuf.get(), len + 1, fmt, args);
|
||||||
va_end(args);
|
va_end(args);
|
||||||
|
return std::string(vbuf.get(), len);
|
||||||
return { vbuf.get(), len };
|
|
||||||
#endif
|
#endif
|
||||||
};
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// VadIterator class: uses ONNX Runtime to detect speech segments.
|
||||||
class VadIterator
|
class VadIterator {
|
||||||
{
|
|
||||||
private:
|
private:
|
||||||
// OnnxRuntime resources
|
// ONNX Runtime resources
|
||||||
Ort::Env env;
|
Ort::Env env;
|
||||||
Ort::SessionOptions session_options;
|
Ort::SessionOptions session_options;
|
||||||
std::shared_ptr<Ort::Session> session = nullptr;
|
std::shared_ptr<Ort::Session> session = nullptr;
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
|
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
|
||||||
|
|
||||||
private:
|
// ----- Context-related additions -----
|
||||||
void init_engine_threads(int inter_threads, int intra_threads)
|
const int context_samples = 64; // For 16kHz, 64 samples are added as context.
|
||||||
{
|
std::vector<float> _context; // Holds the last 64 samples from the previous chunk (initialized to zero).
|
||||||
// The method should be called in each thread/proc in multi-thread/proc work
|
|
||||||
|
// Original window size (e.g., 32ms corresponds to 512 samples)
|
||||||
|
int window_size_samples;
|
||||||
|
// Effective window size = window_size_samples + context_samples
|
||||||
|
int effective_window_size;
|
||||||
|
|
||||||
|
// Additional declaration: samples per millisecond
|
||||||
|
int sr_per_ms;
|
||||||
|
|
||||||
|
// ONNX Runtime input/output buffers
|
||||||
|
std::vector<Ort::Value> ort_inputs;
|
||||||
|
std::vector<const char*> input_node_names = { "input", "state", "sr" };
|
||||||
|
std::vector<float> input;
|
||||||
|
unsigned int size_state = 2 * 1 * 128;
|
||||||
|
std::vector<float> _state;
|
||||||
|
std::vector<int64_t> sr;
|
||||||
|
int64_t input_node_dims[2] = {};
|
||||||
|
const int64_t state_node_dims[3] = { 2, 1, 128 };
|
||||||
|
const int64_t sr_node_dims[1] = { 1 };
|
||||||
|
std::vector<Ort::Value> ort_outputs;
|
||||||
|
std::vector<const char*> output_node_names = { "output", "stateN" };
|
||||||
|
|
||||||
|
// Model configuration parameters
|
||||||
|
int sample_rate;
|
||||||
|
float threshold;
|
||||||
|
int min_silence_samples;
|
||||||
|
int min_silence_samples_at_max_speech;
|
||||||
|
int min_speech_samples;
|
||||||
|
float max_speech_samples;
|
||||||
|
int speech_pad_samples;
|
||||||
|
int audio_length_samples;
|
||||||
|
|
||||||
|
// State management
|
||||||
|
bool triggered = false;
|
||||||
|
unsigned int temp_end = 0;
|
||||||
|
unsigned int current_sample = 0;
|
||||||
|
int prev_end;
|
||||||
|
int next_start = 0;
|
||||||
|
std::vector<timestamp_t> speeches;
|
||||||
|
timestamp_t current_speech;
|
||||||
|
|
||||||
|
// Loads the ONNX model.
|
||||||
|
void init_onnx_model(const std::wstring& model_path) {
|
||||||
|
init_engine_threads(1, 1);
|
||||||
|
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initializes threading settings.
|
||||||
|
void init_engine_threads(int inter_threads, int intra_threads) {
|
||||||
session_options.SetIntraOpNumThreads(intra_threads);
|
session_options.SetIntraOpNumThreads(intra_threads);
|
||||||
session_options.SetInterOpNumThreads(inter_threads);
|
session_options.SetInterOpNumThreads(inter_threads);
|
||||||
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
|
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||||
};
|
}
|
||||||
|
|
||||||
void init_onnx_model(const std::wstring& model_path)
|
// Resets internal state (_state, _context, etc.)
|
||||||
{
|
void reset_states() {
|
||||||
// Init threads = 1 for
|
std::memset(_state.data(), 0, _state.size() * sizeof(float));
|
||||||
init_engine_threads(1, 1);
|
|
||||||
// Load model
|
|
||||||
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
|
|
||||||
};
|
|
||||||
|
|
||||||
void reset_states()
|
|
||||||
{
|
|
||||||
// Call reset before each audio start
|
|
||||||
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
|
|
||||||
triggered = false;
|
triggered = false;
|
||||||
temp_end = 0;
|
temp_end = 0;
|
||||||
current_sample = 0;
|
current_sample = 0;
|
||||||
|
|
||||||
prev_end = next_start = 0;
|
prev_end = next_start = 0;
|
||||||
|
|
||||||
speeches.clear();
|
speeches.clear();
|
||||||
current_speech = timestamp_t();
|
current_speech = timestamp_t();
|
||||||
};
|
std::fill(_context.begin(), _context.end(), 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
void predict(const std::vector<float> &data)
|
// Inference: runs inference on one chunk of input data.
|
||||||
{
|
// data_chunk is expected to have window_size_samples samples.
|
||||||
// Infer
|
void predict(const std::vector<float>& data_chunk) {
|
||||||
// Create ort tensors
|
// Build new input: first context_samples from _context, followed by the current chunk (window_size_samples).
|
||||||
input.assign(data.begin(), data.end());
|
std::vector<float> new_data(effective_window_size, 0.0f);
|
||||||
|
std::copy(_context.begin(), _context.end(), new_data.begin());
|
||||||
|
std::copy(data_chunk.begin(), data_chunk.end(), new_data.begin() + context_samples);
|
||||||
|
input = new_data;
|
||||||
|
|
||||||
|
// Create input tensor (input_node_dims[1] is already set to effective_window_size).
|
||||||
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
|
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
|
||||||
memory_info, input.data(), input.size(), input_node_dims, 2);
|
memory_info, input.data(), input.size(), input_node_dims, 2);
|
||||||
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
|
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
|
||||||
memory_info, _state.data(), _state.size(), state_node_dims, 3);
|
memory_info, _state.data(), _state.size(), state_node_dims, 3);
|
||||||
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
|
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
|
||||||
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
|
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
|
||||||
|
|
||||||
// Clear and add inputs
|
|
||||||
ort_inputs.clear();
|
ort_inputs.clear();
|
||||||
ort_inputs.emplace_back(std::move(input_ort));
|
ort_inputs.emplace_back(std::move(input_ort));
|
||||||
ort_inputs.emplace_back(std::move(state_ort));
|
ort_inputs.emplace_back(std::move(state_ort));
|
||||||
ort_inputs.emplace_back(std::move(sr_ort));
|
ort_inputs.emplace_back(std::move(sr_ort));
|
||||||
|
|
||||||
// Infer
|
// Run inference.
|
||||||
ort_outputs = session->Run(
|
ort_outputs = session->Run(
|
||||||
Ort::RunOptions{nullptr},
|
Ort::RunOptions{ nullptr },
|
||||||
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
|
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
|
||||||
output_node_names.data(), output_node_names.size());
|
output_node_names.data(), output_node_names.size());
|
||||||
|
|
||||||
// Output probability & update h,c recursively
|
|
||||||
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
|
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
|
||||||
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
|
float* stateN = ort_outputs[1].GetTensorMutableData<float>();
|
||||||
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
|
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
|
||||||
|
current_sample += static_cast<unsigned int>(window_size_samples); // Advance by the original window size.
|
||||||
|
|
||||||
// Push forward sample index
|
// If speech is detected (probability >= threshold)
|
||||||
current_sample += window_size_samples;
|
if (speech_prob >= threshold) {
|
||||||
|
|
||||||
// Reset temp_end when > threshold
|
|
||||||
if ((speech_prob >= threshold))
|
|
||||||
{
|
|
||||||
#ifdef __DEBUG_SPEECH_PROB___
|
#ifdef __DEBUG_SPEECH_PROB___
|
||||||
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
|
float speech = current_sample - window_size_samples;
|
||||||
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample- window_size_samples);
|
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0f * speech / sample_rate, speech_prob, current_sample - window_size_samples);
|
||||||
#endif //__DEBUG_SPEECH_PROB___
|
#endif
|
||||||
if (temp_end != 0)
|
if (temp_end != 0) {
|
||||||
{
|
|
||||||
temp_end = 0;
|
temp_end = 0;
|
||||||
if (next_start < prev_end)
|
if (next_start < prev_end)
|
||||||
next_start = current_sample - window_size_samples;
|
next_start = current_sample - window_size_samples;
|
||||||
}
|
}
|
||||||
if (triggered == false)
|
if (!triggered) {
|
||||||
{
|
|
||||||
triggered = true;
|
triggered = true;
|
||||||
|
|
||||||
current_speech.start = current_sample - window_size_samples;
|
current_speech.start = current_sample - window_size_samples;
|
||||||
}
|
}
|
||||||
|
// Update context: copy the last context_samples from new_data.
|
||||||
|
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
// If the speech segment becomes too long.
|
||||||
(triggered == true)
|
if (triggered && ((current_sample - current_speech.start) > max_speech_samples)) {
|
||||||
&& ((current_sample - current_speech.start) > max_speech_samples)
|
|
||||||
) {
|
|
||||||
if (prev_end > 0) {
|
if (prev_end > 0) {
|
||||||
current_speech.end = prev_end;
|
current_speech.end = prev_end;
|
||||||
speeches.push_back(current_speech);
|
speeches.push_back(current_speech);
|
||||||
current_speech = timestamp_t();
|
current_speech = timestamp_t();
|
||||||
|
|
||||||
// previously reached silence(< neg_thres) and is still not speech(< thres)
|
|
||||||
if (next_start < prev_end)
|
if (next_start < prev_end)
|
||||||
triggered = false;
|
triggered = false;
|
||||||
else{
|
else
|
||||||
current_speech.start = next_start;
|
current_speech.start = next_start;
|
||||||
}
|
|
||||||
prev_end = 0;
|
prev_end = 0;
|
||||||
next_start = 0;
|
next_start = 0;
|
||||||
temp_end = 0;
|
temp_end = 0;
|
||||||
|
|
||||||
}
|
}
|
||||||
else{
|
else {
|
||||||
current_speech.end = current_sample;
|
current_speech.end = current_sample;
|
||||||
speeches.push_back(current_speech);
|
speeches.push_back(current_speech);
|
||||||
current_speech = timestamp_t();
|
current_speech = timestamp_t();
|
||||||
@@ -214,53 +230,29 @@ private:
|
|||||||
temp_end = 0;
|
temp_end = 0;
|
||||||
triggered = false;
|
triggered = false;
|
||||||
}
|
}
|
||||||
|
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
|
||||||
return;
|
return;
|
||||||
|
|
||||||
}
|
}
|
||||||
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold))
|
|
||||||
{
|
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) {
|
||||||
|
// When the speech probability temporarily drops but is still in speech, update context without changing state.
|
||||||
|
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speech_prob < (threshold - 0.15)) {
|
||||||
|
#ifdef __DEBUG_SPEECH_PROB___
|
||||||
|
float speech = current_sample - window_size_samples - speech_pad_samples;
|
||||||
|
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0f * speech / sample_rate, speech_prob, current_sample - window_size_samples);
|
||||||
|
#endif
|
||||||
if (triggered) {
|
if (triggered) {
|
||||||
#ifdef __DEBUG_SPEECH_PROB___
|
|
||||||
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
|
|
||||||
printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
|
|
||||||
#endif //__DEBUG_SPEECH_PROB___
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
#ifdef __DEBUG_SPEECH_PROB___
|
|
||||||
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
|
|
||||||
printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
|
|
||||||
#endif //__DEBUG_SPEECH_PROB___
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// 4) End
|
|
||||||
if ((speech_prob < (threshold - 0.15)))
|
|
||||||
{
|
|
||||||
#ifdef __DEBUG_SPEECH_PROB___
|
|
||||||
float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point.
|
|
||||||
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
|
|
||||||
#endif //__DEBUG_SPEECH_PROB___
|
|
||||||
if (triggered == true)
|
|
||||||
{
|
|
||||||
if (temp_end == 0)
|
if (temp_end == 0)
|
||||||
{
|
|
||||||
temp_end = current_sample;
|
temp_end = current_sample;
|
||||||
}
|
|
||||||
if (current_sample - temp_end > min_silence_samples_at_max_speech)
|
if (current_sample - temp_end > min_silence_samples_at_max_speech)
|
||||||
prev_end = temp_end;
|
prev_end = temp_end;
|
||||||
// a. silence < min_slience_samples, continue speaking
|
if ((current_sample - temp_end) >= min_silence_samples) {
|
||||||
if ((current_sample - temp_end) < min_silence_samples)
|
|
||||||
{
|
|
||||||
|
|
||||||
}
|
|
||||||
// b. silence >= min_slience_samples, end speaking
|
|
||||||
else
|
|
||||||
{
|
|
||||||
current_speech.end = temp_end;
|
current_speech.end = temp_end;
|
||||||
if (current_speech.end - current_speech.start > min_speech_samples)
|
if (current_speech.end - current_speech.start > min_speech_samples) {
|
||||||
{
|
|
||||||
speeches.push_back(current_speech);
|
speeches.push_back(current_speech);
|
||||||
current_speech = timestamp_t();
|
current_speech = timestamp_t();
|
||||||
prev_end = 0;
|
prev_end = 0;
|
||||||
@@ -270,27 +262,23 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
|
||||||
// may first windows see end state.
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void process(const std::vector<float>& input_wav)
|
// Process the entire audio input.
|
||||||
{
|
void process(const std::vector<float>& input_wav) {
|
||||||
reset_states();
|
reset_states();
|
||||||
|
audio_length_samples = static_cast<int>(input_wav.size());
|
||||||
audio_length_samples = input_wav.size();
|
// Process audio in chunks of window_size_samples (e.g., 512 samples)
|
||||||
|
for (size_t j = 0; j < static_cast<size_t>(audio_length_samples); j += static_cast<size_t>(window_size_samples)) {
|
||||||
for (int j = 0; j < audio_length_samples; j += window_size_samples)
|
if (j + static_cast<size_t>(window_size_samples) > static_cast<size_t>(audio_length_samples))
|
||||||
{
|
|
||||||
if (j + window_size_samples > audio_length_samples)
|
|
||||||
break;
|
break;
|
||||||
std::vector<float> r{ &input_wav[0] + j, &input_wav[0] + j + window_size_samples };
|
std::vector<float> chunk(&input_wav[j], &input_wav[j] + window_size_samples);
|
||||||
predict(r);
|
predict(chunk);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (current_speech.start >= 0) {
|
if (current_speech.start >= 0) {
|
||||||
current_speech.end = audio_length_samples;
|
current_speech.end = audio_length_samples;
|
||||||
speeches.push_back(current_speech);
|
speeches.push_back(current_speech);
|
||||||
@@ -300,179 +288,80 @@ public:
|
|||||||
temp_end = 0;
|
temp_end = 0;
|
||||||
triggered = false;
|
triggered = false;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
void process(const std::vector<float>& input_wav, std::vector<float>& output_wav)
|
|
||||||
{
|
|
||||||
process(input_wav);
|
|
||||||
collect_chunks(input_wav, output_wav);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void collect_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
|
// Returns the detected speech timestamps.
|
||||||
{
|
const std::vector<timestamp_t> get_speech_timestamps() const {
|
||||||
output_wav.clear();
|
|
||||||
for (int i = 0; i < speeches.size(); i++) {
|
|
||||||
#ifdef __DEBUG_SPEECH_PROB___
|
|
||||||
std::cout << speeches[i].c_str() << std::endl;
|
|
||||||
#endif //#ifdef __DEBUG_SPEECH_PROB___
|
|
||||||
std::vector<float> slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]);
|
|
||||||
output_wav.insert(output_wav.end(),slice.begin(),slice.end());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const std::vector<timestamp_t> get_speech_timestamps() const
|
|
||||||
{
|
|
||||||
return speeches;
|
return speeches;
|
||||||
}
|
}
|
||||||
|
|
||||||
void drop_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
|
// Public method to reset the internal state.
|
||||||
{
|
void reset() {
|
||||||
output_wav.clear();
|
reset_states();
|
||||||
int current_start = 0;
|
}
|
||||||
for (int i = 0; i < speeches.size(); i++) {
|
|
||||||
|
|
||||||
std::vector<float> slice(&input_wav[current_start],&input_wav[speeches[i].start]);
|
|
||||||
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
|
|
||||||
current_start = speeches[i].end;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> slice(&input_wav[current_start], &input_wav[input_wav.size()]);
|
|
||||||
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
|
|
||||||
};
|
|
||||||
|
|
||||||
private:
|
|
||||||
// model config
|
|
||||||
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
|
|
||||||
int sample_rate; //Assign when init support 16000 or 8000
|
|
||||||
int sr_per_ms; // Assign when init, support 8 or 16
|
|
||||||
float threshold;
|
|
||||||
int min_silence_samples; // sr_per_ms * #ms
|
|
||||||
int min_silence_samples_at_max_speech; // sr_per_ms * #98
|
|
||||||
int min_speech_samples; // sr_per_ms * #ms
|
|
||||||
float max_speech_samples;
|
|
||||||
int speech_pad_samples; // usually a
|
|
||||||
int audio_length_samples;
|
|
||||||
|
|
||||||
// model states
|
|
||||||
bool triggered = false;
|
|
||||||
unsigned int temp_end = 0;
|
|
||||||
unsigned int current_sample = 0;
|
|
||||||
// MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
|
|
||||||
int prev_end;
|
|
||||||
int next_start = 0;
|
|
||||||
|
|
||||||
//Output timestamp
|
|
||||||
std::vector<timestamp_t> speeches;
|
|
||||||
timestamp_t current_speech;
|
|
||||||
|
|
||||||
|
|
||||||
// Onnx model
|
|
||||||
// Inputs
|
|
||||||
std::vector<Ort::Value> ort_inputs;
|
|
||||||
|
|
||||||
std::vector<const char *> input_node_names = {"input", "state", "sr"};
|
|
||||||
std::vector<float> input;
|
|
||||||
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
|
|
||||||
std::vector<float> _state;
|
|
||||||
std::vector<int64_t> sr;
|
|
||||||
|
|
||||||
int64_t input_node_dims[2] = {};
|
|
||||||
const int64_t state_node_dims[3] = {2, 1, 128};
|
|
||||||
const int64_t sr_node_dims[1] = {1};
|
|
||||||
|
|
||||||
// Outputs
|
|
||||||
std::vector<Ort::Value> ort_outputs;
|
|
||||||
std::vector<const char *> output_node_names = {"output", "stateN"};
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Construction
|
// Constructor: sets model path, sample rate, window size (ms), and other parameters.
|
||||||
|
// The parameters are set to match the Python version.
|
||||||
VadIterator(const std::wstring ModelPath,
|
VadIterator(const std::wstring ModelPath,
|
||||||
int Sample_rate = 16000, int windows_frame_size = 32,
|
int Sample_rate = 16000, int windows_frame_size = 32,
|
||||||
float Threshold = 0.5, int min_silence_duration_ms = 0,
|
float Threshold = 0.5, int min_silence_duration_ms = 100,
|
||||||
int speech_pad_ms = 32, int min_speech_duration_ms = 32,
|
int speech_pad_ms = 30, int min_speech_duration_ms = 250,
|
||||||
float max_speech_duration_s = std::numeric_limits<float>::infinity())
|
float max_speech_duration_s = std::numeric_limits<float>::infinity())
|
||||||
|
: sample_rate(Sample_rate), threshold(Threshold), speech_pad_samples(speech_pad_ms), prev_end(0)
|
||||||
{
|
{
|
||||||
init_onnx_model(ModelPath);
|
sr_per_ms = sample_rate / 1000; // e.g., 16000 / 1000 = 16
|
||||||
threshold = Threshold;
|
window_size_samples = windows_frame_size * sr_per_ms; // e.g., 32ms * 16 = 512 samples
|
||||||
sample_rate = Sample_rate;
|
effective_window_size = window_size_samples + context_samples; // e.g., 512 + 64 = 576 samples
|
||||||
sr_per_ms = sample_rate / 1000;
|
|
||||||
|
|
||||||
window_size_samples = windows_frame_size * sr_per_ms;
|
|
||||||
|
|
||||||
min_speech_samples = sr_per_ms * min_speech_duration_ms;
|
|
||||||
speech_pad_samples = sr_per_ms * speech_pad_ms;
|
|
||||||
|
|
||||||
max_speech_samples = (
|
|
||||||
sample_rate * max_speech_duration_s
|
|
||||||
- window_size_samples
|
|
||||||
- 2 * speech_pad_samples
|
|
||||||
);
|
|
||||||
|
|
||||||
min_silence_samples = sr_per_ms * min_silence_duration_ms;
|
|
||||||
min_silence_samples_at_max_speech = sr_per_ms * 98;
|
|
||||||
|
|
||||||
input.resize(window_size_samples);
|
|
||||||
input_node_dims[0] = 1;
|
input_node_dims[0] = 1;
|
||||||
input_node_dims[1] = window_size_samples;
|
input_node_dims[1] = effective_window_size;
|
||||||
|
|
||||||
_state.resize(size_state);
|
_state.resize(size_state);
|
||||||
sr.resize(1);
|
sr.resize(1);
|
||||||
sr[0] = sample_rate;
|
sr[0] = sample_rate;
|
||||||
};
|
_context.assign(context_samples, 0.0f);
|
||||||
|
min_speech_samples = sr_per_ms * min_speech_duration_ms;
|
||||||
|
max_speech_samples = (sample_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples);
|
||||||
|
min_silence_samples = sr_per_ms * min_silence_duration_ms;
|
||||||
|
min_silence_samples_at_max_speech = sr_per_ms * 98;
|
||||||
|
init_onnx_model(ModelPath);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
int main()
|
int main() {
|
||||||
{
|
// Read the WAV file (expects 16000 Hz, mono, PCM).
|
||||||
std::vector<timestamp_t> stamps;
|
wav::WavReader wav_reader("audio/recorder.wav"); // File located in the "audio" folder.
|
||||||
|
int numSamples = wav_reader.num_samples();
|
||||||
// Read wav
|
std::vector<float> input_wav(static_cast<size_t>(numSamples));
|
||||||
wav::WavReader wav_reader("recorder.wav"); //16000,1,32float
|
for (size_t i = 0; i < static_cast<size_t>(numSamples); i++) {
|
||||||
std::vector<float> input_wav(wav_reader.num_samples());
|
|
||||||
std::vector<float> output_wav;
|
|
||||||
|
|
||||||
for (int i = 0; i < wav_reader.num_samples(); i++)
|
|
||||||
{
|
|
||||||
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
|
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the ONNX model path (file located in the "model" folder).
|
||||||
|
std::wstring model_path = L"model/silero_vad.onnx";
|
||||||
|
|
||||||
|
// Initialize the VadIterator.
|
||||||
|
VadIterator vad(model_path);
|
||||||
|
|
||||||
// ===== Test configs =====
|
// Process the audio.
|
||||||
std::wstring path = L"silero_vad.onnx";
|
|
||||||
VadIterator vad(path);
|
|
||||||
|
|
||||||
// ==============================================
|
|
||||||
// ==== = Example 1 of full function =====
|
|
||||||
// ==============================================
|
|
||||||
vad.process(input_wav);
|
vad.process(input_wav);
|
||||||
|
|
||||||
// 1.a get_speech_timestamps
|
// Retrieve the speech timestamps (in samples).
|
||||||
stamps = vad.get_speech_timestamps();
|
std::vector<timestamp_t> stamps = vad.get_speech_timestamps();
|
||||||
for (int i = 0; i < stamps.size(); i++) {
|
|
||||||
|
|
||||||
std::cout << stamps[i].c_str() << std::endl;
|
// Convert timestamps to seconds and round to one decimal place (for 16000 Hz).
|
||||||
|
const float sample_rate_float = 16000.0f;
|
||||||
|
for (size_t i = 0; i < stamps.size(); i++) {
|
||||||
|
float start_sec = std::rint((stamps[i].start / sample_rate_float) * 10.0f) / 10.0f;
|
||||||
|
float end_sec = std::rint((stamps[i].end / sample_rate_float) * 10.0f) / 10.0f;
|
||||||
|
std::cout << "Speech detected from "
|
||||||
|
<< std::fixed << std::setprecision(1) << start_sec
|
||||||
|
<< " s to "
|
||||||
|
<< std::fixed << std::setprecision(1) << end_sec
|
||||||
|
<< " s" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1.b collect_chunks output wav
|
// Optionally, reset the internal state.
|
||||||
vad.collect_chunks(input_wav, output_wav);
|
vad.reset();
|
||||||
|
|
||||||
// 1.c drop_chunks output wav
|
return 0;
|
||||||
vad.drop_chunks(input_wav, output_wav);
|
|
||||||
|
|
||||||
// ==============================================
|
|
||||||
// ===== Example 2 of simple full function =====
|
|
||||||
// ==============================================
|
|
||||||
vad.process(input_wav, output_wav);
|
|
||||||
|
|
||||||
stamps = vad.get_speech_timestamps();
|
|
||||||
for (int i = 0; i < stamps.size(); i++) {
|
|
||||||
|
|
||||||
std::cout << stamps[i].c_str() << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==============================================
|
|
||||||
// ===== Example 3 of full function =====
|
|
||||||
// ==============================================
|
|
||||||
for(int i = 0; i<2; i++)
|
|
||||||
vad.process(input_wav, output_wav);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,10 +12,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
#ifndef FRONTEND_WAV_H_
|
#ifndef FRONTEND_WAV_H_
|
||||||
#define FRONTEND_WAV_H_
|
#define FRONTEND_WAV_H_
|
||||||
|
|
||||||
|
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
@@ -24,6 +24,8 @@
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
// #include "utils/log.h"
|
// #include "utils/log.h"
|
||||||
|
|
||||||
namespace wav {
|
namespace wav {
|
||||||
@@ -230,6 +232,6 @@ class WavWriter {
|
|||||||
int bits_per_sample_;
|
int bits_per_sample_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace wenet
|
} // namespace wav
|
||||||
|
|
||||||
#endif // FRONTEND_WAV_H_
|
#endif // FRONTEND_WAV_H_
|
||||||
|
|||||||
45
examples/cpp_libtorch/README.md
Normal file
45
examples/cpp_libtorch/README.md
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# Silero-VAD V5 in C++ (based on LibTorch)
|
||||||
|
|
||||||
|
This is the source code for Silero-VAD V5 in C++, utilizing LibTorch. The primary implementation is CPU-based, and you should compare its results with the Python version. Only results at 16kHz have been tested.
|
||||||
|
|
||||||
|
Additionally, batch and CUDA inference options are available if you want to explore further. Note that when using batch inference, the speech probabilities may slightly differ from the standard version, likely due to differences in caching. Unlike individual input processing, batch inference may not use the cache from previous chunks. Despite this, batch inference offers significantly faster processing. For optimal performance, consider adjusting the threshold when using batch inference.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- GCC 11.4.0 (GCC >= 5.1)
|
||||||
|
- LibTorch 1.13.0 (other versions are also acceptable)
|
||||||
|
|
||||||
|
## Download LibTorch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
-CPU Version
|
||||||
|
wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip
|
||||||
|
unzip libtorch-shared-with-deps-1.13.0+cpu.zip'
|
||||||
|
|
||||||
|
-CUDA Version
|
||||||
|
wget https://download.pytorch.org/libtorch/cu116/libtorch-shared-with-deps-1.13.0%2Bcu116.zip
|
||||||
|
unzip libtorch-shared-with-deps-1.13.0+cu116.zip
|
||||||
|
```
|
||||||
|
|
||||||
|
## Compilation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
-CPU Version
|
||||||
|
g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0
|
||||||
|
|
||||||
|
-CUDA Version
|
||||||
|
g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cuda -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_GPU
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Optional Compilation Flags
|
||||||
|
-DUSE_BATCH: Enable batch inference
|
||||||
|
-DUSE_GPU: Use GPU for inference
|
||||||
|
|
||||||
|
## Run the Program
|
||||||
|
To run the program, use the following command:
|
||||||
|
|
||||||
|
`./silero aepyx.wav 16000 0.5`
|
||||||
|
|
||||||
|
The sample file aepyx.wav is part of the Voxconverse dataset.
|
||||||
|
File details: aepyx.wav is a 16kHz, 16-bit audio file.
|
||||||
BIN
examples/cpp_libtorch/aepyx.wav
Normal file
BIN
examples/cpp_libtorch/aepyx.wav
Normal file
Binary file not shown.
54
examples/cpp_libtorch/main.cc
Normal file
54
examples/cpp_libtorch/main.cc
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include "silero_torch.h"
|
||||||
|
#include "wav.h"
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
|
||||||
|
if(argc != 4){
|
||||||
|
std::cerr<<"Usage : "<<argv[0]<<" <wav.path> <SampleRate> <Threshold>"<<std::endl;
|
||||||
|
std::cerr<<"Usage : "<<argv[0]<<" sample.wav 16000 0.5"<<std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string wav_path = argv[1];
|
||||||
|
float sample_rate = std::stof(argv[2]);
|
||||||
|
float threshold = std::stof(argv[3]);
|
||||||
|
|
||||||
|
|
||||||
|
//Load Model
|
||||||
|
std::string model_path = "../../src/silero_vad/data/silero_vad.jit";
|
||||||
|
silero::VadIterator vad(model_path);
|
||||||
|
|
||||||
|
vad.threshold=threshold; //(Default:0.5)
|
||||||
|
vad.sample_rate=sample_rate; //16000Hz,8000Hz. (Default:16000)
|
||||||
|
vad.print_as_samples=true; //if true, it prints time-stamp with samples. otherwise, in seconds
|
||||||
|
//(Default:false)
|
||||||
|
|
||||||
|
vad.SetVariables();
|
||||||
|
|
||||||
|
// Read wav
|
||||||
|
wav::WavReader wav_reader(wav_path);
|
||||||
|
std::vector<float> input_wav(wav_reader.num_samples());
|
||||||
|
|
||||||
|
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||||
|
{
|
||||||
|
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
|
||||||
|
}
|
||||||
|
|
||||||
|
vad.SpeechProbs(input_wav);
|
||||||
|
|
||||||
|
std::vector<silero::SpeechSegment> speeches = vad.GetSpeechTimestamps();
|
||||||
|
for(const auto& speech : speeches){
|
||||||
|
if(vad.print_as_samples){
|
||||||
|
std::cout<<"{'start': "<<static_cast<int>(speech.start)<<", 'end': "<<static_cast<int>(speech.end)<<"}"<<std::endl;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
std::cout<<"{'start': "<<speech.start<<", 'end': "<<speech.end<<"}"<<std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
BIN
examples/cpp_libtorch/silero
Executable file
BIN
examples/cpp_libtorch/silero
Executable file
Binary file not shown.
285
examples/cpp_libtorch/silero_torch.cc
Normal file
285
examples/cpp_libtorch/silero_torch.cc
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
//Author : Nathan Lee
|
||||||
|
//Created On : 2024-11-18
|
||||||
|
//Description : silero 5.1 system for torch-script(c++).
|
||||||
|
//Version : 1.0
|
||||||
|
|
||||||
|
|
||||||
|
#include "silero_torch.h"
|
||||||
|
|
||||||
|
namespace silero {
|
||||||
|
|
||||||
|
VadIterator::VadIterator(const std::string &model_path, float threshold, int sample_rate, int window_size_ms, int speech_pad_ms, int min_silence_duration_ms, int min_speech_duration_ms, int max_duration_merge_ms, bool print_as_samples)
|
||||||
|
:sample_rate(sample_rate), threshold(threshold), window_size_ms(window_size_ms), speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms), min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms), print_as_samples(print_as_samples)
|
||||||
|
{
|
||||||
|
init_torch_model(model_path);
|
||||||
|
//init_engine(window_size_ms);
|
||||||
|
}
|
||||||
|
VadIterator::~VadIterator(){
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void VadIterator::SpeechProbs(std::vector<float>& input_wav){
|
||||||
|
// Set the sample rate (must match the model's expected sample rate)
|
||||||
|
// Process the waveform in chunks of 512 samples
|
||||||
|
int num_samples = input_wav.size();
|
||||||
|
int num_chunks = num_samples / window_size_samples;
|
||||||
|
int remainder_samples = num_samples % window_size_samples;
|
||||||
|
|
||||||
|
total_sample_size += num_samples;
|
||||||
|
|
||||||
|
torch::Tensor output;
|
||||||
|
std::vector<torch::Tensor> chunks;
|
||||||
|
|
||||||
|
for (int i = 0; i < num_chunks; i++) {
|
||||||
|
|
||||||
|
float* chunk_start = input_wav.data() + i *window_size_samples;
|
||||||
|
torch::Tensor chunk = torch::from_blob(chunk_start, {1,window_size_samples}, torch::kFloat32);
|
||||||
|
//std::cout<<"chunk size : "<<chunk.sizes()<<std::endl;
|
||||||
|
chunks.push_back(chunk);
|
||||||
|
|
||||||
|
|
||||||
|
if(i==num_chunks-1 && remainder_samples>0){//마지막 chunk && 나머지가 존재
|
||||||
|
int remaining_samples = num_samples - num_chunks * window_size_samples;
|
||||||
|
//std::cout<<"Remainder size : "<<remaining_samples;
|
||||||
|
float* chunk_start_remainder = input_wav.data() + num_chunks *window_size_samples;
|
||||||
|
|
||||||
|
torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1,remaining_samples},
|
||||||
|
torch::kFloat32);
|
||||||
|
// Pad the remainder chunk to match window_size_samples
|
||||||
|
torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples
|
||||||
|
- remaining_samples}, torch::kFloat32)}, 1);
|
||||||
|
//std::cout<<", padded_chunk size : "<<padded_chunk.size(1)<<std::endl;
|
||||||
|
|
||||||
|
chunks.push_back(padded_chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!chunks.empty()) {
|
||||||
|
|
||||||
|
#ifdef USE_BATCH
|
||||||
|
torch::Tensor batched_chunks = torch::stack(chunks); // Stack all chunks into a single tensor
|
||||||
|
//batched_chunks = batched_chunks.squeeze(1);
|
||||||
|
batched_chunks = torch::cat({batched_chunks.squeeze(1)});
|
||||||
|
|
||||||
|
#ifdef USE_GPU
|
||||||
|
batched_chunks = batched_chunks.to(at::kCUDA); // Move the entire batch to GPU once
|
||||||
|
#endif
|
||||||
|
// Prepare input for model
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
inputs.push_back(batched_chunks); // Batch of chunks
|
||||||
|
inputs.push_back(sample_rate); // Assuming sample_rate is a valid input for the model
|
||||||
|
|
||||||
|
// Run inference on the batch
|
||||||
|
torch::NoGradGuard no_grad;
|
||||||
|
torch::Tensor output = model.forward(inputs).toTensor();
|
||||||
|
#ifdef USE_GPU
|
||||||
|
output = output.to(at::kCPU); // Move the output back to CPU once
|
||||||
|
#endif
|
||||||
|
// Collect output probabilities
|
||||||
|
for (int i = 0; i < chunks.size(); i++) {
|
||||||
|
float output_f = output[i].item<float>();
|
||||||
|
outputs_prob.push_back(output_f);
|
||||||
|
//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> outputs;
|
||||||
|
torch::Tensor batched_chunks = torch::stack(chunks);
|
||||||
|
#ifdef USE_GPU
|
||||||
|
batched_chunks = batched_chunks.to(at::kCUDA);
|
||||||
|
#endif
|
||||||
|
for (int i = 0; i < chunks.size(); i++) {
|
||||||
|
torch::NoGradGuard no_grad;
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
inputs.push_back(batched_chunks[i]);
|
||||||
|
inputs.push_back(sample_rate);
|
||||||
|
|
||||||
|
torch::Tensor output = model.forward(inputs).toTensor();
|
||||||
|
outputs.push_back(output);
|
||||||
|
}
|
||||||
|
torch::Tensor all_outputs = torch::stack(outputs);
|
||||||
|
#ifdef USE_GPU
|
||||||
|
all_outputs = all_outputs.to(at::kCPU);
|
||||||
|
#endif
|
||||||
|
for (int i = 0; i < chunks.size(); i++) {
|
||||||
|
float output_f = all_outputs[i].item<float>();
|
||||||
|
outputs_prob.push_back(output_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<SpeechSegment> VadIterator::GetSpeechTimestamps() {
|
||||||
|
std::vector<SpeechSegment> speeches = DoVad();
|
||||||
|
|
||||||
|
#ifdef USE_BATCH
|
||||||
|
//When you use BATCH inference. You would better use 'mergeSpeeches' function to arrage time stamp.
|
||||||
|
//It could be better get reasonable output because of distorted probs.
|
||||||
|
duration_merge_samples = sample_rate * max_duration_merge_ms / 1000;
|
||||||
|
std::vector<SpeechSegment> speeches_merge = mergeSpeeches(speeches, duration_merge_samples);
|
||||||
|
if(!print_as_samples){
|
||||||
|
for (auto& speech : speeches_merge) { //samples to second
|
||||||
|
speech.start /= sample_rate;
|
||||||
|
speech.end /= sample_rate;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return speeches_merge;
|
||||||
|
#else
|
||||||
|
|
||||||
|
if(!print_as_samples){
|
||||||
|
for (auto& speech : speeches) { //samples to second
|
||||||
|
speech.start /= sample_rate;
|
||||||
|
speech.end /= sample_rate;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return speeches;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
void VadIterator::SetVariables(){
|
||||||
|
init_engine(window_size_ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::init_engine(int window_size_ms) {
|
||||||
|
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
|
||||||
|
speech_pad_samples = sample_rate * speech_pad_ms / 1000;
|
||||||
|
window_size_samples = sample_rate / 1000 * window_size_ms;
|
||||||
|
min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::init_torch_model(const std::string& model_path) {
|
||||||
|
at::set_num_threads(1);
|
||||||
|
model = torch::jit::load(model_path);
|
||||||
|
|
||||||
|
#ifdef USE_GPU
|
||||||
|
if (!torch::cuda::is_available()) {
|
||||||
|
std::cout<<"CUDA is not available! Please check your GPU settings"<<std::endl;
|
||||||
|
throw std::runtime_error("CUDA is not available!");
|
||||||
|
model.to(at::Device(at::kCPU));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
std::cout<<"CUDA available! Running on '0'th GPU"<<std::endl;
|
||||||
|
model.to(at::Device(at::kCUDA, 0)); //select 0'th machine
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
model.eval();
|
||||||
|
torch::NoGradGuard no_grad;
|
||||||
|
std::cout << "Model loaded successfully"<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::reset_states() {
|
||||||
|
triggered = false;
|
||||||
|
current_sample = 0;
|
||||||
|
temp_end = 0;
|
||||||
|
outputs_prob.clear();
|
||||||
|
model.run_method("reset_states");
|
||||||
|
total_sample_size = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<SpeechSegment> VadIterator::DoVad() {
|
||||||
|
std::vector<SpeechSegment> speeches;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < outputs_prob.size(); ++i) {
|
||||||
|
float speech_prob = outputs_prob[i];
|
||||||
|
//std::cout << speech_prob << std::endl;
|
||||||
|
//std::cout << "Chunk " << i << " Prob: " << speech_prob << "\n";
|
||||||
|
//std::cout << speech_prob << " ";
|
||||||
|
current_sample += window_size_samples;
|
||||||
|
|
||||||
|
if (speech_prob >= threshold && temp_end != 0) {
|
||||||
|
temp_end = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speech_prob >= threshold && !triggered) {
|
||||||
|
triggered = true;
|
||||||
|
SpeechSegment segment;
|
||||||
|
segment.start = std::max(static_cast<int>(0), current_sample - speech_pad_samples - window_size_samples);
|
||||||
|
speeches.push_back(segment);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speech_prob < threshold - 0.15f && triggered) {
|
||||||
|
if (temp_end == 0) {
|
||||||
|
temp_end = current_sample;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (current_sample - temp_end < min_silence_samples) {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
SpeechSegment& segment = speeches.back();
|
||||||
|
segment.end = temp_end + speech_pad_samples - window_size_samples;
|
||||||
|
temp_end = 0;
|
||||||
|
triggered = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (triggered) { //만약 낮은 확률을 보이다가 마지막프레임 prbos만 딱 확률이 높게 나오면 위에서 triggerd = true 메핑과 동시에 segment start가 돼서 문제가 될것 같은데? start = end 같은값? 후처리가 있으니 문제가 없으려나?
|
||||||
|
std::cout<<"when last triggered is keep working until last Probs"<<std::endl;
|
||||||
|
SpeechSegment& segment = speeches.back();
|
||||||
|
segment.end = total_sample_size; // 현재 샘플을 마지막 구간의 종료 시간으로 설정
|
||||||
|
triggered = false; // VAD 상태 초기화
|
||||||
|
}
|
||||||
|
|
||||||
|
speeches.erase(
|
||||||
|
std::remove_if(
|
||||||
|
speeches.begin(),
|
||||||
|
speeches.end(),
|
||||||
|
[this](const SpeechSegment& speech) {
|
||||||
|
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
|
||||||
|
//min_speech_samples is 4000samples(0.25sec)
|
||||||
|
//여기서 포인트!! 계산 할때는 start,end sample에'speech_pad_samples' 사이즈를 추가한후 길이를 측정함.
|
||||||
|
}
|
||||||
|
),
|
||||||
|
speeches.end()
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
//std::cout<<std::endl;
|
||||||
|
//std::cout<<"outputs_prob.size : "<<outputs_prob.size()<<std::endl;
|
||||||
|
|
||||||
|
reset_states();
|
||||||
|
return speeches;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<SpeechSegment> VadIterator::mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples) {
|
||||||
|
std::vector<SpeechSegment> mergedSpeeches;
|
||||||
|
|
||||||
|
if (speeches.empty()) {
|
||||||
|
return mergedSpeeches; // 빈 벡터 반환
|
||||||
|
}
|
||||||
|
|
||||||
|
// 첫 번째 구간으로 초기화
|
||||||
|
SpeechSegment currentSegment = speeches[0];
|
||||||
|
|
||||||
|
for (size_t i = 1; i < speeches.size(); ++i) { //첫번째 start,end 정보 건너뛰기. 그래서 i=1부터
|
||||||
|
// 두 구간의 차이가 threshold(duration_merge_samples)보다 작은 경우, 합침
|
||||||
|
if (speeches[i].start - currentSegment.end < duration_merge_samples) {
|
||||||
|
// 현재 구간의 끝점을 업데이트
|
||||||
|
currentSegment.end = speeches[i].end;
|
||||||
|
} else {
|
||||||
|
// 차이가 threshold(duration_merge_samples) 이상이면 현재 구간을 저장하고 새로운 구간 시작
|
||||||
|
mergedSpeeches.push_back(currentSegment);
|
||||||
|
currentSegment = speeches[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 마지막 구간 추가
|
||||||
|
mergedSpeeches.push_back(currentSegment);
|
||||||
|
|
||||||
|
return mergedSpeeches;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
75
examples/cpp_libtorch/silero_torch.h
Normal file
75
examples/cpp_libtorch/silero_torch.h
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
//Author : Nathan Lee
|
||||||
|
//Created On : 2024-11-18
|
||||||
|
//Description : silero 5.1 system for torch-script(c++).
|
||||||
|
//Version : 1.0
|
||||||
|
|
||||||
|
#ifndef SILERO_TORCH_H
|
||||||
|
#define SILERO_TORCH_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <torch/script.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace silero{
|
||||||
|
|
||||||
|
struct SpeechSegment{
|
||||||
|
int start;
|
||||||
|
int end;
|
||||||
|
};
|
||||||
|
|
||||||
|
class VadIterator{
|
||||||
|
public:
|
||||||
|
|
||||||
|
VadIterator(const std::string &model_path, float threshold = 0.5, int sample_rate = 16000,
|
||||||
|
int window_size_ms = 32, int speech_pad_ms = 30, int min_silence_duration_ms = 100,
|
||||||
|
int min_speech_duration_ms = 250, int max_duration_merge_ms = 300, bool print_as_samples = false);
|
||||||
|
~VadIterator();
|
||||||
|
|
||||||
|
|
||||||
|
void SpeechProbs(std::vector<float>& input_wav);
|
||||||
|
std::vector<silero::SpeechSegment> GetSpeechTimestamps();
|
||||||
|
void SetVariables();
|
||||||
|
|
||||||
|
float threshold;
|
||||||
|
int sample_rate;
|
||||||
|
int window_size_ms;
|
||||||
|
int min_speech_duration_ms;
|
||||||
|
int max_duration_merge_ms;
|
||||||
|
bool print_as_samples;
|
||||||
|
|
||||||
|
private:
|
||||||
|
torch::jit::script::Module model;
|
||||||
|
std::vector<float> outputs_prob;
|
||||||
|
int min_silence_samples;
|
||||||
|
int min_speech_samples;
|
||||||
|
int speech_pad_samples;
|
||||||
|
int window_size_samples;
|
||||||
|
int duration_merge_samples;
|
||||||
|
int current_sample = 0;
|
||||||
|
|
||||||
|
int total_sample_size=0;
|
||||||
|
|
||||||
|
int min_silence_duration_ms;
|
||||||
|
int speech_pad_ms;
|
||||||
|
bool triggered = false;
|
||||||
|
int temp_end = 0;
|
||||||
|
|
||||||
|
void init_engine(int window_size_ms);
|
||||||
|
void init_torch_model(const std::string& model_path);
|
||||||
|
void reset_states();
|
||||||
|
std::vector<SpeechSegment> DoVad();
|
||||||
|
std::vector<SpeechSegment> mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
#endif // SILERO_TORCH_H
|
||||||
235
examples/cpp_libtorch/wav.h
Normal file
235
examples/cpp_libtorch/wav.h
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
// Copyright (c) 2016 Personal (Binbin Zhang)
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef FRONTEND_WAV_H_
|
||||||
|
#define FRONTEND_WAV_H_
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
// #include "utils/log.h"
|
||||||
|
|
||||||
|
namespace wav {
|
||||||
|
|
||||||
|
struct WavHeader {
|
||||||
|
char riff[4]; // "riff"
|
||||||
|
unsigned int size;
|
||||||
|
char wav[4]; // "WAVE"
|
||||||
|
char fmt[4]; // "fmt "
|
||||||
|
unsigned int fmt_size;
|
||||||
|
uint16_t format;
|
||||||
|
uint16_t channels;
|
||||||
|
unsigned int sample_rate;
|
||||||
|
unsigned int bytes_per_second;
|
||||||
|
uint16_t block_size;
|
||||||
|
uint16_t bit;
|
||||||
|
char data[4]; // "data"
|
||||||
|
unsigned int data_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WavReader {
|
||||||
|
public:
|
||||||
|
WavReader() : data_(nullptr) {}
|
||||||
|
explicit WavReader(const std::string& filename) { Open(filename); }
|
||||||
|
|
||||||
|
bool Open(const std::string& filename) {
|
||||||
|
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
|
||||||
|
if (NULL == fp) {
|
||||||
|
std::cout << "Error in read " << filename;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
WavHeader header;
|
||||||
|
fread(&header, 1, sizeof(header), fp);
|
||||||
|
if (header.fmt_size < 16) {
|
||||||
|
printf("WaveData: expect PCM format data "
|
||||||
|
"to have fmt chunk of at least size 16.\n");
|
||||||
|
return false;
|
||||||
|
} else if (header.fmt_size > 16) {
|
||||||
|
int offset = 44 - 8 + header.fmt_size - 16;
|
||||||
|
fseek(fp, offset, SEEK_SET);
|
||||||
|
fread(header.data, 8, sizeof(char), fp);
|
||||||
|
}
|
||||||
|
// check "riff" "WAVE" "fmt " "data"
|
||||||
|
|
||||||
|
// Skip any sub-chunks between "fmt" and "data". Usually there will
|
||||||
|
// be a single "fact" sub chunk, but on Windows there can also be a
|
||||||
|
// "list" sub chunk.
|
||||||
|
while (0 != strncmp(header.data, "data", 4)) {
|
||||||
|
// We will just ignore the data in these chunks.
|
||||||
|
fseek(fp, header.data_size, SEEK_CUR);
|
||||||
|
// read next sub chunk
|
||||||
|
fread(header.data, 8, sizeof(char), fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (header.data_size == 0) {
|
||||||
|
int offset = ftell(fp);
|
||||||
|
fseek(fp, 0, SEEK_END);
|
||||||
|
header.data_size = ftell(fp) - offset;
|
||||||
|
fseek(fp, offset, SEEK_SET);
|
||||||
|
}
|
||||||
|
|
||||||
|
num_channel_ = header.channels;
|
||||||
|
sample_rate_ = header.sample_rate;
|
||||||
|
bits_per_sample_ = header.bit;
|
||||||
|
int num_data = header.data_size / (bits_per_sample_ / 8);
|
||||||
|
data_ = new float[num_data]; // Create 1-dim array
|
||||||
|
num_samples_ = num_data / num_channel_;
|
||||||
|
|
||||||
|
std::cout << "num_channel_ :" << num_channel_ << std::endl;
|
||||||
|
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
|
||||||
|
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
|
||||||
|
std::cout << "num_samples :" << num_data << std::endl;
|
||||||
|
std::cout << "num_data_size :" << header.data_size << std::endl;
|
||||||
|
|
||||||
|
switch (bits_per_sample_) {
|
||||||
|
case 8: {
|
||||||
|
char sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(char), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 16: {
|
||||||
|
int16_t sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(int16_t), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 32:
|
||||||
|
{
|
||||||
|
if (header.format == 1) //S32
|
||||||
|
{
|
||||||
|
int sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(int), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (header.format == 3) // IEEE-float
|
||||||
|
{
|
||||||
|
float sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(float), fp);
|
||||||
|
data_[i] = static_cast<float>(sample);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
printf("unsupported quantization bits\n");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
printf("unsupported quantization bits\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
fclose(fp);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_channel() const { return num_channel_; }
|
||||||
|
int sample_rate() const { return sample_rate_; }
|
||||||
|
int bits_per_sample() const { return bits_per_sample_; }
|
||||||
|
int num_samples() const { return num_samples_; }
|
||||||
|
|
||||||
|
~WavReader() {
|
||||||
|
delete[] data_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float* data() const { return data_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_channel_;
|
||||||
|
int sample_rate_;
|
||||||
|
int bits_per_sample_;
|
||||||
|
int num_samples_; // sample points per channel
|
||||||
|
float* data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WavWriter {
|
||||||
|
public:
|
||||||
|
WavWriter(const float* data, int num_samples, int num_channel,
|
||||||
|
int sample_rate, int bits_per_sample)
|
||||||
|
: data_(data),
|
||||||
|
num_samples_(num_samples),
|
||||||
|
num_channel_(num_channel),
|
||||||
|
sample_rate_(sample_rate),
|
||||||
|
bits_per_sample_(bits_per_sample) {}
|
||||||
|
|
||||||
|
void Write(const std::string& filename) {
|
||||||
|
FILE* fp = fopen(filename.c_str(), "w");
|
||||||
|
// init char 'riff' 'WAVE' 'fmt ' 'data'
|
||||||
|
WavHeader header;
|
||||||
|
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
|
||||||
|
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
|
||||||
|
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
|
||||||
|
memcpy(&header, wav_header, sizeof(header));
|
||||||
|
header.channels = num_channel_;
|
||||||
|
header.bit = bits_per_sample_;
|
||||||
|
header.sample_rate = sample_rate_;
|
||||||
|
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
header.size = sizeof(header) - 8 + header.data_size;
|
||||||
|
header.bytes_per_second =
|
||||||
|
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
header.block_size = num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
|
||||||
|
fwrite(&header, 1, sizeof(header), fp);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_samples_; ++i) {
|
||||||
|
for (int j = 0; j < num_channel_; ++j) {
|
||||||
|
switch (bits_per_sample_) {
|
||||||
|
case 8: {
|
||||||
|
char sample = static_cast<char>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 16: {
|
||||||
|
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 32: {
|
||||||
|
int sample = static_cast<int>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fclose(fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const float* data_;
|
||||||
|
int num_samples_; // total float points in data_
|
||||||
|
int num_channel_;
|
||||||
|
int sample_rate_;
|
||||||
|
int bits_per_sample_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace wenet
|
||||||
|
|
||||||
|
#endif // FRONTEND_WAV_H_
|
||||||
35
examples/csharp/Program.cs
Normal file
35
examples/csharp/Program.cs
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
using System.Text;
|
||||||
|
|
||||||
|
namespace VadDotNet;
|
||||||
|
|
||||||
|
|
||||||
|
class Program
|
||||||
|
{
|
||||||
|
private const string MODEL_PATH = "./resources/silero_vad.onnx";
|
||||||
|
private const string EXAMPLE_WAV_FILE = "./resources/example.wav";
|
||||||
|
private const int SAMPLE_RATE = 16000;
|
||||||
|
private const float THRESHOLD = 0.5f;
|
||||||
|
private const int MIN_SPEECH_DURATION_MS = 250;
|
||||||
|
private const float MAX_SPEECH_DURATION_SECONDS = float.PositiveInfinity;
|
||||||
|
private const int MIN_SILENCE_DURATION_MS = 100;
|
||||||
|
private const int SPEECH_PAD_MS = 30;
|
||||||
|
|
||||||
|
public static void Main(string[] args)
|
||||||
|
{
|
||||||
|
|
||||||
|
var vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
|
||||||
|
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
||||||
|
List<SileroSpeechSegment> speechTimeList = vadDetector.GetSpeechSegmentList(new FileInfo(EXAMPLE_WAV_FILE));
|
||||||
|
//Console.WriteLine(speechTimeList.ToJson());
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
foreach (var speechSegment in speechTimeList)
|
||||||
|
{
|
||||||
|
sb.Append($"start second: {speechSegment.StartSecond}, end second: {speechSegment.EndSecond}\n");
|
||||||
|
|
||||||
|
}
|
||||||
|
Console.WriteLine(sb.ToString());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
21
examples/csharp/SileroSpeechSegment.cs
Normal file
21
examples/csharp/SileroSpeechSegment.cs
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
namespace VadDotNet;
|
||||||
|
|
||||||
|
public class SileroSpeechSegment
|
||||||
|
{
|
||||||
|
public int? StartOffset { get; set; }
|
||||||
|
public int? EndOffset { get; set; }
|
||||||
|
public float? StartSecond { get; set; }
|
||||||
|
public float? EndSecond { get; set; }
|
||||||
|
|
||||||
|
public SileroSpeechSegment()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
public SileroSpeechSegment(int startOffset, int? endOffset, float? startSecond, float? endSecond)
|
||||||
|
{
|
||||||
|
StartOffset = startOffset;
|
||||||
|
EndOffset = endOffset;
|
||||||
|
StartSecond = startSecond;
|
||||||
|
EndSecond = endSecond;
|
||||||
|
}
|
||||||
|
}
|
||||||
250
examples/csharp/SileroVadDetector.cs
Normal file
250
examples/csharp/SileroVadDetector.cs
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
using NAudio.Wave;
|
||||||
|
using VADdotnet;
|
||||||
|
|
||||||
|
namespace VadDotNet;
|
||||||
|
|
||||||
|
public class SileroVadDetector
|
||||||
|
{
|
||||||
|
private readonly SileroVadOnnxModel _model;
|
||||||
|
private readonly float _threshold;
|
||||||
|
private readonly float _negThreshold;
|
||||||
|
private readonly int _samplingRate;
|
||||||
|
private readonly int _windowSizeSample;
|
||||||
|
private readonly float _minSpeechSamples;
|
||||||
|
private readonly float _speechPadSamples;
|
||||||
|
private readonly float _maxSpeechSamples;
|
||||||
|
private readonly float _minSilenceSamples;
|
||||||
|
private readonly float _minSilenceSamplesAtMaxSpeech;
|
||||||
|
private int _audioLengthSamples;
|
||||||
|
private const float THRESHOLD_GAP = 0.15f;
|
||||||
|
// ReSharper disable once InconsistentNaming
|
||||||
|
private const int SAMPLING_RATE_8K = 8000;
|
||||||
|
// ReSharper disable once InconsistentNaming
|
||||||
|
private const int SAMPLING_RATE_16K = 16000;
|
||||||
|
|
||||||
|
public SileroVadDetector(string onnxModelPath, float threshold, int samplingRate,
|
||||||
|
int minSpeechDurationMs, float maxSpeechDurationSeconds,
|
||||||
|
int minSilenceDurationMs, int speechPadMs)
|
||||||
|
{
|
||||||
|
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("Sampling rate not support, only available for [8000, 16000]");
|
||||||
|
}
|
||||||
|
|
||||||
|
this._model = new SileroVadOnnxModel(onnxModelPath);
|
||||||
|
this._samplingRate = samplingRate;
|
||||||
|
this._threshold = threshold;
|
||||||
|
this._negThreshold = threshold - THRESHOLD_GAP;
|
||||||
|
this._windowSizeSample = samplingRate == SAMPLING_RATE_16K ? 512 : 256;
|
||||||
|
this._minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
|
||||||
|
this._speechPadSamples = samplingRate * speechPadMs / 1000f;
|
||||||
|
this._maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - _windowSizeSample - 2 * _speechPadSamples;
|
||||||
|
this._minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
||||||
|
this._minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
|
||||||
|
this.Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Reset()
|
||||||
|
{
|
||||||
|
_model.ResetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SileroSpeechSegment> GetSpeechSegmentList(FileInfo wavFile)
|
||||||
|
{
|
||||||
|
Reset();
|
||||||
|
|
||||||
|
using (var audioFile = new AudioFileReader(wavFile.FullName))
|
||||||
|
{
|
||||||
|
List<float> speechProbList = new List<float>();
|
||||||
|
this._audioLengthSamples = (int)(audioFile.Length / 2);
|
||||||
|
float[] buffer = new float[this._windowSizeSample];
|
||||||
|
|
||||||
|
while (audioFile.Read(buffer, 0, buffer.Length) > 0)
|
||||||
|
{
|
||||||
|
float speechProb = _model.Call(new[] { buffer }, _samplingRate)[0];
|
||||||
|
speechProbList.Add(speechProb);
|
||||||
|
}
|
||||||
|
|
||||||
|
return CalculateProb(speechProbList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SileroSpeechSegment> CalculateProb(List<float> speechProbList)
|
||||||
|
{
|
||||||
|
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
|
||||||
|
bool triggered = false;
|
||||||
|
int tempEnd = 0, prevEnd = 0, nextStart = 0;
|
||||||
|
SileroSpeechSegment segment = new SileroSpeechSegment();
|
||||||
|
|
||||||
|
for (int i = 0; i < speechProbList.Count; i++)
|
||||||
|
{
|
||||||
|
float speechProb = speechProbList[i];
|
||||||
|
if (speechProb >= _threshold && (tempEnd != 0))
|
||||||
|
{
|
||||||
|
tempEnd = 0;
|
||||||
|
if (nextStart < prevEnd)
|
||||||
|
{
|
||||||
|
nextStart = _windowSizeSample * i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speechProb >= _threshold && !triggered)
|
||||||
|
{
|
||||||
|
triggered = true;
|
||||||
|
segment.StartOffset = _windowSizeSample * i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (triggered && (_windowSizeSample * i) - segment.StartOffset > _maxSpeechSamples)
|
||||||
|
{
|
||||||
|
if (prevEnd != 0)
|
||||||
|
{
|
||||||
|
segment.EndOffset = prevEnd;
|
||||||
|
result.Add(segment);
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
if (nextStart < prevEnd)
|
||||||
|
{
|
||||||
|
triggered = false;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
segment.StartOffset = nextStart;
|
||||||
|
}
|
||||||
|
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
segment.EndOffset = _windowSizeSample * i;
|
||||||
|
result.Add(segment);
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
triggered = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speechProb < _negThreshold && triggered)
|
||||||
|
{
|
||||||
|
if (tempEnd == 0)
|
||||||
|
{
|
||||||
|
tempEnd = _windowSizeSample * i;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (((_windowSizeSample * i) - tempEnd) > _minSilenceSamplesAtMaxSpeech)
|
||||||
|
{
|
||||||
|
prevEnd = tempEnd;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((_windowSizeSample * i) - tempEnd < _minSilenceSamples)
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
segment.EndOffset = tempEnd;
|
||||||
|
if ((segment.EndOffset - segment.StartOffset) > _minSpeechSamples)
|
||||||
|
{
|
||||||
|
result.Add(segment);
|
||||||
|
}
|
||||||
|
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
triggered = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (segment.StartOffset != null && (_audioLengthSamples - segment.StartOffset) > _minSpeechSamples)
|
||||||
|
{
|
||||||
|
segment.EndOffset = _audioLengthSamples;
|
||||||
|
result.Add(segment);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < result.Count; i++)
|
||||||
|
{
|
||||||
|
SileroSpeechSegment item = result[i];
|
||||||
|
if (i == 0)
|
||||||
|
{
|
||||||
|
item.StartOffset = (int)Math.Max(0, item.StartOffset.Value - _speechPadSamples);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i != result.Count - 1)
|
||||||
|
{
|
||||||
|
SileroSpeechSegment nextItem = result[i + 1];
|
||||||
|
int silenceDuration = nextItem.StartOffset.Value - item.EndOffset.Value;
|
||||||
|
if (silenceDuration < 2 * _speechPadSamples)
|
||||||
|
{
|
||||||
|
item.EndOffset = item.EndOffset + (silenceDuration / 2);
|
||||||
|
nextItem.StartOffset = Math.Max(0, nextItem.StartOffset.Value - (silenceDuration / 2));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
|
||||||
|
nextItem.StartOffset = (int)Math.Max(0, nextItem.StartOffset.Value - _speechPadSamples);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return MergeListAndCalculateSecond(result, _samplingRate);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SileroSpeechSegment> MergeListAndCalculateSecond(List<SileroSpeechSegment> original, int samplingRate)
|
||||||
|
{
|
||||||
|
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
|
||||||
|
if (original == null || original.Count == 0)
|
||||||
|
{
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
int left = original[0].StartOffset.Value;
|
||||||
|
int right = original[0].EndOffset.Value;
|
||||||
|
if (original.Count > 1)
|
||||||
|
{
|
||||||
|
original.Sort((a, b) => a.StartOffset.Value.CompareTo(b.StartOffset.Value));
|
||||||
|
for (int i = 1; i < original.Count; i++)
|
||||||
|
{
|
||||||
|
SileroSpeechSegment segment = original[i];
|
||||||
|
|
||||||
|
if (segment.StartOffset > right)
|
||||||
|
{
|
||||||
|
result.Add(new SileroSpeechSegment(left, right,
|
||||||
|
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
|
||||||
|
left = segment.StartOffset.Value;
|
||||||
|
right = segment.EndOffset.Value;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
right = Math.Max(right, segment.EndOffset.Value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Add(new SileroSpeechSegment(left, right,
|
||||||
|
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
result.Add(new SileroSpeechSegment(left, right,
|
||||||
|
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private float CalculateSecondByOffset(int offset, int samplingRate)
|
||||||
|
{
|
||||||
|
float secondValue = offset * 1.0f / samplingRate;
|
||||||
|
return (float)Math.Floor(secondValue * 1000.0f) / 1000.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
220
examples/csharp/SileroVadOnnxModel.cs
Normal file
220
examples/csharp/SileroVadOnnxModel.cs
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
using Microsoft.ML.OnnxRuntime;
|
||||||
|
using Microsoft.ML.OnnxRuntime.Tensors;
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Linq;
|
||||||
|
|
||||||
|
namespace VADdotnet;
|
||||||
|
|
||||||
|
|
||||||
|
public class SileroVadOnnxModel : IDisposable
|
||||||
|
{
|
||||||
|
private readonly InferenceSession session;
|
||||||
|
private float[][][] state;
|
||||||
|
private float[][] context;
|
||||||
|
private int lastSr = 0;
|
||||||
|
private int lastBatchSize = 0;
|
||||||
|
private static readonly List<int> SAMPLE_RATES = new List<int> { 8000, 16000 };
|
||||||
|
|
||||||
|
public SileroVadOnnxModel(string modelPath)
|
||||||
|
{
|
||||||
|
var sessionOptions = new SessionOptions();
|
||||||
|
sessionOptions.InterOpNumThreads = 1;
|
||||||
|
sessionOptions.IntraOpNumThreads = 1;
|
||||||
|
sessionOptions.EnableCpuMemArena = true;
|
||||||
|
|
||||||
|
session = new InferenceSession(modelPath, sessionOptions);
|
||||||
|
ResetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void ResetStates()
|
||||||
|
{
|
||||||
|
state = new float[2][][];
|
||||||
|
state[0] = new float[1][];
|
||||||
|
state[1] = new float[1][];
|
||||||
|
state[0][0] = new float[128];
|
||||||
|
state[1][0] = new float[128];
|
||||||
|
context = Array.Empty<float[]>();
|
||||||
|
lastSr = 0;
|
||||||
|
lastBatchSize = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Dispose()
|
||||||
|
{
|
||||||
|
session?.Dispose();
|
||||||
|
}
|
||||||
|
|
||||||
|
public class ValidationResult
|
||||||
|
{
|
||||||
|
public float[][] X { get; }
|
||||||
|
public int Sr { get; }
|
||||||
|
|
||||||
|
public ValidationResult(float[][] x, int sr)
|
||||||
|
{
|
||||||
|
X = x;
|
||||||
|
Sr = sr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ValidationResult ValidateInput(float[][] x, int sr)
|
||||||
|
{
|
||||||
|
if (x.Length == 1)
|
||||||
|
{
|
||||||
|
x = new float[][] { x[0] };
|
||||||
|
}
|
||||||
|
if (x.Length > 2)
|
||||||
|
{
|
||||||
|
throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sr != 16000 && (sr % 16000 == 0))
|
||||||
|
{
|
||||||
|
int step = sr / 16000;
|
||||||
|
float[][] reducedX = new float[x.Length][];
|
||||||
|
|
||||||
|
for (int i = 0; i < x.Length; i++)
|
||||||
|
{
|
||||||
|
float[] current = x[i];
|
||||||
|
float[] newArr = new float[(current.Length + step - 1) / step];
|
||||||
|
|
||||||
|
for (int j = 0, index = 0; j < current.Length; j += step, index++)
|
||||||
|
{
|
||||||
|
newArr[index] = current[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
reducedX[i] = newArr;
|
||||||
|
}
|
||||||
|
|
||||||
|
x = reducedX;
|
||||||
|
sr = 16000;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!SAMPLE_RATES.Contains(sr))
|
||||||
|
{
|
||||||
|
throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (((float)sr) / x[0].Length > 31.25)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("Input audio is too short");
|
||||||
|
}
|
||||||
|
|
||||||
|
return new ValidationResult(x, sr);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[][] Concatenate(float[][] a, float[][] b)
|
||||||
|
{
|
||||||
|
if (a.Length != b.Length)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("The number of rows in both arrays must be the same.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int rows = a.Length;
|
||||||
|
int colsA = a[0].Length;
|
||||||
|
int colsB = b[0].Length;
|
||||||
|
float[][] result = new float[rows][];
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++)
|
||||||
|
{
|
||||||
|
result[i] = new float[colsA + colsB];
|
||||||
|
Array.Copy(a[i], 0, result[i], 0, colsA);
|
||||||
|
Array.Copy(b[i], 0, result[i], colsA, colsB);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[][] GetLastColumns(float[][] array, int contextSize)
|
||||||
|
{
|
||||||
|
int rows = array.Length;
|
||||||
|
int cols = array[0].Length;
|
||||||
|
|
||||||
|
if (contextSize > cols)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("contextSize cannot be greater than the number of columns in the array.");
|
||||||
|
}
|
||||||
|
|
||||||
|
float[][] result = new float[rows][];
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++)
|
||||||
|
{
|
||||||
|
result[i] = new float[contextSize];
|
||||||
|
Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float[] Call(float[][] x, int sr)
|
||||||
|
{
|
||||||
|
var result = ValidateInput(x, sr);
|
||||||
|
x = result.X;
|
||||||
|
sr = result.Sr;
|
||||||
|
int numberSamples = sr == 16000 ? 512 : 256;
|
||||||
|
|
||||||
|
if (x[0].Length != numberSamples)
|
||||||
|
{
|
||||||
|
throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)");
|
||||||
|
}
|
||||||
|
|
||||||
|
int batchSize = x.Length;
|
||||||
|
int contextSize = sr == 16000 ? 64 : 32;
|
||||||
|
|
||||||
|
if (lastBatchSize == 0)
|
||||||
|
{
|
||||||
|
ResetStates();
|
||||||
|
}
|
||||||
|
if (lastSr != 0 && lastSr != sr)
|
||||||
|
{
|
||||||
|
ResetStates();
|
||||||
|
}
|
||||||
|
if (lastBatchSize != 0 && lastBatchSize != batchSize)
|
||||||
|
{
|
||||||
|
ResetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (context.Length == 0)
|
||||||
|
{
|
||||||
|
context = new float[batchSize][];
|
||||||
|
for (int i = 0; i < batchSize; i++)
|
||||||
|
{
|
||||||
|
context[i] = new float[contextSize];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
x = Concatenate(context, x);
|
||||||
|
|
||||||
|
var inputs = new List<NamedOnnxValue>
|
||||||
|
{
|
||||||
|
NamedOnnxValue.CreateFromTensor("input", new DenseTensor<float>(x.SelectMany(a => a).ToArray(), new[] { x.Length, x[0].Length })),
|
||||||
|
NamedOnnxValue.CreateFromTensor("sr", new DenseTensor<long>(new[] { (long)sr }, new[] { 1 })),
|
||||||
|
NamedOnnxValue.CreateFromTensor("state", new DenseTensor<float>(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), new[] { state.Length, state[0].Length, state[0][0].Length }))
|
||||||
|
};
|
||||||
|
|
||||||
|
using (var outputs = session.Run(inputs))
|
||||||
|
{
|
||||||
|
var output = outputs.First(o => o.Name == "output").AsTensor<float>();
|
||||||
|
var newState = outputs.First(o => o.Name == "stateN").AsTensor<float>();
|
||||||
|
|
||||||
|
context = GetLastColumns(x, contextSize);
|
||||||
|
lastSr = sr;
|
||||||
|
lastBatchSize = batchSize;
|
||||||
|
|
||||||
|
state = new float[newState.Dimensions[0]][][];
|
||||||
|
for (int i = 0; i < newState.Dimensions[0]; i++)
|
||||||
|
{
|
||||||
|
state[i] = new float[newState.Dimensions[1]][];
|
||||||
|
for (int j = 0; j < newState.Dimensions[1]; j++)
|
||||||
|
{
|
||||||
|
state[i][j] = new float[newState.Dimensions[2]];
|
||||||
|
for (int k = 0; k < newState.Dimensions[2]; k++)
|
||||||
|
{
|
||||||
|
state[i][j][k] = newState[i, j, k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return output.ToArray();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
25
examples/csharp/VadDotNet.csproj
Normal file
25
examples/csharp/VadDotNet.csproj
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
<Project Sdk="Microsoft.NET.Sdk">
|
||||||
|
|
||||||
|
<PropertyGroup>
|
||||||
|
<OutputType>Exe</OutputType>
|
||||||
|
<TargetFramework>net8.0</TargetFramework>
|
||||||
|
<ImplicitUsings>enable</ImplicitUsings>
|
||||||
|
<Nullable>enable</Nullable>
|
||||||
|
</PropertyGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.18.1" />
|
||||||
|
<PackageReference Include="NAudio" Version="2.2.1" />
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<Folder Include="resources\" />
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<Content Include="resources\**">
|
||||||
|
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||||
|
</Content>
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
</Project>
|
||||||
1
examples/csharp/resources/put_model_here.txt
Normal file
1
examples/csharp/resources/put_model_here.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
place onnx model file and example.wav file in this folder
|
||||||
@@ -11,11 +11,11 @@ import (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
sd, err := speech.NewDetector(speech.DetectorConfig{
|
sd, err := speech.NewDetector(speech.DetectorConfig{
|
||||||
ModelPath: "../../files/silero_vad.onnx",
|
ModelPath: "../../src/silero_vad/data/silero_vad.onnx",
|
||||||
SampleRate: 16000,
|
SampleRate: 16000,
|
||||||
Threshold: 0.5,
|
Threshold: 0.5,
|
||||||
MinSilenceDurationMs: 0,
|
MinSilenceDurationMs: 100,
|
||||||
SpeechPadMs: 0,
|
SpeechPadMs: 30,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create speech detector: %s", err)
|
log.Fatalf("failed to create speech detector: %s", err)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ go 1.21.4
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-audio/wav v1.1.0
|
github.com/go-audio/wav v1.1.0
|
||||||
github.com/streamer45/silero-vad-go v0.2.0
|
github.com/streamer45/silero-vad-go v0.2.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/streamer45/silero-vad-go v0.2.0 h1:bbRTa6cQuc7VI88y0qicx375UyWoxE6wlVOF+mUg0+g=
|
github.com/streamer45/silero-vad-go v0.2.0 h1:bbRTa6cQuc7VI88y0qicx375UyWoxE6wlVOF+mUg0+g=
|
||||||
github.com/streamer45/silero-vad-go v0.2.0/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
|
github.com/streamer45/silero-vad-go v0.2.0/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
|
||||||
|
github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk=
|
||||||
|
github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
|
||||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
|||||||
13
examples/haskell/README.md
Normal file
13
examples/haskell/README.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Haskell example
|
||||||
|
|
||||||
|
To run the example, make sure you put an ``example.wav`` in this directory, and then run the following:
|
||||||
|
```bash
|
||||||
|
stack run
|
||||||
|
```
|
||||||
|
|
||||||
|
The ``example.wav`` file must have the following requirements:
|
||||||
|
- Must be 16khz sample rate.
|
||||||
|
- Must be mono channel.
|
||||||
|
- Must be 16-bit audio.
|
||||||
|
|
||||||
|
This uses the [silero-vad](https://hackage.haskell.org/package/silero-vad) package, a haskell implementation based on the C# example.
|
||||||
22
examples/haskell/app/Main.hs
Normal file
22
examples/haskell/app/Main.hs
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
module Main (main) where
|
||||||
|
|
||||||
|
import qualified Data.Vector.Storable as Vector
|
||||||
|
import Data.WAVE
|
||||||
|
import Data.Function
|
||||||
|
import Silero
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main =
|
||||||
|
withModel $ \model -> do
|
||||||
|
wav <- getWAVEFile "example.wav"
|
||||||
|
let samples =
|
||||||
|
concat (waveSamples wav)
|
||||||
|
& Vector.fromList
|
||||||
|
& Vector.map (realToFrac . sampleToDouble)
|
||||||
|
let vad =
|
||||||
|
(defaultVad model)
|
||||||
|
{ startThreshold = 0.5
|
||||||
|
, endThreshold = 0.35
|
||||||
|
}
|
||||||
|
segments <- detectSegments vad samples
|
||||||
|
print segments
|
||||||
23
examples/haskell/example.cabal
Normal file
23
examples/haskell/example.cabal
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
cabal-version: 1.12
|
||||||
|
|
||||||
|
-- This file has been generated from package.yaml by hpack version 0.37.0.
|
||||||
|
--
|
||||||
|
-- see: https://github.com/sol/hpack
|
||||||
|
|
||||||
|
name: example
|
||||||
|
version: 0.1.0.0
|
||||||
|
build-type: Simple
|
||||||
|
|
||||||
|
executable example-exe
|
||||||
|
main-is: Main.hs
|
||||||
|
other-modules:
|
||||||
|
Paths_example
|
||||||
|
hs-source-dirs:
|
||||||
|
app
|
||||||
|
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
|
||||||
|
build-depends:
|
||||||
|
WAVE
|
||||||
|
, base >=4.7 && <5
|
||||||
|
, silero-vad
|
||||||
|
, vector
|
||||||
|
default-language: Haskell2010
|
||||||
28
examples/haskell/package.yaml
Normal file
28
examples/haskell/package.yaml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
name: example
|
||||||
|
version: 0.1.0.0
|
||||||
|
|
||||||
|
dependencies:
|
||||||
|
- base >= 4.7 && < 5
|
||||||
|
- silero-vad
|
||||||
|
- WAVE
|
||||||
|
- vector
|
||||||
|
|
||||||
|
ghc-options:
|
||||||
|
- -Wall
|
||||||
|
- -Wcompat
|
||||||
|
- -Widentities
|
||||||
|
- -Wincomplete-record-updates
|
||||||
|
- -Wincomplete-uni-patterns
|
||||||
|
- -Wmissing-export-lists
|
||||||
|
- -Wmissing-home-modules
|
||||||
|
- -Wpartial-fields
|
||||||
|
- -Wredundant-constraints
|
||||||
|
|
||||||
|
executables:
|
||||||
|
example-exe:
|
||||||
|
main: Main.hs
|
||||||
|
source-dirs: app
|
||||||
|
ghc-options:
|
||||||
|
- -threaded
|
||||||
|
- -rtsopts
|
||||||
|
- -with-rtsopts=-N
|
||||||
11
examples/haskell/stack.yaml
Normal file
11
examples/haskell/stack.yaml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
snapshot:
|
||||||
|
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml
|
||||||
|
|
||||||
|
packages:
|
||||||
|
- .
|
||||||
|
|
||||||
|
extra-deps:
|
||||||
|
- silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
|
||||||
|
- WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
|
||||||
|
- derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
|
||||||
|
- vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481
|
||||||
41
examples/haskell/stack.yaml.lock
Normal file
41
examples/haskell/stack.yaml.lock
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# This file was autogenerated by Stack.
|
||||||
|
# You should not edit this file by hand.
|
||||||
|
# For more information, please see the documentation at:
|
||||||
|
# https://docs.haskellstack.org/en/stable/lock_files
|
||||||
|
|
||||||
|
packages:
|
||||||
|
- completed:
|
||||||
|
hackage: silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
|
||||||
|
pantry-tree:
|
||||||
|
sha256: a62e813f978d32c87769796fded981d25fcf2875bb2afdf60ed6279f931ccd7f
|
||||||
|
size: 1391
|
||||||
|
original:
|
||||||
|
hackage: silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
|
||||||
|
- completed:
|
||||||
|
hackage: WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
|
||||||
|
pantry-tree:
|
||||||
|
sha256: ee5ccd70fa7fe6ffc360ebd762b2e3f44ae10406aa27f3842d55b8cbd1a19498
|
||||||
|
size: 405
|
||||||
|
original:
|
||||||
|
hackage: WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
|
||||||
|
- completed:
|
||||||
|
hackage: derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
|
||||||
|
pantry-tree:
|
||||||
|
sha256: 48e35a72d1bb593173890616c8d7efd636a650a306a50bb3e1513e679939d27e
|
||||||
|
size: 902
|
||||||
|
original:
|
||||||
|
hackage: derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
|
||||||
|
- completed:
|
||||||
|
hackage: vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481
|
||||||
|
pantry-tree:
|
||||||
|
sha256: 2176fd677a02a4c47337f7dca5aeca2745dbb821a6ea5c7099b3a991ecd7f4f0
|
||||||
|
size: 4478
|
||||||
|
original:
|
||||||
|
hackage: vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481
|
||||||
|
snapshots:
|
||||||
|
- completed:
|
||||||
|
sha256: 5a59b2a405b3aba3c00188453be172b85893cab8ebc352b1ef58b0eae5d248a2
|
||||||
|
size: 650475
|
||||||
|
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml
|
||||||
|
original:
|
||||||
|
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
package org.example;
|
||||||
|
|
||||||
|
import ai.onnxruntime.OrtException;
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class App {
|
||||||
|
|
||||||
|
private static final String MODEL_PATH = "/path/silero_vad.onnx";
|
||||||
|
private static final String EXAMPLE_WAV_FILE = "/path/example.wav";
|
||||||
|
private static final int SAMPLE_RATE = 16000;
|
||||||
|
private static final float THRESHOLD = 0.5f;
|
||||||
|
private static final int MIN_SPEECH_DURATION_MS = 250;
|
||||||
|
private static final float MAX_SPEECH_DURATION_SECONDS = Float.POSITIVE_INFINITY;
|
||||||
|
private static final int MIN_SILENCE_DURATION_MS = 100;
|
||||||
|
private static final int SPEECH_PAD_MS = 30;
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
// Initialize the Voice Activity Detector
|
||||||
|
SileroVadDetector vadDetector;
|
||||||
|
try {
|
||||||
|
vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
|
||||||
|
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
||||||
|
fromWavFile(vadDetector, new File(EXAMPLE_WAV_FILE));
|
||||||
|
} catch (OrtException e) {
|
||||||
|
System.err.println("Error initializing the VAD detector: " + e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void fromWavFile(SileroVadDetector vadDetector, File wavFile) {
|
||||||
|
List<SileroSpeechSegment> speechTimeList = vadDetector.getSpeechSegmentList(wavFile);
|
||||||
|
for (SileroSpeechSegment speechSegment : speechTimeList) {
|
||||||
|
System.out.println(String.format("start second: %f, end second: %f",
|
||||||
|
speechSegment.getStartSecond(), speechSegment.getEndSecond()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package org.example;
|
||||||
|
|
||||||
|
|
||||||
|
public class SileroSpeechSegment {
|
||||||
|
private Integer startOffset;
|
||||||
|
private Integer endOffset;
|
||||||
|
private Float startSecond;
|
||||||
|
private Float endSecond;
|
||||||
|
|
||||||
|
public SileroSpeechSegment() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public SileroSpeechSegment(Integer startOffset, Integer endOffset, Float startSecond, Float endSecond) {
|
||||||
|
this.startOffset = startOffset;
|
||||||
|
this.endOffset = endOffset;
|
||||||
|
this.startSecond = startSecond;
|
||||||
|
this.endSecond = endSecond;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getStartOffset() {
|
||||||
|
return startOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEndOffset() {
|
||||||
|
return endOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Float getStartSecond() {
|
||||||
|
return startSecond;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Float getEndSecond() {
|
||||||
|
return endSecond;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setStartOffset(Integer startOffset) {
|
||||||
|
this.startOffset = startOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setEndOffset(Integer endOffset) {
|
||||||
|
this.endOffset = endOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setStartSecond(Float startSecond) {
|
||||||
|
this.startSecond = startSecond;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setEndSecond(Float endSecond) {
|
||||||
|
this.endSecond = endSecond;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,244 @@
|
|||||||
|
package org.example;
|
||||||
|
|
||||||
|
|
||||||
|
import ai.onnxruntime.OrtException;
|
||||||
|
|
||||||
|
import javax.sound.sampled.AudioInputStream;
|
||||||
|
import javax.sound.sampled.AudioSystem;
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class SileroVadDetector {
|
||||||
|
private final SileroVadOnnxModel model;
|
||||||
|
private final float threshold;
|
||||||
|
private final float negThreshold;
|
||||||
|
private final int samplingRate;
|
||||||
|
private final int windowSizeSample;
|
||||||
|
private final float minSpeechSamples;
|
||||||
|
private final float speechPadSamples;
|
||||||
|
private final float maxSpeechSamples;
|
||||||
|
private final float minSilenceSamples;
|
||||||
|
private final float minSilenceSamplesAtMaxSpeech;
|
||||||
|
private int audioLengthSamples;
|
||||||
|
private static final float THRESHOLD_GAP = 0.15f;
|
||||||
|
private static final Integer SAMPLING_RATE_8K = 8000;
|
||||||
|
private static final Integer SAMPLING_RATE_16K = 16000;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor
|
||||||
|
* @param onnxModelPath the path of silero-vad onnx model
|
||||||
|
* @param threshold threshold for speech start
|
||||||
|
* @param samplingRate audio sampling rate, only available for [8k, 16k]
|
||||||
|
* @param minSpeechDurationMs Minimum speech length in millis, any speech duration that smaller than this value would not be considered as speech
|
||||||
|
* @param maxSpeechDurationSeconds Maximum speech length in millis, recommend to be set as Float.POSITIVE_INFINITY
|
||||||
|
* @param minSilenceDurationMs Minimum silence length in millis, any silence duration that smaller than this value would not be considered as silence
|
||||||
|
* @param speechPadMs Additional pad millis for speech start and end
|
||||||
|
* @throws OrtException
|
||||||
|
*/
|
||||||
|
public SileroVadDetector(String onnxModelPath, float threshold, int samplingRate,
|
||||||
|
int minSpeechDurationMs, float maxSpeechDurationSeconds,
|
||||||
|
int minSilenceDurationMs, int speechPadMs) throws OrtException {
|
||||||
|
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K) {
|
||||||
|
throw new IllegalArgumentException("Sampling rate not support, only available for [8000, 16000]");
|
||||||
|
}
|
||||||
|
this.model = new SileroVadOnnxModel(onnxModelPath);
|
||||||
|
this.samplingRate = samplingRate;
|
||||||
|
this.threshold = threshold;
|
||||||
|
this.negThreshold = threshold - THRESHOLD_GAP;
|
||||||
|
if (samplingRate == SAMPLING_RATE_16K) {
|
||||||
|
this.windowSizeSample = 512;
|
||||||
|
} else {
|
||||||
|
this.windowSizeSample = 256;
|
||||||
|
}
|
||||||
|
this.minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
|
||||||
|
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
|
||||||
|
this.maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - windowSizeSample - 2 * speechPadSamples;
|
||||||
|
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
||||||
|
this.minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
|
||||||
|
this.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Method to reset the state
|
||||||
|
*/
|
||||||
|
public void reset() {
|
||||||
|
model.resetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get speech segment list by given wav-format file
|
||||||
|
* @param wavFile wav file
|
||||||
|
* @return list of speech segment
|
||||||
|
*/
|
||||||
|
public List<SileroSpeechSegment> getSpeechSegmentList(File wavFile) {
|
||||||
|
reset();
|
||||||
|
try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(wavFile)){
|
||||||
|
List<Float> speechProbList = new ArrayList<>();
|
||||||
|
this.audioLengthSamples = audioInputStream.available() / 2;
|
||||||
|
byte[] data = new byte[this.windowSizeSample * 2];
|
||||||
|
int numBytesRead = 0;
|
||||||
|
|
||||||
|
while ((numBytesRead = audioInputStream.read(data)) != -1) {
|
||||||
|
if (numBytesRead <= 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// Convert the byte array to a float array
|
||||||
|
float[] audioData = new float[data.length / 2];
|
||||||
|
for (int i = 0; i < audioData.length; i++) {
|
||||||
|
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
float speechProb = 0;
|
||||||
|
try {
|
||||||
|
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
|
||||||
|
speechProbList.add(speechProb);
|
||||||
|
} catch (OrtException e) {
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return calculateProb(speechProbList);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("SileroVadDetector getSpeechTimeList with error", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate speech segement by probability
|
||||||
|
* @param speechProbList speech probability list
|
||||||
|
* @return list of speech segment
|
||||||
|
*/
|
||||||
|
private List<SileroSpeechSegment> calculateProb(List<Float> speechProbList) {
|
||||||
|
List<SileroSpeechSegment> result = new ArrayList<>();
|
||||||
|
boolean triggered = false;
|
||||||
|
int tempEnd = 0, prevEnd = 0, nextStart = 0;
|
||||||
|
SileroSpeechSegment segment = new SileroSpeechSegment();
|
||||||
|
|
||||||
|
for (int i = 0; i < speechProbList.size(); i++) {
|
||||||
|
Float speechProb = speechProbList.get(i);
|
||||||
|
if (speechProb >= threshold && (tempEnd != 0)) {
|
||||||
|
tempEnd = 0;
|
||||||
|
if (nextStart < prevEnd) {
|
||||||
|
nextStart = windowSizeSample * i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speechProb >= threshold && !triggered) {
|
||||||
|
triggered = true;
|
||||||
|
segment.setStartOffset(windowSizeSample * i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (triggered && (windowSizeSample * i) - segment.getStartOffset() > maxSpeechSamples) {
|
||||||
|
if (prevEnd != 0) {
|
||||||
|
segment.setEndOffset(prevEnd);
|
||||||
|
result.add(segment);
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
if (nextStart < prevEnd) {
|
||||||
|
triggered = false;
|
||||||
|
}else {
|
||||||
|
segment.setStartOffset(nextStart);
|
||||||
|
}
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
}else {
|
||||||
|
segment.setEndOffset(windowSizeSample * i);
|
||||||
|
result.add(segment);
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
triggered = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speechProb < negThreshold && triggered) {
|
||||||
|
if (tempEnd == 0) {
|
||||||
|
tempEnd = windowSizeSample * i;
|
||||||
|
}
|
||||||
|
if (((windowSizeSample * i) - tempEnd) > minSilenceSamplesAtMaxSpeech) {
|
||||||
|
prevEnd = tempEnd;
|
||||||
|
}
|
||||||
|
if ((windowSizeSample * i) - tempEnd < minSilenceSamples) {
|
||||||
|
continue;
|
||||||
|
}else {
|
||||||
|
segment.setEndOffset(tempEnd);
|
||||||
|
if ((segment.getEndOffset() - segment.getStartOffset()) > minSpeechSamples) {
|
||||||
|
result.add(segment);
|
||||||
|
}
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
triggered = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (segment.getStartOffset() != null && (audioLengthSamples - segment.getStartOffset()) > minSpeechSamples) {
|
||||||
|
segment.setEndOffset(audioLengthSamples);
|
||||||
|
result.add(segment);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < result.size(); i++) {
|
||||||
|
SileroSpeechSegment item = result.get(i);
|
||||||
|
if (i == 0) {
|
||||||
|
item.setStartOffset((int)(Math.max(0,item.getStartOffset() - speechPadSamples)));
|
||||||
|
}
|
||||||
|
if (i != result.size() - 1) {
|
||||||
|
SileroSpeechSegment nextItem = result.get(i + 1);
|
||||||
|
Integer silenceDuration = nextItem.getStartOffset() - item.getEndOffset();
|
||||||
|
if(silenceDuration < 2 * speechPadSamples){
|
||||||
|
item.setEndOffset(item.getEndOffset() + (silenceDuration / 2 ));
|
||||||
|
nextItem.setStartOffset(Math.max(0, nextItem.getStartOffset() - (silenceDuration / 2)));
|
||||||
|
} else {
|
||||||
|
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
|
||||||
|
nextItem.setStartOffset((int)(Math.max(0,nextItem.getStartOffset() - speechPadSamples)));
|
||||||
|
}
|
||||||
|
}else {
|
||||||
|
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mergeListAndCalculateSecond(result, samplingRate);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SileroSpeechSegment> mergeListAndCalculateSecond(List<SileroSpeechSegment> original, Integer samplingRate) {
|
||||||
|
List<SileroSpeechSegment> result = new ArrayList<>();
|
||||||
|
if (original == null || original.size() == 0) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
Integer left = original.get(0).getStartOffset();
|
||||||
|
Integer right = original.get(0).getEndOffset();
|
||||||
|
if (original.size() > 1) {
|
||||||
|
original.sort(Comparator.comparingLong(SileroSpeechSegment::getStartOffset));
|
||||||
|
for (int i = 1; i < original.size(); i++) {
|
||||||
|
SileroSpeechSegment segment = original.get(i);
|
||||||
|
|
||||||
|
if (segment.getStartOffset() > right) {
|
||||||
|
result.add(new SileroSpeechSegment(left, right,
|
||||||
|
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||||
|
left = segment.getStartOffset();
|
||||||
|
right = segment.getEndOffset();
|
||||||
|
} else {
|
||||||
|
right = Math.max(right, segment.getEndOffset());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.add(new SileroSpeechSegment(left, right,
|
||||||
|
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||||
|
}else {
|
||||||
|
result.add(new SileroSpeechSegment(left, right,
|
||||||
|
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Float calculateSecondByOffset(Integer offset, Integer samplingRate) {
|
||||||
|
float secondValue = offset * 1.0f / samplingRate;
|
||||||
|
return (float) Math.floor(secondValue * 1000.0f) / 1000.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,234 @@
|
|||||||
|
package org.example;
|
||||||
|
|
||||||
|
import ai.onnxruntime.OnnxTensor;
|
||||||
|
import ai.onnxruntime.OrtEnvironment;
|
||||||
|
import ai.onnxruntime.OrtException;
|
||||||
|
import ai.onnxruntime.OrtSession;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class SileroVadOnnxModel {
|
||||||
|
// Define private variable OrtSession
|
||||||
|
private final OrtSession session;
|
||||||
|
private float[][][] state;
|
||||||
|
private float[][] context;
|
||||||
|
// Define the last sample rate
|
||||||
|
private int lastSr = 0;
|
||||||
|
// Define the last batch size
|
||||||
|
private int lastBatchSize = 0;
|
||||||
|
// Define a list of supported sample rates
|
||||||
|
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
public SileroVadOnnxModel(String modelPath) throws OrtException {
|
||||||
|
// Get the ONNX runtime environment
|
||||||
|
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||||
|
// Create an ONNX session options object
|
||||||
|
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
||||||
|
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
|
||||||
|
opts.setInterOpNumThreads(1);
|
||||||
|
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
|
||||||
|
opts.setIntraOpNumThreads(1);
|
||||||
|
// Add a CPU device, setting to false disables CPU execution optimization
|
||||||
|
opts.addCPU(true);
|
||||||
|
// Create an ONNX session using the environment, model path, and options
|
||||||
|
session = env.createSession(modelPath, opts);
|
||||||
|
// Reset states
|
||||||
|
resetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset states
|
||||||
|
*/
|
||||||
|
void resetStates() {
|
||||||
|
state = new float[2][1][128];
|
||||||
|
context = new float[0][];
|
||||||
|
lastSr = 0;
|
||||||
|
lastBatchSize = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void close() throws OrtException {
|
||||||
|
session.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define inner class ValidationResult
|
||||||
|
*/
|
||||||
|
public static class ValidationResult {
|
||||||
|
public final float[][] x;
|
||||||
|
public final int sr;
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
public ValidationResult(float[][] x, int sr) {
|
||||||
|
this.x = x;
|
||||||
|
this.sr = sr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Function to validate input data
|
||||||
|
*/
|
||||||
|
private ValidationResult validateInput(float[][] x, int sr) {
|
||||||
|
// Process the input data with dimension 1
|
||||||
|
if (x.length == 1) {
|
||||||
|
x = new float[][]{x[0]};
|
||||||
|
}
|
||||||
|
// Throw an exception when the input data dimension is greater than 2
|
||||||
|
if (x.length > 2) {
|
||||||
|
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
|
||||||
|
if (sr != 16000 && (sr % 16000 == 0)) {
|
||||||
|
int step = sr / 16000;
|
||||||
|
float[][] reducedX = new float[x.length][];
|
||||||
|
|
||||||
|
for (int i = 0; i < x.length; i++) {
|
||||||
|
float[] current = x[i];
|
||||||
|
float[] newArr = new float[(current.length + step - 1) / step];
|
||||||
|
|
||||||
|
for (int j = 0, index = 0; j < current.length; j += step, index++) {
|
||||||
|
newArr[index] = current[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
reducedX[i] = newArr;
|
||||||
|
}
|
||||||
|
|
||||||
|
x = reducedX;
|
||||||
|
sr = 16000;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the sample rate is not in the list of supported sample rates, throw an exception
|
||||||
|
if (!SAMPLE_RATES.contains(sr)) {
|
||||||
|
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the input audio block is too short, throw an exception
|
||||||
|
if (((float) sr) / x[0].length > 31.25) {
|
||||||
|
throw new IllegalArgumentException("Input audio is too short");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the validated result
|
||||||
|
return new ValidationResult(x, sr);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[][] concatenate(float[][] a, float[][] b) {
|
||||||
|
if (a.length != b.length) {
|
||||||
|
throw new IllegalArgumentException("The number of rows in both arrays must be the same.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int rows = a.length;
|
||||||
|
int colsA = a[0].length;
|
||||||
|
int colsB = b[0].length;
|
||||||
|
float[][] result = new float[rows][colsA + colsB];
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++) {
|
||||||
|
System.arraycopy(a[i], 0, result[i], 0, colsA);
|
||||||
|
System.arraycopy(b[i], 0, result[i], colsA, colsB);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[][] getLastColumns(float[][] array, int contextSize) {
|
||||||
|
int rows = array.length;
|
||||||
|
int cols = array[0].length;
|
||||||
|
|
||||||
|
if (contextSize > cols) {
|
||||||
|
throw new IllegalArgumentException("contextSize cannot be greater than the number of columns in the array.");
|
||||||
|
}
|
||||||
|
|
||||||
|
float[][] result = new float[rows][contextSize];
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++) {
|
||||||
|
System.arraycopy(array[i], cols - contextSize, result[i], 0, contextSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Method to call the ONNX model
|
||||||
|
*/
|
||||||
|
public float[] call(float[][] x, int sr) throws OrtException {
|
||||||
|
ValidationResult result = validateInput(x, sr);
|
||||||
|
x = result.x;
|
||||||
|
sr = result.sr;
|
||||||
|
int numberSamples = 256;
|
||||||
|
if (sr == 16000) {
|
||||||
|
numberSamples = 512;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x[0].length != numberSamples) {
|
||||||
|
throw new IllegalArgumentException("Provided number of samples is " + x[0].length + " (Supported values: 256 for 8000 sample rate, 512 for 16000)");
|
||||||
|
}
|
||||||
|
|
||||||
|
int batchSize = x.length;
|
||||||
|
|
||||||
|
int contextSize = 32;
|
||||||
|
if (sr == 16000) {
|
||||||
|
contextSize = 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lastBatchSize == 0) {
|
||||||
|
resetStates();
|
||||||
|
}
|
||||||
|
if (lastSr != 0 && lastSr != sr) {
|
||||||
|
resetStates();
|
||||||
|
}
|
||||||
|
if (lastBatchSize != 0 && lastBatchSize != batchSize) {
|
||||||
|
resetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (context.length == 0) {
|
||||||
|
context = new float[batchSize][contextSize];
|
||||||
|
}
|
||||||
|
|
||||||
|
x = concatenate(context, x);
|
||||||
|
|
||||||
|
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||||
|
|
||||||
|
OnnxTensor inputTensor = null;
|
||||||
|
OnnxTensor stateTensor = null;
|
||||||
|
OnnxTensor srTensor = null;
|
||||||
|
OrtSession.Result ortOutputs = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Create input tensors
|
||||||
|
inputTensor = OnnxTensor.createTensor(env, x);
|
||||||
|
stateTensor = OnnxTensor.createTensor(env, state);
|
||||||
|
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
|
||||||
|
|
||||||
|
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||||
|
inputs.put("input", inputTensor);
|
||||||
|
inputs.put("sr", srTensor);
|
||||||
|
inputs.put("state", stateTensor);
|
||||||
|
|
||||||
|
// Call the ONNX model for calculation
|
||||||
|
ortOutputs = session.run(inputs);
|
||||||
|
// Get the output results
|
||||||
|
float[][] output = (float[][]) ortOutputs.get(0).getValue();
|
||||||
|
state = (float[][][]) ortOutputs.get(1).getValue();
|
||||||
|
|
||||||
|
context = getLastColumns(x, contextSize);
|
||||||
|
lastSr = sr;
|
||||||
|
lastBatchSize = batchSize;
|
||||||
|
return output[0];
|
||||||
|
} finally {
|
||||||
|
if (inputTensor != null) {
|
||||||
|
inputTensor.close();
|
||||||
|
}
|
||||||
|
if (stateTensor != null) {
|
||||||
|
stateTensor.close();
|
||||||
|
}
|
||||||
|
if (srTensor != null) {
|
||||||
|
srTensor.close();
|
||||||
|
}
|
||||||
|
if (ortOutputs != null) {
|
||||||
|
ortOutputs.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
@@ -18,17 +17,19 @@
|
|||||||
"SAMPLING_RATE = 16000\n",
|
"SAMPLING_RATE = 16000\n",
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"from pprint import pprint\n",
|
"from pprint import pprint\n",
|
||||||
|
"import time\n",
|
||||||
|
"import shutil\n",
|
||||||
"\n",
|
"\n",
|
||||||
"torch.set_num_threads(1)\n",
|
"torch.set_num_threads(1)\n",
|
||||||
"NUM_PROCESS=4 # set to the number of CPU cores in the machine\n",
|
"NUM_PROCESS=4 # set to the number of CPU cores in the machine\n",
|
||||||
"NUM_COPIES=8\n",
|
"NUM_COPIES=8\n",
|
||||||
"# download wav files, make multiple copies\n",
|
"# download wav files, make multiple copies\n",
|
||||||
"for idx in range(NUM_COPIES):\n",
|
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example0.wav\")\n",
|
||||||
" torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example{idx}.wav\")\n"
|
"for idx in range(NUM_COPIES-1):\n",
|
||||||
|
" shutil.copy(f\"en_example0.wav\", f\"en_example{idx+1}.wav\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
@@ -54,7 +55,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
@@ -99,7 +99,6 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
@@ -127,7 +126,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "diarization",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@@ -141,7 +140,20 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.15"
|
"version": "3.10.14"
|
||||||
|
},
|
||||||
|
"toc": {
|
||||||
|
"base_numbering": 1,
|
||||||
|
"nav_menu": {},
|
||||||
|
"number_sections": true,
|
||||||
|
"sideBar": true,
|
||||||
|
"skip_h1_title": false,
|
||||||
|
"title_cell": "Table of Contents",
|
||||||
|
"title_sidebar": "Contents",
|
||||||
|
"toc_cell": false,
|
||||||
|
"toc_position": {},
|
||||||
|
"toc_section_display": true,
|
||||||
|
"toc_window_display": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ It has been designed as a low-level example for binary real-time streaming using
|
|||||||
Currently, the notebook consits of two examples:
|
Currently, the notebook consits of two examples:
|
||||||
- One that records audio of a predefined length from the microphone, process it with Silero-VAD, and plots it afterwards.
|
- One that records audio of a predefined length from the microphone, process it with Silero-VAD, and plots it afterwards.
|
||||||
- The other one plots the speech probabilities in real-time (using jupyterplot) and records the audio until you press enter.
|
- The other one plots the speech probabilities in real-time (using jupyterplot) and records the audio until you press enter.
|
||||||
|
|
||||||
|
This example does not work in google colab! For local usage only.
|
||||||
|
|
||||||
## Example Video for the Real-Time Visualization
|
## Example Video for the Real-Time Visualization
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "62a0cccb",
|
"id": "76aa55ba",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Pyaudio Microphone Streaming Examples\n",
|
"# Pyaudio Microphone Streaming Examples\n",
|
||||||
@@ -12,12 +12,14 @@
|
|||||||
"I created it as an example on how binary data from a stream could be feed into Silero VAD.\n",
|
"I created it as an example on how binary data from a stream could be feed into Silero VAD.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Has been tested on Ubuntu 21.04 (x86). After you installed the dependencies below, no additional setup is required."
|
"Has been tested on Ubuntu 21.04 (x86). After you installed the dependencies below, no additional setup is required.\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook does not work in google colab! For local usage only."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "64cbe1eb",
|
"id": "4a4e15c2",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Dependencies\n",
|
"## Dependencies\n",
|
||||||
@@ -26,22 +28,27 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 1,
|
||||||
"id": "57bc2aac",
|
"id": "24205cce",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-10-09T08:47:34.056898Z",
|
||||||
|
"start_time": "2024-10-09T08:47:34.053418Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"#!pip install numpy==1.20.2\n",
|
"#!pip install numpy>=1.24.0\n",
|
||||||
"#!pip install torch==1.9.0\n",
|
"#!pip install torch>=1.12.0\n",
|
||||||
"#!pip install matplotlib==3.4.2\n",
|
"#!pip install matplotlib>=3.6.0\n",
|
||||||
"#!pip install torchaudio==0.9.0\n",
|
"#!pip install torchaudio>=0.12.0\n",
|
||||||
"#!pip install soundfile==0.10.3.post1\n",
|
"#!pip install soundfile==0.12.1\n",
|
||||||
"#!pip install pyaudio==0.2.11"
|
"#!apt install python3-pyaudio (linux) or pip install pyaudio (windows)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "110de761",
|
"id": "cd22818f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Imports"
|
"## Imports"
|
||||||
@@ -49,10 +56,27 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 2,
|
||||||
"id": "5a647d8d",
|
"id": "994d7f3a",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
"outputs": [],
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-10-09T08:47:39.005032Z",
|
||||||
|
"start_time": "2024-10-09T08:47:36.489952Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "ModuleNotFoundError",
|
||||||
|
"evalue": "No module named 'pyaudio'",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[0;32mIn[2], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpylab\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpyaudio\u001b[39;00m\n",
|
||||||
|
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pyaudio'"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import io\n",
|
"import io\n",
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
@@ -61,14 +85,13 @@
|
|||||||
"import torchaudio\n",
|
"import torchaudio\n",
|
||||||
"import matplotlib\n",
|
"import matplotlib\n",
|
||||||
"import matplotlib.pylab as plt\n",
|
"import matplotlib.pylab as plt\n",
|
||||||
"torchaudio.set_audio_backend(\"soundfile\")\n",
|
|
||||||
"import pyaudio"
|
"import pyaudio"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "725d7066",
|
"id": "ac5c52f7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -80,7 +103,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "1c0b2ea7",
|
"id": "ad5919dc",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -93,7 +116,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "f9112603",
|
"id": "784d1ab6",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Helper Methods"
|
"### Helper Methods"
|
||||||
@@ -102,7 +125,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "5abc6330",
|
"id": "af4bca64",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -125,7 +148,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "5124095e",
|
"id": "ca13e514",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Pyaudio Set-up"
|
"## Pyaudio Set-up"
|
||||||
@@ -134,7 +157,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "a845356e",
|
"id": "75f99022",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -148,7 +171,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "0b910c99",
|
"id": "4da7d2ef",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Simple Example\n",
|
"## Simple Example\n",
|
||||||
@@ -158,17 +181,17 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "9d3d2c10",
|
"id": "6fe77661",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"num_samples = 1536"
|
"num_samples = 512"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "3cb44a4a",
|
"id": "23f4da3e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -180,6 +203,8 @@
|
|||||||
"data = []\n",
|
"data = []\n",
|
||||||
"voiced_confidences = []\n",
|
"voiced_confidences = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"frames_to_record = 50\n",
|
||||||
|
"\n",
|
||||||
"print(\"Started Recording\")\n",
|
"print(\"Started Recording\")\n",
|
||||||
"for i in range(0, frames_to_record):\n",
|
"for i in range(0, frames_to_record):\n",
|
||||||
" \n",
|
" \n",
|
||||||
@@ -206,7 +231,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "a3dda982",
|
"id": "fd243e8f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Real Time Visualization\n",
|
"## Real Time Visualization\n",
|
||||||
@@ -219,7 +244,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "05ef4100",
|
"id": "d36980c2",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -229,7 +254,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "d1d4cdd6",
|
"id": "5607b616",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -286,7 +311,7 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "1e398009",
|
"id": "dc4f0108",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -296,7 +321,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@@ -310,7 +335,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.7.10"
|
"version": "3.10.14"
|
||||||
},
|
},
|
||||||
"toc": {
|
"toc": {
|
||||||
"base_numbering": 1,
|
"base_numbering": 1,
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
use crate::utils;
|
use crate::utils;
|
||||||
use ndarray::{Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Silero {
|
pub struct Silero {
|
||||||
session: ort::Session,
|
session: ort::Session,
|
||||||
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
||||||
h: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
||||||
c: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Silero {
|
impl Silero {
|
||||||
@@ -16,20 +15,17 @@ impl Silero {
|
|||||||
model_path: impl AsRef<Path>,
|
model_path: impl AsRef<Path>,
|
||||||
) -> Result<Self, ort::Error> {
|
) -> Result<Self, ort::Error> {
|
||||||
let session = ort::Session::builder()?.commit_from_file(model_path)?;
|
let session = ort::Session::builder()?.commit_from_file(model_path)?;
|
||||||
let h = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||||
let c = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
|
||||||
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
|
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
session,
|
session,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
h,
|
state,
|
||||||
c,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset(&mut self) {
|
pub fn reset(&mut self) {
|
||||||
self.h = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||||
self.c = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
|
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
|
||||||
@@ -37,18 +33,17 @@ impl Silero {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|x| (*x as f32) / (i16::MAX as f32))
|
.map(|x| (*x as f32) / (i16::MAX as f32))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
|
let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
|
||||||
|
frame = frame.slice(s![.., ..480]).to_owned();
|
||||||
let inps = ort::inputs![
|
let inps = ort::inputs![
|
||||||
frame,
|
frame,
|
||||||
|
std::mem::take(&mut self.state),
|
||||||
self.sample_rate.clone(),
|
self.sample_rate.clone(),
|
||||||
std::mem::take(&mut self.h),
|
|
||||||
std::mem::take(&mut self.c)
|
|
||||||
]?;
|
]?;
|
||||||
let res = self
|
let res = self
|
||||||
.session
|
.session
|
||||||
.run(ort::SessionInputs::ValueSlice::<4>(&inps))?;
|
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
|
||||||
self.h = res["hn"].try_extract_tensor().unwrap().to_owned();
|
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
|
||||||
self.c = res["cn"].try_extract_tensor().unwrap().to_owned();
|
|
||||||
Ok(*res["output"]
|
Ok(*res["output"]
|
||||||
.try_extract_raw_tensor::<f32>()
|
.try_extract_raw_tensor::<f32>()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ impl VadIter {
|
|||||||
pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> {
|
pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> {
|
||||||
self.reset_states();
|
self.reset_states();
|
||||||
for audio_frame in samples.chunks_exact(self.params.frame_size_samples) {
|
for audio_frame in samples.chunks_exact(self.params.frame_size_samples) {
|
||||||
let speech_prob = self.silero.calc_level(audio_frame)?;
|
let speech_prob: f32 = self.silero.calc_level(audio_frame)?;
|
||||||
self.state.update(&self.params, speech_prob);
|
self.state.update(&self.params, speech_prob);
|
||||||
}
|
}
|
||||||
self.state.check_for_last_speech(samples.len());
|
self.state.check_for_last_speech(samples.len());
|
||||||
|
|||||||
11
hubconf.py
11
hubconf.py
@@ -23,11 +23,14 @@ def versiontuple(v):
|
|||||||
return tuple(version_list)
|
return tuple(version_list)
|
||||||
|
|
||||||
|
|
||||||
def silero_vad(onnx=False, force_onnx_cpu=False):
|
def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16):
|
||||||
"""Silero Voice Activity Detector
|
"""Silero Voice Activity Detector
|
||||||
Returns a model with a set of utils
|
Returns a model with a set of utils
|
||||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
Please see https://github.com/snakers4/silero-vad for usage examples
|
||||||
"""
|
"""
|
||||||
|
available_ops = [15, 16]
|
||||||
|
if onnx and opset_version not in available_ops:
|
||||||
|
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||||
|
|
||||||
if not onnx:
|
if not onnx:
|
||||||
installed_version = torch.__version__
|
installed_version = torch.__version__
|
||||||
@@ -37,7 +40,11 @@ def silero_vad(onnx=False, force_onnx_cpu=False):
|
|||||||
|
|
||||||
model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data')
|
model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data')
|
||||||
if onnx:
|
if onnx:
|
||||||
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
|
if opset_version == 16:
|
||||||
|
model_name = 'silero_vad.onnx'
|
||||||
|
else:
|
||||||
|
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||||
|
model = OnnxWrapper(os.path.join(model_dir, model_name), force_onnx_cpu)
|
||||||
else:
|
else:
|
||||||
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
|
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
|
||||||
utils = (get_speech_timestamps,
|
utils = (get_speech_timestamps,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ requires = ["hatchling"]
|
|||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
[project]
|
[project]
|
||||||
name = "silero-vad"
|
name = "silero-vad"
|
||||||
version = "5.1"
|
version = "6.0.0"
|
||||||
authors = [
|
authors = [
|
||||||
{name="Silero Team", email="hello@silero.ai"},
|
{name="Silero Team", email="hello@silero.ai"},
|
||||||
]
|
]
|
||||||
@@ -21,15 +21,18 @@ classifiers = [
|
|||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
|
"Programming Language :: Python :: 3.14",
|
||||||
|
"Programming Language :: Python :: 3.15",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
"Topic :: Scientific/Engineering",
|
"Topic :: Scientific/Engineering",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch>=1.12.0",
|
"torch>=1.12.0",
|
||||||
"torchaudio>=0.12.0",
|
"torchaudio>=0.12.0",
|
||||||
"onnxruntime>=1.18.0",
|
"onnxruntime>=1.16.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Homepage = "https://github.com/snakers4/silero-vad"
|
Homepage = "https://github.com/snakers4/silero-vad"
|
||||||
Issues = "https://github.com/snakers4/silero-vad/issues"
|
Issues = "https://github.com/snakers4/silero-vad/issues"
|
||||||
|
|||||||
@@ -9,4 +9,5 @@ from silero_vad.utils_vad import (get_speech_timestamps,
|
|||||||
save_audio,
|
save_audio,
|
||||||
read_audio,
|
read_audio,
|
||||||
VADIterator,
|
VADIterator,
|
||||||
collect_chunks)
|
collect_chunks,
|
||||||
|
drop_chunks)
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
BIN
src/silero_vad/data/silero_vad_16k_op15.onnx
Normal file
BIN
src/silero_vad/data/silero_vad_16k_op15.onnx
Normal file
Binary file not shown.
BIN
src/silero_vad/data/silero_vad_half.onnx
Normal file
BIN
src/silero_vad/data/silero_vad_half.onnx
Normal file
Binary file not shown.
@@ -2,10 +2,21 @@ from .utils_vad import init_jit_model, OnnxWrapper
|
|||||||
import torch
|
import torch
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
def load_silero_vad(onnx=False):
|
|
||||||
model_name = 'silero_vad.onnx' if onnx else 'silero_vad.jit'
|
def load_silero_vad(onnx=False, opset_version=16):
|
||||||
|
available_ops = [15, 16]
|
||||||
|
if onnx and opset_version not in available_ops:
|
||||||
|
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||||
|
|
||||||
|
if onnx:
|
||||||
|
if opset_version == 16:
|
||||||
|
model_name = 'silero_vad.onnx'
|
||||||
|
else:
|
||||||
|
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||||
|
else:
|
||||||
|
model_name = 'silero_vad.jit'
|
||||||
package_path = "silero_vad.data"
|
package_path = "silero_vad.data"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import importlib_resources as impresources
|
import importlib_resources as impresources
|
||||||
model_file_path = str(impresources.files(package_path).joinpath(model_name))
|
model_file_path = str(impresources.files(package_path).joinpath(model_name))
|
||||||
@@ -18,8 +29,8 @@ def load_silero_vad(onnx=False):
|
|||||||
model_file_path = str(impresources.files(package_path).joinpath(model_name))
|
model_file_path = str(impresources.files(package_path).joinpath(model_name))
|
||||||
|
|
||||||
if onnx:
|
if onnx:
|
||||||
model = OnnxWrapper(model_file_path, force_onnx_cpu=True)
|
model = OnnxWrapper(str(model_file_path), force_onnx_cpu=True)
|
||||||
else:
|
else:
|
||||||
model = init_jit_model(model_file_path)
|
model = init_jit_model(model_file_path)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -23,7 +23,11 @@ class OnnxWrapper():
|
|||||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||||
|
|
||||||
self.reset_states()
|
self.reset_states()
|
||||||
self.sample_rates = [8000, 16000]
|
if '16k' in path:
|
||||||
|
warnings.warn('This model support only 16000 sampling rate!')
|
||||||
|
self.sample_rates = [16000]
|
||||||
|
else:
|
||||||
|
self.sample_rates = [8000, 16000]
|
||||||
|
|
||||||
def _validate_input(self, x, sr: int):
|
def _validate_input(self, x, sr: int):
|
||||||
if x.dim() == 1:
|
if x.dim() == 1:
|
||||||
@@ -53,10 +57,10 @@ class OnnxWrapper():
|
|||||||
|
|
||||||
x, sr = self._validate_input(x, sr)
|
x, sr = self._validate_input(x, sr)
|
||||||
num_samples = 512 if sr == 16000 else 256
|
num_samples = 512 if sr == 16000 else 256
|
||||||
|
|
||||||
if x.shape[-1] != num_samples:
|
if x.shape[-1] != num_samples:
|
||||||
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
context_size = 64 if sr == 16000 else 32
|
context_size = 64 if sr == 16000 else 32
|
||||||
|
|
||||||
@@ -133,7 +137,7 @@ class Validator():
|
|||||||
def read_audio(path: str,
|
def read_audio(path: str,
|
||||||
sampling_rate: int = 16000):
|
sampling_rate: int = 16000):
|
||||||
list_backends = torchaudio.list_audio_backends()
|
list_backends = torchaudio.list_audio_backends()
|
||||||
|
|
||||||
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
|
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
|
||||||
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
|
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
|
||||||
|
|
||||||
@@ -193,9 +197,13 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
min_silence_duration_ms: int = 100,
|
min_silence_duration_ms: int = 100,
|
||||||
speech_pad_ms: int = 30,
|
speech_pad_ms: int = 30,
|
||||||
return_seconds: bool = False,
|
return_seconds: bool = False,
|
||||||
|
time_resolution: int = 1,
|
||||||
visualize_probs: bool = False,
|
visualize_probs: bool = False,
|
||||||
progress_tracking_callback: Callable[[float], None] = None,
|
progress_tracking_callback: Callable[[float], None] = None,
|
||||||
window_size_samples: int = 512,):
|
neg_threshold: float = None,
|
||||||
|
window_size_samples: int = 512,
|
||||||
|
min_silence_at_max_speech: float = 98,
|
||||||
|
use_max_poss_sil_at_max_speech: bool = True):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This method is used for splitting long audios into speech chunks using silero VAD
|
This method is used for splitting long audios into speech chunks using silero VAD
|
||||||
@@ -231,12 +239,24 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
return_seconds: bool (default - False)
|
return_seconds: bool (default - False)
|
||||||
whether return timestamps in seconds (default - samples)
|
whether return timestamps in seconds (default - samples)
|
||||||
|
|
||||||
|
time_resolution: bool (default - 1)
|
||||||
|
time resolution of speech coordinates when requested as seconds
|
||||||
|
|
||||||
visualize_probs: bool (default - False)
|
visualize_probs: bool (default - False)
|
||||||
whether draw prob hist or not
|
whether draw prob hist or not
|
||||||
|
|
||||||
progress_tracking_callback: Callable[[float], None] (default - None)
|
progress_tracking_callback: Callable[[float], None] (default - None)
|
||||||
callback function taking progress in percents as an argument
|
callback function taking progress in percents as an argument
|
||||||
|
|
||||||
|
neg_threshold: float (default = threshold - 0.15)
|
||||||
|
Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH.
|
||||||
|
|
||||||
|
min_silence_at_max_speech: float (default - 98ms)
|
||||||
|
Minimum silence duration in ms which is used to avoid abrupt cuts when max_speech_duration_s is reached
|
||||||
|
|
||||||
|
use_max_poss_sil_at_max_speech: bool (default - True)
|
||||||
|
Whether to use the maximum possible silence at max_speech_duration_s or not. If not, the last silence is used.
|
||||||
|
|
||||||
window_size_samples: int (default - 512 samples)
|
window_size_samples: int (default - 512 samples)
|
||||||
!!! DEPRECATED, DOES NOTHING !!!
|
!!! DEPRECATED, DOES NOTHING !!!
|
||||||
|
|
||||||
@@ -245,7 +265,6 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
speeches: list of dicts
|
speeches: list of dicts
|
||||||
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
|
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not torch.is_tensor(audio):
|
if not torch.is_tensor(audio):
|
||||||
try:
|
try:
|
||||||
audio = torch.Tensor(audio)
|
audio = torch.Tensor(audio)
|
||||||
@@ -270,25 +289,29 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
|
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
|
||||||
|
|
||||||
window_size_samples = 512 if sampling_rate == 16000 else 256
|
window_size_samples = 512 if sampling_rate == 16000 else 256
|
||||||
|
hop_size_samples = int(window_size_samples)
|
||||||
|
|
||||||
model.reset_states()
|
model.reset_states()
|
||||||
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
||||||
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||||
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
|
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
|
||||||
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||||
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
|
min_silence_samples_at_max_speech = sampling_rate * min_silence_at_max_speech / 1000
|
||||||
|
|
||||||
audio_length_samples = len(audio)
|
audio_length_samples = len(audio)
|
||||||
|
|
||||||
speech_probs = []
|
speech_probs = []
|
||||||
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
for current_start_sample in range(0, audio_length_samples, hop_size_samples):
|
||||||
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
|
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
|
||||||
if len(chunk) < window_size_samples:
|
if len(chunk) < window_size_samples:
|
||||||
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||||
speech_prob = model(chunk, sampling_rate).item()
|
try:
|
||||||
|
speech_prob = model(chunk, sampling_rate).item()
|
||||||
|
except Exception as e:
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
speech_probs.append(speech_prob)
|
speech_probs.append(speech_prob)
|
||||||
# caculate progress and seng it to callback function
|
# caculate progress and seng it to callback function
|
||||||
progress = current_start_sample + window_size_samples
|
progress = current_start_sample + hop_size_samples
|
||||||
if progress > audio_length_samples:
|
if progress > audio_length_samples:
|
||||||
progress = audio_length_samples
|
progress = audio_length_samples
|
||||||
progress_percent = (progress / audio_length_samples) * 100
|
progress_percent = (progress / audio_length_samples) * 100
|
||||||
@@ -298,45 +321,61 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
triggered = False
|
triggered = False
|
||||||
speeches = []
|
speeches = []
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
neg_threshold = threshold - 0.15
|
|
||||||
temp_end = 0 # to save potential segment end (and tolerate some silence)
|
if neg_threshold is None:
|
||||||
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
|
neg_threshold = max(threshold - 0.15, 0.01)
|
||||||
|
temp_end = 0 # to save potential segment end (and tolerate some silence)
|
||||||
|
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
|
||||||
|
possible_ends = []
|
||||||
|
|
||||||
for i, speech_prob in enumerate(speech_probs):
|
for i, speech_prob in enumerate(speech_probs):
|
||||||
if (speech_prob >= threshold) and temp_end:
|
if (speech_prob >= threshold) and temp_end:
|
||||||
temp_end = 0
|
if temp_end != 0:
|
||||||
|
sil_dur = (hop_size_samples * i) - temp_end
|
||||||
|
if sil_dur > min_silence_samples_at_max_speech:
|
||||||
|
possible_ends.append((temp_end, sil_dur))
|
||||||
|
temp_end = 0
|
||||||
if next_start < prev_end:
|
if next_start < prev_end:
|
||||||
next_start = window_size_samples * i
|
next_start = hop_size_samples * i
|
||||||
|
|
||||||
if (speech_prob >= threshold) and not triggered:
|
if (speech_prob >= threshold) and not triggered:
|
||||||
triggered = True
|
triggered = True
|
||||||
current_speech['start'] = window_size_samples * i
|
current_speech['start'] = hop_size_samples * i
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples:
|
if triggered and (hop_size_samples * i) - current_speech['start'] > max_speech_samples:
|
||||||
if prev_end:
|
if possible_ends:
|
||||||
|
if use_max_poss_sil_at_max_speech:
|
||||||
|
prev_end, dur = max(possible_ends, key=lambda x: x[1]) # use the longest possible silence segment in the current speech chunk
|
||||||
|
else:
|
||||||
|
prev_end, dur = possible_ends[-1] # use the last possible silence segement
|
||||||
current_speech['end'] = prev_end
|
current_speech['end'] = prev_end
|
||||||
speeches.append(current_speech)
|
speeches.append(current_speech)
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
|
next_start = prev_end + dur
|
||||||
triggered = False
|
if next_start < prev_end + hop_size_samples * i: # previously reached silence (< neg_thres) and is still not speech (< thres)
|
||||||
else:
|
#triggered = False
|
||||||
current_speech['start'] = next_start
|
current_speech['start'] = next_start
|
||||||
|
else:
|
||||||
|
triggered = False
|
||||||
|
#current_speech['start'] = next_start
|
||||||
prev_end = next_start = temp_end = 0
|
prev_end = next_start = temp_end = 0
|
||||||
|
possible_ends = []
|
||||||
else:
|
else:
|
||||||
current_speech['end'] = window_size_samples * i
|
current_speech['end'] = hop_size_samples * i
|
||||||
speeches.append(current_speech)
|
speeches.append(current_speech)
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
prev_end = next_start = temp_end = 0
|
prev_end = next_start = temp_end = 0
|
||||||
triggered = False
|
triggered = False
|
||||||
|
possible_ends = []
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (speech_prob < neg_threshold) and triggered:
|
if (speech_prob < neg_threshold) and triggered:
|
||||||
if not temp_end:
|
if not temp_end:
|
||||||
temp_end = window_size_samples * i
|
temp_end = hop_size_samples * i
|
||||||
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence
|
# if ((hop_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
|
||||||
prev_end = temp_end
|
# prev_end = temp_end
|
||||||
if (window_size_samples * i) - temp_end < min_silence_samples:
|
if (hop_size_samples * i) - temp_end < min_silence_samples:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
current_speech['end'] = temp_end
|
current_speech['end'] = temp_end
|
||||||
@@ -345,6 +384,7 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
current_speech = {}
|
current_speech = {}
|
||||||
prev_end = next_start = temp_end = 0
|
prev_end = next_start = temp_end = 0
|
||||||
triggered = False
|
triggered = False
|
||||||
|
possible_ends = []
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
|
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
|
||||||
@@ -366,16 +406,17 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
|
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
|
||||||
|
|
||||||
if return_seconds:
|
if return_seconds:
|
||||||
|
audio_length_seconds = audio_length_samples / sampling_rate
|
||||||
for speech_dict in speeches:
|
for speech_dict in speeches:
|
||||||
speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1)
|
speech_dict['start'] = max(round(speech_dict['start'] / sampling_rate, time_resolution), 0)
|
||||||
speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1)
|
speech_dict['end'] = min(round(speech_dict['end'] / sampling_rate, time_resolution), audio_length_seconds)
|
||||||
elif step > 1:
|
elif step > 1:
|
||||||
for speech_dict in speeches:
|
for speech_dict in speeches:
|
||||||
speech_dict['start'] *= step
|
speech_dict['start'] *= step
|
||||||
speech_dict['end'] *= step
|
speech_dict['end'] *= step
|
||||||
|
|
||||||
if visualize_probs:
|
if visualize_probs:
|
||||||
make_visualization(speech_probs, window_size_samples / sampling_rate)
|
make_visualization(speech_probs, hop_size_samples / sampling_rate)
|
||||||
|
|
||||||
return speeches
|
return speeches
|
||||||
|
|
||||||
@@ -429,13 +470,16 @@ class VADIterator:
|
|||||||
self.current_sample = 0
|
self.current_sample = 0
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, x, return_seconds=False):
|
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
|
||||||
"""
|
"""
|
||||||
x: torch.Tensor
|
x: torch.Tensor
|
||||||
audio chunk (see examples in repo)
|
audio chunk (see examples in repo)
|
||||||
|
|
||||||
return_seconds: bool (default - False)
|
return_seconds: bool (default - False)
|
||||||
whether return timestamps in seconds (default - samples)
|
whether return timestamps in seconds (default - samples)
|
||||||
|
|
||||||
|
time_resolution: int (default - 1)
|
||||||
|
time resolution of speech coordinates when requested as seconds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not torch.is_tensor(x):
|
if not torch.is_tensor(x):
|
||||||
@@ -454,8 +498,8 @@ class VADIterator:
|
|||||||
|
|
||||||
if (speech_prob >= self.threshold) and not self.triggered:
|
if (speech_prob >= self.threshold) and not self.triggered:
|
||||||
self.triggered = True
|
self.triggered = True
|
||||||
speech_start = self.current_sample - self.speech_pad_samples - window_size_samples
|
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
||||||
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
|
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
|
||||||
|
|
||||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||||
if not self.temp_end:
|
if not self.temp_end:
|
||||||
@@ -466,24 +510,110 @@ class VADIterator:
|
|||||||
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
||||||
self.temp_end = 0
|
self.temp_end = 0
|
||||||
self.triggered = False
|
self.triggered = False
|
||||||
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
|
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def collect_chunks(tss: List[dict],
|
def collect_chunks(tss: List[dict],
|
||||||
wav: torch.Tensor):
|
wav: torch.Tensor,
|
||||||
chunks = []
|
seconds: bool = False,
|
||||||
for i in tss:
|
sampling_rate: int = None) -> torch.Tensor:
|
||||||
chunks.append(wav[i['start']: i['end']])
|
"""Collect audio chunks from a longer audio clip
|
||||||
|
|
||||||
|
This method extracts audio chunks from an audio clip, using a list of
|
||||||
|
provided coordinates, and concatenates them together. Coordinates can be
|
||||||
|
passed either as sample numbers or in seconds, in which case the audio
|
||||||
|
sampling rate is also needed.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tss: List[dict]
|
||||||
|
Coordinate list of the clips to collect from the audio.
|
||||||
|
wav: torch.Tensor, one dimensional
|
||||||
|
One dimensional float torch.Tensor, containing the audio to clip.
|
||||||
|
seconds: bool (default - False)
|
||||||
|
Whether input coordinates are passed as seconds or samples.
|
||||||
|
sampling_rate: int (default - None)
|
||||||
|
Input audio sampling rate. Required if seconds is True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor, one dimensional
|
||||||
|
One dimensional float torch.Tensor of the concatenated clipped audio
|
||||||
|
chunks.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
Raised if sampling_rate is not provided when seconds is True.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if seconds and not sampling_rate:
|
||||||
|
raise ValueError('sampling_rate must be provided when seconds is True')
|
||||||
|
|
||||||
|
chunks = list()
|
||||||
|
_tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss
|
||||||
|
|
||||||
|
for i in _tss:
|
||||||
|
chunks.append(wav[i['start']:i['end']])
|
||||||
|
|
||||||
return torch.cat(chunks)
|
return torch.cat(chunks)
|
||||||
|
|
||||||
|
|
||||||
def drop_chunks(tss: List[dict],
|
def drop_chunks(tss: List[dict],
|
||||||
wav: torch.Tensor):
|
wav: torch.Tensor,
|
||||||
chunks = []
|
seconds: bool = False,
|
||||||
|
sampling_rate: int = None) -> torch.Tensor:
|
||||||
|
"""Drop audio chunks from a longer audio clip
|
||||||
|
|
||||||
|
This method extracts audio chunks from an audio clip, using a list of
|
||||||
|
provided coordinates, and drops them. Coordinates can be passed either as
|
||||||
|
sample numbers or in seconds, in which case the audio sampling rate is also
|
||||||
|
needed.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tss: List[dict]
|
||||||
|
Coordinate list of the clips to drop from from the audio.
|
||||||
|
wav: torch.Tensor, one dimensional
|
||||||
|
One dimensional float torch.Tensor, containing the audio to clip.
|
||||||
|
seconds: bool (default - False)
|
||||||
|
Whether input coordinates are passed as seconds or samples.
|
||||||
|
sampling_rate: int (default - None)
|
||||||
|
Input audio sampling rate. Required if seconds is True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor, one dimensional
|
||||||
|
One dimensional float torch.Tensor of the input audio minus the dropped
|
||||||
|
chunks.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
Raised if sampling_rate is not provided when seconds is True.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if seconds and not sampling_rate:
|
||||||
|
raise ValueError('sampling_rate must be provided when seconds is True')
|
||||||
|
|
||||||
|
chunks = list()
|
||||||
cur_start = 0
|
cur_start = 0
|
||||||
for i in tss:
|
|
||||||
|
_tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss
|
||||||
|
|
||||||
|
for i in _tss:
|
||||||
chunks.append((wav[cur_start: i['start']]))
|
chunks.append((wav[cur_start: i['start']]))
|
||||||
cur_start = i['end']
|
cur_start = i['end']
|
||||||
|
|
||||||
return torch.cat(chunks)
|
return torch.cat(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def _seconds_to_samples_tss(tss: List[dict], sampling_rate: int) -> List[dict]:
|
||||||
|
"""Convert coordinates expressed in seconds to sample coordinates.
|
||||||
|
"""
|
||||||
|
return [{
|
||||||
|
'start': round(crd['start']) * sampling_rate,
|
||||||
|
'end': round(crd['end']) * sampling_rate
|
||||||
|
} for crd in tss]
|
||||||
|
|||||||
BIN
tests/data/test.mp3
Normal file
BIN
tests/data/test.mp3
Normal file
Binary file not shown.
BIN
tests/data/test.opus
Normal file
BIN
tests/data/test.opus
Normal file
Binary file not shown.
BIN
tests/data/test.wav
Normal file
BIN
tests/data/test.wav
Normal file
Binary file not shown.
22
tests/test_basic.py
Normal file
22
tests/test_basic.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
|
||||||
|
import torch
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
def test_jit_model():
|
||||||
|
model = load_silero_vad(onnx=False)
|
||||||
|
for path in ["tests/data/test.wav", "tests/data/test.opus", "tests/data/test.mp3"]:
|
||||||
|
audio = read_audio(path, sampling_rate=16000)
|
||||||
|
speech_timestamps = get_speech_timestamps(audio, model, visualize_probs=False, return_seconds=True)
|
||||||
|
assert speech_timestamps is not None
|
||||||
|
out = model.audio_forward(audio, sr=16000)
|
||||||
|
assert out is not None
|
||||||
|
|
||||||
|
def test_onnx_model():
|
||||||
|
model = load_silero_vad(onnx=True)
|
||||||
|
for path in ["tests/data/test.wav", "tests/data/test.opus", "tests/data/test.mp3"]:
|
||||||
|
audio = read_audio(path, sampling_rate=16000)
|
||||||
|
speech_timestamps = get_speech_timestamps(audio, model, visualize_probs=False, return_seconds=True)
|
||||||
|
assert speech_timestamps is not None
|
||||||
|
|
||||||
|
out = model.audio_forward(audio, sr=16000)
|
||||||
|
assert out is not None
|
||||||
74
tuning/README.md
Normal file
74
tuning/README.md
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# Тюнинг Silero-VAD модели
|
||||||
|
|
||||||
|
> Код тюнинга создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный
|
||||||
|
интеллект» национальной программы «Цифровая экономика Российской Федерации».
|
||||||
|
|
||||||
|
Тюнинг используется для улучшения качества детекции речи Silero-VAD модели на кастомных данных.
|
||||||
|
|
||||||
|
## Зависимости
|
||||||
|
Следующие зависимости используются при тюнинге VAD модели:
|
||||||
|
- `torchaudio>=0.12.0`
|
||||||
|
- `omegaconf>=2.3.0`
|
||||||
|
- `sklearn>=1.2.0`
|
||||||
|
- `torch>=1.12.0`
|
||||||
|
- `pandas>=2.2.2`
|
||||||
|
- `tqdm`
|
||||||
|
|
||||||
|
## Подготовка данных
|
||||||
|
|
||||||
|
Датафреймы для тюнинга должны быть подготовлены и сохранены в формате `.feather`. Следующие колонки в `.feather` файлах тренировки и валидации являются обязательными:
|
||||||
|
- **audio_path** - абсолютный путь до аудиофайла в дисковой системе. Аудиофайлы должны представлять собой `PCM` данные, предпочтительно в форматах `.wav` или `.opus` (иные популярные форматы аудио тоже поддерживаются). Для ускорения темпа дообучения рекомендуется предварительно выполнить ресемплинг аудиофайлов (изменить частоту дискретизации) до 16000 Гц;
|
||||||
|
- **speech_ts** - разметка для соответствующего аудиофайла. Список, состоящий из словарей формата `{'start': START_SEC, 'end': 'END_SEC'}`, где `START_SEC` и `END_SEC` - время начало и конца речевого отрезка в секундах соответственно. Для качественного дообучения рекомендуется использовать разметку с точностью до 30 миллисекунд.
|
||||||
|
|
||||||
|
Чем больше данных используется на этапе дообучения, тем эффективнее показывает себя адаптированная модель на целевом домене. Длина аудио не ограничена, т.к. каждое аудио будет обрезано до `max_train_length_sec` секунд перед подачей в нейросеть. Длинные аудио лучше предварительно порезать на кусочки длины `max_train_length_sec`.
|
||||||
|
|
||||||
|
Пример `.feather` датафрейма можно посмотреть в файле `example_dataframe.feather`
|
||||||
|
|
||||||
|
## Файл конфигурации `config.yml`
|
||||||
|
|
||||||
|
Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения:
|
||||||
|
- `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
|
||||||
|
- `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
|
||||||
|
- `jit_model_path` - абсолютный путь до Silero-VAD модели в формате `.jit`. Если оставить это поле пустым, то модель будет загружена из репозитория в зависимости от значения поля `use_torchhub`
|
||||||
|
- `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`);
|
||||||
|
- `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц;
|
||||||
|
- `model_save_path` - путь сохранения добученной модели;
|
||||||
|
- `noise_loss` - коэффициент лосса, применяемый для неречевых окон аудио;
|
||||||
|
- `max_train_length_sec` - максимальная длина аудио в секундах на этапе дообучения. Более длительные аудио будут обрезаны до этого показателя;
|
||||||
|
- `aug_prob` - вероятность применения аугментаций к аудиофайлу на этапе дообучения;
|
||||||
|
- `learning_rate` - темп дообучения;
|
||||||
|
- `batch_size` - размер батча при дообучении и валидации;
|
||||||
|
- `num_workers` - количество потоков, используемых для загрузки данных;
|
||||||
|
- `num_epochs` - количество эпох дообучения. За одну эпоху прогоняются все тренировочные данные;
|
||||||
|
- `device` - `cpu` или `cuda`.
|
||||||
|
|
||||||
|
## Дообучение
|
||||||
|
|
||||||
|
Дообучение запускается командой
|
||||||
|
|
||||||
|
`python tune.py`
|
||||||
|
|
||||||
|
Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit.
|
||||||
|
|
||||||
|
## Поиск пороговых значений
|
||||||
|
|
||||||
|
Порог на вход и порог на выход можно подобрать, используя команду
|
||||||
|
|
||||||
|
`python search_thresholds`
|
||||||
|
|
||||||
|
Данный скрипт использует файл конфигурации, описанный выше. Указанная в конфигурации модель будет использована для поиска оптимальных порогов на валидационном датасете.
|
||||||
|
|
||||||
|
## Цитирование
|
||||||
|
|
||||||
|
```
|
||||||
|
@misc{Silero VAD,
|
||||||
|
author = {Silero Team},
|
||||||
|
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
|
||||||
|
year = {2024},
|
||||||
|
publisher = {GitHub},
|
||||||
|
journal = {GitHub repository},
|
||||||
|
howpublished = {\url{https://github.com/snakers4/silero-vad}},
|
||||||
|
commit = {insert_some_commit_here},
|
||||||
|
email = {hello@silero.ai}
|
||||||
|
}
|
||||||
|
```
|
||||||
0
tuning/__init__.py
Normal file
0
tuning/__init__.py
Normal file
17
tuning/config.yml
Normal file
17
tuning/config.yml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
jit_model_path: '' # путь до Silero-VAD модели в формате jit, эта модель будет использована для дообучения. Если оставить поле пустым, то модель будет загружена автоматически
|
||||||
|
use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False
|
||||||
|
|
||||||
|
tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True
|
||||||
|
train_dataset_path: 'train_dataset_path.feather' # путь до датасета в формате feather для дообучения, подробности в README
|
||||||
|
val_dataset_path: 'val_dataset_path.feather' # путь до датасета в формате feather для валидации, подробности в README
|
||||||
|
model_save_path: 'model_save_path.jit' # путь сохранения дообученной модели
|
||||||
|
|
||||||
|
noise_loss: 0.5 # коэффициент, применяемый к лоссу на неречевых окнах
|
||||||
|
max_train_length_sec: 8 # во время тюнинга аудио длиннее будут обрезаны до данного значения
|
||||||
|
aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения
|
||||||
|
|
||||||
|
learning_rate: 5e-4 # темп дообучения модели
|
||||||
|
batch_size: 128 # размер батча при дообучении и валидации
|
||||||
|
num_workers: 4 # количество потоков, используемых для даталоадеров
|
||||||
|
num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных
|
||||||
|
device: 'cuda' # cpu или cuda, на чем будет производится дообучение
|
||||||
BIN
tuning/example_dataframe.feather
Normal file
BIN
tuning/example_dataframe.feather
Normal file
Binary file not shown.
36
tuning/search_thresholds.py
Normal file
36
tuning/search_thresholds.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
import torch
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = OmegaConf.load('config.yml')
|
||||||
|
|
||||||
|
loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'),
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
collate_fn=SileroVadPadder,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
if config.jit_model_path:
|
||||||
|
print(f'Loading model from the local folder: {config.jit_model_path}')
|
||||||
|
model = init_jit_model(config.jit_model_path, device=config.device)
|
||||||
|
else:
|
||||||
|
if config.use_torchhub:
|
||||||
|
print('Loading model using torch.hub')
|
||||||
|
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||||
|
model='silero_vad',
|
||||||
|
onnx=False,
|
||||||
|
force_reload=True)
|
||||||
|
else:
|
||||||
|
print('Loading model using silero-vad library')
|
||||||
|
from silero_vad import load_silero_vad
|
||||||
|
model = load_silero_vad(onnx=False)
|
||||||
|
|
||||||
|
print('Model loaded')
|
||||||
|
model.to(config.device)
|
||||||
|
|
||||||
|
print('Making predicts...')
|
||||||
|
all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000)
|
||||||
|
print('Calculating thresholds...')
|
||||||
|
best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts)
|
||||||
|
print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')
|
||||||
65
tuning/tune.py
Normal file
65
tuning/tune.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = OmegaConf.load('config.yml')
|
||||||
|
|
||||||
|
train_dataset = SileroVadDataset(config, mode='train')
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
collate_fn=SileroVadPadder,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
val_dataset = SileroVadDataset(config, mode='val')
|
||||||
|
val_loader = torch.utils.data.DataLoader(val_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
collate_fn=SileroVadPadder,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
if config.jit_model_path:
|
||||||
|
print(f'Loading model from the local folder: {config.jit_model_path}')
|
||||||
|
model = init_jit_model(config.jit_model_path, device=config.device)
|
||||||
|
else:
|
||||||
|
if config.use_torchhub:
|
||||||
|
print('Loading model using torch.hub')
|
||||||
|
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||||
|
model='silero_vad',
|
||||||
|
onnx=False,
|
||||||
|
force_reload=True)
|
||||||
|
else:
|
||||||
|
print('Loading model using silero-vad library')
|
||||||
|
from silero_vad import load_silero_vad
|
||||||
|
model = load_silero_vad(onnx=False)
|
||||||
|
|
||||||
|
print('Model loaded')
|
||||||
|
model.to(config.device)
|
||||||
|
decoder = VADDecoderRNNJIT().to(config.device)
|
||||||
|
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())
|
||||||
|
decoder.train()
|
||||||
|
params = decoder.parameters()
|
||||||
|
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params),
|
||||||
|
lr=config.learning_rate)
|
||||||
|
criterion = nn.BCELoss(reduction='none')
|
||||||
|
|
||||||
|
best_val_roc = 0
|
||||||
|
for i in range(config.num_epochs):
|
||||||
|
print(f'Starting epoch {i + 1}')
|
||||||
|
train_loss = train(config, train_loader, model, decoder, criterion, optimizer, config.device)
|
||||||
|
val_loss, val_roc = validate(config, val_loader, model, decoder, criterion, config.device)
|
||||||
|
print(f'Metrics after epoch {i + 1}:\n'
|
||||||
|
f'\tTrain loss: {round(train_loss, 3)}\n',
|
||||||
|
f'\tValidation loss: {round(val_loss, 3)}\n'
|
||||||
|
f'\tValidation ROC-AUC: {round(val_roc, 3)}')
|
||||||
|
|
||||||
|
if val_roc > best_val_roc:
|
||||||
|
print('New best ROC-AUC, saving model')
|
||||||
|
best_val_roc = val_roc
|
||||||
|
if config.tune_8k:
|
||||||
|
model._model_8k.decoder.load_state_dict(decoder.state_dict())
|
||||||
|
else:
|
||||||
|
model._model.decoder.load_state_dict(decoder.state_dict())
|
||||||
|
torch.jit.save(model, config.model_save_path)
|
||||||
|
print('Done')
|
||||||
357
tuning/utils.py
Normal file
357
tuning/utils.py
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
from sklearn.metrics import roc_auc_score, accuracy_score
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import torch.nn as nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import torchaudio
|
||||||
|
import warnings
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import gc
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def read_audio(path: str,
|
||||||
|
sampling_rate: int = 16000,
|
||||||
|
normalize=False):
|
||||||
|
|
||||||
|
wav, sr = torchaudio.load(path)
|
||||||
|
|
||||||
|
if wav.size(0) > 1:
|
||||||
|
wav = wav.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
if sampling_rate:
|
||||||
|
if sr != sampling_rate:
|
||||||
|
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
||||||
|
new_freq=sampling_rate)
|
||||||
|
wav = transform(wav)
|
||||||
|
sr = sampling_rate
|
||||||
|
|
||||||
|
if normalize and wav.abs().max() != 0:
|
||||||
|
wav = wav / wav.abs().max()
|
||||||
|
|
||||||
|
return wav.squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def build_audiomentations_augs(p):
|
||||||
|
from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \
|
||||||
|
LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \
|
||||||
|
Aliasing, AddGaussianNoise
|
||||||
|
transforms = [Aliasing(p=1),
|
||||||
|
AddGaussianNoise(p=1),
|
||||||
|
AirAbsorption(p=1),
|
||||||
|
BandPassFilter(p=1),
|
||||||
|
BandStopFilter(p=1),
|
||||||
|
ClippingDistortion(p=1),
|
||||||
|
HighPassFilter(p=1),
|
||||||
|
HighShelfFilter(p=1),
|
||||||
|
LowPassFilter(p=1),
|
||||||
|
LowShelfFilter(p=1),
|
||||||
|
Mp3Compression(p=1),
|
||||||
|
PeakingFilter(p=1),
|
||||||
|
PitchShift(p=1),
|
||||||
|
RoomSimulator(p=1, leave_length_unchanged=True),
|
||||||
|
SevenBandParametricEQ(p=1)]
|
||||||
|
tr = SomeOf((1, 3), transforms=transforms, p=p)
|
||||||
|
return tr
|
||||||
|
|
||||||
|
|
||||||
|
class SileroVadDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
config,
|
||||||
|
mode='train'):
|
||||||
|
|
||||||
|
self.num_samples = 512 # constant, do not change
|
||||||
|
self.sr = 16000 # constant, do not change
|
||||||
|
|
||||||
|
self.resample_to_8k = config.tune_8k
|
||||||
|
self.noise_loss = config.noise_loss
|
||||||
|
self.max_train_length_sec = config.max_train_length_sec
|
||||||
|
self.max_train_length_samples = config.max_train_length_sec * self.sr
|
||||||
|
|
||||||
|
assert self.max_train_length_samples % self.num_samples == 0
|
||||||
|
assert mode in ['train', 'val']
|
||||||
|
|
||||||
|
dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path
|
||||||
|
self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True)
|
||||||
|
self.index_dict = self.dataframe.to_dict('index')
|
||||||
|
self.mode = mode
|
||||||
|
print(f'DATASET SIZE : {len(self.dataframe)}')
|
||||||
|
|
||||||
|
if mode == 'train':
|
||||||
|
self.augs = build_audiomentations_augs(p=config.aug_prob)
|
||||||
|
else:
|
||||||
|
self.augs = None
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
idx = None if self.mode == 'train' else idx
|
||||||
|
wav, gt, mask = self.load_speech_sample(idx)
|
||||||
|
|
||||||
|
if self.mode == 'train':
|
||||||
|
wav = self.add_augs(wav)
|
||||||
|
if len(wav) > self.max_train_length_samples:
|
||||||
|
wav = wav[:self.max_train_length_samples]
|
||||||
|
gt = gt[:int(self.max_train_length_samples / self.num_samples)]
|
||||||
|
mask = mask[:int(self.max_train_length_samples / self.num_samples)]
|
||||||
|
|
||||||
|
wav = torch.FloatTensor(wav)
|
||||||
|
if self.resample_to_8k:
|
||||||
|
transform = torchaudio.transforms.Resample(orig_freq=self.sr,
|
||||||
|
new_freq=8000)
|
||||||
|
wav = transform(wav)
|
||||||
|
return wav, torch.FloatTensor(gt), torch.from_numpy(mask)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.index_dict)
|
||||||
|
|
||||||
|
def load_speech_sample(self, idx=None):
|
||||||
|
if idx is None:
|
||||||
|
idx = random.randint(0, len(self.index_dict) - 1)
|
||||||
|
wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy()
|
||||||
|
|
||||||
|
if len(wav) % self.num_samples != 0:
|
||||||
|
pad_num = self.num_samples - (len(wav) % (self.num_samples))
|
||||||
|
wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0)
|
||||||
|
|
||||||
|
gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav))
|
||||||
|
|
||||||
|
assert len(gt) == len(wav) / self.num_samples
|
||||||
|
|
||||||
|
mask[gt == 0]
|
||||||
|
|
||||||
|
return wav, gt, mask
|
||||||
|
|
||||||
|
def get_ground_truth_annotated(self, annotation, audio_length_samples):
|
||||||
|
gt = np.zeros(audio_length_samples)
|
||||||
|
|
||||||
|
for i in annotation:
|
||||||
|
gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1
|
||||||
|
|
||||||
|
squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1)
|
||||||
|
squeezed_predicts = (squeezed_predicts > 0.5).astype(int)
|
||||||
|
mask = np.ones(len(squeezed_predicts))
|
||||||
|
mask[squeezed_predicts == 0] = self.noise_loss
|
||||||
|
return squeezed_predicts, mask
|
||||||
|
|
||||||
|
def add_augs(self, wav):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
wav_aug = self.augs(wav, self.sr)
|
||||||
|
if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()):
|
||||||
|
return wav
|
||||||
|
return wav_aug
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def SileroVadPadder(batch):
|
||||||
|
wavs = [batch[i][0] for i in range(len(batch))]
|
||||||
|
labels = [batch[i][1] for i in range(len(batch))]
|
||||||
|
masks = [batch[i][2] for i in range(len(batch))]
|
||||||
|
|
||||||
|
wavs = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
wavs, batch_first=True, padding_value=0)
|
||||||
|
|
||||||
|
labels = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
labels, batch_first=True, padding_value=0)
|
||||||
|
|
||||||
|
masks = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
masks, batch_first=True, padding_value=0)
|
||||||
|
|
||||||
|
return wavs, labels, masks
|
||||||
|
|
||||||
|
|
||||||
|
class VADDecoderRNNJIT(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(VADDecoderRNNJIT, self).__init__()
|
||||||
|
|
||||||
|
self.rnn = nn.LSTMCell(128, 128)
|
||||||
|
self.decoder = nn.Sequential(nn.Dropout(0.1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1d(128, 1, kernel_size=1),
|
||||||
|
nn.Sigmoid())
|
||||||
|
|
||||||
|
def forward(self, x, state=torch.zeros(0)):
|
||||||
|
x = x.squeeze(-1)
|
||||||
|
if len(state):
|
||||||
|
h, c = self.rnn(x, (state[0], state[1]))
|
||||||
|
else:
|
||||||
|
h, c = self.rnn(x)
|
||||||
|
|
||||||
|
x = h.unsqueeze(-1).float()
|
||||||
|
state = torch.stack([h, c])
|
||||||
|
x = self.decoder(x)
|
||||||
|
return x, state
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
def train(config,
|
||||||
|
loader,
|
||||||
|
jit_model,
|
||||||
|
decoder,
|
||||||
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
device):
|
||||||
|
|
||||||
|
losses = AverageMeter()
|
||||||
|
decoder.train()
|
||||||
|
|
||||||
|
context_size = 32 if config.tune_8k else 64
|
||||||
|
num_samples = 256 if config.tune_8k else 512
|
||||||
|
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
|
||||||
|
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||||
|
targets = targets.to(device)
|
||||||
|
x = x.to(device)
|
||||||
|
masks = masks.to(device)
|
||||||
|
x = torch.nn.functional.pad(x, (context_size, 0))
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
state = torch.zeros(0)
|
||||||
|
for i in range(context_size, x.shape[1], num_samples):
|
||||||
|
input_ = x[:, i-context_size:i+num_samples]
|
||||||
|
out = stft_layer(input_)
|
||||||
|
out = encoder_layer(out)
|
||||||
|
out, state = decoder(out, state)
|
||||||
|
outs.append(out)
|
||||||
|
stacked = torch.cat(outs, dim=2).squeeze(1)
|
||||||
|
|
||||||
|
loss = criterion(stacked, targets)
|
||||||
|
loss = (loss * masks).mean()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
losses.update(loss.item(), masks.numel())
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return losses.avg
|
||||||
|
|
||||||
|
|
||||||
|
def validate(config,
|
||||||
|
loader,
|
||||||
|
jit_model,
|
||||||
|
decoder,
|
||||||
|
criterion,
|
||||||
|
device):
|
||||||
|
|
||||||
|
losses = AverageMeter()
|
||||||
|
decoder.eval()
|
||||||
|
|
||||||
|
predicts = []
|
||||||
|
gts = []
|
||||||
|
|
||||||
|
context_size = 32 if config.tune_8k else 64
|
||||||
|
num_samples = 256 if config.tune_8k else 512
|
||||||
|
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
|
||||||
|
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||||
|
targets = targets.to(device)
|
||||||
|
x = x.to(device)
|
||||||
|
masks = masks.to(device)
|
||||||
|
x = torch.nn.functional.pad(x, (context_size, 0))
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
state = torch.zeros(0)
|
||||||
|
for i in range(context_size, x.shape[1], num_samples):
|
||||||
|
input_ = x[:, i-context_size:i+num_samples]
|
||||||
|
out = stft_layer(input_)
|
||||||
|
out = encoder_layer(out)
|
||||||
|
out, state = decoder(out, state)
|
||||||
|
outs.append(out)
|
||||||
|
stacked = torch.cat(outs, dim=2).squeeze(1)
|
||||||
|
|
||||||
|
predicts.extend(stacked[masks != 0].tolist())
|
||||||
|
gts.extend(targets[masks != 0].tolist())
|
||||||
|
|
||||||
|
loss = criterion(stacked, targets)
|
||||||
|
loss = (loss * masks).mean()
|
||||||
|
losses.update(loss.item(), masks.numel())
|
||||||
|
score = roc_auc_score(gts, predicts)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return losses.avg, round(score, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def init_jit_model(model_path: str,
|
||||||
|
device=torch.device('cpu')):
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
model = torch.jit.load(model_path, map_location=device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def predict(model, loader, device, sr):
|
||||||
|
with torch.no_grad():
|
||||||
|
all_predicts = []
|
||||||
|
all_gts = []
|
||||||
|
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||||
|
x = x.to(device)
|
||||||
|
out = model.audio_forward(x, sr=sr)
|
||||||
|
|
||||||
|
for i, out_chunk in enumerate(out):
|
||||||
|
predict = out_chunk[masks[i] != 0].cpu().tolist()
|
||||||
|
gt = targets[i, masks[i] != 0].cpu().tolist()
|
||||||
|
|
||||||
|
all_predicts.append(predict)
|
||||||
|
all_gts.append(gt)
|
||||||
|
return all_predicts, all_gts
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_best_thresholds(all_predicts, all_gts):
|
||||||
|
best_acc = 0
|
||||||
|
for ths_enter in tqdm(np.linspace(0, 1, 20)):
|
||||||
|
for ths_exit in np.linspace(0, 1, 20):
|
||||||
|
if ths_exit >= ths_enter:
|
||||||
|
continue
|
||||||
|
|
||||||
|
accs = []
|
||||||
|
for j, predict in enumerate(all_predicts):
|
||||||
|
predict_bool = []
|
||||||
|
is_speech = False
|
||||||
|
for i in predict:
|
||||||
|
if i >= ths_enter:
|
||||||
|
is_speech = True
|
||||||
|
predict_bool.append(1)
|
||||||
|
elif i <= ths_exit:
|
||||||
|
is_speech = False
|
||||||
|
predict_bool.append(0)
|
||||||
|
else:
|
||||||
|
val = 1 if is_speech else 0
|
||||||
|
predict_bool.append(val)
|
||||||
|
|
||||||
|
score = round(accuracy_score(all_gts[j], predict_bool), 4)
|
||||||
|
accs.append(score)
|
||||||
|
|
||||||
|
mean_acc = round(np.mean(accs), 3)
|
||||||
|
if mean_acc > best_acc:
|
||||||
|
best_acc = mean_acc
|
||||||
|
best_ths_enter = round(ths_enter, 2)
|
||||||
|
best_ths_exit = round(ths_exit, 2)
|
||||||
|
return best_ths_enter, best_ths_exit, best_acc
|
||||||
Reference in New Issue
Block a user