55 Commits
v6.0 ... master

Author SHA1 Message Date
Alexander Veysov
2688a6e352 Merge pull request #747 from d-e-s-o/topic/ort-rc.10
Update ort dependency to 2.0.0-rc.10
2025-12-30 07:05:45 +03:00
Daniel Müller
c5542cd4a8 Update ort dependency to 2.0.0-rc.10
Update the ort dependency from 2.0.0-rc.2 to 2.0.0-rc.10 and adapt the code
to work with the new API. This includes:
- Updating ndarray to 0.16 to match ort's requirements
- Using Session and Value from their new module locations
- Adapting to the new Value::from_array() and try_extract_tensor() APIs
- Converting SessionInputs from Value references

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-29 19:29:22 -08:00
Alexander Veysov
4725c40105 Merge pull request #746 from d-e-s-o/topic/fix-rust
Fix `rust-example`
2025-12-29 09:34:47 +03:00
Daniel Müller
cfe63384f0 Update model plumbing for Rust example
The v6.2 models broke the Rust example. Update the logic for driving
them to reflect what the reference Python code does.

Fixes: #745
Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-28 07:15:01 -08:00
Daniel Müller
2a08f0b90d Remove 'load-dynamic' feature of 'ort' dependency
It's unclear why we'd want this feature. It seems to make things even
less isolated and self-contained than it already is, which certainly
isn't a boon for an example.
2025-12-27 06:36:07 -08:00
Daniel Müller
21ffe8576e Fix model path in Rust example 2025-12-25 18:25:33 -08:00
Dimitrii Voronin
d5b52843f7 Merge pull request #736 from snakers4/adamnsandle
add tinygrad model
2025-12-10 16:35:36 +03:00
adamnsandle
fb7d7c7f5d add tinygrad model 2025-12-10 13:31:25 +00:00
Dimitrii Voronin
e7c3d6f2bd Merge pull request #734 from snakers4/adamnsandle
Adamnsandle
2025-12-08 10:27:37 +03:00
adamnsandle
390614894d Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2025-12-08 07:26:37 +00:00
adamnsandle
33eb4c7f84 fx ifless model 2025-12-08 07:26:18 +00:00
Dimitrii Voronin
c913b0c4b3 Merge pull request #732 from snakers4/adamnsandle
add ifless model
2025-12-05 16:58:29 +03:00
adamnsandle
4dd2e8f6f9 add ifless model 2025-12-05 13:57:43 +00:00
Alexander Veysov
63fe03add7 Merge pull request #727 from dfengpo/master
delete debug code
2025-11-25 13:31:10 +03:00
dongfp
29a582ba37 fix 2025-11-25 16:46:03 +08:00
Alexander Veysov
3ca476e4fb Merge pull request #722 from dfengpo/master
修复C# CalculateProb方法计算句子EndOffset的bug
2025-11-10 11:04:45 +03:00
Alexander Veysov
7de462944a Update README.md 2025-11-10 10:59:13 +03:00
Alexander Veysov
12b0121993 Merge pull request #721 from NathanJHLee/feature/onnx-libtorch-cpp-examples
Add C++ examples supporting ONNX & LibTorch; rename legacy folder
2025-11-10 10:58:27 +03:00
dongfp
7b0aaa1c4c 修复CalculateProb方法计算句子EndOffset的bug
修改语法提示
2025-11-10 15:58:20 +08:00
NathanLee
540eff3e24 Rename cpp_libtorch to cpp_libtorch_deprecated 2025-11-10 07:32:10 +00:00
NathanLee
dfeba4fc0f Add C++ folder for supporting ONNX & LibTorch 2025-11-10 07:31:58 +00:00
Dimitrii Voronin
be95df9152 Merge pull request #719 from snakers4/adamnsandle
Adamnsandle
2025-11-06 11:25:49 +03:00
adamnsandle
ec56fe50a5 fx workflow 2025-11-06 08:18:46 +00:00
adamnsandle
dea5980320 fx workflow 2025-11-06 08:04:02 +00:00
adamnsandle
90d9ce7695 fx workflow 2025-11-06 07:49:44 +00:00
adamnsandle
c56dbb11ac Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2025-11-06 07:36:38 +00:00
adamnsandle
9b686893ad fx test workflow 2025-11-06 07:36:23 +00:00
Dimitrii Voronin
6979fbd535 Merge pull request #717 from snakers4/adamnsandle
v6.2.0 release
2025-11-06 10:28:00 +03:00
adamnsandle
1cff663de5 fix version to 6.2.0 2025-11-06 07:27:07 +00:00
adamnsandle
bfdc019302 add v6.2 model 2025-11-06 07:23:43 +00:00
Alexander Veysov
c0c0ffa0c5 Merge pull request #714 from Purfview/patch-4
Fix type hint for min_silence_at_max_speech (float -> int)
2025-11-05 08:44:00 +03:00
Alexander Veysov
3f0c9ead54 Update pyproject.toml 2025-11-05 08:38:07 +03:00
Purfview
556a442942 Fix type hint for min_silence_at_max_speech (float -> int) 2025-11-04 08:30:01 +00:00
Dimitrii Voronin
9623ce72da Merge pull request #710 from Purfview/patch-3
Fixes and refines - use_max_poss_sil_at_max_speech arg
2025-10-29 12:36:58 +03:00
Dimitrii Voronin
b6dd0599fc Merge pull request #712 from snakers4/adamnsandle
drop_chunks fix
2025-10-29 12:16:10 +03:00
adamnsandle
d8f88c9157 drop_chunks fix 2025-10-29 09:14:45 +00:00
Purfview
b15a216b47 Reword a comment 2025-10-24 10:30:34 +01:00
Purfview
2389039408 Fixes and refines - use_max_poss_sil_at_max_speech arg
Removed redundant "if temp_end != 0:" check.
Multiple "window_size_samples * i" - assigned to a variable.
Restored the previous functionality (which was broken) when use_max_poss_sil_at_max_speech=False.

@shashank14k was your https://github.com/snakers4/silero-vad/pull/664 PR still WIP when it was merged?
Anyway, please test if use_max_poss_sil_at_max_speech=True behaviour is same, and "False" is same as before your PR.
2025-10-24 07:46:41 +01:00
Alexander Veysov
df22fcaec8 Merge pull request #708 from Purfview/patch-2
Removes redundant hop_size_samples variable
2025-10-23 15:58:00 +03:00
Purfview
81e8a48e25 Removes redundant hop_size_samples variable
Remove redundant hop_size_samples variable
2025-10-23 05:23:18 +01:00
Alexander Veysov
a14a23faa7 Merge pull request #707 from Purfview/patch-1
Fixes few typos
2025-10-23 06:35:58 +03:00
Purfview
a30b5843c1 Fixes various typos 2025-10-23 04:02:13 +01:00
Dimitrii Voronin
a66c890188 Merge pull request #704 from snakers4/adamnsandle
resolve torchaudio 2.9 utils
2025-10-17 15:50:20 +03:00
adamnsandle
77c91a91fa resolve torchaudio 2.9 utils 2025-10-17 12:35:40 +00:00
Alexander Veysov
33093c6f1b Update utils.py 2025-10-14 14:51:23 +03:00
Alexander Veysov
dc0b62e1e4 Merge pull request #699 from JiJiJiang/master
fix bug in tuning/utils.py: add optimizer.zero_grad() before loss.bac…
2025-10-14 14:50:58 +03:00
Hongji Wang
64fb49e1c8 fix bug in tuning/utils.py: add optimizer.zero_grad() before loss.backward() 2025-10-13 20:50:29 +08:00
Alexander Veysov
55ba6e2825 Merge pull request #697 from VvvvvGH/java-example-v6
Update java example for v6
2025-10-11 11:41:15 +03:00
GH
b90f8c012f Update SlieroVadOnnxModel.java 2025-10-11 16:21:57 +08:00
GH
25a778c798 Update SlieroVadDetector.java 2025-10-11 16:21:45 +08:00
GH
3d860e6ace Update App.java 2025-10-11 16:21:32 +08:00
GH
f5ea01bfda Update pom.xml 2025-10-11 16:21:03 +08:00
Alexander Veysov
dd651a54a5 Merge pull request #695 from mpariente/master
Remove ipdb and raise error directly in get_speech_timestamps
2025-10-11 08:07:18 +03:00
Manuel Pariente
f1175c902f Remove ipdb and raise error directly 2025-10-10 10:46:44 +02:00
Alexander Veysov
7819fd911b Update README.md 2025-10-09 17:34:33 +03:00
38 changed files with 2626 additions and 732 deletions

View File

@@ -24,6 +24,7 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install build hatchling pytest soundfile pip install build hatchling pytest soundfile
pip install .[test]
- name: Build package - name: Build package
run: python -m build --wheel --outdir dist run: python -m build --wheel --outdir dist

View File

@@ -1,6 +1,6 @@
[![Mailing list : test](http://img.shields.io/badge/Email-gray.svg?style=for-the-badge&logo=gmail)](mailto:hello@silero.ai) [![Mailing list : test](http://img.shields.io/badge/Telegram-blue.svg?style=for-the-badge&logo=telegram)](https://t.me/silero_speech) [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-MIT-lightgrey.svg?style=for-the-badge)](https://github.com/snakers4/silero-vad/blob/master/LICENSE) [![downloads](https://img.shields.io/pypi/dm/silero-vad?style=for-the-badge)](https://pypi.org/project/silero-vad/) [![Mailing list : test](http://img.shields.io/badge/Email-gray.svg?style=for-the-badge&logo=gmail)](mailto:hello@silero.ai) [![Mailing list : test](http://img.shields.io/badge/Telegram-blue.svg?style=for-the-badge&logo=telegram)](https://t.me/silero_speech) [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-MIT-lightgrey.svg?style=for-the-badge)](https://github.com/snakers4/silero-vad/blob/master/LICENSE) [![downloads](https://img.shields.io/pypi/dm/silero-vad?style=for-the-badge)](https://pypi.org/project/silero-vad/)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) [![Test Package](https://github.com/snakers4/silero-vad/actions/workflows/test.yml/badge.svg)](https://github.com/snakers4/silero-vad/actions/workflows/test.yml) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) [![Test Package](https://github.com/snakers4/silero-vad/actions/workflows/test.yml/badge.svg)](https://github.com/snakers4/silero-vad/actions/workflows/test.yml) [![Pypi version](https://img.shields.io/pypi/v/silero-vad)](https://pypi.org/project/silero-vad/) [![Python version](https://img.shields.io/pypi/pyversions/silero-vad)](https://pypi.org/project/silero-vad)
![header](https://user-images.githubusercontent.com/12515440/89997349-b3523080-dc94-11ea-9906-ca2e8bc50535.png) ![header](https://user-images.githubusercontent.com/12515440/89997349-b3523080-dc94-11ea-9906-ca2e8bc50535.png)

49
examples/c++/README.md Normal file
View File

@@ -0,0 +1,49 @@
# Silero-VAD V6 in C++ (based on LibTorch)
This is the source code for Silero-VAD V6 in C++, utilizing LibTorch & Onnxruntime.
You should compare its results with the Python version.
Results at 16 and 8kHz have been tested. Batch and CUDA inference options are deprecated.
## Requirements
- GCC 11.4.0 (GCC >= 5.1)
- Onnxruntime 1.11.0 (other versions are also acceptable)
- LibTorch 1.13.0 (other versions are also acceptable)
## Download LibTorch
```bash
-Onnxruntime
$wget https://github.com/microsoft/onnxruntime/releases/download/v1.11.1/onnxruntime-linux-x64-1.11.1.tgz
$tar -xvf onnxruntime-linux-x64-1.11.1.tgz
$ln -s onnxruntime-linux-x64-1.11.1 onnxruntime-linux #soft-link
-Libtorch
$wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip
$unzip libtorch-shared-with-deps-1.13.0+cpu.zip
```
## Compilation
```bash
-ONNX-build
$g++ main.cc silero.cc -I ./onnxruntime-linux/include/ -L ./onnxruntime-linux/lib/ -lonnxruntime -Wl,-rpath,./onnxruntime-linux/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_ONNX
-TORCH-build
$g++ main.cc silero.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_TORCH
```
## Optional Compilation Flags
-DUSE_TORCH
-DUSE_ONNX
## Run the Program
To run the program, use the following command:
`./silero <sample.wav> <SampleRate> <threshold>`
`./silero aepyx.wav 16000 0.5`
`./silero aepyx_8k.wav 8000 0.5`
The sample file aepyx.wav is part of the Voxconverse dataset.
File details: aepyx.wav is a 16kHz, 16-bit audio file.
File details: aepyx_8k.wav is a 8kHz, 16-bit audio file.

BIN
examples/c++/aepyx.wav Normal file

Binary file not shown.

BIN
examples/c++/aepyx_8k.wav Normal file

Binary file not shown.

61
examples/c++/main.cc Normal file
View File

@@ -0,0 +1,61 @@
#include <iostream>
#include "silero.h"
#include "wav.h"
int main(int argc, char* argv[]) {
if(argc != 4){
std::cerr<<"Usage : "<<argv[0]<<" <wav.path> <SampleRate> <Threshold>"<<std::endl;
std::cerr<<"Usage : "<<argv[0]<<" sample.wav 16000 0.5"<<std::endl;
return 1;
}
std::string wav_path = argv[1];
float sample_rate = std::stof(argv[2]);
float threshold = std::stof(argv[3]);
if (sample_rate != 16000 && sample_rate != 8000) {
std::cout<<"Unsupported sample rate (only 16000 or 8000)."<<std::endl;
exit (0);
}
//Load Model
#ifdef USE_TORCH
std::string model_path = "../../src/silero_vad/data/silero_vad.jit";
#elif USE_ONNX
std::string model_path = "../../src/silero_vad/data/silero_vad.onnx";
#endif
silero::VadIterator vad(model_path);
vad.threshold=threshold; //(Default:0.5)
vad.sample_rate=sample_rate; //16000Hz,8000Hz. (Default:16000)
vad.print_as_samples=false; //if true, it prints time-stamp with samples. otherwise, in seconds
//(Default:false)
vad.SetVariables();
// Read wav
wav::WavReader wav_reader(wav_path);
std::vector<float> input_wav(wav_reader.num_samples());
for (int i = 0; i < wav_reader.num_samples(); i++)
{
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
}
vad.SpeechProbs(input_wav);
std::vector<silero::Interval> speeches = vad.GetSpeechTimestamps();
for(const auto& speech : speeches){
if(vad.print_as_samples){
std::cout<<"{'start': "<<static_cast<int>(speech.start)<<", 'end': "<<static_cast<int>(speech.end)<<"}"<<std::endl;
}
else{
std::cout<<"{'start': "<<speech.start<<", 'end': "<<speech.end<<"}"<<std::endl;
}
}
return 0;
}

273
examples/c++/silero.cc Normal file
View File

@@ -0,0 +1,273 @@
// silero.cc
// Author : NathanJHLee
// Created On : 2025-11-10
// Description : silero 6.2 system for onnx-runtime(c++) and torch-script(c++)
// Version : 1.3
#include "silero.h"
namespace silero {
#ifdef USE_TORCH
VadIterator::VadIterator(const std::string &model_path,
float threshold,
int sample_rate,
int window_size_ms,
int speech_pad_ms,
int min_silence_duration_ms,
int min_speech_duration_ms,
int max_duration_merge_ms,
bool print_as_samples)
: threshold(threshold), sample_rate(sample_rate), window_size_ms(window_size_ms),
speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms),
min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms),
print_as_samples(print_as_samples)
{
init_torch_model(model_path);
}
VadIterator::~VadIterator(){
}
void VadIterator::init_torch_model(const std::string& model_path) {
at::set_num_threads(1);
model = torch::jit::load(model_path);
model.eval();
torch::NoGradGuard no_grad;
std::cout<<"Silero libtorch-Model loaded successfully"<<std::endl;
}
void VadIterator::SpeechProbs(std::vector<float>& input_wav) {
int num_samples = input_wav.size();
int num_chunks = num_samples / window_size_samples;
int remainder_samples = num_samples % window_size_samples;
total_sample_size += num_samples;
std::vector<torch::Tensor> chunks;
for (int i = 0; i < num_chunks; i++) {
float* chunk_start = input_wav.data() + i * window_size_samples;
torch::Tensor chunk = torch::from_blob(chunk_start, {1, window_size_samples}, torch::kFloat32);
chunks.push_back(chunk);
if (i == num_chunks - 1 && remainder_samples > 0) {
int remaining_samples = num_samples - num_chunks * window_size_samples;
float* chunk_start_remainder = input_wav.data() + num_chunks * window_size_samples;
torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1, remaining_samples}, torch::kFloat32);
torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples - remaining_samples}, torch::kFloat32)}, 1);
chunks.push_back(padded_chunk);
}
}
if (!chunks.empty()) {
std::vector<torch::Tensor> outputs;
torch::Tensor batched_chunks = torch::stack(chunks);
for (size_t i = 0; i < chunks.size(); i++) {
torch::NoGradGuard no_grad;
std::vector<torch::jit::IValue> inputs;
inputs.push_back(batched_chunks[i]);
inputs.push_back(sample_rate);
torch::Tensor output = model.forward(inputs).toTensor();
outputs.push_back(output);
}
torch::Tensor all_outputs = torch::stack(outputs);
for (size_t i = 0; i < chunks.size(); i++) {
float output_f = all_outputs[i].item<float>();
outputs_prob.push_back(output_f);
//////To print Probs by libtorch
//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
}
}
}
#elif USE_ONNX
VadIterator::VadIterator(const std::string &model_path,
float threshold,
int sample_rate,
int window_size_ms,
int speech_pad_ms,
int min_silence_duration_ms,
int min_speech_duration_ms,
int max_duration_merge_ms,
bool print_as_samples)
:sample_rate(sample_rate), threshold(threshold), window_size_ms(window_size_ms),
speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms),
min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms),
print_as_samples(print_as_samples),
env(ORT_LOGGING_LEVEL_ERROR, "Vad"), session_options(), session(nullptr), allocator(),
memory_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU)), context_samples(64),
_context(64, 0.0f), current_sample(0), size_state(2 * 1 * 128),
input_node_names({"input", "state", "sr"}), output_node_names({"output", "stateN"}),
state_node_dims{2, 1, 128}, sr_node_dims{1}
{
init_onnx_model(model_path);
}
VadIterator::~VadIterator(){
}
void VadIterator::init_onnx_model(const std::string& model_path) {
int inter_threads=1;
int intra_threads=1;
session_options.SetIntraOpNumThreads(intra_threads);
session_options.SetInterOpNumThreads(inter_threads);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
std::cout<<"Silero onnx-Model loaded successfully"<<std::endl;
}
float VadIterator::predict(const std::vector<float>& data_chunk) {
// _context와 현재 청크를 결합하여 입력 데이터 구성
std::vector<float> new_data(effective_window_size, 0.0f);
std::copy(_context.begin(), _context.end(), new_data.begin());
std::copy(data_chunk.begin(), data_chunk.end(), new_data.begin() + context_samples);
input = new_data;
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
memory_info, input.data(), input.size(), input_node_dims, 2);
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
memory_info, _state.data(), _state.size(), state_node_dims, 3);
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
ort_inputs.clear();
ort_inputs.push_back(std::move(input_ort));
ort_inputs.push_back(std::move(state_ort));
ort_inputs.push_back(std::move(sr_ort));
ort_outputs = session->Run(
Ort::RunOptions{ nullptr },
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
output_node_names.data(), output_node_names.size());
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0]; // ONNX 출력: 첫 번째 값이 음성 확률
float* stateN = ort_outputs[1].GetTensorMutableData<float>(); // 두 번째 출력값: 상태 업데이트
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
// _context 업데이트: new_data의 마지막 context_samples 유지
return speech_prob;
}
void VadIterator::SpeechProbs(std::vector<float>& input_wav) {
reset_states();
total_sample_size = static_cast<int>(input_wav.size());
for (size_t j = 0; j < static_cast<size_t>(total_sample_size); j += window_size_samples) {
if (j + window_size_samples > static_cast<size_t>(total_sample_size))
break;
std::vector<float> chunk(input_wav.begin() + j, input_wav.begin() + j + window_size_samples);
float speech_prob = predict(chunk);
outputs_prob.push_back(speech_prob);
}
}
#endif
void VadIterator::reset_states() {
triggered = false;
current_sample = 0;
temp_end = 0;
outputs_prob.clear();
total_sample_size = 0;
#ifdef USE_TORCH
model.run_method("reset_states"); // Reset model states if applicable
#elif USE_ONNX
std::memset(_state.data(), 0, _state.size() * sizeof(float));
std::fill(_context.begin(), _context.end(), 0.0f);
#endif
}
std::vector<Interval> VadIterator::GetSpeechTimestamps() {
std::vector<Interval> speeches = DoVad();
if(!print_as_samples){
for (auto& speech : speeches) {
speech.start /= sample_rate;
speech.end /= sample_rate;
}
}
return speeches;
}
void VadIterator::SetVariables(){
// Initialize internal engine parameters
init_engine(window_size_ms);
}
void VadIterator::init_engine(int window_size_ms) {
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
speech_pad_samples = sample_rate * speech_pad_ms / 1000;
window_size_samples = sample_rate / 1000 * window_size_ms;
min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
#ifdef USE_ONNX
//for ONNX
context_samples=window_size_samples / 8;
_context.assign(context_samples, 0.0f);
effective_window_size = window_size_samples + context_samples; // 예: 512 + 64 = 576 samples
input_node_dims[0] = 1;
input_node_dims[1] = effective_window_size;
_state.resize(size_state);
sr.resize(1);
sr[0] = sample_rate;
#endif
}
std::vector<Interval> VadIterator::DoVad() {
std::vector<Interval> speeches;
for (size_t i = 0; i < outputs_prob.size(); ++i) {
float speech_prob = outputs_prob[i];
current_sample += window_size_samples;
if (speech_prob >= threshold && temp_end != 0) {
temp_end = 0;
}
if (speech_prob >= threshold) {
if (!triggered) {
triggered = true;
Interval segment;
segment.start = std::max(0, current_sample - speech_pad_samples - window_size_samples);
speeches.push_back(segment);
}
}else {
if (triggered) {
if (speech_prob < threshold - 0.15f) {
if (temp_end == 0) {
temp_end = current_sample;
}
if (current_sample - temp_end >= min_silence_samples) {
Interval& segment = speeches.back();
segment.end = temp_end + speech_pad_samples - window_size_samples;
temp_end = 0;
triggered = false;
}
}
}
}
}
if (triggered) {
std::cout<<"Finalizing active speech segment at stream end."<<std::endl;
Interval& segment = speeches.back();
segment.end = total_sample_size;
triggered = false;
}
speeches.erase(std::remove_if(speeches.begin(), speeches.end(),
[this](const Interval& speech) {
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
}), speeches.end());
reset_states();
return speeches;
}
} // namespace silero

123
examples/c++/silero.h Normal file
View File

@@ -0,0 +1,123 @@
#ifndef SILERO_H
#define SILERO_H
// silero.h
// Author : NathanJHLee
// Created On : 2025-11-10
// Description : silero 6.2 system for onnx-runtime(c++) and torch-script(c++)
// Version : 1.3
#include <string>
#include <vector>
#include <iostream>
#include <fstream>
#include <chrono>
#include <algorithm>
#include <cstring>
#ifdef USE_TORCH
#include <torch/torch.h>
#include <torch/script.h>
#elif USE_ONNX
#include "onnxruntime_cxx_api.h"
#endif
namespace silero {
struct Interval {
float start;
float end;
int numberOfSubseg;
void initialize() {
start = 0;
end = 0;
numberOfSubseg = 0;
}
};
class VadIterator {
public:
VadIterator(const std::string &model_path,
float threshold = 0.5,
int sample_rate = 16000,
int window_size_ms = 32,
int speech_pad_ms = 30,
int min_silence_duration_ms = 100,
int min_speech_duration_ms = 250,
int max_duration_merge_ms = 300,
bool print_as_samples = false);
~VadIterator();
// Batch (non-streaming) interface (for backward compatibility)
void SpeechProbs(std::vector<float>& input_wav);
std::vector<Interval> GetSpeechTimestamps();
void SetVariables();
// Public parameters (can be modified by user)
float threshold;
int sample_rate;
int window_size_ms;
int min_speech_duration_ms;
int max_duration_merge_ms;
bool print_as_samples;
private:
#ifdef USE_TORCH
torch::jit::script::Module model;
void init_torch_model(const std::string& model_path);
#elif USE_ONNX
Ort::Env env; // 환경 객체
Ort::SessionOptions session_options; // 세션 옵션
std::shared_ptr<Ort::Session> session; // ONNX 세션
Ort::AllocatorWithDefaultOptions allocator; // 기본 할당자
Ort::MemoryInfo memory_info; // 메모리 정보 (CPU)
void init_onnx_model(const std::string& model_path);
float predict(const std::vector<float>& data_chunk);
//const int context_samples; // 예: 64 samples
int context_samples; // 예: 64 samples
std::vector<float> _context; // 초기값 모두 0
int effective_window_size;
// ONNX 입력/출력 관련 버퍼 및 노드 이름들
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_node_names;
std::vector<float> input;
unsigned int size_state; // 고정값: 2*1*128
std::vector<float> _state;
std::vector<int64_t> sr;
int64_t input_node_dims[2]; // [1, effective_window_size]
const int64_t state_node_dims[3]; // [ 2, 1, 128 ]
const int64_t sr_node_dims[1]; // [ 1 ]
std::vector<Ort::Value> ort_outputs;
std::vector<const char*> output_node_names; // 기본값: [ "output", "stateN" ]
#endif
std::vector<float> outputs_prob; // used in batch mode
int min_silence_samples;
int min_speech_samples;
int speech_pad_samples;
int window_size_samples;
int duration_merge_samples;
int current_sample = 0;
int total_sample_size = 0;
int min_silence_duration_ms;
int speech_pad_ms;
bool triggered = false;
int temp_end = 0;
int global_end = 0;
int erase_tail_count = 0;
void init_engine(int window_size_ms);
void reset_states();
std::vector<Interval> DoVad();
};
} // namespace silero
#endif // SILERO_H

237
examples/c++/wav.h Normal file
View File

@@ -0,0 +1,237 @@
// Copyright (c) 2016 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
// #include "utils/log.h"
namespace wav {
struct WavHeader {
char riff[4]; // "riff"
unsigned int size;
char wav[4]; // "WAVE"
char fmt[4]; // "fmt "
unsigned int fmt_size;
uint16_t format;
uint16_t channels;
unsigned int sample_rate;
unsigned int bytes_per_second;
uint16_t block_size;
uint16_t bit;
char data[4]; // "data"
unsigned int data_size;
};
class WavReader {
public:
WavReader() : data_(nullptr) {}
explicit WavReader(const std::string& filename) { Open(filename); }
bool Open(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
if (NULL == fp) {
std::cout << "Error in read " << filename;
return false;
}
WavHeader header;
fread(&header, 1, sizeof(header), fp);
if (header.fmt_size < 16) {
printf("WaveData: expect PCM format data "
"to have fmt chunk of at least size 16.\n");
return false;
} else if (header.fmt_size > 16) {
int offset = 44 - 8 + header.fmt_size - 16;
fseek(fp, offset, SEEK_SET);
fread(header.data, 8, sizeof(char), fp);
}
// check "riff" "WAVE" "fmt " "data"
// Skip any sub-chunks between "fmt" and "data". Usually there will
// be a single "fact" sub chunk, but on Windows there can also be a
// "list" sub chunk.
while (0 != strncmp(header.data, "data", 4)) {
// We will just ignore the data in these chunks.
fseek(fp, header.data_size, SEEK_CUR);
// read next sub chunk
fread(header.data, 8, sizeof(char), fp);
}
if (header.data_size == 0) {
int offset = ftell(fp);
fseek(fp, 0, SEEK_END);
header.data_size = ftell(fp) - offset;
fseek(fp, offset, SEEK_SET);
}
num_channel_ = header.channels;
sample_rate_ = header.sample_rate;
bits_per_sample_ = header.bit;
int num_data = header.data_size / (bits_per_sample_ / 8);
data_ = new float[num_data]; // Create 1-dim array
num_samples_ = num_data / num_channel_;
std::cout << "num_channel_ :" << num_channel_ << std::endl;
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
std::cout << "num_samples :" << num_data << std::endl;
std::cout << "num_data_size :" << header.data_size << std::endl;
switch (bits_per_sample_) {
case 8: {
char sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(char), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
break;
}
case 16: {
int16_t sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(int16_t), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
break;
}
case 32:
{
if (header.format == 1) //S32
{
int sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(int), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
}
else if (header.format == 3) // IEEE-float
{
float sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(float), fp);
data_[i] = static_cast<float>(sample);
}
}
else {
printf("unsupported quantization bits\n");
}
break;
}
default:
printf("unsupported quantization bits\n");
break;
}
fclose(fp);
return true;
}
int num_channel() const { return num_channel_; }
int sample_rate() const { return sample_rate_; }
int bits_per_sample() const { return bits_per_sample_; }
int num_samples() const { return num_samples_; }
~WavReader() {
delete[] data_;
}
const float* data() const { return data_; }
private:
int num_channel_;
int sample_rate_;
int bits_per_sample_;
int num_samples_; // sample points per channel
float* data_;
};
class WavWriter {
public:
WavWriter(const float* data, int num_samples, int num_channel,
int sample_rate, int bits_per_sample)
: data_(data),
num_samples_(num_samples),
num_channel_(num_channel),
sample_rate_(sample_rate),
bits_per_sample_(bits_per_sample) {}
void Write(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "w");
// init char 'riff' 'WAVE' 'fmt ' 'data'
WavHeader header;
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
memcpy(&header, wav_header, sizeof(header));
header.channels = num_channel_;
header.bit = bits_per_sample_;
header.sample_rate = sample_rate_;
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
header.size = sizeof(header) - 8 + header.data_size;
header.bytes_per_second =
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
header.block_size = num_channel_ * (bits_per_sample_ / 8);
fwrite(&header, 1, sizeof(header), fp);
for (int i = 0; i < num_samples_; ++i) {
for (int j = 0; j < num_channel_; ++j) {
switch (bits_per_sample_) {
case 8: {
char sample = static_cast<char>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 16: {
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 32: {
int sample = static_cast<int>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
}
}
}
fclose(fp);
}
private:
const float* data_;
int num_samples_; // total float points in data_
int num_channel_;
int sample_rate_;
int bits_per_sample_;
};
} // namespace wav
#endif // FRONTEND_WAV_H_

View File

@@ -0,0 +1,45 @@
# Silero-VAD V5 in C++ (based on LibTorch)
This is the source code for Silero-VAD V5 in C++, utilizing LibTorch. The primary implementation is CPU-based, and you should compare its results with the Python version. Only results at 16kHz have been tested.
Additionally, batch and CUDA inference options are available if you want to explore further. Note that when using batch inference, the speech probabilities may slightly differ from the standard version, likely due to differences in caching. Unlike individual input processing, batch inference may not use the cache from previous chunks. Despite this, batch inference offers significantly faster processing. For optimal performance, consider adjusting the threshold when using batch inference.
## Requirements
- GCC 11.4.0 (GCC >= 5.1)
- LibTorch 1.13.0 (other versions are also acceptable)
## Download LibTorch
```bash
-CPU Version
wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip
unzip libtorch-shared-with-deps-1.13.0+cpu.zip'
-CUDA Version
wget https://download.pytorch.org/libtorch/cu116/libtorch-shared-with-deps-1.13.0%2Bcu116.zip
unzip libtorch-shared-with-deps-1.13.0+cu116.zip
```
## Compilation
```bash
-CPU Version
g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0
-CUDA Version
g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cuda -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_GPU
```
## Optional Compilation Flags
-DUSE_BATCH: Enable batch inference
-DUSE_GPU: Use GPU for inference
## Run the Program
To run the program, use the following command:
`./silero aepyx.wav 16000 0.5`
The sample file aepyx.wav is part of the Voxconverse dataset.
File details: aepyx.wav is a 16kHz, 16-bit audio file.

Binary file not shown.

View File

@@ -0,0 +1,54 @@
#include <iostream>
#include "silero_torch.h"
#include "wav.h"
int main(int argc, char* argv[]) {
if(argc != 4){
std::cerr<<"Usage : "<<argv[0]<<" <wav.path> <SampleRate> <Threshold>"<<std::endl;
std::cerr<<"Usage : "<<argv[0]<<" sample.wav 16000 0.5"<<std::endl;
return 1;
}
std::string wav_path = argv[1];
float sample_rate = std::stof(argv[2]);
float threshold = std::stof(argv[3]);
//Load Model
std::string model_path = "../../src/silero_vad/data/silero_vad.jit";
silero::VadIterator vad(model_path);
vad.threshold=threshold; //(Default:0.5)
vad.sample_rate=sample_rate; //16000Hz,8000Hz. (Default:16000)
vad.print_as_samples=true; //if true, it prints time-stamp with samples. otherwise, in seconds
//(Default:false)
vad.SetVariables();
// Read wav
wav::WavReader wav_reader(wav_path);
std::vector<float> input_wav(wav_reader.num_samples());
for (int i = 0; i < wav_reader.num_samples(); i++)
{
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
}
vad.SpeechProbs(input_wav);
std::vector<silero::SpeechSegment> speeches = vad.GetSpeechTimestamps();
for(const auto& speech : speeches){
if(vad.print_as_samples){
std::cout<<"{'start': "<<static_cast<int>(speech.start)<<", 'end': "<<static_cast<int>(speech.end)<<"}"<<std::endl;
}
else{
std::cout<<"{'start': "<<speech.start<<", 'end': "<<speech.end<<"}"<<std::endl;
}
}
return 0;
}

Binary file not shown.

View File

@@ -0,0 +1,285 @@
//Author : Nathan Lee
//Created On : 2024-11-18
//Description : silero 5.1 system for torch-script(c++).
//Version : 1.0
#include "silero_torch.h"
namespace silero {
VadIterator::VadIterator(const std::string &model_path, float threshold, int sample_rate, int window_size_ms, int speech_pad_ms, int min_silence_duration_ms, int min_speech_duration_ms, int max_duration_merge_ms, bool print_as_samples)
:sample_rate(sample_rate), threshold(threshold), window_size_ms(window_size_ms), speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms), min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms), print_as_samples(print_as_samples)
{
init_torch_model(model_path);
//init_engine(window_size_ms);
}
VadIterator::~VadIterator(){
}
void VadIterator::SpeechProbs(std::vector<float>& input_wav){
// Set the sample rate (must match the model's expected sample rate)
// Process the waveform in chunks of 512 samples
int num_samples = input_wav.size();
int num_chunks = num_samples / window_size_samples;
int remainder_samples = num_samples % window_size_samples;
total_sample_size += num_samples;
torch::Tensor output;
std::vector<torch::Tensor> chunks;
for (int i = 0; i < num_chunks; i++) {
float* chunk_start = input_wav.data() + i *window_size_samples;
torch::Tensor chunk = torch::from_blob(chunk_start, {1,window_size_samples}, torch::kFloat32);
//std::cout<<"chunk size : "<<chunk.sizes()<<std::endl;
chunks.push_back(chunk);
if(i==num_chunks-1 && remainder_samples>0){//마지막 chunk && 나머지가 존재
int remaining_samples = num_samples - num_chunks * window_size_samples;
//std::cout<<"Remainder size : "<<remaining_samples;
float* chunk_start_remainder = input_wav.data() + num_chunks *window_size_samples;
torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1,remaining_samples},
torch::kFloat32);
// Pad the remainder chunk to match window_size_samples
torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples
- remaining_samples}, torch::kFloat32)}, 1);
//std::cout<<", padded_chunk size : "<<padded_chunk.size(1)<<std::endl;
chunks.push_back(padded_chunk);
}
}
if (!chunks.empty()) {
#ifdef USE_BATCH
torch::Tensor batched_chunks = torch::stack(chunks); // Stack all chunks into a single tensor
//batched_chunks = batched_chunks.squeeze(1);
batched_chunks = torch::cat({batched_chunks.squeeze(1)});
#ifdef USE_GPU
batched_chunks = batched_chunks.to(at::kCUDA); // Move the entire batch to GPU once
#endif
// Prepare input for model
std::vector<torch::jit::IValue> inputs;
inputs.push_back(batched_chunks); // Batch of chunks
inputs.push_back(sample_rate); // Assuming sample_rate is a valid input for the model
// Run inference on the batch
torch::NoGradGuard no_grad;
torch::Tensor output = model.forward(inputs).toTensor();
#ifdef USE_GPU
output = output.to(at::kCPU); // Move the output back to CPU once
#endif
// Collect output probabilities
for (int i = 0; i < chunks.size(); i++) {
float output_f = output[i].item<float>();
outputs_prob.push_back(output_f);
//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
}
#else
std::vector<torch::Tensor> outputs;
torch::Tensor batched_chunks = torch::stack(chunks);
#ifdef USE_GPU
batched_chunks = batched_chunks.to(at::kCUDA);
#endif
for (int i = 0; i < chunks.size(); i++) {
torch::NoGradGuard no_grad;
std::vector<torch::jit::IValue> inputs;
inputs.push_back(batched_chunks[i]);
inputs.push_back(sample_rate);
torch::Tensor output = model.forward(inputs).toTensor();
outputs.push_back(output);
}
torch::Tensor all_outputs = torch::stack(outputs);
#ifdef USE_GPU
all_outputs = all_outputs.to(at::kCPU);
#endif
for (int i = 0; i < chunks.size(); i++) {
float output_f = all_outputs[i].item<float>();
outputs_prob.push_back(output_f);
}
#endif
}
}
std::vector<SpeechSegment> VadIterator::GetSpeechTimestamps() {
std::vector<SpeechSegment> speeches = DoVad();
#ifdef USE_BATCH
//When you use BATCH inference. You would better use 'mergeSpeeches' function to arrage time stamp.
//It could be better get reasonable output because of distorted probs.
duration_merge_samples = sample_rate * max_duration_merge_ms / 1000;
std::vector<SpeechSegment> speeches_merge = mergeSpeeches(speeches, duration_merge_samples);
if(!print_as_samples){
for (auto& speech : speeches_merge) { //samples to second
speech.start /= sample_rate;
speech.end /= sample_rate;
}
}
return speeches_merge;
#else
if(!print_as_samples){
for (auto& speech : speeches) { //samples to second
speech.start /= sample_rate;
speech.end /= sample_rate;
}
}
return speeches;
#endif
}
void VadIterator::SetVariables(){
init_engine(window_size_ms);
}
void VadIterator::init_engine(int window_size_ms) {
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
speech_pad_samples = sample_rate * speech_pad_ms / 1000;
window_size_samples = sample_rate / 1000 * window_size_ms;
min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
}
void VadIterator::init_torch_model(const std::string& model_path) {
at::set_num_threads(1);
model = torch::jit::load(model_path);
#ifdef USE_GPU
if (!torch::cuda::is_available()) {
std::cout<<"CUDA is not available! Please check your GPU settings"<<std::endl;
throw std::runtime_error("CUDA is not available!");
model.to(at::Device(at::kCPU));
} else {
std::cout<<"CUDA available! Running on '0'th GPU"<<std::endl;
model.to(at::Device(at::kCUDA, 0)); //select 0'th machine
}
#endif
model.eval();
torch::NoGradGuard no_grad;
std::cout << "Model loaded successfully"<<std::endl;
}
void VadIterator::reset_states() {
triggered = false;
current_sample = 0;
temp_end = 0;
outputs_prob.clear();
model.run_method("reset_states");
total_sample_size = 0;
}
std::vector<SpeechSegment> VadIterator::DoVad() {
std::vector<SpeechSegment> speeches;
for (size_t i = 0; i < outputs_prob.size(); ++i) {
float speech_prob = outputs_prob[i];
//std::cout << speech_prob << std::endl;
//std::cout << "Chunk " << i << " Prob: " << speech_prob << "\n";
//std::cout << speech_prob << " ";
current_sample += window_size_samples;
if (speech_prob >= threshold && temp_end != 0) {
temp_end = 0;
}
if (speech_prob >= threshold && !triggered) {
triggered = true;
SpeechSegment segment;
segment.start = std::max(static_cast<int>(0), current_sample - speech_pad_samples - window_size_samples);
speeches.push_back(segment);
continue;
}
if (speech_prob < threshold - 0.15f && triggered) {
if (temp_end == 0) {
temp_end = current_sample;
}
if (current_sample - temp_end < min_silence_samples) {
continue;
} else {
SpeechSegment& segment = speeches.back();
segment.end = temp_end + speech_pad_samples - window_size_samples;
temp_end = 0;
triggered = false;
}
}
}
if (triggered) { //만약 낮은 확률을 보이다가 마지막프레임 prbos만 딱 확률이 높게 나오면 위에서 triggerd = true 메핑과 동시에 segment start가 돼서 문제가 될것 같은데? start = end 같은값? 후처리가 있으니 문제가 없으려나?
std::cout<<"when last triggered is keep working until last Probs"<<std::endl;
SpeechSegment& segment = speeches.back();
segment.end = total_sample_size; // 현재 샘플을 마지막 구간의 종료 시간으로 설정
triggered = false; // VAD 상태 초기화
}
speeches.erase(
std::remove_if(
speeches.begin(),
speeches.end(),
[this](const SpeechSegment& speech) {
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
//min_speech_samples is 4000samples(0.25sec)
//여기서 포인트!! 계산 할때는 start,end sample에'speech_pad_samples' 사이즈를 추가한후 길이를 측정함.
}
),
speeches.end()
);
//std::cout<<std::endl;
//std::cout<<"outputs_prob.size : "<<outputs_prob.size()<<std::endl;
reset_states();
return speeches;
}
std::vector<SpeechSegment> VadIterator::mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples) {
std::vector<SpeechSegment> mergedSpeeches;
if (speeches.empty()) {
return mergedSpeeches; // 빈 벡터 반환
}
// 첫 번째 구간으로 초기화
SpeechSegment currentSegment = speeches[0];
for (size_t i = 1; i < speeches.size(); ++i) { //첫번째 start,end 정보 건너뛰기. 그래서 i=1부터
// 두 구간의 차이가 threshold(duration_merge_samples)보다 작은 경우, 합침
if (speeches[i].start - currentSegment.end < duration_merge_samples) {
// 현재 구간의 끝점을 업데이트
currentSegment.end = speeches[i].end;
} else {
// 차이가 threshold(duration_merge_samples) 이상이면 현재 구간을 저장하고 새로운 구간 시작
mergedSpeeches.push_back(currentSegment);
currentSegment = speeches[i];
}
}
// 마지막 구간 추가
mergedSpeeches.push_back(currentSegment);
return mergedSpeeches;
}
}

View File

@@ -0,0 +1,75 @@
//Author : Nathan Lee
//Created On : 2024-11-18
//Description : silero 5.1 system for torch-script(c++).
//Version : 1.0
#ifndef SILERO_TORCH_H
#define SILERO_TORCH_H
#include <string>
#include <memory>
#include <stdexcept>
#include <iostream>
#include <memory>
#include <vector>
#include <fstream>
#include <chrono>
#include <torch/torch.h>
#include <torch/script.h>
namespace silero{
struct SpeechSegment{
int start;
int end;
};
class VadIterator{
public:
VadIterator(const std::string &model_path, float threshold = 0.5, int sample_rate = 16000,
int window_size_ms = 32, int speech_pad_ms = 30, int min_silence_duration_ms = 100,
int min_speech_duration_ms = 250, int max_duration_merge_ms = 300, bool print_as_samples = false);
~VadIterator();
void SpeechProbs(std::vector<float>& input_wav);
std::vector<silero::SpeechSegment> GetSpeechTimestamps();
void SetVariables();
float threshold;
int sample_rate;
int window_size_ms;
int min_speech_duration_ms;
int max_duration_merge_ms;
bool print_as_samples;
private:
torch::jit::script::Module model;
std::vector<float> outputs_prob;
int min_silence_samples;
int min_speech_samples;
int speech_pad_samples;
int window_size_samples;
int duration_merge_samples;
int current_sample = 0;
int total_sample_size=0;
int min_silence_duration_ms;
int speech_pad_ms;
bool triggered = false;
int temp_end = 0;
void init_engine(int window_size_ms);
void init_torch_model(const std::string& model_path);
void reset_states();
std::vector<SpeechSegment> DoVad();
std::vector<SpeechSegment> mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples);
};
}
#endif // SILERO_TORCH_H

View File

@@ -0,0 +1,235 @@
// Copyright (c) 2016 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
// #include "utils/log.h"
namespace wav {
struct WavHeader {
char riff[4]; // "riff"
unsigned int size;
char wav[4]; // "WAVE"
char fmt[4]; // "fmt "
unsigned int fmt_size;
uint16_t format;
uint16_t channels;
unsigned int sample_rate;
unsigned int bytes_per_second;
uint16_t block_size;
uint16_t bit;
char data[4]; // "data"
unsigned int data_size;
};
class WavReader {
public:
WavReader() : data_(nullptr) {}
explicit WavReader(const std::string& filename) { Open(filename); }
bool Open(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
if (NULL == fp) {
std::cout << "Error in read " << filename;
return false;
}
WavHeader header;
fread(&header, 1, sizeof(header), fp);
if (header.fmt_size < 16) {
printf("WaveData: expect PCM format data "
"to have fmt chunk of at least size 16.\n");
return false;
} else if (header.fmt_size > 16) {
int offset = 44 - 8 + header.fmt_size - 16;
fseek(fp, offset, SEEK_SET);
fread(header.data, 8, sizeof(char), fp);
}
// check "riff" "WAVE" "fmt " "data"
// Skip any sub-chunks between "fmt" and "data". Usually there will
// be a single "fact" sub chunk, but on Windows there can also be a
// "list" sub chunk.
while (0 != strncmp(header.data, "data", 4)) {
// We will just ignore the data in these chunks.
fseek(fp, header.data_size, SEEK_CUR);
// read next sub chunk
fread(header.data, 8, sizeof(char), fp);
}
if (header.data_size == 0) {
int offset = ftell(fp);
fseek(fp, 0, SEEK_END);
header.data_size = ftell(fp) - offset;
fseek(fp, offset, SEEK_SET);
}
num_channel_ = header.channels;
sample_rate_ = header.sample_rate;
bits_per_sample_ = header.bit;
int num_data = header.data_size / (bits_per_sample_ / 8);
data_ = new float[num_data]; // Create 1-dim array
num_samples_ = num_data / num_channel_;
std::cout << "num_channel_ :" << num_channel_ << std::endl;
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
std::cout << "num_samples :" << num_data << std::endl;
std::cout << "num_data_size :" << header.data_size << std::endl;
switch (bits_per_sample_) {
case 8: {
char sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(char), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
break;
}
case 16: {
int16_t sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(int16_t), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
break;
}
case 32:
{
if (header.format == 1) //S32
{
int sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(int), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
}
else if (header.format == 3) // IEEE-float
{
float sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(float), fp);
data_[i] = static_cast<float>(sample);
}
}
else {
printf("unsupported quantization bits\n");
}
break;
}
default:
printf("unsupported quantization bits\n");
break;
}
fclose(fp);
return true;
}
int num_channel() const { return num_channel_; }
int sample_rate() const { return sample_rate_; }
int bits_per_sample() const { return bits_per_sample_; }
int num_samples() const { return num_samples_; }
~WavReader() {
delete[] data_;
}
const float* data() const { return data_; }
private:
int num_channel_;
int sample_rate_;
int bits_per_sample_;
int num_samples_; // sample points per channel
float* data_;
};
class WavWriter {
public:
WavWriter(const float* data, int num_samples, int num_channel,
int sample_rate, int bits_per_sample)
: data_(data),
num_samples_(num_samples),
num_channel_(num_channel),
sample_rate_(sample_rate),
bits_per_sample_(bits_per_sample) {}
void Write(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "w");
// init char 'riff' 'WAVE' 'fmt ' 'data'
WavHeader header;
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
memcpy(&header, wav_header, sizeof(header));
header.channels = num_channel_;
header.bit = bits_per_sample_;
header.sample_rate = sample_rate_;
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
header.size = sizeof(header) - 8 + header.data_size;
header.bytes_per_second =
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
header.block_size = num_channel_ * (bits_per_sample_ / 8);
fwrite(&header, 1, sizeof(header), fp);
for (int i = 0; i < num_samples_; ++i) {
for (int j = 0; j < num_channel_; ++j) {
switch (bits_per_sample_) {
case 8: {
char sample = static_cast<char>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 16: {
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 32: {
int sample = static_cast<int>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
}
}
}
fclose(fp);
}
private:
const float* data_;
int num_samples_; // total float points in data_
int num_channel_;
int sample_rate_;
int bits_per_sample_;
};
} // namespace wenet
#endif // FRONTEND_WAV_H_

View File

@@ -21,7 +21,7 @@ class Program
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS); MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
List<SileroSpeechSegment> speechTimeList = vadDetector.GetSpeechSegmentList(new FileInfo(EXAMPLE_WAV_FILE)); List<SileroSpeechSegment> speechTimeList = vadDetector.GetSpeechSegmentList(new FileInfo(EXAMPLE_WAV_FILE));
//Console.WriteLine(speechTimeList.ToJson()); //Console.WriteLine(speechTimeList.ToJson());
StringBuilder sb = new StringBuilder(); StringBuilder sb = new();
foreach (var speechSegment in speechTimeList) foreach (var speechSegment in speechTimeList)
{ {
sb.Append($"start second: {speechSegment.StartSecond}, end second: {speechSegment.EndSecond}\n"); sb.Append($"start second: {speechSegment.StartSecond}, end second: {speechSegment.EndSecond}\n");

View File

@@ -53,28 +53,26 @@ public class SileroVadDetector
{ {
Reset(); Reset();
using (var audioFile = new AudioFileReader(wavFile.FullName)) using var audioFile = new AudioFileReader(wavFile.FullName);
List<float> speechProbList = [];
this._audioLengthSamples = (int)(audioFile.Length / 2);
float[] buffer = new float[this._windowSizeSample];
while (audioFile.Read(buffer, 0, buffer.Length) > 0)
{ {
List<float> speechProbList = new List<float>(); float speechProb = _model.Call([buffer], _samplingRate)[0];
this._audioLengthSamples = (int)(audioFile.Length / 2); speechProbList.Add(speechProb);
float[] buffer = new float[this._windowSizeSample];
while (audioFile.Read(buffer, 0, buffer.Length) > 0)
{
float speechProb = _model.Call(new[] { buffer }, _samplingRate)[0];
speechProbList.Add(speechProb);
}
return CalculateProb(speechProbList);
} }
return CalculateProb(speechProbList);
} }
private List<SileroSpeechSegment> CalculateProb(List<float> speechProbList) private List<SileroSpeechSegment> CalculateProb(List<float> speechProbList)
{ {
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>(); List<SileroSpeechSegment> result = [];
bool triggered = false; bool triggered = false;
int tempEnd = 0, prevEnd = 0, nextStart = 0; int tempEnd = 0, prevEnd = 0, nextStart = 0;
SileroSpeechSegment segment = new SileroSpeechSegment(); SileroSpeechSegment segment = new();
for (int i = 0; i < speechProbList.Count; i++) for (int i = 0; i < speechProbList.Count; i++)
{ {
@@ -164,7 +162,8 @@ public class SileroVadDetector
if (segment.StartOffset != null && (_audioLengthSamples - segment.StartOffset) > _minSpeechSamples) if (segment.StartOffset != null && (_audioLengthSamples - segment.StartOffset) > _minSpeechSamples)
{ {
segment.EndOffset = _audioLengthSamples; //segment.EndOffset = _audioLengthSamples;
segment.EndOffset = speechProbList.Count * _windowSizeSample;
result.Add(segment); result.Add(segment);
} }
@@ -182,7 +181,7 @@ public class SileroVadDetector
int silenceDuration = nextItem.StartOffset.Value - item.EndOffset.Value; int silenceDuration = nextItem.StartOffset.Value - item.EndOffset.Value;
if (silenceDuration < 2 * _speechPadSamples) if (silenceDuration < 2 * _speechPadSamples)
{ {
item.EndOffset = item.EndOffset + (silenceDuration / 2); item.EndOffset += (silenceDuration / 2);
nextItem.StartOffset = Math.Max(0, nextItem.StartOffset.Value - (silenceDuration / 2)); nextItem.StartOffset = Math.Max(0, nextItem.StartOffset.Value - (silenceDuration / 2));
} }
else else
@@ -200,9 +199,9 @@ public class SileroVadDetector
return MergeListAndCalculateSecond(result, _samplingRate); return MergeListAndCalculateSecond(result, _samplingRate);
} }
private List<SileroSpeechSegment> MergeListAndCalculateSecond(List<SileroSpeechSegment> original, int samplingRate) private static List<SileroSpeechSegment> MergeListAndCalculateSecond(List<SileroSpeechSegment> original, int samplingRate)
{ {
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>(); List<SileroSpeechSegment> result = [];
if (original == null || original.Count == 0) if (original == null || original.Count == 0)
{ {
return result; return result;
@@ -216,7 +215,7 @@ public class SileroVadDetector
for (int i = 1; i < original.Count; i++) for (int i = 1; i < original.Count; i++)
{ {
SileroSpeechSegment segment = original[i]; SileroSpeechSegment segment = original[i];
if (segment.StartOffset > right) if (segment.StartOffset > right)
{ {
result.Add(new SileroSpeechSegment(left, right, result.Add(new SileroSpeechSegment(left, right,
@@ -242,7 +241,7 @@ public class SileroVadDetector
return result; return result;
} }
private float CalculateSecondByOffset(int offset, int samplingRate) private static float CalculateSecondByOffset(int offset, int samplingRate)
{ {
float secondValue = offset * 1.0f / samplingRate; float secondValue = offset * 1.0f / samplingRate;
return (float)Math.Floor(secondValue * 1000.0f) / 1000.0f; return (float)Math.Floor(secondValue * 1000.0f) / 1000.0f;

View File

@@ -1,5 +1,6 @@
using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors; using Microsoft.ML.OnnxRuntime.Tensors;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
@@ -7,214 +8,208 @@ using System.Linq;
namespace VADdotnet; namespace VADdotnet;
public class SileroVadOnnxModel : IDisposable public class SileroVadOnnxModel : IDisposable
{
private readonly InferenceSession session;
private float[][][] state;
private float[][] context;
private int lastSr = 0;
private int lastBatchSize = 0;
private static readonly List<int> SAMPLE_RATES = [8000, 16000];
public SileroVadOnnxModel(string modelPath)
{ {
private readonly InferenceSession session; var sessionOptions = new SessionOptions
private float[][][] state;
private float[][] context;
private int lastSr = 0;
private int lastBatchSize = 0;
private static readonly List<int> SAMPLE_RATES = new List<int> { 8000, 16000 };
public SileroVadOnnxModel(string modelPath)
{ {
var sessionOptions = new SessionOptions(); InterOpNumThreads = 1,
sessionOptions.InterOpNumThreads = 1; IntraOpNumThreads = 1,
sessionOptions.IntraOpNumThreads = 1; EnableCpuMemArena = true
sessionOptions.EnableCpuMemArena = true; };
session = new InferenceSession(modelPath, sessionOptions); session = new InferenceSession(modelPath, sessionOptions);
ResetStates();
}
public void ResetStates()
{
state = new float[2][][];
state[0] = new float[1][];
state[1] = new float[1][];
state[0][0] = new float[128];
state[1][0] = new float[128];
context = [];
lastSr = 0;
lastBatchSize = 0;
}
public void Dispose()
{
GC.SuppressFinalize(this);
}
public class ValidationResult(float[][] x, int sr)
{
public float[][] X { get; } = x;
public int Sr { get; } = sr;
}
private static ValidationResult ValidateInput(float[][] x, int sr)
{
if (x.Length == 1)
{
x = [x[0]];
}
if (x.Length > 2)
{
throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}");
}
if (sr != 16000 && (sr % 16000 == 0))
{
int step = sr / 16000;
float[][] reducedX = new float[x.Length][];
for (int i = 0; i < x.Length; i++)
{
float[] current = x[i];
float[] newArr = new float[(current.Length + step - 1) / step];
for (int j = 0, index = 0; j < current.Length; j += step, index++)
{
newArr[index] = current[j];
}
reducedX[i] = newArr;
}
x = reducedX;
sr = 16000;
}
if (!SAMPLE_RATES.Contains(sr))
{
throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)");
}
if (((float)sr) / x[0].Length > 31.25)
{
throw new ArgumentException("Input audio is too short");
}
return new ValidationResult(x, sr);
}
private static float[][] Concatenate(float[][] a, float[][] b)
{
if (a.Length != b.Length)
{
throw new ArgumentException("The number of rows in both arrays must be the same.");
}
int rows = a.Length;
int colsA = a[0].Length;
int colsB = b[0].Length;
float[][] result = new float[rows][];
for (int i = 0; i < rows; i++)
{
result[i] = new float[colsA + colsB];
Array.Copy(a[i], 0, result[i], 0, colsA);
Array.Copy(b[i], 0, result[i], colsA, colsB);
}
return result;
}
private static float[][] GetLastColumns(float[][] array, int contextSize)
{
int rows = array.Length;
int cols = array[0].Length;
if (contextSize > cols)
{
throw new ArgumentException("contextSize cannot be greater than the number of columns in the array.");
}
float[][] result = new float[rows][];
for (int i = 0; i < rows; i++)
{
result[i] = new float[contextSize];
Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize);
}
return result;
}
public float[] Call(float[][] x, int sr)
{
var result = ValidateInput(x, sr);
x = result.X;
sr = result.Sr;
int numberSamples = sr == 16000 ? 512 : 256;
if (x[0].Length != numberSamples)
{
throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)");
}
int batchSize = x.Length;
int contextSize = sr == 16000 ? 64 : 32;
if (lastBatchSize == 0)
{
ResetStates();
}
if (lastSr != 0 && lastSr != sr)
{
ResetStates();
}
if (lastBatchSize != 0 && lastBatchSize != batchSize)
{
ResetStates(); ResetStates();
} }
public void ResetStates() if (context.Length == 0)
{ {
state = new float[2][][]; context = new float[batchSize][];
state[0] = new float[1][]; for (int i = 0; i < batchSize; i++)
state[1] = new float[1][];
state[0][0] = new float[128];
state[1][0] = new float[128];
context = Array.Empty<float[]>();
lastSr = 0;
lastBatchSize = 0;
}
public void Dispose()
{
session?.Dispose();
}
public class ValidationResult
{
public float[][] X { get; }
public int Sr { get; }
public ValidationResult(float[][] x, int sr)
{ {
X = x; context[i] = new float[contextSize];
Sr = sr;
} }
} }
private ValidationResult ValidateInput(float[][] x, int sr) x = Concatenate(context, x);
{
if (x.Length == 1) var inputs = new List<NamedOnnxValue>
{ {
x = new float[][] { x[0] }; NamedOnnxValue.CreateFromTensor("input", new DenseTensor<float>(x.SelectMany(a => a).ToArray(), [x.Length, x[0].Length])),
} NamedOnnxValue.CreateFromTensor("sr", new DenseTensor<long>(new[] { (long)sr }, [1])),
if (x.Length > 2) NamedOnnxValue.CreateFromTensor("state", new DenseTensor<float>(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), [state.Length, state[0].Length, state[0][0].Length]))
{
throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}");
}
if (sr != 16000 && (sr % 16000 == 0))
{
int step = sr / 16000;
float[][] reducedX = new float[x.Length][];
for (int i = 0; i < x.Length; i++)
{
float[] current = x[i];
float[] newArr = new float[(current.Length + step - 1) / step];
for (int j = 0, index = 0; j < current.Length; j += step, index++)
{
newArr[index] = current[j];
}
reducedX[i] = newArr;
}
x = reducedX;
sr = 16000;
}
if (!SAMPLE_RATES.Contains(sr))
{
throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)");
}
if (((float)sr) / x[0].Length > 31.25)
{
throw new ArgumentException("Input audio is too short");
}
return new ValidationResult(x, sr);
}
private static float[][] Concatenate(float[][] a, float[][] b)
{
if (a.Length != b.Length)
{
throw new ArgumentException("The number of rows in both arrays must be the same.");
}
int rows = a.Length;
int colsA = a[0].Length;
int colsB = b[0].Length;
float[][] result = new float[rows][];
for (int i = 0; i < rows; i++)
{
result[i] = new float[colsA + colsB];
Array.Copy(a[i], 0, result[i], 0, colsA);
Array.Copy(b[i], 0, result[i], colsA, colsB);
}
return result;
}
private static float[][] GetLastColumns(float[][] array, int contextSize)
{
int rows = array.Length;
int cols = array[0].Length;
if (contextSize > cols)
{
throw new ArgumentException("contextSize cannot be greater than the number of columns in the array.");
}
float[][] result = new float[rows][];
for (int i = 0; i < rows; i++)
{
result[i] = new float[contextSize];
Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize);
}
return result;
}
public float[] Call(float[][] x, int sr)
{
var result = ValidateInput(x, sr);
x = result.X;
sr = result.Sr;
int numberSamples = sr == 16000 ? 512 : 256;
if (x[0].Length != numberSamples)
{
throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)");
}
int batchSize = x.Length;
int contextSize = sr == 16000 ? 64 : 32;
if (lastBatchSize == 0)
{
ResetStates();
}
if (lastSr != 0 && lastSr != sr)
{
ResetStates();
}
if (lastBatchSize != 0 && lastBatchSize != batchSize)
{
ResetStates();
}
if (context.Length == 0)
{
context = new float[batchSize][];
for (int i = 0; i < batchSize; i++)
{
context[i] = new float[contextSize];
}
}
x = Concatenate(context, x);
var inputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("input", new DenseTensor<float>(x.SelectMany(a => a).ToArray(), new[] { x.Length, x[0].Length })),
NamedOnnxValue.CreateFromTensor("sr", new DenseTensor<long>(new[] { (long)sr }, new[] { 1 })),
NamedOnnxValue.CreateFromTensor("state", new DenseTensor<float>(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), new[] { state.Length, state[0].Length, state[0][0].Length }))
}; };
using (var outputs = session.Run(inputs)) using var outputs = session.Run(inputs);
var output = outputs.First(o => o.Name == "output").AsTensor<float>();
var newState = outputs.First(o => o.Name == "stateN").AsTensor<float>();
context = GetLastColumns(x, contextSize);
lastSr = sr;
lastBatchSize = batchSize;
state = new float[newState.Dimensions[0]][][];
for (int i = 0; i < newState.Dimensions[0]; i++)
{
state[i] = new float[newState.Dimensions[1]][];
for (int j = 0; j < newState.Dimensions[1]; j++)
{ {
var output = outputs.First(o => o.Name == "output").AsTensor<float>(); state[i][j] = new float[newState.Dimensions[2]];
var newState = outputs.First(o => o.Name == "stateN").AsTensor<float>(); for (int k = 0; k < newState.Dimensions[2]; k++)
context = GetLastColumns(x, contextSize);
lastSr = sr;
lastBatchSize = batchSize;
state = new float[newState.Dimensions[0]][][];
for (int i = 0; i < newState.Dimensions[0]; i++)
{ {
state[i] = new float[newState.Dimensions[1]][]; state[i][j][k] = newState[i, j, k];
for (int j = 0; j < newState.Dimensions[1]; j++)
{
state[i][j] = new float[newState.Dimensions[2]];
for (int k = 0; k < newState.Dimensions[2]; k++)
{
state[i][j][k] = newState[i, j, k];
}
}
} }
return output.ToArray();
} }
} }
return [.. output];
} }
}

View File

@@ -0,0 +1,25 @@

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 17
VisualStudioVersion = 17.14.36616.10 d17.14
MinimumVisualStudioVersion = 10.0.40219.1
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "VadDotNet", "VadDotNet.csproj", "{F36E1741-EDDB-90C7-7501-4911058F8996}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Release|Any CPU = Release|Any CPU
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{F36E1741-EDDB-90C7-7501-4911058F8996}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{F36E1741-EDDB-90C7-7501-4911058F8996}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F36E1741-EDDB-90C7-7501-4911058F8996}.Release|Any CPU.ActiveCfg = Release|Any CPU
{F36E1741-EDDB-90C7-7501-4911058F8996}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {DFC4CEE8-1034-46B4-A5F4-D1649B3543E6}
EndGlobalSection
EndGlobal

View File

@@ -1,30 +1,31 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId> <groupId>org.example</groupId>
<artifactId>java-example</artifactId> <artifactId>java-example</artifactId>
<version>1.0-SNAPSHOT</version> <version>1.0-SNAPSHOT</version>
<packaging>jar</packaging> <packaging>jar</packaging>
<name>sliero-vad-example</name> <name>sliero-vad-example</name>
<url>http://maven.apache.org</url> <url>http://maven.apache.org</url>
<properties> <properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties> </properties>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>junit</groupId>
<artifactId>junit</artifactId> <artifactId>junit</artifactId>
<version>3.8.1</version> <version>3.8.1</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency> <!-- https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime -->
<groupId>com.microsoft.onnxruntime</groupId> <dependency>
<artifactId>onnxruntime</artifactId> <groupId>com.microsoft.onnxruntime</groupId>
<version>1.16.0-rc1</version> <artifactId>onnxruntime</artifactId>
</dependency> <version>1.23.1</version>
</dependencies> </dependency>
</dependencies>
</project> </project>

View File

@@ -2,68 +2,263 @@ package org.example;
import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtException;
import javax.sound.sampled.*; import javax.sound.sampled.*;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
/**
* Silero VAD Java Example
* Voice Activity Detection using ONNX model
*
* @author VvvvvGH
*/
public class App { public class App {
private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx"; // ONNX model path - using the model file from the project
private static final String MODEL_PATH = "../../src/silero_vad/data/silero_vad.onnx";
// Test audio file path
private static final String AUDIO_FILE_PATH = "../../en_example.wav";
// Sampling rate
private static final int SAMPLE_RATE = 16000; private static final int SAMPLE_RATE = 16000;
private static final float START_THRESHOLD = 0.6f; // Speech threshold (consistent with Python default)
private static final float END_THRESHOLD = 0.45f; private static final float THRESHOLD = 0.5f;
private static final int MIN_SILENCE_DURATION_MS = 600; // Negative threshold (used to determine speech end)
private static final int SPEECH_PAD_MS = 500; private static final float NEG_THRESHOLD = 0.35f; // threshold - 0.15
private static final int WINDOW_SIZE_SAMPLES = 2048; // Minimum speech duration (milliseconds)
private static final int MIN_SPEECH_DURATION_MS = 250;
// Minimum silence duration (milliseconds)
private static final int MIN_SILENCE_DURATION_MS = 100;
// Speech padding (milliseconds)
private static final int SPEECH_PAD_MS = 30;
// Window size (samples) - 512 samples for 16kHz
private static final int WINDOW_SIZE_SAMPLES = 512;
public static void main(String[] args) { public static void main(String[] args) {
// Initialize the Voice Activity Detector System.out.println("=".repeat(60));
SlieroVadDetector vadDetector; System.out.println("Silero VAD Java ONNX Example");
System.out.println("=".repeat(60));
// Load ONNX model
SlieroVadOnnxModel model;
try { try {
vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS); System.out.println("Loading ONNX model: " + MODEL_PATH);
model = new SlieroVadOnnxModel(MODEL_PATH);
System.out.println("Model loaded successfully!");
} catch (OrtException e) { } catch (OrtException e) {
System.err.println("Error initializing the VAD detector: " + e.getMessage()); System.err.println("Failed to load model: " + e.getMessage());
e.printStackTrace();
return; return;
} }
// Set audio format // Read WAV file
AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false); float[] audioData;
DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
// Get the target data line and open it with the specified format
TargetDataLine targetDataLine;
try { try {
targetDataLine = (TargetDataLine) AudioSystem.getLine(info); System.out.println("\nReading audio file: " + AUDIO_FILE_PATH);
targetDataLine.open(format); audioData = readWavFileAsFloatArray(AUDIO_FILE_PATH);
targetDataLine.start(); System.out.println("Audio file read successfully, samples: " + audioData.length);
} catch (LineUnavailableException e) { System.out.println("Audio duration: " + String.format("%.2f", (audioData.length / (float) SAMPLE_RATE)) + " seconds");
System.err.println("Error opening target data line: " + e.getMessage()); } catch (Exception e) {
System.err.println("Failed to read audio file: " + e.getMessage());
e.printStackTrace();
return; return;
} }
// Main loop to continuously read data and apply Voice Activity Detection // Get speech timestamps (batch mode, consistent with Python's get_speech_timestamps)
while (targetDataLine.isOpen()) { System.out.println("\nDetecting speech segments...");
byte[] data = new byte[WINDOW_SIZE_SAMPLES]; List<Map<String, Integer>> speechTimestamps;
try {
int numBytesRead = targetDataLine.read(data, 0, data.length); speechTimestamps = getSpeechTimestamps(
if (numBytesRead <= 0) { audioData,
System.err.println("Error reading data from target data line."); model,
continue; THRESHOLD,
} SAMPLE_RATE,
MIN_SPEECH_DURATION_MS,
// Apply the Voice Activity Detector to the data and get the result MIN_SILENCE_DURATION_MS,
Map<String, Double> detectResult; SPEECH_PAD_MS,
try { NEG_THRESHOLD
detectResult = vadDetector.apply(data, true); );
} catch (Exception e) { } catch (OrtException e) {
System.err.println("Error applying VAD detector: " + e.getMessage()); System.err.println("Failed to detect speech timestamps: " + e.getMessage());
continue; e.printStackTrace();
} return;
if (!detectResult.isEmpty()) {
System.out.println(detectResult);
}
} }
// Close the target data line to release audio resources // Output detection results
targetDataLine.close(); System.out.println("\nDetected speech timestamps (in samples):");
for (Map<String, Integer> timestamp : speechTimestamps) {
System.out.println(timestamp);
}
// Output summary
System.out.println("\n" + "=".repeat(60));
System.out.println("Detection completed!");
System.out.println("Total detected " + speechTimestamps.size() + " speech segments");
System.out.println("=".repeat(60));
// Close model
try {
model.close();
} catch (OrtException e) {
System.err.println("Error closing model: " + e.getMessage());
}
} }
/**
* Get speech timestamps
* Implements the same logic as Python's get_speech_timestamps
*
* @param audio Audio data (float array)
* @param model ONNX model
* @param threshold Speech threshold
* @param samplingRate Sampling rate
* @param minSpeechDurationMs Minimum speech duration (milliseconds)
* @param minSilenceDurationMs Minimum silence duration (milliseconds)
* @param speechPadMs Speech padding (milliseconds)
* @param negThreshold Negative threshold (used to determine speech end)
* @return List of speech timestamps
*/
private static List<Map<String, Integer>> getSpeechTimestamps(
float[] audio,
SlieroVadOnnxModel model,
float threshold,
int samplingRate,
int minSpeechDurationMs,
int minSilenceDurationMs,
int speechPadMs,
float negThreshold) throws OrtException {
// Reset model states
model.resetStates();
// Calculate parameters
int minSpeechSamples = samplingRate * minSpeechDurationMs / 1000;
int speechPadSamples = samplingRate * speechPadMs / 1000;
int minSilenceSamples = samplingRate * minSilenceDurationMs / 1000;
int windowSizeSamples = samplingRate == 16000 ? 512 : 256;
int audioLengthSamples = audio.length;
// Calculate speech probabilities for all audio chunks
List<Float> speechProbs = new ArrayList<>();
for (int currentStart = 0; currentStart < audioLengthSamples; currentStart += windowSizeSamples) {
float[] chunk = new float[windowSizeSamples];
int chunkLength = Math.min(windowSizeSamples, audioLengthSamples - currentStart);
System.arraycopy(audio, currentStart, chunk, 0, chunkLength);
// Pad with zeros if chunk is shorter than window size
if (chunkLength < windowSizeSamples) {
for (int i = chunkLength; i < windowSizeSamples; i++) {
chunk[i] = 0.0f;
}
}
float speechProb = model.call(new float[][]{chunk}, samplingRate)[0];
speechProbs.add(speechProb);
}
// Detect speech segments using the same algorithm as Python
boolean triggered = false;
List<Map<String, Integer>> speeches = new ArrayList<>();
Map<String, Integer> currentSpeech = null;
int tempEnd = 0;
for (int i = 0; i < speechProbs.size(); i++) {
float speechProb = speechProbs.get(i);
// Reset temporary end if speech probability exceeds threshold
if (speechProb >= threshold && tempEnd != 0) {
tempEnd = 0;
}
// Detect speech start
if (speechProb >= threshold && !triggered) {
triggered = true;
currentSpeech = new HashMap<>();
currentSpeech.put("start", windowSizeSamples * i);
continue;
}
// Detect speech end
if (speechProb < negThreshold && triggered) {
if (tempEnd == 0) {
tempEnd = windowSizeSamples * i;
}
if (windowSizeSamples * i - tempEnd < minSilenceSamples) {
continue;
} else {
currentSpeech.put("end", tempEnd);
if (currentSpeech.get("end") - currentSpeech.get("start") > minSpeechSamples) {
speeches.add(currentSpeech);
}
currentSpeech = null;
tempEnd = 0;
triggered = false;
}
}
}
// Handle the last speech segment
if (currentSpeech != null &&
(audioLengthSamples - currentSpeech.get("start")) > minSpeechSamples) {
currentSpeech.put("end", audioLengthSamples);
speeches.add(currentSpeech);
}
// Add speech padding - same logic as Python
for (int i = 0; i < speeches.size(); i++) {
Map<String, Integer> speech = speeches.get(i);
if (i == 0) {
speech.put("start", Math.max(0, speech.get("start") - speechPadSamples));
}
if (i != speeches.size() - 1) {
int silenceDuration = speeches.get(i + 1).get("start") - speech.get("end");
if (silenceDuration < 2 * speechPadSamples) {
speech.put("end", speech.get("end") + silenceDuration / 2);
speeches.get(i + 1).put("start",
Math.max(0, speeches.get(i + 1).get("start") - silenceDuration / 2));
} else {
speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples));
speeches.get(i + 1).put("start",
Math.max(0, speeches.get(i + 1).get("start") - speechPadSamples));
}
} else {
speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples));
}
}
return speeches;
}
/**
* Read WAV file and return as float array
*
* @param filePath WAV file path
* @return Audio data as float array (normalized to -1.0 to 1.0)
*/
private static float[] readWavFileAsFloatArray(String filePath)
throws UnsupportedAudioFileException, IOException {
File audioFile = new File(filePath);
AudioInputStream audioStream = AudioSystem.getAudioInputStream(audioFile);
// Get audio format information
AudioFormat format = audioStream.getFormat();
System.out.println("Audio format: " + format);
// Read all audio data
byte[] audioBytes = audioStream.readAllBytes();
audioStream.close();
// Convert to float array
float[] audioData = new float[audioBytes.length / 2];
for (int i = 0; i < audioData.length; i++) {
// 16-bit PCM: two bytes per sample (little-endian)
short sample = (short) ((audioBytes[i * 2] & 0xff) | (audioBytes[i * 2 + 1] << 8));
audioData[i] = sample / 32768.0f; // Normalize to -1.0 to 1.0
}
return audioData;
}
} }

View File

@@ -8,25 +8,30 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
/**
* Silero VAD Detector
* Real-time voice activity detection
*
* @author VvvvvGH
*/
public class SlieroVadDetector { public class SlieroVadDetector {
// OnnxModel model used for speech processing // ONNX model for speech processing
private final SlieroVadOnnxModel model; private final SlieroVadOnnxModel model;
// Threshold for speech start // Speech start threshold
private final float startThreshold; private final float startThreshold;
// Threshold for speech end // Speech end threshold
private final float endThreshold; private final float endThreshold;
// Sampling rate // Sampling rate
private final int samplingRate; private final int samplingRate;
// Minimum number of silence samples to determine the end threshold of speech // Minimum silence samples to determine speech end
private final float minSilenceSamples; private final float minSilenceSamples;
// Additional number of samples for speech start or end to calculate speech start or end time // Speech padding samples for calculating speech boundaries
private final float speechPadSamples; private final float speechPadSamples;
// Whether in the triggered state (i.e. whether speech is being detected) // Triggered state (whether speech is being detected)
private boolean triggered; private boolean triggered;
// Temporarily stored number of speech end samples // Temporary speech end sample position
private int tempEnd; private int tempEnd;
// Number of samples currently being processed // Current sample position
private int currentSample; private int currentSample;
@@ -36,23 +41,25 @@ public class SlieroVadDetector {
int samplingRate, int samplingRate,
int minSilenceDurationMs, int minSilenceDurationMs,
int speechPadMs) throws OrtException { int speechPadMs) throws OrtException {
// Check if the sampling rate is 8000 or 16000, if not, throw an exception // Validate sampling rate
if (samplingRate != 8000 && samplingRate != 16000) { if (samplingRate != 8000 && samplingRate != 16000) {
throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]"); throw new IllegalArgumentException("Does not support sampling rates other than [8000, 16000]");
} }
// Initialize the parameters // Initialize parameters
this.model = new SlieroVadOnnxModel(modelPath); this.model = new SlieroVadOnnxModel(modelPath);
this.startThreshold = startThreshold; this.startThreshold = startThreshold;
this.endThreshold = endThreshold; this.endThreshold = endThreshold;
this.samplingRate = samplingRate; this.samplingRate = samplingRate;
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f; this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
this.speechPadSamples = samplingRate * speechPadMs / 1000f; this.speechPadSamples = samplingRate * speechPadMs / 1000f;
// Reset the state // Reset state
reset(); reset();
} }
// Method to reset the state, including the model state, trigger state, temporary end time, and current sample count /**
* Reset detector state
*/
public void reset() { public void reset() {
model.resetStates(); model.resetStates();
triggered = false; triggered = false;
@@ -60,21 +67,27 @@ public class SlieroVadDetector {
currentSample = 0; currentSample = 0;
} }
// apply method for processing the audio array, returning possible speech start or end times /**
* Process audio data and detect speech events
*
* @param data Audio data as byte array
* @param returnSeconds Whether to return timestamps in seconds
* @return Speech event (start or end) or empty map if no event
*/
public Map<String, Double> apply(byte[] data, boolean returnSeconds) { public Map<String, Double> apply(byte[] data, boolean returnSeconds) {
// Convert the byte array to a float array // Convert byte array to float array
float[] audioData = new float[data.length / 2]; float[] audioData = new float[data.length / 2];
for (int i = 0; i < audioData.length; i++) { for (int i = 0; i < audioData.length; i++) {
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f; audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
} }
// Get the length of the audio array as the window size // Get window size from audio data length
int windowSizeSamples = audioData.length; int windowSizeSamples = audioData.length;
// Update the current sample count // Update current sample position
currentSample += windowSizeSamples; currentSample += windowSizeSamples;
// Call the model to get the prediction probability of speech // Get speech probability from model
float speechProb = 0; float speechProb = 0;
try { try {
speechProb = model.call(new float[][]{audioData}, samplingRate)[0]; speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
@@ -82,19 +95,18 @@ public class SlieroVadDetector {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
// If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time // Reset temporary end if speech probability exceeds threshold
// This indicates that the speech duration has exceeded expectations and needs to recalculate the end time
if (speechProb >= startThreshold && tempEnd != 0) { if (speechProb >= startThreshold && tempEnd != 0) {
tempEnd = 0; tempEnd = 0;
} }
// If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time // Detect speech start
if (speechProb >= startThreshold && !triggered) { if (speechProb >= startThreshold && !triggered) {
triggered = true; triggered = true;
int speechStart = (int) (currentSample - speechPadSamples); int speechStart = (int) (currentSample - speechPadSamples);
speechStart = Math.max(speechStart, 0); speechStart = Math.max(speechStart, 0);
Map<String, Double> result = new HashMap<>(); Map<String, Double> result = new HashMap<>();
// Decide whether to return the result in seconds or sample count based on the returnSeconds parameter // Return in seconds or samples based on returnSeconds parameter
if (returnSeconds) { if (returnSeconds) {
double speechStartSeconds = speechStart / (double) samplingRate; double speechStartSeconds = speechStart / (double) samplingRate;
double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue(); double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
@@ -106,18 +118,17 @@ public class SlieroVadDetector {
return result; return result;
} }
// If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time // Detect speech end
if (speechProb < endThreshold && triggered) { if (speechProb < endThreshold && triggered) {
// Initialize or update the temporary end time // Initialize or update temporary end position
if (tempEnd == 0) { if (tempEnd == 0) {
tempEnd = currentSample; tempEnd = currentSample;
} }
// If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null // Wait for minimum silence duration before confirming speech end
// This indicates that it is not yet possible to determine whether the speech has ended
if (currentSample - tempEnd < minSilenceSamples) { if (currentSample - tempEnd < minSilenceSamples) {
return Collections.emptyMap(); return Collections.emptyMap();
} else { } else {
// Calculate the speech end time, reset the trigger state and temporary end time // Calculate speech end time and reset state
int speechEnd = (int) (tempEnd + speechPadSamples); int speechEnd = (int) (tempEnd + speechPadSamples);
tempEnd = 0; tempEnd = 0;
triggered = false; triggered = false;
@@ -134,7 +145,7 @@ public class SlieroVadDetector {
} }
} }
// If the above conditions are not met, return null by default // No speech event detected
return Collections.emptyMap(); return Collections.emptyMap();
} }

View File

@@ -9,42 +9,58 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
/**
* Silero VAD ONNX Model Wrapper
*
* @author VvvvvGH
*/
public class SlieroVadOnnxModel { public class SlieroVadOnnxModel {
// Define private variable OrtSession // ONNX runtime session
private final OrtSession session; private final OrtSession session;
private float[][][] h; // Model state - dimensions: [2, batch_size, 128]
private float[][][] c; private float[][][] state;
// Define the last sample rate // Context - stores the tail of the previous audio chunk
private float[][] context;
// Last sample rate
private int lastSr = 0; private int lastSr = 0;
// Define the last batch size // Last batch size
private int lastBatchSize = 0; private int lastBatchSize = 0;
// Define a list of supported sample rates // Supported sample rates
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000); private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
// Constructor // Constructor
public SlieroVadOnnxModel(String modelPath) throws OrtException { public SlieroVadOnnxModel(String modelPath) throws OrtException {
// Get the ONNX runtime environment // Get the ONNX runtime environment
OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtEnvironment env = OrtEnvironment.getEnvironment();
// Create an ONNX session options object // Create ONNX session options
OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations // Set InterOp thread count to 1 (for parallel processing of different graph operations)
opts.setInterOpNumThreads(1); opts.setInterOpNumThreads(1);
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation // Set IntraOp thread count to 1 (for parallel processing within a single operation)
opts.setIntraOpNumThreads(1); opts.setIntraOpNumThreads(1);
// Add a CPU device, setting to false disables CPU execution optimization // Enable CPU execution optimization
opts.addCPU(true); opts.addCPU(true);
// Create an ONNX session using the environment, model path, and options // Create ONNX session with the environment, model path, and options
session = env.createSession(modelPath, opts); session = env.createSession(modelPath, opts);
// Reset states // Reset states
resetStates(); resetStates();
} }
/** /**
* Reset states * Reset states with default batch size
*/ */
void resetStates() { void resetStates() {
h = new float[2][1][64]; resetStates(1);
c = new float[2][1][64]; }
/**
* Reset states with specific batch size
*
* @param batchSize Batch size for state initialization
*/
void resetStates(int batchSize) {
state = new float[2][batchSize][128];
context = new float[0][]; // Empty context
lastSr = 0; lastSr = 0;
lastBatchSize = 0; lastBatchSize = 0;
} }
@@ -54,13 +70,12 @@ public class SlieroVadOnnxModel {
} }
/** /**
* Define inner class ValidationResult * Inner class for validation result
*/ */
public static class ValidationResult { public static class ValidationResult {
public final float[][] x; public final float[][] x;
public final int sr; public final int sr;
// Constructor
public ValidationResult(float[][] x, int sr) { public ValidationResult(float[][] x, int sr) {
this.x = x; this.x = x;
this.sr = sr; this.sr = sr;
@@ -68,19 +83,23 @@ public class SlieroVadOnnxModel {
} }
/** /**
* Function to validate input data * Validate input data
*
* @param x Audio data array
* @param sr Sample rate
* @return Validated input data and sample rate
*/ */
private ValidationResult validateInput(float[][] x, int sr) { private ValidationResult validateInput(float[][] x, int sr) {
// Process the input data with dimension 1 // Ensure input is at least 2D
if (x.length == 1) { if (x.length == 1) {
x = new float[][]{x[0]}; x = new float[][]{x[0]};
} }
// Throw an exception when the input data dimension is greater than 2 // Check if input dimension is valid
if (x.length > 2) { if (x.length > 2) {
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length); throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
} }
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000 // Downsample if sample rate is a multiple of 16000
if (sr != 16000 && (sr % 16000 == 0)) { if (sr != 16000 && (sr % 16000 == 0)) {
int step = sr / 16000; int step = sr / 16000;
float[][] reducedX = new float[x.length][]; float[][] reducedX = new float[x.length][];
@@ -100,22 +119,26 @@ public class SlieroVadOnnxModel {
sr = 16000; sr = 16000;
} }
// If the sample rate is not in the list of supported sample rates, throw an exception // Validate sample rate
if (!SAMPLE_RATES.contains(sr)) { if (!SAMPLE_RATES.contains(sr)) {
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)"); throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
} }
// If the input audio block is too short, throw an exception // Check if audio chunk is too short
if (((float) sr) / x[0].length > 31.25) { if (((float) sr) / x[0].length > 31.25) {
throw new IllegalArgumentException("Input audio is too short"); throw new IllegalArgumentException("Input audio is too short");
} }
// Return the validated result
return new ValidationResult(x, sr); return new ValidationResult(x, sr);
} }
/** /**
* Method to call the ONNX model * Call the ONNX model for inference
*
* @param x Audio data array
* @param sr Sample rate
* @return Speech probability output
* @throws OrtException If ONNX runtime error occurs
*/ */
public float[] call(float[][] x, int sr) throws OrtException { public float[] call(float[][] x, int sr) throws OrtException {
ValidationResult result = validateInput(x, sr); ValidationResult result = validateInput(x, sr);
@@ -123,38 +146,62 @@ public class SlieroVadOnnxModel {
sr = result.sr; sr = result.sr;
int batchSize = x.length; int batchSize = x.length;
int numSamples = sr == 16000 ? 512 : 256;
int contextSize = sr == 16000 ? 64 : 32;
if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) { // Reset states only when sample rate or batch size changes
resetStates(); if (lastSr != 0 && lastSr != sr) {
resetStates(batchSize);
} else if (lastBatchSize != 0 && lastBatchSize != batchSize) {
resetStates(batchSize);
} else if (lastBatchSize == 0) {
// First call - state is already initialized, just set batch size
lastBatchSize = batchSize;
}
// Initialize context if needed
if (context.length == 0) {
context = new float[batchSize][contextSize];
}
// Concatenate context and input
float[][] xWithContext = new float[batchSize][contextSize + numSamples];
for (int i = 0; i < batchSize; i++) {
// Copy context
System.arraycopy(context[i], 0, xWithContext[i], 0, contextSize);
// Copy input
System.arraycopy(x[i], 0, xWithContext[i], contextSize, numSamples);
} }
OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtEnvironment env = OrtEnvironment.getEnvironment();
OnnxTensor inputTensor = null; OnnxTensor inputTensor = null;
OnnxTensor hTensor = null; OnnxTensor stateTensor = null;
OnnxTensor cTensor = null;
OnnxTensor srTensor = null; OnnxTensor srTensor = null;
OrtSession.Result ortOutputs = null; OrtSession.Result ortOutputs = null;
try { try {
// Create input tensors // Create input tensors
inputTensor = OnnxTensor.createTensor(env, x); inputTensor = OnnxTensor.createTensor(env, xWithContext);
hTensor = OnnxTensor.createTensor(env, h); stateTensor = OnnxTensor.createTensor(env, state);
cTensor = OnnxTensor.createTensor(env, c);
srTensor = OnnxTensor.createTensor(env, new long[]{sr}); srTensor = OnnxTensor.createTensor(env, new long[]{sr});
Map<String, OnnxTensor> inputs = new HashMap<>(); Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input", inputTensor); inputs.put("input", inputTensor);
inputs.put("sr", srTensor); inputs.put("sr", srTensor);
inputs.put("h", hTensor); inputs.put("state", stateTensor);
inputs.put("c", cTensor);
// Call the ONNX model for calculation // Run ONNX model inference
ortOutputs = session.run(inputs); ortOutputs = session.run(inputs);
// Get the output results // Get output results
float[][] output = (float[][]) ortOutputs.get(0).getValue(); float[][] output = (float[][]) ortOutputs.get(0).getValue();
h = (float[][][]) ortOutputs.get(1).getValue(); state = (float[][][]) ortOutputs.get(1).getValue();
c = (float[][][]) ortOutputs.get(2).getValue();
// Update context - save the last contextSize samples from input
for (int i = 0; i < batchSize; i++) {
System.arraycopy(xWithContext[i], xWithContext[i].length - contextSize,
context[i], 0, contextSize);
}
lastSr = sr; lastSr = sr;
lastBatchSize = batchSize; lastBatchSize = batchSize;
@@ -163,11 +210,8 @@ public class SlieroVadOnnxModel {
if (inputTensor != null) { if (inputTensor != null) {
inputTensor.close(); inputTensor.close();
} }
if (hTensor != null) { if (stateTensor != null) {
hTensor.close(); stateTensor.close();
}
if (cTensor != null) {
cTensor.close();
} }
if (srTensor != null) { if (srTensor != null) {
srTensor.close(); srTensor.close();

View File

@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 4
[[package]] [[package]]
name = "adler" name = "adler"
@@ -20,6 +20,12 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "base64ct"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e050f626429857a27ddccb31e0aca21356bfa709c04041aefddac081a8f068a"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@@ -42,10 +48,16 @@ dependencies = [
] ]
[[package]] [[package]]
name = "bumpalo" name = "byteorder"
version = "3.16.0" version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3"
[[package]] [[package]]
name = "cc" name = "cc"
@@ -59,6 +71,22 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]] [[package]]
name = "cpufeatures" name = "cpufeatures"
version = "0.2.12" version = "0.2.12"
@@ -77,12 +105,6 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "crunchy"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]] [[package]]
name = "crypto-common" name = "crypto-common"
version = "0.1.6" version = "0.1.6"
@@ -93,6 +115,16 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "der"
version = "0.7.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb"
dependencies = [
"pem-rfc7468",
"zeroize",
]
[[package]] [[package]]
name = "digest" name = "digest"
version = "0.10.7" version = "0.10.7"
@@ -110,9 +142,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]] [[package]]
name = "filetime" name = "filetime"
version = "0.2.23" version = "0.2.23"
@@ -122,7 +160,7 @@ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
"redox_syscall", "redox_syscall",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@@ -136,14 +174,20 @@ dependencies = [
] ]
[[package]] [[package]]
name = "form_urlencoded" name = "foreign-types"
version = "1.2.1" version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [ dependencies = [
"percent-encoding", "foreign-types-shared",
] ]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]] [[package]]
name = "generic-array" name = "generic-array"
version = "0.14.7" version = "0.14.7"
@@ -154,27 +198,6 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "getrandom"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "half"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
dependencies = [
"cfg-if",
"crunchy",
]
[[package]] [[package]]
name = "hound" name = "hound"
version = "3.5.1" version = "3.5.1"
@@ -182,23 +205,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f"
[[package]] [[package]]
name = "idna" name = "http"
version = "0.5.0" version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a"
dependencies = [ dependencies = [
"unicode-bidi", "bytes",
"unicode-normalization", "itoa",
] ]
[[package]] [[package]]
name = "js-sys" name = "httparse"
version = "0.3.69" version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
dependencies = [
"wasm-bindgen", [[package]]
] name = "itoa"
version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
[[package]] [[package]]
name = "libc" name = "libc"
@@ -206,16 +232,6 @@ version = "0.2.155"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]]
name = "libloading"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
dependencies = [
"cfg-if",
"windows-targets",
]
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.4.14" version = "0.4.14"
@@ -224,9 +240,9 @@ checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
[[package]] [[package]]
name = "log" name = "log"
version = "0.4.21" version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]] [[package]]
name = "matrixmultiply" name = "matrixmultiply"
@@ -248,15 +264,34 @@ dependencies = [
] ]
[[package]] [[package]]
name = "ndarray" name = "native-tls"
version = "0.15.6" version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "ndarray"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
dependencies = [ dependencies = [
"matrixmultiply", "matrixmultiply",
"num-complex", "num-complex",
"num-integer", "num-integer",
"num-traits", "num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer", "rawpointer",
] ]
@@ -294,33 +329,83 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]] [[package]]
name = "ort" name = "openssl"
version = "2.0.0-rc.2" version = "0.10.75"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14" checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
dependencies = [
"bitflags 2.5.0",
"cfg-if",
"foreign-types",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "ort"
version = "2.0.0-rc.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721"
dependencies = [ dependencies = [
"half",
"js-sys",
"libloading",
"ndarray", "ndarray",
"ort-sys", "ort-sys",
"thiserror", "smallvec",
"tracing", "tracing",
"web-sys",
] ]
[[package]] [[package]]
name = "ort-sys" name = "ort-sys"
version = "2.0.0-rc.2" version = "2.0.0-rc.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe" checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890"
dependencies = [ dependencies = [
"flate2", "flate2",
"pkg-config",
"sha2", "sha2",
"tar", "tar",
"ureq", "ureq",
] ]
[[package]]
name = "pem-rfc7468"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412"
dependencies = [
"base64ct",
]
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.3.1" version = "2.3.1"
@@ -333,6 +418,27 @@ version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
[[package]]
name = "pkg-config"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "portable-atomic"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950"
[[package]]
name = "portable-atomic-util"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507"
dependencies = [
"portable-atomic",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.84" version = "1.0.84"
@@ -366,21 +472,6 @@ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
] ]
[[package]]
name = "ring"
version = "0.17.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
dependencies = [
"cc",
"cfg-if",
"getrandom",
"libc",
"spin",
"untrusted",
"windows-sys",
]
[[package]] [[package]]
name = "rust-example" name = "rust-example"
version = "0.1.0" version = "0.1.0"
@@ -400,38 +491,48 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
"windows-sys", "windows-sys 0.52.0",
]
[[package]]
name = "rustls"
version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
dependencies = [
"log",
"ring",
"rustls-pki-types",
"rustls-webpki",
"subtle",
"zeroize",
] ]
[[package]] [[package]]
name = "rustls-pki-types" name = "rustls-pki-types"
version = "1.7.0" version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282"
dependencies = [
"zeroize",
]
[[package]] [[package]]
name = "rustls-webpki" name = "schannel"
version = "0.102.4" version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
dependencies = [ dependencies = [
"ring", "windows-sys 0.61.2",
"rustls-pki-types", ]
"untrusted",
[[package]]
name = "security-framework"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0"
dependencies = [
"bitflags 2.5.0",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0"
dependencies = [
"core-foundation-sys",
"libc",
] ]
[[package]] [[package]]
@@ -446,16 +547,21 @@ dependencies = [
] ]
[[package]] [[package]]
name = "spin" name = "smallvec"
version = "0.9.8" version = "2.0.0-alpha.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b"
[[package]] [[package]]
name = "subtle" name = "socks"
version = "2.5.0" version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
dependencies = [
"byteorder",
"libc",
"winapi",
]
[[package]] [[package]]
name = "syn" name = "syn"
@@ -480,40 +586,18 @@ dependencies = [
] ]
[[package]] [[package]]
name = "thiserror" name = "tempfile"
version = "1.0.61" version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
dependencies = [ dependencies = [
"thiserror-impl", "cfg-if",
"fastrand",
"once_cell",
"rustix",
"windows-sys 0.59.0",
] ]
[[package]]
name = "thiserror-impl"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tinyvec"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]] [[package]]
name = "tracing" name = "tracing"
version = "0.1.40" version = "0.1.40"
@@ -521,21 +605,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [ dependencies = [
"pin-project-lite", "pin-project-lite",
"tracing-attributes",
"tracing-core", "tracing-core",
] ]
[[package]]
name = "tracing-attributes"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "tracing-core" name = "tracing-core"
version = "0.1.32" version = "0.1.32"
@@ -551,60 +623,54 @@ version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "unicode-bidi"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.12" version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "unicode-normalization"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5"
dependencies = [
"tinyvec",
]
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]] [[package]]
name = "ureq" name = "ureq"
version = "2.9.7" version = "3.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a"
dependencies = [ dependencies = [
"base64", "base64",
"der",
"log", "log",
"once_cell", "native-tls",
"rustls", "percent-encoding",
"rustls-pki-types", "rustls-pki-types",
"rustls-webpki", "socks",
"url", "ureq-proto",
"webpki-roots", "utf-8",
"webpki-root-certs",
] ]
[[package]] [[package]]
name = "url" name = "ureq-proto"
version = "2.5.0" version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f"
dependencies = [ dependencies = [
"form_urlencoded", "base64",
"idna", "http",
"percent-encoding", "httparse",
"log",
] ]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.4" version = "0.9.4"
@@ -612,84 +678,42 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]] [[package]]
name = "wasi" name = "webpki-root-certs"
version = "0.11.0+wasi-snapshot-preview1" version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" checksum = "ee3e3b5f5e80bc89f30ce8d0343bf4e5f12341c51f3e26cbeecbc7c85443e85b"
[[package]]
name = "wasm-bindgen"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8"
dependencies = [
"cfg-if",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da"
dependencies = [
"bumpalo",
"log",
"once_cell",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96"
[[package]]
name = "web-sys"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009"
dependencies = [ dependencies = [
"rustls-pki-types", "rustls-pki-types",
] ]
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-link"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]] [[package]]
name = "windows-sys" name = "windows-sys"
version = "0.52.0" version = "0.52.0"
@@ -700,10 +724,28 @@ dependencies = [
] ]
[[package]] [[package]]
name = "windows-targets" name = "windows-sys"
version = "0.52.5" version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-sys"
version = "0.61.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-targets"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [ dependencies = [
"windows_aarch64_gnullvm", "windows_aarch64_gnullvm",
"windows_aarch64_msvc", "windows_aarch64_msvc",
@@ -717,51 +759,51 @@ dependencies = [
[[package]] [[package]]
name = "windows_aarch64_gnullvm" name = "windows_aarch64_gnullvm"
version = "0.52.5" version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]] [[package]]
name = "windows_aarch64_msvc" name = "windows_aarch64_msvc"
version = "0.52.5" version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]] [[package]]
name = "windows_i686_gnu" name = "windows_i686_gnu"
version = "0.52.5" version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]] [[package]]
name = "windows_i686_gnullvm" name = "windows_i686_gnullvm"
version = "0.52.5" version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]] [[package]]
name = "windows_i686_msvc" name = "windows_i686_msvc"
version = "0.52.5" version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]] [[package]]
name = "windows_x86_64_gnu" name = "windows_x86_64_gnu"
version = "0.52.5" version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]] [[package]]
name = "windows_x86_64_gnullvm" name = "windows_x86_64_gnullvm"
version = "0.52.5" version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]] [[package]]
name = "windows_x86_64_msvc" name = "windows_x86_64_msvc"
version = "0.52.5" version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]] [[package]]
name = "xattr" name = "xattr"

View File

@@ -4,6 +4,6 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
ort = { version = "2.0.0-rc.2", features = ["load-dynamic", "ndarray"] } ort = { version = "=2.0.0-rc.10", features = ["ndarray"] }
ndarray = "0.15" ndarray = "0.16"
hound = "3" hound = "3"

View File

@@ -4,7 +4,7 @@ mod vad_iter;
fn main() { fn main() {
let model_path = std::env::var("SILERO_MODEL_PATH") let model_path = std::env::var("SILERO_MODEL_PATH")
.unwrap_or_else(|_| String::from("../../files/silero_vad.onnx")); .unwrap_or_else(|_| String::from("../../src/silero_vad/data/silero_vad.onnx"));
let audio_path = std::env::args() let audio_path = std::env::args()
.nth(1) .nth(1)
.unwrap_or_else(|| String::from("recorder.wav")); .unwrap_or_else(|| String::from("recorder.wav"));

View File

@@ -1,12 +1,17 @@
use crate::utils; use crate::utils;
use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr}; use ndarray::{Array, Array1, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
use ort::session::Session;
use ort::value::Value;
use std::mem::take;
use std::path::Path; use std::path::Path;
#[derive(Debug)] #[derive(Debug)]
pub struct Silero { pub struct Silero {
session: ort::Session, session: Session,
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>, sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>, state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
context: Array1<f32>,
context_size: usize,
} }
impl Silero { impl Silero {
@@ -14,18 +19,24 @@ impl Silero {
sample_rate: utils::SampleRate, sample_rate: utils::SampleRate,
model_path: impl AsRef<Path>, model_path: impl AsRef<Path>,
) -> Result<Self, ort::Error> { ) -> Result<Self, ort::Error> {
let session = ort::Session::builder()?.commit_from_file(model_path)?; let session = Session::builder()?.commit_from_file(model_path)?;
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice()); let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap(); let sample_rate_val: i64 = sample_rate.into();
let context_size = if sample_rate_val == 16000 { 64 } else { 32 };
let context = Array1::<f32>::zeros(context_size);
let sample_rate = Array::from_shape_vec([1], vec![sample_rate_val]).unwrap();
Ok(Self { Ok(Self {
session, session,
sample_rate, sample_rate,
state, state,
context,
context_size,
}) })
} }
pub fn reset(&mut self) { pub fn reset(&mut self) {
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice()); self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
self.context = Array1::<f32>::zeros(self.context_size);
} }
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> { pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
@@ -33,22 +44,41 @@ impl Silero {
.iter() .iter()
.map(|x| (*x as f32) / (i16::MAX as f32)) .map(|x| (*x as f32) / (i16::MAX as f32))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
frame = frame.slice(s![.., ..480]).to_owned(); // Concatenate context with input
let inps = ort::inputs![ let mut input_with_context = Vec::with_capacity(self.context_size + data.len());
frame, input_with_context.extend_from_slice(self.context.as_slice().unwrap());
std::mem::take(&mut self.state), input_with_context.extend_from_slice(&data);
self.sample_rate.clone(),
]?; let frame =
let res = self Array2::<f32>::from_shape_vec([1, input_with_context.len()], input_with_context)
.session .unwrap();
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned(); let frame_value = Value::from_array(frame)?;
Ok(*res["output"] let state_value = Value::from_array(take(&mut self.state))?;
.try_extract_raw_tensor::<f32>() let sr_value = Value::from_array(self.sample_rate.clone())?;
let res = self.session.run([
(&frame_value).into(),
(&state_value).into(),
(&sr_value).into(),
])?;
let (shape, state_data) = res["stateN"].try_extract_tensor::<f32>()?;
let shape_usize: Vec<usize> = shape.as_ref().iter().map(|&d| d as usize).collect();
self.state = ArrayD::from_shape_vec(shape_usize.as_slice(), state_data.to_vec()).unwrap();
// Update context with last context_size samples from the input
if data.len() >= self.context_size {
self.context = Array1::from_vec(data[data.len() - self.context_size..].to_vec());
}
let prob = *res["output"]
.try_extract_tensor::<f32>()
.unwrap() .unwrap()
.1 .1
.first() .first()
.unwrap()) .unwrap();
Ok(prob)
} }
} }

View File

@@ -36,7 +36,7 @@ pub struct VadParams {
impl Default for VadParams { impl Default for VadParams {
fn default() -> Self { fn default() -> Self {
Self { Self {
frame_size: 64, frame_size: 32, // 32ms for 512 samples at 16kHz
threshold: 0.5, threshold: 0.5,
min_silence_duration_ms: 0, min_silence_duration_ms: 0,
speech_pad_ms: 64, speech_pad_ms: 64,

View File

@@ -3,7 +3,7 @@ requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[project] [project]
name = "silero-vad" name = "silero-vad"
version = "6.0.0" version = "6.2.0"
authors = [ authors = [
{name="Silero Team", email="hello@silero.ai"}, {name="Silero Team", email="hello@silero.ai"},
] ]
@@ -28,6 +28,7 @@ classifiers = [
"Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering",
] ]
dependencies = [ dependencies = [
"packaging",
"torch>=1.12.0", "torch>=1.12.0",
"torchaudio>=0.12.0", "torchaudio>=0.12.0",
"onnxruntime>=1.16.1", "onnxruntime>=1.16.1",
@@ -36,3 +37,10 @@ dependencies = [
[project.urls] [project.urls]
Homepage = "https://github.com/snakers4/silero-vad" Homepage = "https://github.com/snakers4/silero-vad"
Issues = "https://github.com/snakers4/silero-vad/issues" Issues = "https://github.com/snakers4/silero-vad/issues"
[project.optional-dependencies]
test = [
"pytest",
"soundfile",
"torch<2.9",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,71 @@
from tinygrad import nn
class TinySileroVAD:
def __init__(self):
"""
from tinygrad.nn.state import safe_load, load_state_dict
tiny_model = TinySileroVAD()
state_dict = safe_load('data/silero_vad_16k.safetensors')
load_state_dict(tiny_model, state_dict)
"""
self.n_fft = 256
self.stride = 128
self.pad = 64
self.cutoff = int(self.n_fft // 2) + 1
self.stft_conv = nn.Conv1d(1, 258, kernel_size=256, stride=self.stride, padding=0, bias=False)
self.conv1 = nn.Conv1d(129, 128, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv1d(128, 64, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1)
self.conv4 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1)
self.lstm_cell = nn.LSTMCell(128, 128)
self.final_conv = nn.Conv1d(128, 1, 1)
def __call__(self, x, state=None):
"""
# full audio example:
import torch
from tinygrad import Tensor
wav = read_audio(audio_path, sampling_rate=16000).unsqueeze(0)
num_samples = 512
context_size = 64
context = Tensor(np.zeros((1, context_size))).float()
outs = []
state = None
if wav.shape[1] % num_samples:
pad_num = num_samples - (wav.shape[1] % num_samples)
wav = torch.nn.functional.pad(wav, (0, pad_num), 'constant', value=0.0)
wav = torch.nn.functional.pad(wav, (context_size, 0))
wav = Tensor(wav.numpy()).float()
for i in tqdm(range(context_size, wav.shape[1], num_samples)):
wavs_batch = wav[:, i-context_size:i+num_samples]
out_chunk, state = tiny_model(wavs_batch, state)
#outs.append(out_chunk.numpy())
outs.append(out_chunk)
predict = outs[0].cat(*outs[1:], dim=1).numpy()
"""
if state is not None:
state = (state[0], state[1])
x = x.pad((0, self.pad), "reflect").unsqueeze(1)
x = self.stft_conv(x)
x = (x[:, :self.cutoff, :]**2 + x[:, self.cutoff:, :]**2).sqrt()
x = self.conv1(x).relu()
x = self.conv2(x).relu()
x = self.conv3(x).relu()
x = self.conv4(x).relu().squeeze(-1)
h, c = self.lstm_cell(x, state)
x = h.unsqueeze(-1)
state = h.stack(c, dim=0)
x = x.relu()
x = self.final_conv(x).sigmoid()
x = x.squeeze(1).mean(axis=1).unsqueeze(1)
return x, state

View File

@@ -2,6 +2,7 @@ import torch
import torchaudio import torchaudio
from typing import Callable, List from typing import Callable, List
import warnings import warnings
from packaging import version
languages = ['ru', 'en', 'de', 'es'] languages = ['ru', 'en', 'de', 'es']
@@ -134,40 +135,60 @@ class Validator():
return outs return outs
def read_audio(path: str, def read_audio(path: str, sampling_rate: int = 16000) -> torch.Tensor:
sampling_rate: int = 16000): ta_ver = version.parse(torchaudio.__version__)
list_backends = torchaudio.list_audio_backends() if ta_ver < version.parse("2.9"):
try:
effects = [['channels', '1'],['rate', str(sampling_rate)]]
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
except:
wav, sr = torchaudio.load(path)
else:
try:
wav, sr = torchaudio.load(path)
except:
try:
from torchcodec.decoders import AudioDecoder
samples = AudioDecoder(path).get_all_samples()
wav = samples.data
sr = samples.sample_rate
except ImportError:
raise RuntimeError(
f"torchaudio version {torchaudio.__version__} requires torchcodec for audio I/O. "
+ "Install torchcodec or pin torchaudio < 2.9"
)
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \ if wav.ndim > 1 and wav.size(0) > 1:
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)' wav = wav.mean(dim=0, keepdim=True)
try: if sr != sampling_rate:
effects = [ wav = torchaudio.transforms.Resample(sr, sampling_rate)(wav)
['channels', '1'],
['rate', str(sampling_rate)]
]
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
except:
wav, sr = torchaudio.load(path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != sampling_rate:
transform = torchaudio.transforms.Resample(orig_freq=sr,
new_freq=sampling_rate)
wav = transform(wav)
sr = sampling_rate
assert sr == sampling_rate
return wav.squeeze(0) return wav.squeeze(0)
def save_audio(path: str, def save_audio(path: str, tensor: torch.Tensor, sampling_rate: int = 16000):
tensor: torch.Tensor, tensor = tensor.detach().cpu()
sampling_rate: int = 16000): if tensor.ndim == 1:
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16) tensor = tensor.unsqueeze(0)
ta_ver = version.parse(torchaudio.__version__)
try:
torchaudio.save(path, tensor, sampling_rate, bits_per_sample=16)
except Exception:
if ta_ver >= version.parse("2.9"):
try:
from torchcodec.encoders import AudioEncoder
encoder = AudioEncoder(tensor, sample_rate=16000)
encoder.to_file(path)
except ImportError:
raise RuntimeError(
f"torchaudio version {torchaudio.__version__} requires torchcodec for saving. "
+ "Install torchcodec or pin torchaudio < 2.9"
)
else:
raise
def init_jit_model(model_path: str, def init_jit_model(model_path: str,
@@ -202,7 +223,7 @@ def get_speech_timestamps(audio: torch.Tensor,
progress_tracking_callback: Callable[[float], None] = None, progress_tracking_callback: Callable[[float], None] = None,
neg_threshold: float = None, neg_threshold: float = None,
window_size_samples: int = 512, window_size_samples: int = 512,
min_silence_at_max_speech: float = 98, min_silence_at_max_speech: int = 98,
use_max_poss_sil_at_max_speech: bool = True): use_max_poss_sil_at_max_speech: bool = True):
""" """
@@ -227,7 +248,7 @@ def get_speech_timestamps(audio: torch.Tensor,
max_speech_duration_s: int (default - inf) max_speech_duration_s: int (default - inf)
Maximum duration of speech chunks in seconds Maximum duration of speech chunks in seconds
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting. Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent aggressive cutting.
Otherwise, they will be split aggressively just before max_speech_duration_s. Otherwise, they will be split aggressively just before max_speech_duration_s.
min_silence_duration_ms: int (default - 100 milliseconds) min_silence_duration_ms: int (default - 100 milliseconds)
@@ -251,7 +272,7 @@ def get_speech_timestamps(audio: torch.Tensor,
neg_threshold: float (default = threshold - 0.15) neg_threshold: float (default = threshold - 0.15)
Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH. Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH.
min_silence_at_max_speech: float (default - 98ms) min_silence_at_max_speech: int (default - 98ms)
Minimum silence duration in ms which is used to avoid abrupt cuts when max_speech_duration_s is reached Minimum silence duration in ms which is used to avoid abrupt cuts when max_speech_duration_s is reached
use_max_poss_sil_at_max_speech: bool (default - True) use_max_poss_sil_at_max_speech: bool (default - True)
@@ -289,7 +310,6 @@ def get_speech_timestamps(audio: torch.Tensor,
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates") raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
window_size_samples = 512 if sampling_rate == 16000 else 256 window_size_samples = 512 if sampling_rate == 16000 else 256
hop_size_samples = int(window_size_samples)
model.reset_states() model.reset_states()
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
@@ -301,17 +321,14 @@ def get_speech_timestamps(audio: torch.Tensor,
audio_length_samples = len(audio) audio_length_samples = len(audio)
speech_probs = [] speech_probs = []
for current_start_sample in range(0, audio_length_samples, hop_size_samples): for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[current_start_sample: current_start_sample + window_size_samples] chunk = audio[current_start_sample: current_start_sample + window_size_samples]
if len(chunk) < window_size_samples: if len(chunk) < window_size_samples:
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk)))) chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
try: speech_prob = model(chunk, sampling_rate).item()
speech_prob = model(chunk, sampling_rate).item()
except Exception as e:
import ipdb; ipdb.set_trace()
speech_probs.append(speech_prob) speech_probs.append(speech_prob)
# caculate progress and seng it to callback function # calculate progress and send it to callback function
progress = current_start_sample + hop_size_samples progress = current_start_sample + window_size_samples
if progress > audio_length_samples: if progress > audio_length_samples:
progress = audio_length_samples progress = audio_length_samples
progress_percent = (progress / audio_length_samples) * 100 progress_percent = (progress / audio_length_samples) * 100
@@ -329,53 +346,70 @@ def get_speech_timestamps(audio: torch.Tensor,
possible_ends = [] possible_ends = []
for i, speech_prob in enumerate(speech_probs): for i, speech_prob in enumerate(speech_probs):
if (speech_prob >= threshold) and temp_end: cur_sample = window_size_samples * i
if temp_end != 0:
sil_dur = (hop_size_samples * i) - temp_end
if sil_dur > min_silence_samples_at_max_speech:
possible_ends.append((temp_end, sil_dur))
temp_end = 0
if next_start < prev_end:
next_start = hop_size_samples * i
# If speech returns after a temp_end, record candidate silence if long enough and clear temp_end
if (speech_prob >= threshold) and temp_end:
sil_dur = cur_sample - temp_end
if sil_dur > min_silence_samples_at_max_speech:
possible_ends.append((temp_end, sil_dur))
temp_end = 0
if next_start < prev_end:
next_start = cur_sample
# Start of speech
if (speech_prob >= threshold) and not triggered: if (speech_prob >= threshold) and not triggered:
triggered = True triggered = True
current_speech['start'] = hop_size_samples * i current_speech['start'] = cur_sample
continue continue
if triggered and (hop_size_samples * i) - current_speech['start'] > max_speech_samples: # Max speech length reached: decide where to cut
if possible_ends: if triggered and (cur_sample - current_speech['start'] > max_speech_samples):
if use_max_poss_sil_at_max_speech: if use_max_poss_sil_at_max_speech and possible_ends:
prev_end, dur = max(possible_ends, key=lambda x: x[1]) # use the longest possible silence segment in the current speech chunk prev_end, dur = max(possible_ends, key=lambda x: x[1]) # use the longest possible silence segment in the current speech chunk
else:
prev_end, dur = possible_ends[-1] # use the last possible silence segement
current_speech['end'] = prev_end current_speech['end'] = prev_end
speeches.append(current_speech) speeches.append(current_speech)
current_speech = {} current_speech = {}
next_start = prev_end + dur next_start = prev_end + dur
if next_start < prev_end + hop_size_samples * i: # previously reached silence (< neg_thres) and is still not speech (< thres)
#triggered = False if next_start < prev_end + cur_sample: # previously reached silence (< neg_thres) and is still not speech (< thres)
current_speech['start'] = next_start current_speech['start'] = next_start
else: else:
triggered = False triggered = False
#current_speech['start'] = next_start
prev_end = next_start = temp_end = 0 prev_end = next_start = temp_end = 0
possible_ends = [] possible_ends = []
else: else:
current_speech['end'] = hop_size_samples * i # Legacy max-speech cut (use_max_poss_sil_at_max_speech=False): prefer last valid silence (prev_end) if available
speeches.append(current_speech) if prev_end:
current_speech = {} current_speech['end'] = prev_end
prev_end = next_start = temp_end = 0 speeches.append(current_speech)
triggered = False current_speech = {}
possible_ends = [] if next_start < prev_end:
continue triggered = False
else:
current_speech['start'] = next_start
prev_end = next_start = temp_end = 0
possible_ends = []
else:
# No prev_end -> fallback to cutting at current sample
current_speech['end'] = cur_sample
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
possible_ends = []
continue
# Silence detection while in speech
if (speech_prob < neg_threshold) and triggered: if (speech_prob < neg_threshold) and triggered:
if not temp_end: if not temp_end:
temp_end = hop_size_samples * i temp_end = cur_sample
# if ((hop_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence sil_dur_now = cur_sample - temp_end
# prev_end = temp_end
if (hop_size_samples * i) - temp_end < min_silence_samples: if not use_max_poss_sil_at_max_speech and sil_dur_now > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
prev_end = temp_end
if sil_dur_now < min_silence_samples:
continue continue
else: else:
current_speech['end'] = temp_end current_speech['end'] = temp_end
@@ -416,7 +450,7 @@ def get_speech_timestamps(audio: torch.Tensor,
speech_dict['end'] *= step speech_dict['end'] *= step
if visualize_probs: if visualize_probs:
make_visualization(speech_probs, hop_size_samples / sampling_rate) make_visualization(speech_probs, window_size_samples / sampling_rate)
return speeches return speeches
@@ -607,6 +641,8 @@ def drop_chunks(tss: List[dict],
chunks.append((wav[cur_start: i['start']])) chunks.append((wav[cur_start: i['start']]))
cur_start = i['end'] cur_start = i['end']
chunks.append(wav[cur_start:])
return torch.cat(chunks) return torch.cat(chunks)

View File

@@ -118,8 +118,6 @@ class SileroVadDataset(Dataset):
assert len(gt) == len(wav) / self.num_samples assert len(gt) == len(wav) / self.num_samples
mask[gt == 0]
return wav, gt, mask return wav, gt, mask
def get_ground_truth_annotated(self, annotation, audio_length_samples): def get_ground_truth_annotated(self, annotation, audio_length_samples):
@@ -240,6 +238,7 @@ def train(config,
loss = criterion(stacked, targets) loss = criterion(stacked, targets)
loss = (loss * masks).mean() loss = (loss * masks).mean()
optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
losses.update(loss.item(), masks.numel()) losses.update(loss.item(), masks.numel())