mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 01:49:22 +08:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2688a6e352 | ||
|
|
c5542cd4a8 | ||
|
|
4725c40105 | ||
|
|
cfe63384f0 | ||
|
|
2a08f0b90d | ||
|
|
21ffe8576e | ||
|
|
d5b52843f7 | ||
|
|
fb7d7c7f5d | ||
|
|
e7c3d6f2bd | ||
|
|
390614894d | ||
|
|
33eb4c7f84 | ||
|
|
c913b0c4b3 | ||
|
|
4dd2e8f6f9 | ||
|
|
63fe03add7 | ||
|
|
29a582ba37 | ||
|
|
3ca476e4fb | ||
|
|
7de462944a | ||
|
|
12b0121993 | ||
|
|
7b0aaa1c4c | ||
|
|
540eff3e24 | ||
|
|
dfeba4fc0f |
49
examples/c++/README.md
Normal file
49
examples/c++/README.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Silero-VAD V6 in C++ (based on LibTorch)
|
||||||
|
|
||||||
|
This is the source code for Silero-VAD V6 in C++, utilizing LibTorch & Onnxruntime.
|
||||||
|
You should compare its results with the Python version.
|
||||||
|
Results at 16 and 8kHz have been tested. Batch and CUDA inference options are deprecated.
|
||||||
|
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
- GCC 11.4.0 (GCC >= 5.1)
|
||||||
|
- Onnxruntime 1.11.0 (other versions are also acceptable)
|
||||||
|
- LibTorch 1.13.0 (other versions are also acceptable)
|
||||||
|
|
||||||
|
## Download LibTorch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
-Onnxruntime
|
||||||
|
$wget https://github.com/microsoft/onnxruntime/releases/download/v1.11.1/onnxruntime-linux-x64-1.11.1.tgz
|
||||||
|
$tar -xvf onnxruntime-linux-x64-1.11.1.tgz
|
||||||
|
$ln -s onnxruntime-linux-x64-1.11.1 onnxruntime-linux #soft-link
|
||||||
|
|
||||||
|
-Libtorch
|
||||||
|
$wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip
|
||||||
|
$unzip libtorch-shared-with-deps-1.13.0+cpu.zip
|
||||||
|
```
|
||||||
|
|
||||||
|
## Compilation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
-ONNX-build
|
||||||
|
$g++ main.cc silero.cc -I ./onnxruntime-linux/include/ -L ./onnxruntime-linux/lib/ -lonnxruntime -Wl,-rpath,./onnxruntime-linux/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_ONNX
|
||||||
|
|
||||||
|
-TORCH-build
|
||||||
|
$g++ main.cc silero.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_TORCH
|
||||||
|
```
|
||||||
|
|
||||||
|
## Optional Compilation Flags
|
||||||
|
-DUSE_TORCH
|
||||||
|
-DUSE_ONNX
|
||||||
|
|
||||||
|
## Run the Program
|
||||||
|
To run the program, use the following command:
|
||||||
|
|
||||||
|
`./silero <sample.wav> <SampleRate> <threshold>`
|
||||||
|
`./silero aepyx.wav 16000 0.5`
|
||||||
|
`./silero aepyx_8k.wav 8000 0.5`
|
||||||
|
|
||||||
|
The sample file aepyx.wav is part of the Voxconverse dataset.
|
||||||
|
File details: aepyx.wav is a 16kHz, 16-bit audio file.
|
||||||
|
File details: aepyx_8k.wav is a 8kHz, 16-bit audio file.
|
||||||
BIN
examples/c++/aepyx.wav
Normal file
BIN
examples/c++/aepyx.wav
Normal file
Binary file not shown.
BIN
examples/c++/aepyx_8k.wav
Normal file
BIN
examples/c++/aepyx_8k.wav
Normal file
Binary file not shown.
61
examples/c++/main.cc
Normal file
61
examples/c++/main.cc
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include "silero.h"
|
||||||
|
#include "wav.h"
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
|
||||||
|
if(argc != 4){
|
||||||
|
std::cerr<<"Usage : "<<argv[0]<<" <wav.path> <SampleRate> <Threshold>"<<std::endl;
|
||||||
|
std::cerr<<"Usage : "<<argv[0]<<" sample.wav 16000 0.5"<<std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string wav_path = argv[1];
|
||||||
|
float sample_rate = std::stof(argv[2]);
|
||||||
|
float threshold = std::stof(argv[3]);
|
||||||
|
|
||||||
|
if (sample_rate != 16000 && sample_rate != 8000) {
|
||||||
|
std::cout<<"Unsupported sample rate (only 16000 or 8000)."<<std::endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Load Model
|
||||||
|
#ifdef USE_TORCH
|
||||||
|
std::string model_path = "../../src/silero_vad/data/silero_vad.jit";
|
||||||
|
#elif USE_ONNX
|
||||||
|
std::string model_path = "../../src/silero_vad/data/silero_vad.onnx";
|
||||||
|
#endif
|
||||||
|
silero::VadIterator vad(model_path);
|
||||||
|
|
||||||
|
vad.threshold=threshold; //(Default:0.5)
|
||||||
|
vad.sample_rate=sample_rate; //16000Hz,8000Hz. (Default:16000)
|
||||||
|
vad.print_as_samples=false; //if true, it prints time-stamp with samples. otherwise, in seconds
|
||||||
|
//(Default:false)
|
||||||
|
|
||||||
|
vad.SetVariables();
|
||||||
|
|
||||||
|
// Read wav
|
||||||
|
wav::WavReader wav_reader(wav_path);
|
||||||
|
std::vector<float> input_wav(wav_reader.num_samples());
|
||||||
|
|
||||||
|
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||||
|
{
|
||||||
|
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
|
||||||
|
}
|
||||||
|
|
||||||
|
vad.SpeechProbs(input_wav);
|
||||||
|
|
||||||
|
std::vector<silero::Interval> speeches = vad.GetSpeechTimestamps();
|
||||||
|
for(const auto& speech : speeches){
|
||||||
|
if(vad.print_as_samples){
|
||||||
|
std::cout<<"{'start': "<<static_cast<int>(speech.start)<<", 'end': "<<static_cast<int>(speech.end)<<"}"<<std::endl;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
std::cout<<"{'start': "<<speech.start<<", 'end': "<<speech.end<<"}"<<std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
273
examples/c++/silero.cc
Normal file
273
examples/c++/silero.cc
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
// silero.cc
|
||||||
|
// Author : NathanJHLee
|
||||||
|
// Created On : 2025-11-10
|
||||||
|
// Description : silero 6.2 system for onnx-runtime(c++) and torch-script(c++)
|
||||||
|
// Version : 1.3
|
||||||
|
|
||||||
|
#include "silero.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace silero {
|
||||||
|
|
||||||
|
#ifdef USE_TORCH
|
||||||
|
VadIterator::VadIterator(const std::string &model_path,
|
||||||
|
float threshold,
|
||||||
|
int sample_rate,
|
||||||
|
int window_size_ms,
|
||||||
|
int speech_pad_ms,
|
||||||
|
int min_silence_duration_ms,
|
||||||
|
int min_speech_duration_ms,
|
||||||
|
int max_duration_merge_ms,
|
||||||
|
bool print_as_samples)
|
||||||
|
: threshold(threshold), sample_rate(sample_rate), window_size_ms(window_size_ms),
|
||||||
|
speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms),
|
||||||
|
min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms),
|
||||||
|
print_as_samples(print_as_samples)
|
||||||
|
{
|
||||||
|
init_torch_model(model_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
VadIterator::~VadIterator(){
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void VadIterator::init_torch_model(const std::string& model_path) {
|
||||||
|
at::set_num_threads(1);
|
||||||
|
model = torch::jit::load(model_path);
|
||||||
|
|
||||||
|
model.eval();
|
||||||
|
torch::NoGradGuard no_grad;
|
||||||
|
std::cout<<"Silero libtorch-Model loaded successfully"<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::SpeechProbs(std::vector<float>& input_wav) {
|
||||||
|
int num_samples = input_wav.size();
|
||||||
|
int num_chunks = num_samples / window_size_samples;
|
||||||
|
int remainder_samples = num_samples % window_size_samples;
|
||||||
|
total_sample_size += num_samples;
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> chunks;
|
||||||
|
|
||||||
|
for (int i = 0; i < num_chunks; i++) {
|
||||||
|
float* chunk_start = input_wav.data() + i * window_size_samples;
|
||||||
|
torch::Tensor chunk = torch::from_blob(chunk_start, {1, window_size_samples}, torch::kFloat32);
|
||||||
|
chunks.push_back(chunk);
|
||||||
|
|
||||||
|
if (i == num_chunks - 1 && remainder_samples > 0) {
|
||||||
|
int remaining_samples = num_samples - num_chunks * window_size_samples;
|
||||||
|
float* chunk_start_remainder = input_wav.data() + num_chunks * window_size_samples;
|
||||||
|
torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1, remaining_samples}, torch::kFloat32);
|
||||||
|
torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples - remaining_samples}, torch::kFloat32)}, 1);
|
||||||
|
chunks.push_back(padded_chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!chunks.empty()) {
|
||||||
|
std::vector<torch::Tensor> outputs;
|
||||||
|
torch::Tensor batched_chunks = torch::stack(chunks);
|
||||||
|
for (size_t i = 0; i < chunks.size(); i++) {
|
||||||
|
torch::NoGradGuard no_grad;
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
inputs.push_back(batched_chunks[i]);
|
||||||
|
inputs.push_back(sample_rate);
|
||||||
|
torch::Tensor output = model.forward(inputs).toTensor();
|
||||||
|
outputs.push_back(output);
|
||||||
|
}
|
||||||
|
torch::Tensor all_outputs = torch::stack(outputs);
|
||||||
|
for (size_t i = 0; i < chunks.size(); i++) {
|
||||||
|
float output_f = all_outputs[i].item<float>();
|
||||||
|
outputs_prob.push_back(output_f);
|
||||||
|
//////To print Probs by libtorch
|
||||||
|
//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#elif USE_ONNX
|
||||||
|
|
||||||
|
VadIterator::VadIterator(const std::string &model_path,
|
||||||
|
float threshold,
|
||||||
|
int sample_rate,
|
||||||
|
int window_size_ms,
|
||||||
|
int speech_pad_ms,
|
||||||
|
int min_silence_duration_ms,
|
||||||
|
int min_speech_duration_ms,
|
||||||
|
int max_duration_merge_ms,
|
||||||
|
bool print_as_samples)
|
||||||
|
:sample_rate(sample_rate), threshold(threshold), window_size_ms(window_size_ms),
|
||||||
|
speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms),
|
||||||
|
min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms),
|
||||||
|
print_as_samples(print_as_samples),
|
||||||
|
env(ORT_LOGGING_LEVEL_ERROR, "Vad"), session_options(), session(nullptr), allocator(),
|
||||||
|
memory_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU)), context_samples(64),
|
||||||
|
_context(64, 0.0f), current_sample(0), size_state(2 * 1 * 128),
|
||||||
|
input_node_names({"input", "state", "sr"}), output_node_names({"output", "stateN"}),
|
||||||
|
state_node_dims{2, 1, 128}, sr_node_dims{1}
|
||||||
|
|
||||||
|
{
|
||||||
|
init_onnx_model(model_path);
|
||||||
|
}
|
||||||
|
VadIterator::~VadIterator(){
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::init_onnx_model(const std::string& model_path) {
|
||||||
|
int inter_threads=1;
|
||||||
|
int intra_threads=1;
|
||||||
|
session_options.SetIntraOpNumThreads(intra_threads);
|
||||||
|
session_options.SetInterOpNumThreads(inter_threads);
|
||||||
|
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||||
|
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
|
||||||
|
std::cout<<"Silero onnx-Model loaded successfully"<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
float VadIterator::predict(const std::vector<float>& data_chunk) {
|
||||||
|
// _context와 현재 청크를 결합하여 입력 데이터 구성
|
||||||
|
std::vector<float> new_data(effective_window_size, 0.0f);
|
||||||
|
std::copy(_context.begin(), _context.end(), new_data.begin());
|
||||||
|
std::copy(data_chunk.begin(), data_chunk.end(), new_data.begin() + context_samples);
|
||||||
|
input = new_data;
|
||||||
|
|
||||||
|
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
|
||||||
|
memory_info, input.data(), input.size(), input_node_dims, 2);
|
||||||
|
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
|
||||||
|
memory_info, _state.data(), _state.size(), state_node_dims, 3);
|
||||||
|
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
|
||||||
|
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
|
||||||
|
ort_inputs.clear();
|
||||||
|
ort_inputs.push_back(std::move(input_ort));
|
||||||
|
ort_inputs.push_back(std::move(state_ort));
|
||||||
|
ort_inputs.push_back(std::move(sr_ort));
|
||||||
|
|
||||||
|
ort_outputs = session->Run(
|
||||||
|
Ort::RunOptions{ nullptr },
|
||||||
|
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
|
||||||
|
output_node_names.data(), output_node_names.size());
|
||||||
|
|
||||||
|
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0]; // ONNX 출력: 첫 번째 값이 음성 확률
|
||||||
|
|
||||||
|
float* stateN = ort_outputs[1].GetTensorMutableData<float>(); // 두 번째 출력값: 상태 업데이트
|
||||||
|
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
|
||||||
|
|
||||||
|
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
|
||||||
|
// _context 업데이트: new_data의 마지막 context_samples 유지
|
||||||
|
|
||||||
|
return speech_prob;
|
||||||
|
}
|
||||||
|
void VadIterator::SpeechProbs(std::vector<float>& input_wav) {
|
||||||
|
reset_states();
|
||||||
|
total_sample_size = static_cast<int>(input_wav.size());
|
||||||
|
for (size_t j = 0; j < static_cast<size_t>(total_sample_size); j += window_size_samples) {
|
||||||
|
if (j + window_size_samples > static_cast<size_t>(total_sample_size))
|
||||||
|
break;
|
||||||
|
std::vector<float> chunk(input_wav.begin() + j, input_wav.begin() + j + window_size_samples);
|
||||||
|
float speech_prob = predict(chunk);
|
||||||
|
outputs_prob.push_back(speech_prob);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void VadIterator::reset_states() {
|
||||||
|
triggered = false;
|
||||||
|
current_sample = 0;
|
||||||
|
temp_end = 0;
|
||||||
|
outputs_prob.clear();
|
||||||
|
total_sample_size = 0;
|
||||||
|
|
||||||
|
#ifdef USE_TORCH
|
||||||
|
model.run_method("reset_states"); // Reset model states if applicable
|
||||||
|
#elif USE_ONNX
|
||||||
|
std::memset(_state.data(), 0, _state.size() * sizeof(float));
|
||||||
|
std::fill(_context.begin(), _context.end(), 0.0f);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Interval> VadIterator::GetSpeechTimestamps() {
|
||||||
|
std::vector<Interval> speeches = DoVad();
|
||||||
|
if(!print_as_samples){
|
||||||
|
for (auto& speech : speeches) {
|
||||||
|
speech.start /= sample_rate;
|
||||||
|
speech.end /= sample_rate;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return speeches;
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::SetVariables(){
|
||||||
|
// Initialize internal engine parameters
|
||||||
|
init_engine(window_size_ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::init_engine(int window_size_ms) {
|
||||||
|
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
|
||||||
|
speech_pad_samples = sample_rate * speech_pad_ms / 1000;
|
||||||
|
window_size_samples = sample_rate / 1000 * window_size_ms;
|
||||||
|
min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
|
||||||
|
#ifdef USE_ONNX
|
||||||
|
//for ONNX
|
||||||
|
context_samples=window_size_samples / 8;
|
||||||
|
_context.assign(context_samples, 0.0f);
|
||||||
|
|
||||||
|
effective_window_size = window_size_samples + context_samples; // 예: 512 + 64 = 576 samples
|
||||||
|
input_node_dims[0] = 1;
|
||||||
|
input_node_dims[1] = effective_window_size;
|
||||||
|
_state.resize(size_state);
|
||||||
|
sr.resize(1);
|
||||||
|
sr[0] = sample_rate;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Interval> VadIterator::DoVad() {
|
||||||
|
std::vector<Interval> speeches;
|
||||||
|
for (size_t i = 0; i < outputs_prob.size(); ++i) {
|
||||||
|
float speech_prob = outputs_prob[i];
|
||||||
|
current_sample += window_size_samples;
|
||||||
|
if (speech_prob >= threshold && temp_end != 0) {
|
||||||
|
temp_end = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speech_prob >= threshold) {
|
||||||
|
if (!triggered) {
|
||||||
|
triggered = true;
|
||||||
|
Interval segment;
|
||||||
|
segment.start = std::max(0, current_sample - speech_pad_samples - window_size_samples);
|
||||||
|
speeches.push_back(segment);
|
||||||
|
}
|
||||||
|
}else {
|
||||||
|
if (triggered) {
|
||||||
|
if (speech_prob < threshold - 0.15f) {
|
||||||
|
if (temp_end == 0) {
|
||||||
|
temp_end = current_sample;
|
||||||
|
}
|
||||||
|
if (current_sample - temp_end >= min_silence_samples) {
|
||||||
|
Interval& segment = speeches.back();
|
||||||
|
segment.end = temp_end + speech_pad_samples - window_size_samples;
|
||||||
|
temp_end = 0;
|
||||||
|
triggered = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if (triggered) {
|
||||||
|
std::cout<<"Finalizing active speech segment at stream end."<<std::endl;
|
||||||
|
Interval& segment = speeches.back();
|
||||||
|
segment.end = total_sample_size;
|
||||||
|
triggered = false;
|
||||||
|
}
|
||||||
|
speeches.erase(std::remove_if(speeches.begin(), speeches.end(),
|
||||||
|
[this](const Interval& speech) {
|
||||||
|
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
|
||||||
|
}), speeches.end());
|
||||||
|
|
||||||
|
reset_states();
|
||||||
|
return speeches;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace silero
|
||||||
|
|
||||||
123
examples/c++/silero.h
Normal file
123
examples/c++/silero.h
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
#ifndef SILERO_H
|
||||||
|
#define SILERO_H
|
||||||
|
|
||||||
|
// silero.h
|
||||||
|
// Author : NathanJHLee
|
||||||
|
// Created On : 2025-11-10
|
||||||
|
// Description : silero 6.2 system for onnx-runtime(c++) and torch-script(c++)
|
||||||
|
// Version : 1.3
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <chrono>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
#ifdef USE_TORCH
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <torch/script.h>
|
||||||
|
#elif USE_ONNX
|
||||||
|
#include "onnxruntime_cxx_api.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace silero {
|
||||||
|
|
||||||
|
struct Interval {
|
||||||
|
float start;
|
||||||
|
float end;
|
||||||
|
int numberOfSubseg;
|
||||||
|
|
||||||
|
void initialize() {
|
||||||
|
start = 0;
|
||||||
|
end = 0;
|
||||||
|
numberOfSubseg = 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class VadIterator {
|
||||||
|
public:
|
||||||
|
VadIterator(const std::string &model_path,
|
||||||
|
float threshold = 0.5,
|
||||||
|
int sample_rate = 16000,
|
||||||
|
int window_size_ms = 32,
|
||||||
|
int speech_pad_ms = 30,
|
||||||
|
int min_silence_duration_ms = 100,
|
||||||
|
int min_speech_duration_ms = 250,
|
||||||
|
int max_duration_merge_ms = 300,
|
||||||
|
bool print_as_samples = false);
|
||||||
|
~VadIterator();
|
||||||
|
|
||||||
|
// Batch (non-streaming) interface (for backward compatibility)
|
||||||
|
void SpeechProbs(std::vector<float>& input_wav);
|
||||||
|
std::vector<Interval> GetSpeechTimestamps();
|
||||||
|
void SetVariables();
|
||||||
|
|
||||||
|
// Public parameters (can be modified by user)
|
||||||
|
float threshold;
|
||||||
|
int sample_rate;
|
||||||
|
int window_size_ms;
|
||||||
|
int min_speech_duration_ms;
|
||||||
|
int max_duration_merge_ms;
|
||||||
|
bool print_as_samples;
|
||||||
|
|
||||||
|
private:
|
||||||
|
#ifdef USE_TORCH
|
||||||
|
torch::jit::script::Module model;
|
||||||
|
void init_torch_model(const std::string& model_path);
|
||||||
|
#elif USE_ONNX
|
||||||
|
Ort::Env env; // 환경 객체
|
||||||
|
Ort::SessionOptions session_options; // 세션 옵션
|
||||||
|
std::shared_ptr<Ort::Session> session; // ONNX 세션
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator; // 기본 할당자
|
||||||
|
Ort::MemoryInfo memory_info; // 메모리 정보 (CPU)
|
||||||
|
|
||||||
|
void init_onnx_model(const std::string& model_path);
|
||||||
|
float predict(const std::vector<float>& data_chunk);
|
||||||
|
|
||||||
|
//const int context_samples; // 예: 64 samples
|
||||||
|
int context_samples; // 예: 64 samples
|
||||||
|
std::vector<float> _context; // 초기값 모두 0
|
||||||
|
int effective_window_size;
|
||||||
|
|
||||||
|
// ONNX 입력/출력 관련 버퍼 및 노드 이름들
|
||||||
|
std::vector<Ort::Value> ort_inputs;
|
||||||
|
std::vector<const char*> input_node_names;
|
||||||
|
std::vector<float> input;
|
||||||
|
unsigned int size_state; // 고정값: 2*1*128
|
||||||
|
std::vector<float> _state;
|
||||||
|
std::vector<int64_t> sr;
|
||||||
|
int64_t input_node_dims[2]; // [1, effective_window_size]
|
||||||
|
const int64_t state_node_dims[3]; // [ 2, 1, 128 ]
|
||||||
|
const int64_t sr_node_dims[1]; // [ 1 ]
|
||||||
|
std::vector<Ort::Value> ort_outputs;
|
||||||
|
std::vector<const char*> output_node_names; // 기본값: [ "output", "stateN" ]
|
||||||
|
#endif
|
||||||
|
std::vector<float> outputs_prob; // used in batch mode
|
||||||
|
int min_silence_samples;
|
||||||
|
int min_speech_samples;
|
||||||
|
int speech_pad_samples;
|
||||||
|
int window_size_samples;
|
||||||
|
int duration_merge_samples;
|
||||||
|
int current_sample = 0;
|
||||||
|
int total_sample_size = 0;
|
||||||
|
int min_silence_duration_ms;
|
||||||
|
int speech_pad_ms;
|
||||||
|
bool triggered = false;
|
||||||
|
int temp_end = 0;
|
||||||
|
int global_end = 0;
|
||||||
|
int erase_tail_count = 0;
|
||||||
|
|
||||||
|
|
||||||
|
void init_engine(int window_size_ms);
|
||||||
|
void reset_states();
|
||||||
|
std::vector<Interval> DoVad();
|
||||||
|
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace silero
|
||||||
|
|
||||||
|
#endif // SILERO_H
|
||||||
|
|
||||||
237
examples/c++/wav.h
Normal file
237
examples/c++/wav.h
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
// Copyright (c) 2016 Personal (Binbin Zhang)
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef FRONTEND_WAV_H_
|
||||||
|
#define FRONTEND_WAV_H_
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
// #include "utils/log.h"
|
||||||
|
|
||||||
|
namespace wav {
|
||||||
|
|
||||||
|
struct WavHeader {
|
||||||
|
char riff[4]; // "riff"
|
||||||
|
unsigned int size;
|
||||||
|
char wav[4]; // "WAVE"
|
||||||
|
char fmt[4]; // "fmt "
|
||||||
|
unsigned int fmt_size;
|
||||||
|
uint16_t format;
|
||||||
|
uint16_t channels;
|
||||||
|
unsigned int sample_rate;
|
||||||
|
unsigned int bytes_per_second;
|
||||||
|
uint16_t block_size;
|
||||||
|
uint16_t bit;
|
||||||
|
char data[4]; // "data"
|
||||||
|
unsigned int data_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WavReader {
|
||||||
|
public:
|
||||||
|
WavReader() : data_(nullptr) {}
|
||||||
|
explicit WavReader(const std::string& filename) { Open(filename); }
|
||||||
|
|
||||||
|
bool Open(const std::string& filename) {
|
||||||
|
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
|
||||||
|
if (NULL == fp) {
|
||||||
|
std::cout << "Error in read " << filename;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
WavHeader header;
|
||||||
|
fread(&header, 1, sizeof(header), fp);
|
||||||
|
if (header.fmt_size < 16) {
|
||||||
|
printf("WaveData: expect PCM format data "
|
||||||
|
"to have fmt chunk of at least size 16.\n");
|
||||||
|
return false;
|
||||||
|
} else if (header.fmt_size > 16) {
|
||||||
|
int offset = 44 - 8 + header.fmt_size - 16;
|
||||||
|
fseek(fp, offset, SEEK_SET);
|
||||||
|
fread(header.data, 8, sizeof(char), fp);
|
||||||
|
}
|
||||||
|
// check "riff" "WAVE" "fmt " "data"
|
||||||
|
|
||||||
|
// Skip any sub-chunks between "fmt" and "data". Usually there will
|
||||||
|
// be a single "fact" sub chunk, but on Windows there can also be a
|
||||||
|
// "list" sub chunk.
|
||||||
|
while (0 != strncmp(header.data, "data", 4)) {
|
||||||
|
// We will just ignore the data in these chunks.
|
||||||
|
fseek(fp, header.data_size, SEEK_CUR);
|
||||||
|
// read next sub chunk
|
||||||
|
fread(header.data, 8, sizeof(char), fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (header.data_size == 0) {
|
||||||
|
int offset = ftell(fp);
|
||||||
|
fseek(fp, 0, SEEK_END);
|
||||||
|
header.data_size = ftell(fp) - offset;
|
||||||
|
fseek(fp, offset, SEEK_SET);
|
||||||
|
}
|
||||||
|
|
||||||
|
num_channel_ = header.channels;
|
||||||
|
sample_rate_ = header.sample_rate;
|
||||||
|
bits_per_sample_ = header.bit;
|
||||||
|
int num_data = header.data_size / (bits_per_sample_ / 8);
|
||||||
|
data_ = new float[num_data]; // Create 1-dim array
|
||||||
|
num_samples_ = num_data / num_channel_;
|
||||||
|
|
||||||
|
std::cout << "num_channel_ :" << num_channel_ << std::endl;
|
||||||
|
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
|
||||||
|
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
|
||||||
|
std::cout << "num_samples :" << num_data << std::endl;
|
||||||
|
std::cout << "num_data_size :" << header.data_size << std::endl;
|
||||||
|
|
||||||
|
switch (bits_per_sample_) {
|
||||||
|
case 8: {
|
||||||
|
char sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(char), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 16: {
|
||||||
|
int16_t sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(int16_t), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 32:
|
||||||
|
{
|
||||||
|
if (header.format == 1) //S32
|
||||||
|
{
|
||||||
|
int sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(int), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (header.format == 3) // IEEE-float
|
||||||
|
{
|
||||||
|
float sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(float), fp);
|
||||||
|
data_[i] = static_cast<float>(sample);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
printf("unsupported quantization bits\n");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
printf("unsupported quantization bits\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
fclose(fp);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_channel() const { return num_channel_; }
|
||||||
|
int sample_rate() const { return sample_rate_; }
|
||||||
|
int bits_per_sample() const { return bits_per_sample_; }
|
||||||
|
int num_samples() const { return num_samples_; }
|
||||||
|
|
||||||
|
~WavReader() {
|
||||||
|
delete[] data_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float* data() const { return data_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_channel_;
|
||||||
|
int sample_rate_;
|
||||||
|
int bits_per_sample_;
|
||||||
|
int num_samples_; // sample points per channel
|
||||||
|
float* data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WavWriter {
|
||||||
|
public:
|
||||||
|
WavWriter(const float* data, int num_samples, int num_channel,
|
||||||
|
int sample_rate, int bits_per_sample)
|
||||||
|
: data_(data),
|
||||||
|
num_samples_(num_samples),
|
||||||
|
num_channel_(num_channel),
|
||||||
|
sample_rate_(sample_rate),
|
||||||
|
bits_per_sample_(bits_per_sample) {}
|
||||||
|
|
||||||
|
void Write(const std::string& filename) {
|
||||||
|
FILE* fp = fopen(filename.c_str(), "w");
|
||||||
|
// init char 'riff' 'WAVE' 'fmt ' 'data'
|
||||||
|
WavHeader header;
|
||||||
|
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
|
||||||
|
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
|
||||||
|
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
|
||||||
|
memcpy(&header, wav_header, sizeof(header));
|
||||||
|
header.channels = num_channel_;
|
||||||
|
header.bit = bits_per_sample_;
|
||||||
|
header.sample_rate = sample_rate_;
|
||||||
|
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
header.size = sizeof(header) - 8 + header.data_size;
|
||||||
|
header.bytes_per_second =
|
||||||
|
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
header.block_size = num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
|
||||||
|
fwrite(&header, 1, sizeof(header), fp);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_samples_; ++i) {
|
||||||
|
for (int j = 0; j < num_channel_; ++j) {
|
||||||
|
switch (bits_per_sample_) {
|
||||||
|
case 8: {
|
||||||
|
char sample = static_cast<char>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 16: {
|
||||||
|
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 32: {
|
||||||
|
int sample = static_cast<int>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fclose(fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const float* data_;
|
||||||
|
int num_samples_; // total float points in data_
|
||||||
|
int num_channel_;
|
||||||
|
int sample_rate_;
|
||||||
|
int bits_per_sample_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace wav
|
||||||
|
|
||||||
|
#endif // FRONTEND_WAV_H_
|
||||||
|
|
||||||
|
|
||||||
45
examples/cpp_libtorch_deprecated/README.md
Normal file
45
examples/cpp_libtorch_deprecated/README.md
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# Silero-VAD V5 in C++ (based on LibTorch)
|
||||||
|
|
||||||
|
This is the source code for Silero-VAD V5 in C++, utilizing LibTorch. The primary implementation is CPU-based, and you should compare its results with the Python version. Only results at 16kHz have been tested.
|
||||||
|
|
||||||
|
Additionally, batch and CUDA inference options are available if you want to explore further. Note that when using batch inference, the speech probabilities may slightly differ from the standard version, likely due to differences in caching. Unlike individual input processing, batch inference may not use the cache from previous chunks. Despite this, batch inference offers significantly faster processing. For optimal performance, consider adjusting the threshold when using batch inference.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- GCC 11.4.0 (GCC >= 5.1)
|
||||||
|
- LibTorch 1.13.0 (other versions are also acceptable)
|
||||||
|
|
||||||
|
## Download LibTorch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
-CPU Version
|
||||||
|
wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip
|
||||||
|
unzip libtorch-shared-with-deps-1.13.0+cpu.zip'
|
||||||
|
|
||||||
|
-CUDA Version
|
||||||
|
wget https://download.pytorch.org/libtorch/cu116/libtorch-shared-with-deps-1.13.0%2Bcu116.zip
|
||||||
|
unzip libtorch-shared-with-deps-1.13.0+cu116.zip
|
||||||
|
```
|
||||||
|
|
||||||
|
## Compilation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
-CPU Version
|
||||||
|
g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0
|
||||||
|
|
||||||
|
-CUDA Version
|
||||||
|
g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cuda -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_GPU
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Optional Compilation Flags
|
||||||
|
-DUSE_BATCH: Enable batch inference
|
||||||
|
-DUSE_GPU: Use GPU for inference
|
||||||
|
|
||||||
|
## Run the Program
|
||||||
|
To run the program, use the following command:
|
||||||
|
|
||||||
|
`./silero aepyx.wav 16000 0.5`
|
||||||
|
|
||||||
|
The sample file aepyx.wav is part of the Voxconverse dataset.
|
||||||
|
File details: aepyx.wav is a 16kHz, 16-bit audio file.
|
||||||
BIN
examples/cpp_libtorch_deprecated/aepyx.wav
Normal file
BIN
examples/cpp_libtorch_deprecated/aepyx.wav
Normal file
Binary file not shown.
54
examples/cpp_libtorch_deprecated/main.cc
Normal file
54
examples/cpp_libtorch_deprecated/main.cc
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include "silero_torch.h"
|
||||||
|
#include "wav.h"
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
|
||||||
|
if(argc != 4){
|
||||||
|
std::cerr<<"Usage : "<<argv[0]<<" <wav.path> <SampleRate> <Threshold>"<<std::endl;
|
||||||
|
std::cerr<<"Usage : "<<argv[0]<<" sample.wav 16000 0.5"<<std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string wav_path = argv[1];
|
||||||
|
float sample_rate = std::stof(argv[2]);
|
||||||
|
float threshold = std::stof(argv[3]);
|
||||||
|
|
||||||
|
|
||||||
|
//Load Model
|
||||||
|
std::string model_path = "../../src/silero_vad/data/silero_vad.jit";
|
||||||
|
silero::VadIterator vad(model_path);
|
||||||
|
|
||||||
|
vad.threshold=threshold; //(Default:0.5)
|
||||||
|
vad.sample_rate=sample_rate; //16000Hz,8000Hz. (Default:16000)
|
||||||
|
vad.print_as_samples=true; //if true, it prints time-stamp with samples. otherwise, in seconds
|
||||||
|
//(Default:false)
|
||||||
|
|
||||||
|
vad.SetVariables();
|
||||||
|
|
||||||
|
// Read wav
|
||||||
|
wav::WavReader wav_reader(wav_path);
|
||||||
|
std::vector<float> input_wav(wav_reader.num_samples());
|
||||||
|
|
||||||
|
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||||
|
{
|
||||||
|
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
|
||||||
|
}
|
||||||
|
|
||||||
|
vad.SpeechProbs(input_wav);
|
||||||
|
|
||||||
|
std::vector<silero::SpeechSegment> speeches = vad.GetSpeechTimestamps();
|
||||||
|
for(const auto& speech : speeches){
|
||||||
|
if(vad.print_as_samples){
|
||||||
|
std::cout<<"{'start': "<<static_cast<int>(speech.start)<<", 'end': "<<static_cast<int>(speech.end)<<"}"<<std::endl;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
std::cout<<"{'start': "<<speech.start<<", 'end': "<<speech.end<<"}"<<std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
BIN
examples/cpp_libtorch_deprecated/silero
Executable file
BIN
examples/cpp_libtorch_deprecated/silero
Executable file
Binary file not shown.
285
examples/cpp_libtorch_deprecated/silero_torch.cc
Normal file
285
examples/cpp_libtorch_deprecated/silero_torch.cc
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
//Author : Nathan Lee
|
||||||
|
//Created On : 2024-11-18
|
||||||
|
//Description : silero 5.1 system for torch-script(c++).
|
||||||
|
//Version : 1.0
|
||||||
|
|
||||||
|
|
||||||
|
#include "silero_torch.h"
|
||||||
|
|
||||||
|
namespace silero {
|
||||||
|
|
||||||
|
VadIterator::VadIterator(const std::string &model_path, float threshold, int sample_rate, int window_size_ms, int speech_pad_ms, int min_silence_duration_ms, int min_speech_duration_ms, int max_duration_merge_ms, bool print_as_samples)
|
||||||
|
:sample_rate(sample_rate), threshold(threshold), window_size_ms(window_size_ms), speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms), min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms), print_as_samples(print_as_samples)
|
||||||
|
{
|
||||||
|
init_torch_model(model_path);
|
||||||
|
//init_engine(window_size_ms);
|
||||||
|
}
|
||||||
|
VadIterator::~VadIterator(){
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void VadIterator::SpeechProbs(std::vector<float>& input_wav){
|
||||||
|
// Set the sample rate (must match the model's expected sample rate)
|
||||||
|
// Process the waveform in chunks of 512 samples
|
||||||
|
int num_samples = input_wav.size();
|
||||||
|
int num_chunks = num_samples / window_size_samples;
|
||||||
|
int remainder_samples = num_samples % window_size_samples;
|
||||||
|
|
||||||
|
total_sample_size += num_samples;
|
||||||
|
|
||||||
|
torch::Tensor output;
|
||||||
|
std::vector<torch::Tensor> chunks;
|
||||||
|
|
||||||
|
for (int i = 0; i < num_chunks; i++) {
|
||||||
|
|
||||||
|
float* chunk_start = input_wav.data() + i *window_size_samples;
|
||||||
|
torch::Tensor chunk = torch::from_blob(chunk_start, {1,window_size_samples}, torch::kFloat32);
|
||||||
|
//std::cout<<"chunk size : "<<chunk.sizes()<<std::endl;
|
||||||
|
chunks.push_back(chunk);
|
||||||
|
|
||||||
|
|
||||||
|
if(i==num_chunks-1 && remainder_samples>0){//마지막 chunk && 나머지가 존재
|
||||||
|
int remaining_samples = num_samples - num_chunks * window_size_samples;
|
||||||
|
//std::cout<<"Remainder size : "<<remaining_samples;
|
||||||
|
float* chunk_start_remainder = input_wav.data() + num_chunks *window_size_samples;
|
||||||
|
|
||||||
|
torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1,remaining_samples},
|
||||||
|
torch::kFloat32);
|
||||||
|
// Pad the remainder chunk to match window_size_samples
|
||||||
|
torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples
|
||||||
|
- remaining_samples}, torch::kFloat32)}, 1);
|
||||||
|
//std::cout<<", padded_chunk size : "<<padded_chunk.size(1)<<std::endl;
|
||||||
|
|
||||||
|
chunks.push_back(padded_chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!chunks.empty()) {
|
||||||
|
|
||||||
|
#ifdef USE_BATCH
|
||||||
|
torch::Tensor batched_chunks = torch::stack(chunks); // Stack all chunks into a single tensor
|
||||||
|
//batched_chunks = batched_chunks.squeeze(1);
|
||||||
|
batched_chunks = torch::cat({batched_chunks.squeeze(1)});
|
||||||
|
|
||||||
|
#ifdef USE_GPU
|
||||||
|
batched_chunks = batched_chunks.to(at::kCUDA); // Move the entire batch to GPU once
|
||||||
|
#endif
|
||||||
|
// Prepare input for model
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
inputs.push_back(batched_chunks); // Batch of chunks
|
||||||
|
inputs.push_back(sample_rate); // Assuming sample_rate is a valid input for the model
|
||||||
|
|
||||||
|
// Run inference on the batch
|
||||||
|
torch::NoGradGuard no_grad;
|
||||||
|
torch::Tensor output = model.forward(inputs).toTensor();
|
||||||
|
#ifdef USE_GPU
|
||||||
|
output = output.to(at::kCPU); // Move the output back to CPU once
|
||||||
|
#endif
|
||||||
|
// Collect output probabilities
|
||||||
|
for (int i = 0; i < chunks.size(); i++) {
|
||||||
|
float output_f = output[i].item<float>();
|
||||||
|
outputs_prob.push_back(output_f);
|
||||||
|
//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> outputs;
|
||||||
|
torch::Tensor batched_chunks = torch::stack(chunks);
|
||||||
|
#ifdef USE_GPU
|
||||||
|
batched_chunks = batched_chunks.to(at::kCUDA);
|
||||||
|
#endif
|
||||||
|
for (int i = 0; i < chunks.size(); i++) {
|
||||||
|
torch::NoGradGuard no_grad;
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
inputs.push_back(batched_chunks[i]);
|
||||||
|
inputs.push_back(sample_rate);
|
||||||
|
|
||||||
|
torch::Tensor output = model.forward(inputs).toTensor();
|
||||||
|
outputs.push_back(output);
|
||||||
|
}
|
||||||
|
torch::Tensor all_outputs = torch::stack(outputs);
|
||||||
|
#ifdef USE_GPU
|
||||||
|
all_outputs = all_outputs.to(at::kCPU);
|
||||||
|
#endif
|
||||||
|
for (int i = 0; i < chunks.size(); i++) {
|
||||||
|
float output_f = all_outputs[i].item<float>();
|
||||||
|
outputs_prob.push_back(output_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<SpeechSegment> VadIterator::GetSpeechTimestamps() {
|
||||||
|
std::vector<SpeechSegment> speeches = DoVad();
|
||||||
|
|
||||||
|
#ifdef USE_BATCH
|
||||||
|
//When you use BATCH inference. You would better use 'mergeSpeeches' function to arrage time stamp.
|
||||||
|
//It could be better get reasonable output because of distorted probs.
|
||||||
|
duration_merge_samples = sample_rate * max_duration_merge_ms / 1000;
|
||||||
|
std::vector<SpeechSegment> speeches_merge = mergeSpeeches(speeches, duration_merge_samples);
|
||||||
|
if(!print_as_samples){
|
||||||
|
for (auto& speech : speeches_merge) { //samples to second
|
||||||
|
speech.start /= sample_rate;
|
||||||
|
speech.end /= sample_rate;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return speeches_merge;
|
||||||
|
#else
|
||||||
|
|
||||||
|
if(!print_as_samples){
|
||||||
|
for (auto& speech : speeches) { //samples to second
|
||||||
|
speech.start /= sample_rate;
|
||||||
|
speech.end /= sample_rate;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return speeches;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
void VadIterator::SetVariables(){
|
||||||
|
init_engine(window_size_ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::init_engine(int window_size_ms) {
|
||||||
|
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
|
||||||
|
speech_pad_samples = sample_rate * speech_pad_ms / 1000;
|
||||||
|
window_size_samples = sample_rate / 1000 * window_size_ms;
|
||||||
|
min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::init_torch_model(const std::string& model_path) {
|
||||||
|
at::set_num_threads(1);
|
||||||
|
model = torch::jit::load(model_path);
|
||||||
|
|
||||||
|
#ifdef USE_GPU
|
||||||
|
if (!torch::cuda::is_available()) {
|
||||||
|
std::cout<<"CUDA is not available! Please check your GPU settings"<<std::endl;
|
||||||
|
throw std::runtime_error("CUDA is not available!");
|
||||||
|
model.to(at::Device(at::kCPU));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
std::cout<<"CUDA available! Running on '0'th GPU"<<std::endl;
|
||||||
|
model.to(at::Device(at::kCUDA, 0)); //select 0'th machine
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
model.eval();
|
||||||
|
torch::NoGradGuard no_grad;
|
||||||
|
std::cout << "Model loaded successfully"<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::reset_states() {
|
||||||
|
triggered = false;
|
||||||
|
current_sample = 0;
|
||||||
|
temp_end = 0;
|
||||||
|
outputs_prob.clear();
|
||||||
|
model.run_method("reset_states");
|
||||||
|
total_sample_size = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<SpeechSegment> VadIterator::DoVad() {
|
||||||
|
std::vector<SpeechSegment> speeches;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < outputs_prob.size(); ++i) {
|
||||||
|
float speech_prob = outputs_prob[i];
|
||||||
|
//std::cout << speech_prob << std::endl;
|
||||||
|
//std::cout << "Chunk " << i << " Prob: " << speech_prob << "\n";
|
||||||
|
//std::cout << speech_prob << " ";
|
||||||
|
current_sample += window_size_samples;
|
||||||
|
|
||||||
|
if (speech_prob >= threshold && temp_end != 0) {
|
||||||
|
temp_end = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speech_prob >= threshold && !triggered) {
|
||||||
|
triggered = true;
|
||||||
|
SpeechSegment segment;
|
||||||
|
segment.start = std::max(static_cast<int>(0), current_sample - speech_pad_samples - window_size_samples);
|
||||||
|
speeches.push_back(segment);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speech_prob < threshold - 0.15f && triggered) {
|
||||||
|
if (temp_end == 0) {
|
||||||
|
temp_end = current_sample;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (current_sample - temp_end < min_silence_samples) {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
SpeechSegment& segment = speeches.back();
|
||||||
|
segment.end = temp_end + speech_pad_samples - window_size_samples;
|
||||||
|
temp_end = 0;
|
||||||
|
triggered = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (triggered) { //만약 낮은 확률을 보이다가 마지막프레임 prbos만 딱 확률이 높게 나오면 위에서 triggerd = true 메핑과 동시에 segment start가 돼서 문제가 될것 같은데? start = end 같은값? 후처리가 있으니 문제가 없으려나?
|
||||||
|
std::cout<<"when last triggered is keep working until last Probs"<<std::endl;
|
||||||
|
SpeechSegment& segment = speeches.back();
|
||||||
|
segment.end = total_sample_size; // 현재 샘플을 마지막 구간의 종료 시간으로 설정
|
||||||
|
triggered = false; // VAD 상태 초기화
|
||||||
|
}
|
||||||
|
|
||||||
|
speeches.erase(
|
||||||
|
std::remove_if(
|
||||||
|
speeches.begin(),
|
||||||
|
speeches.end(),
|
||||||
|
[this](const SpeechSegment& speech) {
|
||||||
|
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
|
||||||
|
//min_speech_samples is 4000samples(0.25sec)
|
||||||
|
//여기서 포인트!! 계산 할때는 start,end sample에'speech_pad_samples' 사이즈를 추가한후 길이를 측정함.
|
||||||
|
}
|
||||||
|
),
|
||||||
|
speeches.end()
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
//std::cout<<std::endl;
|
||||||
|
//std::cout<<"outputs_prob.size : "<<outputs_prob.size()<<std::endl;
|
||||||
|
|
||||||
|
reset_states();
|
||||||
|
return speeches;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<SpeechSegment> VadIterator::mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples) {
|
||||||
|
std::vector<SpeechSegment> mergedSpeeches;
|
||||||
|
|
||||||
|
if (speeches.empty()) {
|
||||||
|
return mergedSpeeches; // 빈 벡터 반환
|
||||||
|
}
|
||||||
|
|
||||||
|
// 첫 번째 구간으로 초기화
|
||||||
|
SpeechSegment currentSegment = speeches[0];
|
||||||
|
|
||||||
|
for (size_t i = 1; i < speeches.size(); ++i) { //첫번째 start,end 정보 건너뛰기. 그래서 i=1부터
|
||||||
|
// 두 구간의 차이가 threshold(duration_merge_samples)보다 작은 경우, 합침
|
||||||
|
if (speeches[i].start - currentSegment.end < duration_merge_samples) {
|
||||||
|
// 현재 구간의 끝점을 업데이트
|
||||||
|
currentSegment.end = speeches[i].end;
|
||||||
|
} else {
|
||||||
|
// 차이가 threshold(duration_merge_samples) 이상이면 현재 구간을 저장하고 새로운 구간 시작
|
||||||
|
mergedSpeeches.push_back(currentSegment);
|
||||||
|
currentSegment = speeches[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 마지막 구간 추가
|
||||||
|
mergedSpeeches.push_back(currentSegment);
|
||||||
|
|
||||||
|
return mergedSpeeches;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
75
examples/cpp_libtorch_deprecated/silero_torch.h
Normal file
75
examples/cpp_libtorch_deprecated/silero_torch.h
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
//Author : Nathan Lee
|
||||||
|
//Created On : 2024-11-18
|
||||||
|
//Description : silero 5.1 system for torch-script(c++).
|
||||||
|
//Version : 1.0
|
||||||
|
|
||||||
|
#ifndef SILERO_TORCH_H
|
||||||
|
#define SILERO_TORCH_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <torch/script.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace silero{
|
||||||
|
|
||||||
|
struct SpeechSegment{
|
||||||
|
int start;
|
||||||
|
int end;
|
||||||
|
};
|
||||||
|
|
||||||
|
class VadIterator{
|
||||||
|
public:
|
||||||
|
|
||||||
|
VadIterator(const std::string &model_path, float threshold = 0.5, int sample_rate = 16000,
|
||||||
|
int window_size_ms = 32, int speech_pad_ms = 30, int min_silence_duration_ms = 100,
|
||||||
|
int min_speech_duration_ms = 250, int max_duration_merge_ms = 300, bool print_as_samples = false);
|
||||||
|
~VadIterator();
|
||||||
|
|
||||||
|
|
||||||
|
void SpeechProbs(std::vector<float>& input_wav);
|
||||||
|
std::vector<silero::SpeechSegment> GetSpeechTimestamps();
|
||||||
|
void SetVariables();
|
||||||
|
|
||||||
|
float threshold;
|
||||||
|
int sample_rate;
|
||||||
|
int window_size_ms;
|
||||||
|
int min_speech_duration_ms;
|
||||||
|
int max_duration_merge_ms;
|
||||||
|
bool print_as_samples;
|
||||||
|
|
||||||
|
private:
|
||||||
|
torch::jit::script::Module model;
|
||||||
|
std::vector<float> outputs_prob;
|
||||||
|
int min_silence_samples;
|
||||||
|
int min_speech_samples;
|
||||||
|
int speech_pad_samples;
|
||||||
|
int window_size_samples;
|
||||||
|
int duration_merge_samples;
|
||||||
|
int current_sample = 0;
|
||||||
|
|
||||||
|
int total_sample_size=0;
|
||||||
|
|
||||||
|
int min_silence_duration_ms;
|
||||||
|
int speech_pad_ms;
|
||||||
|
bool triggered = false;
|
||||||
|
int temp_end = 0;
|
||||||
|
|
||||||
|
void init_engine(int window_size_ms);
|
||||||
|
void init_torch_model(const std::string& model_path);
|
||||||
|
void reset_states();
|
||||||
|
std::vector<SpeechSegment> DoVad();
|
||||||
|
std::vector<SpeechSegment> mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
#endif // SILERO_TORCH_H
|
||||||
235
examples/cpp_libtorch_deprecated/wav.h
Normal file
235
examples/cpp_libtorch_deprecated/wav.h
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
// Copyright (c) 2016 Personal (Binbin Zhang)
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef FRONTEND_WAV_H_
|
||||||
|
#define FRONTEND_WAV_H_
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
// #include "utils/log.h"
|
||||||
|
|
||||||
|
namespace wav {
|
||||||
|
|
||||||
|
struct WavHeader {
|
||||||
|
char riff[4]; // "riff"
|
||||||
|
unsigned int size;
|
||||||
|
char wav[4]; // "WAVE"
|
||||||
|
char fmt[4]; // "fmt "
|
||||||
|
unsigned int fmt_size;
|
||||||
|
uint16_t format;
|
||||||
|
uint16_t channels;
|
||||||
|
unsigned int sample_rate;
|
||||||
|
unsigned int bytes_per_second;
|
||||||
|
uint16_t block_size;
|
||||||
|
uint16_t bit;
|
||||||
|
char data[4]; // "data"
|
||||||
|
unsigned int data_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WavReader {
|
||||||
|
public:
|
||||||
|
WavReader() : data_(nullptr) {}
|
||||||
|
explicit WavReader(const std::string& filename) { Open(filename); }
|
||||||
|
|
||||||
|
bool Open(const std::string& filename) {
|
||||||
|
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
|
||||||
|
if (NULL == fp) {
|
||||||
|
std::cout << "Error in read " << filename;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
WavHeader header;
|
||||||
|
fread(&header, 1, sizeof(header), fp);
|
||||||
|
if (header.fmt_size < 16) {
|
||||||
|
printf("WaveData: expect PCM format data "
|
||||||
|
"to have fmt chunk of at least size 16.\n");
|
||||||
|
return false;
|
||||||
|
} else if (header.fmt_size > 16) {
|
||||||
|
int offset = 44 - 8 + header.fmt_size - 16;
|
||||||
|
fseek(fp, offset, SEEK_SET);
|
||||||
|
fread(header.data, 8, sizeof(char), fp);
|
||||||
|
}
|
||||||
|
// check "riff" "WAVE" "fmt " "data"
|
||||||
|
|
||||||
|
// Skip any sub-chunks between "fmt" and "data". Usually there will
|
||||||
|
// be a single "fact" sub chunk, but on Windows there can also be a
|
||||||
|
// "list" sub chunk.
|
||||||
|
while (0 != strncmp(header.data, "data", 4)) {
|
||||||
|
// We will just ignore the data in these chunks.
|
||||||
|
fseek(fp, header.data_size, SEEK_CUR);
|
||||||
|
// read next sub chunk
|
||||||
|
fread(header.data, 8, sizeof(char), fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (header.data_size == 0) {
|
||||||
|
int offset = ftell(fp);
|
||||||
|
fseek(fp, 0, SEEK_END);
|
||||||
|
header.data_size = ftell(fp) - offset;
|
||||||
|
fseek(fp, offset, SEEK_SET);
|
||||||
|
}
|
||||||
|
|
||||||
|
num_channel_ = header.channels;
|
||||||
|
sample_rate_ = header.sample_rate;
|
||||||
|
bits_per_sample_ = header.bit;
|
||||||
|
int num_data = header.data_size / (bits_per_sample_ / 8);
|
||||||
|
data_ = new float[num_data]; // Create 1-dim array
|
||||||
|
num_samples_ = num_data / num_channel_;
|
||||||
|
|
||||||
|
std::cout << "num_channel_ :" << num_channel_ << std::endl;
|
||||||
|
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
|
||||||
|
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
|
||||||
|
std::cout << "num_samples :" << num_data << std::endl;
|
||||||
|
std::cout << "num_data_size :" << header.data_size << std::endl;
|
||||||
|
|
||||||
|
switch (bits_per_sample_) {
|
||||||
|
case 8: {
|
||||||
|
char sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(char), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 16: {
|
||||||
|
int16_t sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(int16_t), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 32:
|
||||||
|
{
|
||||||
|
if (header.format == 1) //S32
|
||||||
|
{
|
||||||
|
int sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(int), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (header.format == 3) // IEEE-float
|
||||||
|
{
|
||||||
|
float sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(float), fp);
|
||||||
|
data_[i] = static_cast<float>(sample);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
printf("unsupported quantization bits\n");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
printf("unsupported quantization bits\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
fclose(fp);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_channel() const { return num_channel_; }
|
||||||
|
int sample_rate() const { return sample_rate_; }
|
||||||
|
int bits_per_sample() const { return bits_per_sample_; }
|
||||||
|
int num_samples() const { return num_samples_; }
|
||||||
|
|
||||||
|
~WavReader() {
|
||||||
|
delete[] data_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float* data() const { return data_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_channel_;
|
||||||
|
int sample_rate_;
|
||||||
|
int bits_per_sample_;
|
||||||
|
int num_samples_; // sample points per channel
|
||||||
|
float* data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WavWriter {
|
||||||
|
public:
|
||||||
|
WavWriter(const float* data, int num_samples, int num_channel,
|
||||||
|
int sample_rate, int bits_per_sample)
|
||||||
|
: data_(data),
|
||||||
|
num_samples_(num_samples),
|
||||||
|
num_channel_(num_channel),
|
||||||
|
sample_rate_(sample_rate),
|
||||||
|
bits_per_sample_(bits_per_sample) {}
|
||||||
|
|
||||||
|
void Write(const std::string& filename) {
|
||||||
|
FILE* fp = fopen(filename.c_str(), "w");
|
||||||
|
// init char 'riff' 'WAVE' 'fmt ' 'data'
|
||||||
|
WavHeader header;
|
||||||
|
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
|
||||||
|
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
|
||||||
|
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
|
||||||
|
memcpy(&header, wav_header, sizeof(header));
|
||||||
|
header.channels = num_channel_;
|
||||||
|
header.bit = bits_per_sample_;
|
||||||
|
header.sample_rate = sample_rate_;
|
||||||
|
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
header.size = sizeof(header) - 8 + header.data_size;
|
||||||
|
header.bytes_per_second =
|
||||||
|
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
header.block_size = num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
|
||||||
|
fwrite(&header, 1, sizeof(header), fp);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_samples_; ++i) {
|
||||||
|
for (int j = 0; j < num_channel_; ++j) {
|
||||||
|
switch (bits_per_sample_) {
|
||||||
|
case 8: {
|
||||||
|
char sample = static_cast<char>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 16: {
|
||||||
|
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 32: {
|
||||||
|
int sample = static_cast<int>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fclose(fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const float* data_;
|
||||||
|
int num_samples_; // total float points in data_
|
||||||
|
int num_channel_;
|
||||||
|
int sample_rate_;
|
||||||
|
int bits_per_sample_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace wenet
|
||||||
|
|
||||||
|
#endif // FRONTEND_WAV_H_
|
||||||
@@ -21,7 +21,7 @@ class Program
|
|||||||
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
||||||
List<SileroSpeechSegment> speechTimeList = vadDetector.GetSpeechSegmentList(new FileInfo(EXAMPLE_WAV_FILE));
|
List<SileroSpeechSegment> speechTimeList = vadDetector.GetSpeechSegmentList(new FileInfo(EXAMPLE_WAV_FILE));
|
||||||
//Console.WriteLine(speechTimeList.ToJson());
|
//Console.WriteLine(speechTimeList.ToJson());
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new();
|
||||||
foreach (var speechSegment in speechTimeList)
|
foreach (var speechSegment in speechTimeList)
|
||||||
{
|
{
|
||||||
sb.Append($"start second: {speechSegment.StartSecond}, end second: {speechSegment.EndSecond}\n");
|
sb.Append($"start second: {speechSegment.StartSecond}, end second: {speechSegment.EndSecond}\n");
|
||||||
|
|||||||
@@ -53,28 +53,26 @@ public class SileroVadDetector
|
|||||||
{
|
{
|
||||||
Reset();
|
Reset();
|
||||||
|
|
||||||
using (var audioFile = new AudioFileReader(wavFile.FullName))
|
using var audioFile = new AudioFileReader(wavFile.FullName);
|
||||||
|
List<float> speechProbList = [];
|
||||||
|
this._audioLengthSamples = (int)(audioFile.Length / 2);
|
||||||
|
float[] buffer = new float[this._windowSizeSample];
|
||||||
|
|
||||||
|
while (audioFile.Read(buffer, 0, buffer.Length) > 0)
|
||||||
{
|
{
|
||||||
List<float> speechProbList = new List<float>();
|
float speechProb = _model.Call([buffer], _samplingRate)[0];
|
||||||
this._audioLengthSamples = (int)(audioFile.Length / 2);
|
speechProbList.Add(speechProb);
|
||||||
float[] buffer = new float[this._windowSizeSample];
|
|
||||||
|
|
||||||
while (audioFile.Read(buffer, 0, buffer.Length) > 0)
|
|
||||||
{
|
|
||||||
float speechProb = _model.Call(new[] { buffer }, _samplingRate)[0];
|
|
||||||
speechProbList.Add(speechProb);
|
|
||||||
}
|
|
||||||
|
|
||||||
return CalculateProb(speechProbList);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return CalculateProb(speechProbList);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<SileroSpeechSegment> CalculateProb(List<float> speechProbList)
|
private List<SileroSpeechSegment> CalculateProb(List<float> speechProbList)
|
||||||
{
|
{
|
||||||
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
|
List<SileroSpeechSegment> result = [];
|
||||||
bool triggered = false;
|
bool triggered = false;
|
||||||
int tempEnd = 0, prevEnd = 0, nextStart = 0;
|
int tempEnd = 0, prevEnd = 0, nextStart = 0;
|
||||||
SileroSpeechSegment segment = new SileroSpeechSegment();
|
SileroSpeechSegment segment = new();
|
||||||
|
|
||||||
for (int i = 0; i < speechProbList.Count; i++)
|
for (int i = 0; i < speechProbList.Count; i++)
|
||||||
{
|
{
|
||||||
@@ -164,7 +162,8 @@ public class SileroVadDetector
|
|||||||
|
|
||||||
if (segment.StartOffset != null && (_audioLengthSamples - segment.StartOffset) > _minSpeechSamples)
|
if (segment.StartOffset != null && (_audioLengthSamples - segment.StartOffset) > _minSpeechSamples)
|
||||||
{
|
{
|
||||||
segment.EndOffset = _audioLengthSamples;
|
//segment.EndOffset = _audioLengthSamples;
|
||||||
|
segment.EndOffset = speechProbList.Count * _windowSizeSample;
|
||||||
result.Add(segment);
|
result.Add(segment);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,7 +181,7 @@ public class SileroVadDetector
|
|||||||
int silenceDuration = nextItem.StartOffset.Value - item.EndOffset.Value;
|
int silenceDuration = nextItem.StartOffset.Value - item.EndOffset.Value;
|
||||||
if (silenceDuration < 2 * _speechPadSamples)
|
if (silenceDuration < 2 * _speechPadSamples)
|
||||||
{
|
{
|
||||||
item.EndOffset = item.EndOffset + (silenceDuration / 2);
|
item.EndOffset += (silenceDuration / 2);
|
||||||
nextItem.StartOffset = Math.Max(0, nextItem.StartOffset.Value - (silenceDuration / 2));
|
nextItem.StartOffset = Math.Max(0, nextItem.StartOffset.Value - (silenceDuration / 2));
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -200,9 +199,9 @@ public class SileroVadDetector
|
|||||||
return MergeListAndCalculateSecond(result, _samplingRate);
|
return MergeListAndCalculateSecond(result, _samplingRate);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<SileroSpeechSegment> MergeListAndCalculateSecond(List<SileroSpeechSegment> original, int samplingRate)
|
private static List<SileroSpeechSegment> MergeListAndCalculateSecond(List<SileroSpeechSegment> original, int samplingRate)
|
||||||
{
|
{
|
||||||
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
|
List<SileroSpeechSegment> result = [];
|
||||||
if (original == null || original.Count == 0)
|
if (original == null || original.Count == 0)
|
||||||
{
|
{
|
||||||
return result;
|
return result;
|
||||||
@@ -242,7 +241,7 @@ public class SileroVadDetector
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
private float CalculateSecondByOffset(int offset, int samplingRate)
|
private static float CalculateSecondByOffset(int offset, int samplingRate)
|
||||||
{
|
{
|
||||||
float secondValue = offset * 1.0f / samplingRate;
|
float secondValue = offset * 1.0f / samplingRate;
|
||||||
return (float)Math.Floor(secondValue * 1000.0f) / 1000.0f;
|
return (float)Math.Floor(secondValue * 1000.0f) / 1000.0f;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
using Microsoft.ML.OnnxRuntime;
|
using Microsoft.ML.OnnxRuntime;
|
||||||
using Microsoft.ML.OnnxRuntime.Tensors;
|
using Microsoft.ML.OnnxRuntime.Tensors;
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using System.Linq;
|
using System.Linq;
|
||||||
@@ -7,214 +8,208 @@ using System.Linq;
|
|||||||
namespace VADdotnet;
|
namespace VADdotnet;
|
||||||
|
|
||||||
|
|
||||||
public class SileroVadOnnxModel : IDisposable
|
public class SileroVadOnnxModel : IDisposable
|
||||||
|
{
|
||||||
|
private readonly InferenceSession session;
|
||||||
|
private float[][][] state;
|
||||||
|
private float[][] context;
|
||||||
|
private int lastSr = 0;
|
||||||
|
private int lastBatchSize = 0;
|
||||||
|
private static readonly List<int> SAMPLE_RATES = [8000, 16000];
|
||||||
|
|
||||||
|
public SileroVadOnnxModel(string modelPath)
|
||||||
{
|
{
|
||||||
private readonly InferenceSession session;
|
var sessionOptions = new SessionOptions
|
||||||
private float[][][] state;
|
|
||||||
private float[][] context;
|
|
||||||
private int lastSr = 0;
|
|
||||||
private int lastBatchSize = 0;
|
|
||||||
private static readonly List<int> SAMPLE_RATES = new List<int> { 8000, 16000 };
|
|
||||||
|
|
||||||
public SileroVadOnnxModel(string modelPath)
|
|
||||||
{
|
{
|
||||||
var sessionOptions = new SessionOptions();
|
InterOpNumThreads = 1,
|
||||||
sessionOptions.InterOpNumThreads = 1;
|
IntraOpNumThreads = 1,
|
||||||
sessionOptions.IntraOpNumThreads = 1;
|
EnableCpuMemArena = true
|
||||||
sessionOptions.EnableCpuMemArena = true;
|
};
|
||||||
|
|
||||||
session = new InferenceSession(modelPath, sessionOptions);
|
session = new InferenceSession(modelPath, sessionOptions);
|
||||||
|
ResetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void ResetStates()
|
||||||
|
{
|
||||||
|
state = new float[2][][];
|
||||||
|
state[0] = new float[1][];
|
||||||
|
state[1] = new float[1][];
|
||||||
|
state[0][0] = new float[128];
|
||||||
|
state[1][0] = new float[128];
|
||||||
|
context = [];
|
||||||
|
lastSr = 0;
|
||||||
|
lastBatchSize = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Dispose()
|
||||||
|
{
|
||||||
|
GC.SuppressFinalize(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
public class ValidationResult(float[][] x, int sr)
|
||||||
|
{
|
||||||
|
public float[][] X { get; } = x;
|
||||||
|
public int Sr { get; } = sr;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ValidationResult ValidateInput(float[][] x, int sr)
|
||||||
|
{
|
||||||
|
if (x.Length == 1)
|
||||||
|
{
|
||||||
|
x = [x[0]];
|
||||||
|
}
|
||||||
|
if (x.Length > 2)
|
||||||
|
{
|
||||||
|
throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sr != 16000 && (sr % 16000 == 0))
|
||||||
|
{
|
||||||
|
int step = sr / 16000;
|
||||||
|
float[][] reducedX = new float[x.Length][];
|
||||||
|
|
||||||
|
for (int i = 0; i < x.Length; i++)
|
||||||
|
{
|
||||||
|
float[] current = x[i];
|
||||||
|
float[] newArr = new float[(current.Length + step - 1) / step];
|
||||||
|
|
||||||
|
for (int j = 0, index = 0; j < current.Length; j += step, index++)
|
||||||
|
{
|
||||||
|
newArr[index] = current[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
reducedX[i] = newArr;
|
||||||
|
}
|
||||||
|
|
||||||
|
x = reducedX;
|
||||||
|
sr = 16000;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!SAMPLE_RATES.Contains(sr))
|
||||||
|
{
|
||||||
|
throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (((float)sr) / x[0].Length > 31.25)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("Input audio is too short");
|
||||||
|
}
|
||||||
|
|
||||||
|
return new ValidationResult(x, sr);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[][] Concatenate(float[][] a, float[][] b)
|
||||||
|
{
|
||||||
|
if (a.Length != b.Length)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("The number of rows in both arrays must be the same.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int rows = a.Length;
|
||||||
|
int colsA = a[0].Length;
|
||||||
|
int colsB = b[0].Length;
|
||||||
|
float[][] result = new float[rows][];
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++)
|
||||||
|
{
|
||||||
|
result[i] = new float[colsA + colsB];
|
||||||
|
Array.Copy(a[i], 0, result[i], 0, colsA);
|
||||||
|
Array.Copy(b[i], 0, result[i], colsA, colsB);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[][] GetLastColumns(float[][] array, int contextSize)
|
||||||
|
{
|
||||||
|
int rows = array.Length;
|
||||||
|
int cols = array[0].Length;
|
||||||
|
|
||||||
|
if (contextSize > cols)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("contextSize cannot be greater than the number of columns in the array.");
|
||||||
|
}
|
||||||
|
|
||||||
|
float[][] result = new float[rows][];
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++)
|
||||||
|
{
|
||||||
|
result[i] = new float[contextSize];
|
||||||
|
Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float[] Call(float[][] x, int sr)
|
||||||
|
{
|
||||||
|
var result = ValidateInput(x, sr);
|
||||||
|
x = result.X;
|
||||||
|
sr = result.Sr;
|
||||||
|
int numberSamples = sr == 16000 ? 512 : 256;
|
||||||
|
|
||||||
|
if (x[0].Length != numberSamples)
|
||||||
|
{
|
||||||
|
throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)");
|
||||||
|
}
|
||||||
|
|
||||||
|
int batchSize = x.Length;
|
||||||
|
int contextSize = sr == 16000 ? 64 : 32;
|
||||||
|
|
||||||
|
if (lastBatchSize == 0)
|
||||||
|
{
|
||||||
|
ResetStates();
|
||||||
|
}
|
||||||
|
if (lastSr != 0 && lastSr != sr)
|
||||||
|
{
|
||||||
|
ResetStates();
|
||||||
|
}
|
||||||
|
if (lastBatchSize != 0 && lastBatchSize != batchSize)
|
||||||
|
{
|
||||||
ResetStates();
|
ResetStates();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void ResetStates()
|
if (context.Length == 0)
|
||||||
{
|
{
|
||||||
state = new float[2][][];
|
context = new float[batchSize][];
|
||||||
state[0] = new float[1][];
|
for (int i = 0; i < batchSize; i++)
|
||||||
state[1] = new float[1][];
|
|
||||||
state[0][0] = new float[128];
|
|
||||||
state[1][0] = new float[128];
|
|
||||||
context = Array.Empty<float[]>();
|
|
||||||
lastSr = 0;
|
|
||||||
lastBatchSize = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void Dispose()
|
|
||||||
{
|
|
||||||
session?.Dispose();
|
|
||||||
}
|
|
||||||
|
|
||||||
public class ValidationResult
|
|
||||||
{
|
|
||||||
public float[][] X { get; }
|
|
||||||
public int Sr { get; }
|
|
||||||
|
|
||||||
public ValidationResult(float[][] x, int sr)
|
|
||||||
{
|
{
|
||||||
X = x;
|
context[i] = new float[contextSize];
|
||||||
Sr = sr;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private ValidationResult ValidateInput(float[][] x, int sr)
|
x = Concatenate(context, x);
|
||||||
{
|
|
||||||
if (x.Length == 1)
|
var inputs = new List<NamedOnnxValue>
|
||||||
{
|
{
|
||||||
x = new float[][] { x[0] };
|
NamedOnnxValue.CreateFromTensor("input", new DenseTensor<float>(x.SelectMany(a => a).ToArray(), [x.Length, x[0].Length])),
|
||||||
}
|
NamedOnnxValue.CreateFromTensor("sr", new DenseTensor<long>(new[] { (long)sr }, [1])),
|
||||||
if (x.Length > 2)
|
NamedOnnxValue.CreateFromTensor("state", new DenseTensor<float>(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), [state.Length, state[0].Length, state[0][0].Length]))
|
||||||
{
|
|
||||||
throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sr != 16000 && (sr % 16000 == 0))
|
|
||||||
{
|
|
||||||
int step = sr / 16000;
|
|
||||||
float[][] reducedX = new float[x.Length][];
|
|
||||||
|
|
||||||
for (int i = 0; i < x.Length; i++)
|
|
||||||
{
|
|
||||||
float[] current = x[i];
|
|
||||||
float[] newArr = new float[(current.Length + step - 1) / step];
|
|
||||||
|
|
||||||
for (int j = 0, index = 0; j < current.Length; j += step, index++)
|
|
||||||
{
|
|
||||||
newArr[index] = current[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
reducedX[i] = newArr;
|
|
||||||
}
|
|
||||||
|
|
||||||
x = reducedX;
|
|
||||||
sr = 16000;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!SAMPLE_RATES.Contains(sr))
|
|
||||||
{
|
|
||||||
throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (((float)sr) / x[0].Length > 31.25)
|
|
||||||
{
|
|
||||||
throw new ArgumentException("Input audio is too short");
|
|
||||||
}
|
|
||||||
|
|
||||||
return new ValidationResult(x, sr);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static float[][] Concatenate(float[][] a, float[][] b)
|
|
||||||
{
|
|
||||||
if (a.Length != b.Length)
|
|
||||||
{
|
|
||||||
throw new ArgumentException("The number of rows in both arrays must be the same.");
|
|
||||||
}
|
|
||||||
|
|
||||||
int rows = a.Length;
|
|
||||||
int colsA = a[0].Length;
|
|
||||||
int colsB = b[0].Length;
|
|
||||||
float[][] result = new float[rows][];
|
|
||||||
|
|
||||||
for (int i = 0; i < rows; i++)
|
|
||||||
{
|
|
||||||
result[i] = new float[colsA + colsB];
|
|
||||||
Array.Copy(a[i], 0, result[i], 0, colsA);
|
|
||||||
Array.Copy(b[i], 0, result[i], colsA, colsB);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static float[][] GetLastColumns(float[][] array, int contextSize)
|
|
||||||
{
|
|
||||||
int rows = array.Length;
|
|
||||||
int cols = array[0].Length;
|
|
||||||
|
|
||||||
if (contextSize > cols)
|
|
||||||
{
|
|
||||||
throw new ArgumentException("contextSize cannot be greater than the number of columns in the array.");
|
|
||||||
}
|
|
||||||
|
|
||||||
float[][] result = new float[rows][];
|
|
||||||
|
|
||||||
for (int i = 0; i < rows; i++)
|
|
||||||
{
|
|
||||||
result[i] = new float[contextSize];
|
|
||||||
Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float[] Call(float[][] x, int sr)
|
|
||||||
{
|
|
||||||
var result = ValidateInput(x, sr);
|
|
||||||
x = result.X;
|
|
||||||
sr = result.Sr;
|
|
||||||
int numberSamples = sr == 16000 ? 512 : 256;
|
|
||||||
|
|
||||||
if (x[0].Length != numberSamples)
|
|
||||||
{
|
|
||||||
throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)");
|
|
||||||
}
|
|
||||||
|
|
||||||
int batchSize = x.Length;
|
|
||||||
int contextSize = sr == 16000 ? 64 : 32;
|
|
||||||
|
|
||||||
if (lastBatchSize == 0)
|
|
||||||
{
|
|
||||||
ResetStates();
|
|
||||||
}
|
|
||||||
if (lastSr != 0 && lastSr != sr)
|
|
||||||
{
|
|
||||||
ResetStates();
|
|
||||||
}
|
|
||||||
if (lastBatchSize != 0 && lastBatchSize != batchSize)
|
|
||||||
{
|
|
||||||
ResetStates();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (context.Length == 0)
|
|
||||||
{
|
|
||||||
context = new float[batchSize][];
|
|
||||||
for (int i = 0; i < batchSize; i++)
|
|
||||||
{
|
|
||||||
context[i] = new float[contextSize];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
x = Concatenate(context, x);
|
|
||||||
|
|
||||||
var inputs = new List<NamedOnnxValue>
|
|
||||||
{
|
|
||||||
NamedOnnxValue.CreateFromTensor("input", new DenseTensor<float>(x.SelectMany(a => a).ToArray(), new[] { x.Length, x[0].Length })),
|
|
||||||
NamedOnnxValue.CreateFromTensor("sr", new DenseTensor<long>(new[] { (long)sr }, new[] { 1 })),
|
|
||||||
NamedOnnxValue.CreateFromTensor("state", new DenseTensor<float>(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), new[] { state.Length, state[0].Length, state[0][0].Length }))
|
|
||||||
};
|
};
|
||||||
|
|
||||||
using (var outputs = session.Run(inputs))
|
using var outputs = session.Run(inputs);
|
||||||
|
var output = outputs.First(o => o.Name == "output").AsTensor<float>();
|
||||||
|
var newState = outputs.First(o => o.Name == "stateN").AsTensor<float>();
|
||||||
|
|
||||||
|
context = GetLastColumns(x, contextSize);
|
||||||
|
lastSr = sr;
|
||||||
|
lastBatchSize = batchSize;
|
||||||
|
|
||||||
|
state = new float[newState.Dimensions[0]][][];
|
||||||
|
for (int i = 0; i < newState.Dimensions[0]; i++)
|
||||||
|
{
|
||||||
|
state[i] = new float[newState.Dimensions[1]][];
|
||||||
|
for (int j = 0; j < newState.Dimensions[1]; j++)
|
||||||
{
|
{
|
||||||
var output = outputs.First(o => o.Name == "output").AsTensor<float>();
|
state[i][j] = new float[newState.Dimensions[2]];
|
||||||
var newState = outputs.First(o => o.Name == "stateN").AsTensor<float>();
|
for (int k = 0; k < newState.Dimensions[2]; k++)
|
||||||
|
|
||||||
context = GetLastColumns(x, contextSize);
|
|
||||||
lastSr = sr;
|
|
||||||
lastBatchSize = batchSize;
|
|
||||||
|
|
||||||
state = new float[newState.Dimensions[0]][][];
|
|
||||||
for (int i = 0; i < newState.Dimensions[0]; i++)
|
|
||||||
{
|
{
|
||||||
state[i] = new float[newState.Dimensions[1]][];
|
state[i][j][k] = newState[i, j, k];
|
||||||
for (int j = 0; j < newState.Dimensions[1]; j++)
|
|
||||||
{
|
|
||||||
state[i][j] = new float[newState.Dimensions[2]];
|
|
||||||
for (int k = 0; k < newState.Dimensions[2]; k++)
|
|
||||||
{
|
|
||||||
state[i][j][k] = newState[i, j, k];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return output.ToArray();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return [.. output];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|||||||
25
examples/csharp/VadDotNet.sln
Normal file
25
examples/csharp/VadDotNet.sln
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
|
||||||
|
Microsoft Visual Studio Solution File, Format Version 12.00
|
||||||
|
# Visual Studio Version 17
|
||||||
|
VisualStudioVersion = 17.14.36616.10 d17.14
|
||||||
|
MinimumVisualStudioVersion = 10.0.40219.1
|
||||||
|
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "VadDotNet", "VadDotNet.csproj", "{F36E1741-EDDB-90C7-7501-4911058F8996}"
|
||||||
|
EndProject
|
||||||
|
Global
|
||||||
|
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||||
|
Debug|Any CPU = Debug|Any CPU
|
||||||
|
Release|Any CPU = Release|Any CPU
|
||||||
|
EndGlobalSection
|
||||||
|
GlobalSection(ProjectConfigurationPlatforms) = postSolution
|
||||||
|
{F36E1741-EDDB-90C7-7501-4911058F8996}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||||
|
{F36E1741-EDDB-90C7-7501-4911058F8996}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||||
|
{F36E1741-EDDB-90C7-7501-4911058F8996}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||||
|
{F36E1741-EDDB-90C7-7501-4911058F8996}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||||
|
EndGlobalSection
|
||||||
|
GlobalSection(SolutionProperties) = preSolution
|
||||||
|
HideSolutionNode = FALSE
|
||||||
|
EndGlobalSection
|
||||||
|
GlobalSection(ExtensibilityGlobals) = postSolution
|
||||||
|
SolutionGuid = {DFC4CEE8-1034-46B4-A5F4-D1649B3543E6}
|
||||||
|
EndGlobalSection
|
||||||
|
EndGlobal
|
||||||
612
examples/rust-example/Cargo.lock
generated
612
examples/rust-example/Cargo.lock
generated
@@ -1,6 +1,6 @@
|
|||||||
# This file is automatically @generated by Cargo.
|
# This file is automatically @generated by Cargo.
|
||||||
# It is not intended for manual editing.
|
# It is not intended for manual editing.
|
||||||
version = 3
|
version = 4
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "adler"
|
name = "adler"
|
||||||
@@ -20,6 +20,12 @@ version = "0.22.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "base64ct"
|
||||||
|
version = "1.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0e050f626429857a27ddccb31e0aca21356bfa709c04041aefddac081a8f068a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bitflags"
|
name = "bitflags"
|
||||||
version = "1.3.2"
|
version = "1.3.2"
|
||||||
@@ -42,10 +48,16 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bumpalo"
|
name = "byteorder"
|
||||||
version = "3.16.0"
|
version = "1.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
|
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bytes"
|
||||||
|
version = "1.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
@@ -59,6 +71,22 @@ version = "1.0.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "core-foundation"
|
||||||
|
version = "0.9.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
|
||||||
|
dependencies = [
|
||||||
|
"core-foundation-sys",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "core-foundation-sys"
|
||||||
|
version = "0.8.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cpufeatures"
|
name = "cpufeatures"
|
||||||
version = "0.2.12"
|
version = "0.2.12"
|
||||||
@@ -77,12 +105,6 @@ dependencies = [
|
|||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "crunchy"
|
|
||||||
version = "0.2.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crypto-common"
|
name = "crypto-common"
|
||||||
version = "0.1.6"
|
version = "0.1.6"
|
||||||
@@ -93,6 +115,16 @@ dependencies = [
|
|||||||
"typenum",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "der"
|
||||||
|
version = "0.7.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb"
|
||||||
|
dependencies = [
|
||||||
|
"pem-rfc7468",
|
||||||
|
"zeroize",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "digest"
|
name = "digest"
|
||||||
version = "0.10.7"
|
version = "0.10.7"
|
||||||
@@ -110,9 +142,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
|
checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"windows-sys",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fastrand"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "filetime"
|
name = "filetime"
|
||||||
version = "0.2.23"
|
version = "0.2.23"
|
||||||
@@ -122,7 +160,7 @@ dependencies = [
|
|||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
"redox_syscall",
|
"redox_syscall",
|
||||||
"windows-sys",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -136,14 +174,20 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "form_urlencoded"
|
name = "foreign-types"
|
||||||
version = "1.2.1"
|
version = "0.3.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
|
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"percent-encoding",
|
"foreign-types-shared",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "foreign-types-shared"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "generic-array"
|
name = "generic-array"
|
||||||
version = "0.14.7"
|
version = "0.14.7"
|
||||||
@@ -154,27 +198,6 @@ dependencies = [
|
|||||||
"version_check",
|
"version_check",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "getrandom"
|
|
||||||
version = "0.2.15"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
|
|
||||||
dependencies = [
|
|
||||||
"cfg-if",
|
|
||||||
"libc",
|
|
||||||
"wasi",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "half"
|
|
||||||
version = "2.4.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
|
|
||||||
dependencies = [
|
|
||||||
"cfg-if",
|
|
||||||
"crunchy",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hound"
|
name = "hound"
|
||||||
version = "3.5.1"
|
version = "3.5.1"
|
||||||
@@ -182,23 +205,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f"
|
checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "idna"
|
name = "http"
|
||||||
version = "0.5.0"
|
version = "1.4.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
|
checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"unicode-bidi",
|
"bytes",
|
||||||
"unicode-normalization",
|
"itoa",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "js-sys"
|
name = "httparse"
|
||||||
version = "0.3.69"
|
version = "1.10.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d"
|
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||||
dependencies = [
|
|
||||||
"wasm-bindgen",
|
[[package]]
|
||||||
]
|
name = "itoa"
|
||||||
|
version = "1.0.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
@@ -206,16 +232,6 @@ version = "0.2.155"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "libloading"
|
|
||||||
version = "0.8.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
|
|
||||||
dependencies = [
|
|
||||||
"cfg-if",
|
|
||||||
"windows-targets",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "linux-raw-sys"
|
name = "linux-raw-sys"
|
||||||
version = "0.4.14"
|
version = "0.4.14"
|
||||||
@@ -224,9 +240,9 @@ checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "log"
|
name = "log"
|
||||||
version = "0.4.21"
|
version = "0.4.29"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
|
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matrixmultiply"
|
name = "matrixmultiply"
|
||||||
@@ -248,15 +264,34 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ndarray"
|
name = "native-tls"
|
||||||
version = "0.15.6"
|
version = "0.2.14"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"log",
|
||||||
|
"openssl",
|
||||||
|
"openssl-probe",
|
||||||
|
"openssl-sys",
|
||||||
|
"schannel",
|
||||||
|
"security-framework",
|
||||||
|
"security-framework-sys",
|
||||||
|
"tempfile",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray"
|
||||||
|
version = "0.16.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"matrixmultiply",
|
"matrixmultiply",
|
||||||
"num-complex",
|
"num-complex",
|
||||||
"num-integer",
|
"num-integer",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
|
"portable-atomic",
|
||||||
|
"portable-atomic-util",
|
||||||
"rawpointer",
|
"rawpointer",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -294,33 +329,83 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ort"
|
name = "openssl"
|
||||||
version = "2.0.0-rc.2"
|
version = "0.10.75"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14"
|
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.5.0",
|
||||||
|
"cfg-if",
|
||||||
|
"foreign-types",
|
||||||
|
"libc",
|
||||||
|
"once_cell",
|
||||||
|
"openssl-macros",
|
||||||
|
"openssl-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "openssl-macros"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "openssl-probe"
|
||||||
|
version = "0.1.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "openssl-sys"
|
||||||
|
version = "0.9.111"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"libc",
|
||||||
|
"pkg-config",
|
||||||
|
"vcpkg",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ort"
|
||||||
|
version = "2.0.0-rc.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"half",
|
|
||||||
"js-sys",
|
|
||||||
"libloading",
|
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"ort-sys",
|
"ort-sys",
|
||||||
"thiserror",
|
"smallvec",
|
||||||
"tracing",
|
"tracing",
|
||||||
"web-sys",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ort-sys"
|
name = "ort-sys"
|
||||||
version = "2.0.0-rc.2"
|
version = "2.0.0-rc.10"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe"
|
checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"flate2",
|
"flate2",
|
||||||
|
"pkg-config",
|
||||||
"sha2",
|
"sha2",
|
||||||
"tar",
|
"tar",
|
||||||
"ureq",
|
"ureq",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pem-rfc7468"
|
||||||
|
version = "0.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412"
|
||||||
|
dependencies = [
|
||||||
|
"base64ct",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "percent-encoding"
|
name = "percent-encoding"
|
||||||
version = "2.3.1"
|
version = "2.3.1"
|
||||||
@@ -333,6 +418,27 @@ version = "0.2.14"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
|
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pkg-config"
|
||||||
|
version = "0.3.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "portable-atomic"
|
||||||
|
version = "1.13.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "portable-atomic-util"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507"
|
||||||
|
dependencies = [
|
||||||
|
"portable-atomic",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.84"
|
version = "1.0.84"
|
||||||
@@ -366,21 +472,6 @@ dependencies = [
|
|||||||
"bitflags 1.3.2",
|
"bitflags 1.3.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "ring"
|
|
||||||
version = "0.17.8"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
|
|
||||||
dependencies = [
|
|
||||||
"cc",
|
|
||||||
"cfg-if",
|
|
||||||
"getrandom",
|
|
||||||
"libc",
|
|
||||||
"spin",
|
|
||||||
"untrusted",
|
|
||||||
"windows-sys",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rust-example"
|
name = "rust-example"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -400,38 +491,48 @@ dependencies = [
|
|||||||
"errno",
|
"errno",
|
||||||
"libc",
|
"libc",
|
||||||
"linux-raw-sys",
|
"linux-raw-sys",
|
||||||
"windows-sys",
|
"windows-sys 0.52.0",
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rustls"
|
|
||||||
version = "0.22.4"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
|
|
||||||
dependencies = [
|
|
||||||
"log",
|
|
||||||
"ring",
|
|
||||||
"rustls-pki-types",
|
|
||||||
"rustls-webpki",
|
|
||||||
"subtle",
|
|
||||||
"zeroize",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls-pki-types"
|
name = "rustls-pki-types"
|
||||||
version = "1.7.0"
|
version = "1.13.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d"
|
checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282"
|
||||||
|
dependencies = [
|
||||||
|
"zeroize",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls-webpki"
|
name = "schannel"
|
||||||
version = "0.102.4"
|
version = "0.1.28"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e"
|
checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ring",
|
"windows-sys 0.61.2",
|
||||||
"rustls-pki-types",
|
]
|
||||||
"untrusted",
|
|
||||||
|
[[package]]
|
||||||
|
name = "security-framework"
|
||||||
|
version = "2.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.5.0",
|
||||||
|
"core-foundation",
|
||||||
|
"core-foundation-sys",
|
||||||
|
"libc",
|
||||||
|
"security-framework-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "security-framework-sys"
|
||||||
|
version = "2.15.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0"
|
||||||
|
dependencies = [
|
||||||
|
"core-foundation-sys",
|
||||||
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -446,16 +547,21 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "spin"
|
name = "smallvec"
|
||||||
version = "0.9.8"
|
version = "2.0.0-alpha.10"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "subtle"
|
name = "socks"
|
||||||
version = "2.5.0"
|
version = "0.3.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc"
|
checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"libc",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
@@ -480,40 +586,18 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "tempfile"
|
||||||
version = "1.0.61"
|
version = "3.12.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
|
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"thiserror-impl",
|
"cfg-if",
|
||||||
|
"fastrand",
|
||||||
|
"once_cell",
|
||||||
|
"rustix",
|
||||||
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "thiserror-impl"
|
|
||||||
version = "1.0.61"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tinyvec"
|
|
||||||
version = "1.6.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50"
|
|
||||||
dependencies = [
|
|
||||||
"tinyvec_macros",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tinyvec_macros"
|
|
||||||
version = "0.1.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing"
|
name = "tracing"
|
||||||
version = "0.1.40"
|
version = "0.1.40"
|
||||||
@@ -521,21 +605,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tracing-attributes",
|
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tracing-attributes"
|
|
||||||
version = "0.1.27"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-core"
|
name = "tracing-core"
|
||||||
version = "0.1.32"
|
version = "0.1.32"
|
||||||
@@ -551,60 +623,54 @@ version = "1.17.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "unicode-bidi"
|
|
||||||
version = "0.3.15"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
version = "1.0.12"
|
version = "1.0.12"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "unicode-normalization"
|
|
||||||
version = "0.1.23"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5"
|
|
||||||
dependencies = [
|
|
||||||
"tinyvec",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "untrusted"
|
|
||||||
version = "0.9.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ureq"
|
name = "ureq"
|
||||||
version = "2.9.7"
|
version = "3.1.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd"
|
checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64",
|
"base64",
|
||||||
|
"der",
|
||||||
"log",
|
"log",
|
||||||
"once_cell",
|
"native-tls",
|
||||||
"rustls",
|
"percent-encoding",
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
"rustls-webpki",
|
"socks",
|
||||||
"url",
|
"ureq-proto",
|
||||||
"webpki-roots",
|
"utf-8",
|
||||||
|
"webpki-root-certs",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "url"
|
name = "ureq-proto"
|
||||||
version = "2.5.0"
|
version = "0.5.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633"
|
checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"form_urlencoded",
|
"base64",
|
||||||
"idna",
|
"http",
|
||||||
"percent-encoding",
|
"httparse",
|
||||||
|
"log",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utf-8"
|
||||||
|
version = "0.7.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "vcpkg"
|
||||||
|
version = "0.2.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "version_check"
|
name = "version_check"
|
||||||
version = "0.9.4"
|
version = "0.9.4"
|
||||||
@@ -612,84 +678,42 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasi"
|
name = "webpki-root-certs"
|
||||||
version = "0.11.0+wasi-snapshot-preview1"
|
version = "1.0.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
checksum = "ee3e3b5f5e80bc89f30ce8d0343bf4e5f12341c51f3e26cbeecbc7c85443e85b"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "wasm-bindgen"
|
|
||||||
version = "0.2.92"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8"
|
|
||||||
dependencies = [
|
|
||||||
"cfg-if",
|
|
||||||
"wasm-bindgen-macro",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "wasm-bindgen-backend"
|
|
||||||
version = "0.2.92"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da"
|
|
||||||
dependencies = [
|
|
||||||
"bumpalo",
|
|
||||||
"log",
|
|
||||||
"once_cell",
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn",
|
|
||||||
"wasm-bindgen-shared",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "wasm-bindgen-macro"
|
|
||||||
version = "0.2.92"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726"
|
|
||||||
dependencies = [
|
|
||||||
"quote",
|
|
||||||
"wasm-bindgen-macro-support",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "wasm-bindgen-macro-support"
|
|
||||||
version = "0.2.92"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn",
|
|
||||||
"wasm-bindgen-backend",
|
|
||||||
"wasm-bindgen-shared",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "wasm-bindgen-shared"
|
|
||||||
version = "0.2.92"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "web-sys"
|
|
||||||
version = "0.3.69"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef"
|
|
||||||
dependencies = [
|
|
||||||
"js-sys",
|
|
||||||
"wasm-bindgen",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "webpki-roots"
|
|
||||||
version = "0.26.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi"
|
||||||
|
version = "0.3.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-i686-pc-windows-gnu",
|
||||||
|
"winapi-x86_64-pc-windows-gnu",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-i686-pc-windows-gnu"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-x86_64-pc-windows-gnu"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-link"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-sys"
|
name = "windows-sys"
|
||||||
version = "0.52.0"
|
version = "0.52.0"
|
||||||
@@ -700,10 +724,28 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-targets"
|
name = "windows-sys"
|
||||||
version = "0.52.5"
|
version = "0.59.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
|
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
|
||||||
|
dependencies = [
|
||||||
|
"windows-targets",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-sys"
|
||||||
|
version = "0.61.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc"
|
||||||
|
dependencies = [
|
||||||
|
"windows-link",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-targets"
|
||||||
|
version = "0.52.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows_aarch64_gnullvm",
|
"windows_aarch64_gnullvm",
|
||||||
"windows_aarch64_msvc",
|
"windows_aarch64_msvc",
|
||||||
@@ -717,51 +759,51 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_gnullvm"
|
name = "windows_aarch64_gnullvm"
|
||||||
version = "0.52.5"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
|
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_msvc"
|
name = "windows_aarch64_msvc"
|
||||||
version = "0.52.5"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
|
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnu"
|
name = "windows_i686_gnu"
|
||||||
version = "0.52.5"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
|
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnullvm"
|
name = "windows_i686_gnullvm"
|
||||||
version = "0.52.5"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
|
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_msvc"
|
name = "windows_i686_msvc"
|
||||||
version = "0.52.5"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
|
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnu"
|
name = "windows_x86_64_gnu"
|
||||||
version = "0.52.5"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
|
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnullvm"
|
name = "windows_x86_64_gnullvm"
|
||||||
version = "0.52.5"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
|
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_msvc"
|
name = "windows_x86_64_msvc"
|
||||||
version = "0.52.5"
|
version = "0.52.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
|
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "xattr"
|
name = "xattr"
|
||||||
|
|||||||
@@ -4,6 +4,6 @@ version = "0.1.0"
|
|||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ort = { version = "2.0.0-rc.2", features = ["load-dynamic", "ndarray"] }
|
ort = { version = "=2.0.0-rc.10", features = ["ndarray"] }
|
||||||
ndarray = "0.15"
|
ndarray = "0.16"
|
||||||
hound = "3"
|
hound = "3"
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ mod vad_iter;
|
|||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let model_path = std::env::var("SILERO_MODEL_PATH")
|
let model_path = std::env::var("SILERO_MODEL_PATH")
|
||||||
.unwrap_or_else(|_| String::from("../../files/silero_vad.onnx"));
|
.unwrap_or_else(|_| String::from("../../src/silero_vad/data/silero_vad.onnx"));
|
||||||
let audio_path = std::env::args()
|
let audio_path = std::env::args()
|
||||||
.nth(1)
|
.nth(1)
|
||||||
.unwrap_or_else(|| String::from("recorder.wav"));
|
.unwrap_or_else(|| String::from("recorder.wav"));
|
||||||
|
|||||||
@@ -1,12 +1,17 @@
|
|||||||
use crate::utils;
|
use crate::utils;
|
||||||
use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
use ndarray::{Array, Array1, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
||||||
|
use ort::session::Session;
|
||||||
|
use ort::value::Value;
|
||||||
|
use std::mem::take;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Silero {
|
pub struct Silero {
|
||||||
session: ort::Session,
|
session: Session,
|
||||||
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
||||||
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
||||||
|
context: Array1<f32>,
|
||||||
|
context_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Silero {
|
impl Silero {
|
||||||
@@ -14,18 +19,24 @@ impl Silero {
|
|||||||
sample_rate: utils::SampleRate,
|
sample_rate: utils::SampleRate,
|
||||||
model_path: impl AsRef<Path>,
|
model_path: impl AsRef<Path>,
|
||||||
) -> Result<Self, ort::Error> {
|
) -> Result<Self, ort::Error> {
|
||||||
let session = ort::Session::builder()?.commit_from_file(model_path)?;
|
let session = Session::builder()?.commit_from_file(model_path)?;
|
||||||
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||||
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
|
let sample_rate_val: i64 = sample_rate.into();
|
||||||
|
let context_size = if sample_rate_val == 16000 { 64 } else { 32 };
|
||||||
|
let context = Array1::<f32>::zeros(context_size);
|
||||||
|
let sample_rate = Array::from_shape_vec([1], vec![sample_rate_val]).unwrap();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
session,
|
session,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
state,
|
state,
|
||||||
|
context,
|
||||||
|
context_size,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset(&mut self) {
|
pub fn reset(&mut self) {
|
||||||
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||||
|
self.context = Array1::<f32>::zeros(self.context_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
|
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
|
||||||
@@ -33,22 +44,41 @@ impl Silero {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|x| (*x as f32) / (i16::MAX as f32))
|
.map(|x| (*x as f32) / (i16::MAX as f32))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
|
|
||||||
frame = frame.slice(s![.., ..480]).to_owned();
|
// Concatenate context with input
|
||||||
let inps = ort::inputs![
|
let mut input_with_context = Vec::with_capacity(self.context_size + data.len());
|
||||||
frame,
|
input_with_context.extend_from_slice(self.context.as_slice().unwrap());
|
||||||
std::mem::take(&mut self.state),
|
input_with_context.extend_from_slice(&data);
|
||||||
self.sample_rate.clone(),
|
|
||||||
]?;
|
let frame =
|
||||||
let res = self
|
Array2::<f32>::from_shape_vec([1, input_with_context.len()], input_with_context)
|
||||||
.session
|
.unwrap();
|
||||||
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
|
|
||||||
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
|
let frame_value = Value::from_array(frame)?;
|
||||||
Ok(*res["output"]
|
let state_value = Value::from_array(take(&mut self.state))?;
|
||||||
.try_extract_raw_tensor::<f32>()
|
let sr_value = Value::from_array(self.sample_rate.clone())?;
|
||||||
|
|
||||||
|
let res = self.session.run([
|
||||||
|
(&frame_value).into(),
|
||||||
|
(&state_value).into(),
|
||||||
|
(&sr_value).into(),
|
||||||
|
])?;
|
||||||
|
|
||||||
|
let (shape, state_data) = res["stateN"].try_extract_tensor::<f32>()?;
|
||||||
|
let shape_usize: Vec<usize> = shape.as_ref().iter().map(|&d| d as usize).collect();
|
||||||
|
self.state = ArrayD::from_shape_vec(shape_usize.as_slice(), state_data.to_vec()).unwrap();
|
||||||
|
|
||||||
|
// Update context with last context_size samples from the input
|
||||||
|
if data.len() >= self.context_size {
|
||||||
|
self.context = Array1::from_vec(data[data.len() - self.context_size..].to_vec());
|
||||||
|
}
|
||||||
|
|
||||||
|
let prob = *res["output"]
|
||||||
|
.try_extract_tensor::<f32>()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.1
|
.1
|
||||||
.first()
|
.first()
|
||||||
.unwrap())
|
.unwrap();
|
||||||
|
Ok(prob)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ pub struct VadParams {
|
|||||||
impl Default for VadParams {
|
impl Default for VadParams {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
frame_size: 64,
|
frame_size: 32, // 32ms for 512 samples at 16kHz
|
||||||
threshold: 0.5,
|
threshold: 0.5,
|
||||||
min_silence_duration_ms: 0,
|
min_silence_duration_ms: 0,
|
||||||
speech_pad_ms: 64,
|
speech_pad_ms: 64,
|
||||||
|
|||||||
BIN
src/silero_vad/data/silero_vad_16k.safetensors
Executable file
BIN
src/silero_vad/data/silero_vad_16k.safetensors
Executable file
Binary file not shown.
BIN
src/silero_vad/data/silero_vad_op18_ifless.onnx
Normal file
BIN
src/silero_vad/data/silero_vad_op18_ifless.onnx
Normal file
Binary file not shown.
71
src/silero_vad/tinygrad_model.py
Normal file
71
src/silero_vad/tinygrad_model.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from tinygrad import nn
|
||||||
|
|
||||||
|
|
||||||
|
class TinySileroVAD:
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
from tinygrad.nn.state import safe_load, load_state_dict
|
||||||
|
|
||||||
|
tiny_model = TinySileroVAD()
|
||||||
|
state_dict = safe_load('data/silero_vad_16k.safetensors')
|
||||||
|
load_state_dict(tiny_model, state_dict)
|
||||||
|
"""
|
||||||
|
self.n_fft = 256
|
||||||
|
self.stride = 128
|
||||||
|
self.pad = 64
|
||||||
|
self.cutoff = int(self.n_fft // 2) + 1
|
||||||
|
|
||||||
|
self.stft_conv = nn.Conv1d(1, 258, kernel_size=256, stride=self.stride, padding=0, bias=False)
|
||||||
|
self.conv1 = nn.Conv1d(129, 128, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.conv2 = nn.Conv1d(128, 64, kernel_size=3, stride=2, padding=1)
|
||||||
|
self.conv3 = nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1)
|
||||||
|
self.conv4 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.lstm_cell = nn.LSTMCell(128, 128)
|
||||||
|
self.final_conv = nn.Conv1d(128, 1, 1)
|
||||||
|
|
||||||
|
def __call__(self, x, state=None):
|
||||||
|
"""
|
||||||
|
# full audio example:
|
||||||
|
import torch
|
||||||
|
from tinygrad import Tensor
|
||||||
|
|
||||||
|
wav = read_audio(audio_path, sampling_rate=16000).unsqueeze(0)
|
||||||
|
num_samples = 512
|
||||||
|
context_size = 64
|
||||||
|
context = Tensor(np.zeros((1, context_size))).float()
|
||||||
|
outs = []
|
||||||
|
state = None
|
||||||
|
if wav.shape[1] % num_samples:
|
||||||
|
pad_num = num_samples - (wav.shape[1] % num_samples)
|
||||||
|
wav = torch.nn.functional.pad(wav, (0, pad_num), 'constant', value=0.0)
|
||||||
|
|
||||||
|
wav = torch.nn.functional.pad(wav, (context_size, 0))
|
||||||
|
|
||||||
|
wav = Tensor(wav.numpy()).float()
|
||||||
|
|
||||||
|
for i in tqdm(range(context_size, wav.shape[1], num_samples)):
|
||||||
|
wavs_batch = wav[:, i-context_size:i+num_samples]
|
||||||
|
out_chunk, state = tiny_model(wavs_batch, state)
|
||||||
|
#outs.append(out_chunk.numpy())
|
||||||
|
outs.append(out_chunk)
|
||||||
|
|
||||||
|
predict = outs[0].cat(*outs[1:], dim=1).numpy()
|
||||||
|
|
||||||
|
"""
|
||||||
|
if state is not None:
|
||||||
|
state = (state[0], state[1])
|
||||||
|
x = x.pad((0, self.pad), "reflect").unsqueeze(1)
|
||||||
|
x = self.stft_conv(x)
|
||||||
|
x = (x[:, :self.cutoff, :]**2 + x[:, self.cutoff:, :]**2).sqrt()
|
||||||
|
x = self.conv1(x).relu()
|
||||||
|
x = self.conv2(x).relu()
|
||||||
|
x = self.conv3(x).relu()
|
||||||
|
x = self.conv4(x).relu().squeeze(-1)
|
||||||
|
h, c = self.lstm_cell(x, state)
|
||||||
|
x = h.unsqueeze(-1)
|
||||||
|
state = h.stack(c, dim=0)
|
||||||
|
x = x.relu()
|
||||||
|
x = self.final_conv(x).sigmoid()
|
||||||
|
x = x.squeeze(1).mean(axis=1).unsqueeze(1)
|
||||||
|
return x, state
|
||||||
Reference in New Issue
Block a user