mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
Add c++ onnxruntime example
This commit is contained in:
@@ -20,9 +20,9 @@ This repository also includes Number Detector and Language classifier [models](h
|
||||
|
||||
<details>
|
||||
<summary>Real Time Example</summary>
|
||||
|
||||
|
||||
https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<br/>
|
||||
|
||||
290
cpp/silero_vad_onnx_1.cpp
Normal file
290
cpp/silero_vad_onnx_1.cpp
Normal file
@@ -0,0 +1,290 @@
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
#include <chrono>
|
||||
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#include "wav.h"
|
||||
|
||||
class VadModel
|
||||
{
|
||||
// OnnxRuntime resources
|
||||
Ort::Env env;
|
||||
Ort::SessionOptions session_options;
|
||||
std::shared_ptr<Ort::Session> session = nullptr;
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
|
||||
|
||||
public:
|
||||
void init_engine_threads(int inter_threads, int intra_threads)
|
||||
{
|
||||
// The method should be called in each thread/proc in multi-thread/proc work
|
||||
session_options.SetIntraOpNumThreads(intra_threads);
|
||||
session_options.SetInterOpNumThreads(inter_threads);
|
||||
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||
}
|
||||
|
||||
void init_onnx_model(const std::string &model_path)
|
||||
{
|
||||
// Init threads = 1 for
|
||||
init_engine_threads(1, 1);
|
||||
// Load model
|
||||
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
|
||||
}
|
||||
|
||||
void reset_states()
|
||||
{
|
||||
// Call reset before each audio start
|
||||
std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
|
||||
std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
|
||||
triggerd = false;
|
||||
temp_end = 0;
|
||||
current_sample = 0;
|
||||
}
|
||||
|
||||
// Call it in predict func. if you prefer raw bytes input.
|
||||
void bytes_to_float_tensor(const char *pcm_bytes)
|
||||
{
|
||||
std::memcpy(input.data(), pcm_bytes, window_size_samples * sizeof(int16_t));
|
||||
for (int i = 0; i < window_size_samples; i++)
|
||||
{
|
||||
input[i] = static_cast<float>(input[i]) / 32768; // int16_t normalized to float
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void predict(const std::vector<float> &data) // const char *data
|
||||
{
|
||||
// bytes_to_float_tensor(data);
|
||||
|
||||
// Infer
|
||||
// Inputs
|
||||
input.assign(data.begin(), data.end());
|
||||
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, input.data(), input.size(), input_node_dims, 2);
|
||||
// std::cout << "input size:" << input.size() << std::endl;
|
||||
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
|
||||
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
|
||||
Ort::Value h_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, _h.data(), _h.size(), hc_node_dims, 3);
|
||||
Ort::Value c_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, _c.data(), _c.size(), hc_node_dims, 3);
|
||||
|
||||
ort_inputs.clear(); // clear inputs
|
||||
ort_inputs.emplace_back(std::move(input_ort));
|
||||
ort_inputs.emplace_back(std::move(sr_ort));
|
||||
ort_inputs.emplace_back(std::move(h_ort));
|
||||
ort_inputs.emplace_back(std::move(c_ort));
|
||||
|
||||
// Infer
|
||||
ort_outputs = session->Run(
|
||||
Ort::RunOptions{nullptr},
|
||||
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
|
||||
output_node_names.data(), output_node_names.size());
|
||||
|
||||
// out put Probability & update h,c recursively
|
||||
float output = ort_outputs[0].GetTensorMutableData<float>()[0];
|
||||
float *hn = ort_outputs[1].GetTensorMutableData<float>();
|
||||
std::memcpy(_h.data(), hn, size_hc * sizeof(float));
|
||||
float *cn = ort_outputs[2].GetTensorMutableData<float>();
|
||||
std::memcpy(_c.data(), cn, size_hc * sizeof(float));
|
||||
|
||||
// Push forward sample index
|
||||
current_sample += window_size_samples;
|
||||
|
||||
// 1) Reset temp_end when > threshold
|
||||
if ((output >= threshold) && (temp_end != 0))
|
||||
{
|
||||
temp_end = 0;
|
||||
}
|
||||
// 2) Trigger and start sentence
|
||||
if ((output >= threshold) && (triggerd == false))
|
||||
{
|
||||
triggerd = true;
|
||||
speech_start = current_sample - speech_pad_samples;
|
||||
printf("{ start: %.3f s }\n", 1.0 * current_sample / sample_rate);
|
||||
}
|
||||
// 3) Speaking
|
||||
if ((output >= (threshold - 0.15)) && (triggerd == true))
|
||||
{
|
||||
printf("{ speaking: %.3f s }\n", 1.0 * current_sample / sample_rate);
|
||||
}
|
||||
// 4) End
|
||||
if ((output < (threshold - 0.15)) && (triggerd == true))
|
||||
{
|
||||
|
||||
if (temp_end != 0)
|
||||
{
|
||||
temp_end = current_sample;
|
||||
}
|
||||
// a. silence < min_slience_samples, continue speaking
|
||||
if ((current_sample - temp_end) < min_silence_samples)
|
||||
{
|
||||
printf("{ speaking: %.3f s }\n", 1.0 * current_sample / sample_rate);
|
||||
}
|
||||
// b. silence >= min_slience_samples, end speaking
|
||||
else
|
||||
{
|
||||
speech_end = temp_end + speech_pad_samples;
|
||||
temp_end = 0;
|
||||
triggerd = false;
|
||||
printf("{ end: %.3f s }\n", 1.0 * current_sample / sample_rate);
|
||||
}
|
||||
}
|
||||
// 5) Silence
|
||||
if ((output < threshold) && (triggerd == false))
|
||||
{
|
||||
printf("{ silence: %.3f s }\n", 1.0 * current_sample / sample_rate);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Print input output shape of the model
|
||||
void GetInputOutputInfo(
|
||||
const std::shared_ptr<Ort::Session> &session,
|
||||
std::vector<const char *> *in_names, std::vector<const char *> *out_names)
|
||||
{
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
// Input info
|
||||
int num_nodes = session->GetInputCount();
|
||||
in_names->resize(num_nodes);
|
||||
for (int i = 0; i < num_nodes; ++i)
|
||||
{
|
||||
char *name = session->GetInputName(i, allocator);
|
||||
Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
|
||||
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
||||
ONNXTensorElementDataType type = tensor_info.GetElementType();
|
||||
std::vector<int64_t> node_dims = tensor_info.GetShape();
|
||||
std::stringstream shape;
|
||||
for (auto j : node_dims)
|
||||
{
|
||||
shape << j;
|
||||
shape << " ";
|
||||
}
|
||||
std::cout << "\tInput " << i << " : name=" << name << " type=" << type
|
||||
<< " dims=" << shape.str() << std::endl;
|
||||
(*in_names)[i] = name;
|
||||
}
|
||||
// Output info
|
||||
num_nodes = session->GetOutputCount();
|
||||
out_names->resize(num_nodes);
|
||||
for (int i = 0; i < num_nodes; ++i)
|
||||
{
|
||||
char *name = session->GetOutputName(i, allocator);
|
||||
Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
|
||||
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
||||
ONNXTensorElementDataType type = tensor_info.GetElementType();
|
||||
std::vector<int64_t> node_dims = tensor_info.GetShape();
|
||||
std::stringstream shape;
|
||||
for (auto j : node_dims)
|
||||
{
|
||||
shape << j;
|
||||
shape << " ";
|
||||
}
|
||||
std::cout << "\tOutput " << i << " : name=" << name << " type=" << type
|
||||
<< " dims=" << shape.str() << std::endl;
|
||||
;
|
||||
(*out_names)[i] = name;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// model config
|
||||
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
|
||||
int sample_rate;
|
||||
int sr_per_ms; // Assign when init, support 8 or 16
|
||||
float threshold = 0.5;
|
||||
int min_silence_samples; // sr_per_ms * #ms
|
||||
int speech_pad_samples = 0; // Can be used in offline infer to get as much speech as possible
|
||||
|
||||
// model states
|
||||
bool triggerd = false;
|
||||
unsigned int speech_start = 0;
|
||||
unsigned int speech_end = 0;
|
||||
unsigned int temp_end = 0;
|
||||
unsigned int current_sample = 0;
|
||||
// MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
|
||||
float output;
|
||||
|
||||
// Onnx model
|
||||
// Inputs
|
||||
std::vector<Ort::Value> ort_inputs;
|
||||
|
||||
std::vector<const char *> input_node_names = {"input", "sr", "h", "c"};
|
||||
std::vector<float> input;
|
||||
std::vector<int64_t> sr;
|
||||
unsigned int size_hc = 2 * 1 * 64; // It's FIXED.
|
||||
std::vector<float> _h;
|
||||
std::vector<float> _c;
|
||||
|
||||
int64_t input_node_dims[2] = {};
|
||||
const int64_t sr_node_dims[1] = {1};
|
||||
const int64_t hc_node_dims[3] = {2, 1, 64};
|
||||
|
||||
// Outputs
|
||||
std::vector<Ort::Value> ort_outputs;
|
||||
std::vector<const char *> output_node_names = {"output", "hn", "cn"};
|
||||
|
||||
|
||||
public:
|
||||
// Construct init
|
||||
VadModel(const std::string ModelPath,
|
||||
int sample_rate, int frame_size,
|
||||
float threshold, int min_silence_duration_ms, int speech_pad_ms)
|
||||
{
|
||||
init_onnx_model(ModelPath);
|
||||
sr_per_ms = sample_rate / 1000;
|
||||
min_silence_samples = sr_per_ms * min_silence_duration_ms;
|
||||
speech_pad_samples = sr_per_ms * speech_pad_ms;
|
||||
window_size_samples = frame_size * sr_per_ms; // Input 64ms/frame * 8ms = 512 samples/frame
|
||||
input.resize(window_size_samples);
|
||||
input_node_dims[0] = 1;
|
||||
input_node_dims[1] = window_size_samples;
|
||||
// std::cout << "== Input size" << input.size() << std::endl;
|
||||
_h.resize(size_hc);
|
||||
_c.resize(size_hc);
|
||||
sr.resize(1);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
|
||||
// Read wav
|
||||
wav::WavReader wav_reader("silero-vad-master/test_audios/test0_for_vad.wav");
|
||||
|
||||
std::vector<int16_t> data(wav_reader.num_samples());
|
||||
std::vector<float> input_wav(wav_reader.num_samples());
|
||||
|
||||
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||
{
|
||||
data[i] = static_cast<int16_t>(*(wav_reader.data() + i));
|
||||
}
|
||||
|
||||
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||
{
|
||||
input_wav[i] = static_cast<float>(data[i]) / 32768;
|
||||
}
|
||||
|
||||
std::string path = "silero-vad-master/files/silero_vad.onnx";
|
||||
int test_sr = 8000;
|
||||
int test_frame_ms = 64;
|
||||
int test_window_samples = test_frame_ms * (test_sr/1000);
|
||||
VadModel vad(path, test_sr, test_frame_ms);
|
||||
// std::cout << "== 3" << std::endl;
|
||||
// std::cout << vad.window_size_samples1() << std::endl;
|
||||
|
||||
for (int j = 0; j < wav_reader.num_samples(); j += test_window_samples)
|
||||
{
|
||||
std::vector<float> r{&input_wav[0] + j, &input_wav[0] + j + test_window_samples};
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
// Predict and print throughout process time
|
||||
vad.predict(r);
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
auto elapsed_time = std::chrono::duration_cast<std::chrono::nanoseconds>(end-start);
|
||||
std::cout << "== Elapsed time: " << elapsed_time.count() << "ns" << " ==" <<std::endl;
|
||||
|
||||
}
|
||||
}
|
||||
205
cpp/wav.h
Normal file
205
cpp/wav.h
Normal file
@@ -0,0 +1,205 @@
|
||||
// Copyright (c) 2016 Personal (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef FRONTEND_WAV_H_
|
||||
#define FRONTEND_WAV_H_
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
// #include "utils/log.h"
|
||||
|
||||
namespace wav {
|
||||
|
||||
struct WavHeader {
|
||||
char riff[4]; // "riff"
|
||||
unsigned int size;
|
||||
char wav[4]; // "WAVE"
|
||||
char fmt[4]; // "fmt "
|
||||
unsigned int fmt_size;
|
||||
uint16_t format;
|
||||
uint16_t channels;
|
||||
unsigned int sample_rate;
|
||||
unsigned int bytes_per_second;
|
||||
uint16_t block_size;
|
||||
uint16_t bit;
|
||||
char data[4]; // "data"
|
||||
unsigned int data_size;
|
||||
};
|
||||
|
||||
class WavReader {
|
||||
public:
|
||||
WavReader() : data_(nullptr) {}
|
||||
explicit WavReader(const std::string& filename) { Open(filename); }
|
||||
|
||||
bool Open(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
|
||||
if (NULL == fp) {
|
||||
std::cout << "Error in read " << filename;
|
||||
return false;
|
||||
}
|
||||
|
||||
WavHeader header;
|
||||
fread(&header, 1, sizeof(header), fp);
|
||||
if (header.fmt_size < 16) {
|
||||
fprintf(stderr,
|
||||
"WaveData: expect PCM format data "
|
||||
"to have fmt chunk of at least size 16.\n");
|
||||
return false;
|
||||
} else if (header.fmt_size > 16) {
|
||||
int offset = 44 - 8 + header.fmt_size - 16;
|
||||
fseek(fp, offset, SEEK_SET);
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
// check "riff" "WAVE" "fmt " "data"
|
||||
|
||||
// Skip any sub-chunks between "fmt" and "data". Usually there will
|
||||
// be a single "fact" sub chunk, but on Windows there can also be a
|
||||
// "list" sub chunk.
|
||||
while (0 != strncmp(header.data, "data", 4)) {
|
||||
// We will just ignore the data in these chunks.
|
||||
fseek(fp, header.data_size, SEEK_CUR);
|
||||
// read next sub chunk
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
|
||||
num_channel_ = header.channels;
|
||||
sample_rate_ = header.sample_rate;
|
||||
bits_per_sample_ = header.bit;
|
||||
int num_data = header.data_size / (bits_per_sample_ / 8);
|
||||
data_ = new float[num_data]; // Create 1-dim array
|
||||
num_samples_ = num_data / num_channel_;
|
||||
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample;
|
||||
fread(&sample, 1, sizeof(char), fp);
|
||||
data_[i] = static_cast<float>(sample);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample;
|
||||
fread(&sample, 1, sizeof(int16_t), fp);
|
||||
// std::cout << sample;
|
||||
data_[i] = static_cast<float>(sample);
|
||||
// std::cout << data_[i];
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
int sample;
|
||||
fread(&sample, 1, sizeof(int), fp);
|
||||
data_[i] = static_cast<float>(sample);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
fprintf(stderr, "unsupported quantization bits");
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
fclose(fp);
|
||||
return true;
|
||||
}
|
||||
|
||||
int num_channel() const { return num_channel_; }
|
||||
int sample_rate() const { return sample_rate_; }
|
||||
int bits_per_sample() const { return bits_per_sample_; }
|
||||
int num_samples() const { return num_samples_; }
|
||||
|
||||
~WavReader() {
|
||||
delete[] data_;
|
||||
}
|
||||
|
||||
const float* data() const { return data_; }
|
||||
|
||||
private:
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
int num_samples_; // sample points per channel
|
||||
float* data_;
|
||||
};
|
||||
|
||||
class WavWriter {
|
||||
public:
|
||||
WavWriter(const float* data, int num_samples, int num_channel,
|
||||
int sample_rate, int bits_per_sample)
|
||||
: data_(data),
|
||||
num_samples_(num_samples),
|
||||
num_channel_(num_channel),
|
||||
sample_rate_(sample_rate),
|
||||
bits_per_sample_(bits_per_sample) {}
|
||||
|
||||
void Write(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "w");
|
||||
// init char 'riff' 'WAVE' 'fmt ' 'data'
|
||||
WavHeader header;
|
||||
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
|
||||
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
|
||||
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
|
||||
memcpy(&header, wav_header, sizeof(header));
|
||||
header.channels = num_channel_;
|
||||
header.bit = bits_per_sample_;
|
||||
header.sample_rate = sample_rate_;
|
||||
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.size = sizeof(header) - 8 + header.data_size;
|
||||
header.bytes_per_second =
|
||||
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.block_size = num_channel_ * (bits_per_sample_ / 8);
|
||||
|
||||
fwrite(&header, 1, sizeof(header), fp);
|
||||
|
||||
for (int i = 0; i < num_samples_; ++i) {
|
||||
for (int j = 0; j < num_channel_; ++j) {
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample = static_cast<char>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
int sample = static_cast<int>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fclose(fp);
|
||||
}
|
||||
|
||||
private:
|
||||
const float* data_;
|
||||
int num_samples_; // total float points in data_
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
};
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // FRONTEND_WAV_H_
|
||||
50
runtime/cpp/README.md
Normal file
50
runtime/cpp/README.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Stream example in C++
|
||||
|
||||
Here's a simple example of the vad model in c++ onnxruntime.
|
||||
|
||||
|
||||
|
||||
## Requirements
|
||||
|
||||
Code are tested in the environments bellow, feel free to try others.
|
||||
|
||||
- WSL2 + Debian-bullseye (docker)
|
||||
- gcc 12.2.0
|
||||
- onnxruntime-linux-x64-1.12.1
|
||||
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
1. Install gcc 12.2.0, or just pull the docker image with `docker pull gcc:12.2.0-bullseye`
|
||||
|
||||
2. Install onnxruntime-linux-x64-1.12.1
|
||||
|
||||
- Dowload lib onnxruntime:
|
||||
|
||||
`wget https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz`
|
||||
|
||||
- Unzip. Assume the path is `/root/onnxruntime-linux-x64-1.12.1`
|
||||
|
||||
3. Modify wav path & Test configs in main function
|
||||
|
||||
`wav::WavReader wav_reader("${path_to_your_wav_file}");`
|
||||
|
||||
test sample rate, frame per ms, threshold...
|
||||
|
||||
4. Build with gcc and run
|
||||
|
||||
```bash
|
||||
# Build
|
||||
g++ silero-vad-onnx.cpp -I /root/onnxruntime-linux-x64-1.12.1/include/ -L /root/onnxruntime-linux-x64-1.12.1/lib/ -lonnxruntime -Wl,-rpath,/root/onnxruntime-linux-x64-1.12.1/lib/ -o test
|
||||
|
||||
# Run
|
||||
./test
|
||||
```
|
||||
|
||||
build:
|
||||
|
||||
`g++ silero-vad-onnx.cpp -I /root/onnxruntime-linux-x64-1.12.1/include/ -L /root/onnxruntime-linux-x64-1.12.1/lib/ -lonnxruntime -Wl,-rpath,/root/onnxruntime-linux-x64-1.12.1/lib/ -o test`
|
||||
|
||||
`./test`
|
||||
|
||||
253
runtime/cpp/silero-vad-onnx.cpp
Normal file
253
runtime/cpp/silero-vad-onnx.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
#include <chrono>
|
||||
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#include "wav.h"
|
||||
|
||||
class VadIterator
|
||||
{
|
||||
// OnnxRuntime resources
|
||||
Ort::Env env;
|
||||
Ort::SessionOptions session_options;
|
||||
std::shared_ptr<Ort::Session> session = nullptr;
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
|
||||
|
||||
public:
|
||||
void init_engine_threads(int inter_threads, int intra_threads)
|
||||
{
|
||||
// The method should be called in each thread/proc in multi-thread/proc work
|
||||
session_options.SetIntraOpNumThreads(intra_threads);
|
||||
session_options.SetInterOpNumThreads(inter_threads);
|
||||
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||
}
|
||||
|
||||
void init_onnx_model(const std::string &model_path)
|
||||
{
|
||||
// Init threads = 1 for
|
||||
init_engine_threads(1, 1);
|
||||
// Load model
|
||||
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
|
||||
}
|
||||
|
||||
void reset_states()
|
||||
{
|
||||
// Call reset before each audio start
|
||||
std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
|
||||
std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
|
||||
triggerd = false;
|
||||
temp_end = 0;
|
||||
current_sample = 0;
|
||||
}
|
||||
|
||||
// Call it in predict func. if you prefer raw bytes input.
|
||||
void bytes_to_float_tensor(const char *pcm_bytes)
|
||||
{
|
||||
std::memcpy(input.data(), pcm_bytes, window_size_samples * sizeof(int16_t));
|
||||
for (int i = 0; i < window_size_samples; i++)
|
||||
{
|
||||
input[i] = static_cast<float>(input[i]) / 32768; // int16_t normalized to float
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void predict(const std::vector<float> &data)
|
||||
{
|
||||
// bytes_to_float_tensor(data);
|
||||
|
||||
// Infer
|
||||
// Create ort tensors
|
||||
input.assign(data.begin(), data.end());
|
||||
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, input.data(), input.size(), input_node_dims, 2);
|
||||
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
|
||||
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
|
||||
Ort::Value h_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, _h.data(), _h.size(), hc_node_dims, 3);
|
||||
Ort::Value c_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, _c.data(), _c.size(), hc_node_dims, 3);
|
||||
|
||||
// Clear and add inputs
|
||||
ort_inputs.clear();
|
||||
ort_inputs.emplace_back(std::move(input_ort));
|
||||
ort_inputs.emplace_back(std::move(sr_ort));
|
||||
ort_inputs.emplace_back(std::move(h_ort));
|
||||
ort_inputs.emplace_back(std::move(c_ort));
|
||||
|
||||
// Infer
|
||||
ort_outputs = session->Run(
|
||||
Ort::RunOptions{nullptr},
|
||||
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
|
||||
output_node_names.data(), output_node_names.size());
|
||||
|
||||
// Output probability & update h,c recursively
|
||||
float output = ort_outputs[0].GetTensorMutableData<float>()[0];
|
||||
float *hn = ort_outputs[1].GetTensorMutableData<float>();
|
||||
std::memcpy(_h.data(), hn, size_hc * sizeof(float));
|
||||
float *cn = ort_outputs[2].GetTensorMutableData<float>();
|
||||
std::memcpy(_c.data(), cn, size_hc * sizeof(float));
|
||||
|
||||
// Push forward sample index
|
||||
current_sample += window_size_samples;
|
||||
|
||||
// Reset temp_end when > threshold
|
||||
if ((output >= threshold) && (temp_end != 0))
|
||||
{
|
||||
temp_end = 0;
|
||||
}
|
||||
// 1) Silence
|
||||
if ((output < threshold) && (triggerd == false))
|
||||
{
|
||||
// printf("{ silence: %.3f s }\n", 1.0 * current_sample / sample_rate);
|
||||
}
|
||||
// 2) Speaking
|
||||
if ((output >= (threshold - 0.15)) && (triggerd == true))
|
||||
{
|
||||
// printf("{ speaking_2: %.3f s }\n", 1.0 * current_sample / sample_rate);
|
||||
}
|
||||
|
||||
// 3) Start
|
||||
if ((output >= threshold) && (triggerd == false))
|
||||
{
|
||||
triggerd = true;
|
||||
speech_start = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point.
|
||||
printf("{ start: %.3f s }\n", 1.0 * speech_start / sample_rate);
|
||||
}
|
||||
|
||||
// 4) End
|
||||
if ((output < (threshold - 0.15)) && (triggerd == true))
|
||||
{
|
||||
|
||||
if (temp_end != 0)
|
||||
{
|
||||
temp_end = current_sample;
|
||||
}
|
||||
// a. silence < min_slience_samples, continue speaking
|
||||
if ((current_sample - temp_end) < min_silence_samples)
|
||||
{
|
||||
// printf("{ speaking_4: %.3f s }\n", 1.0 * current_sample / sample_rate);
|
||||
// printf("");
|
||||
}
|
||||
// b. silence >= min_slience_samples, end speaking
|
||||
else
|
||||
{
|
||||
speech_end = current_sample + speech_pad_samples;
|
||||
temp_end = 0;
|
||||
triggerd = false;
|
||||
printf("{ end: %.3f s }\n", 1.0 * speech_end / sample_rate);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
// model config
|
||||
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
|
||||
int sample_rate;
|
||||
int sr_per_ms; // Assign when init, support 8 or 16
|
||||
float threshold;
|
||||
int min_silence_samples; // sr_per_ms * #ms
|
||||
int speech_pad_samples; // usually a
|
||||
|
||||
// model states
|
||||
bool triggerd = false;
|
||||
unsigned int speech_start = 0;
|
||||
unsigned int speech_end = 0;
|
||||
unsigned int temp_end = 0;
|
||||
unsigned int current_sample = 0;
|
||||
// MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
|
||||
float output;
|
||||
|
||||
// Onnx model
|
||||
// Inputs
|
||||
std::vector<Ort::Value> ort_inputs;
|
||||
|
||||
std::vector<const char *> input_node_names = {"input", "sr", "h", "c"};
|
||||
std::vector<float> input;
|
||||
std::vector<int64_t> sr;
|
||||
unsigned int size_hc = 2 * 1 * 64; // It's FIXED.
|
||||
std::vector<float> _h;
|
||||
std::vector<float> _c;
|
||||
|
||||
int64_t input_node_dims[2] = {};
|
||||
const int64_t sr_node_dims[1] = {1};
|
||||
const int64_t hc_node_dims[3] = {2, 1, 64};
|
||||
|
||||
// Outputs
|
||||
std::vector<Ort::Value> ort_outputs;
|
||||
std::vector<const char *> output_node_names = {"output", "hn", "cn"};
|
||||
|
||||
|
||||
public:
|
||||
// Construction
|
||||
VadIterator(const std::string ModelPath,
|
||||
int Sample_rate, int frame_size,
|
||||
float Threshold, int min_silence_duration_ms, int speech_pad_ms)
|
||||
{
|
||||
init_onnx_model(ModelPath);
|
||||
sample_rate = Sample_rate;
|
||||
sr_per_ms = sample_rate / 1000;
|
||||
threshold = Threshold;
|
||||
min_silence_samples = sr_per_ms * min_silence_duration_ms;
|
||||
speech_pad_samples = sr_per_ms * speech_pad_ms;
|
||||
window_size_samples = frame_size * sr_per_ms;
|
||||
|
||||
input.resize(window_size_samples);
|
||||
input_node_dims[0] = 1;
|
||||
input_node_dims[1] = window_size_samples;
|
||||
// std::cout << "== Input size" << input.size() << std::endl;
|
||||
_h.resize(size_hc);
|
||||
_c.resize(size_hc);
|
||||
sr.resize(1);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
|
||||
// Read wav
|
||||
wav::WavReader wav_reader("./test_for_vad.wav");
|
||||
std::vector<int16_t> data(wav_reader.num_samples());
|
||||
std::vector<float> input_wav(wav_reader.num_samples());
|
||||
|
||||
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||
{
|
||||
data[i] = static_cast<int16_t>(*(wav_reader.data() + i));
|
||||
}
|
||||
|
||||
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||
{
|
||||
input_wav[i] = static_cast<float>(data[i]) / 32768;
|
||||
}
|
||||
|
||||
// ===== Test configs =====
|
||||
std::string path = "../files/silero_vad.onnx";
|
||||
int test_sr = 8000;
|
||||
int test_frame_ms = 64;
|
||||
float test_threshold = 0.5f;
|
||||
int test_min_silence_duration_ms = 0;
|
||||
int test_speech_pad_ms = 0;
|
||||
int test_window_samples = test_frame_ms * (test_sr/1000);
|
||||
|
||||
VadIterator vad(
|
||||
path, test_sr, test_frame_ms, test_threshold,
|
||||
test_min_silence_duration_ms, test_speech_pad_ms);
|
||||
|
||||
for (int j = 0; j < wav_reader.num_samples(); j += test_window_samples)
|
||||
{
|
||||
// std::cout << "== 4" << std::endl;
|
||||
std::vector<float> r{&input_wav[0] + j, &input_wav[0] + j + test_window_samples};
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
// Predict and print throughout process time
|
||||
vad.predict(r);
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
auto elapsed_time = std::chrono::duration_cast<std::chrono::nanoseconds>(end-start);
|
||||
// std::cout << "== Elapsed time: " << 1.0*elapsed_time.count()/1000000 << "ms" << " ==" <<std::endl;
|
||||
|
||||
}
|
||||
}
|
||||
205
runtime/cpp/wav.h
Normal file
205
runtime/cpp/wav.h
Normal file
@@ -0,0 +1,205 @@
|
||||
// Copyright (c) 2016 Personal (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef FRONTEND_WAV_H_
|
||||
#define FRONTEND_WAV_H_
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
// #include "utils/log.h"
|
||||
|
||||
namespace wav {
|
||||
|
||||
struct WavHeader {
|
||||
char riff[4]; // "riff"
|
||||
unsigned int size;
|
||||
char wav[4]; // "WAVE"
|
||||
char fmt[4]; // "fmt "
|
||||
unsigned int fmt_size;
|
||||
uint16_t format;
|
||||
uint16_t channels;
|
||||
unsigned int sample_rate;
|
||||
unsigned int bytes_per_second;
|
||||
uint16_t block_size;
|
||||
uint16_t bit;
|
||||
char data[4]; // "data"
|
||||
unsigned int data_size;
|
||||
};
|
||||
|
||||
class WavReader {
|
||||
public:
|
||||
WavReader() : data_(nullptr) {}
|
||||
explicit WavReader(const std::string& filename) { Open(filename); }
|
||||
|
||||
bool Open(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
|
||||
if (NULL == fp) {
|
||||
std::cout << "Error in read " << filename;
|
||||
return false;
|
||||
}
|
||||
|
||||
WavHeader header;
|
||||
fread(&header, 1, sizeof(header), fp);
|
||||
if (header.fmt_size < 16) {
|
||||
fprintf(stderr,
|
||||
"WaveData: expect PCM format data "
|
||||
"to have fmt chunk of at least size 16.\n");
|
||||
return false;
|
||||
} else if (header.fmt_size > 16) {
|
||||
int offset = 44 - 8 + header.fmt_size - 16;
|
||||
fseek(fp, offset, SEEK_SET);
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
// check "riff" "WAVE" "fmt " "data"
|
||||
|
||||
// Skip any sub-chunks between "fmt" and "data". Usually there will
|
||||
// be a single "fact" sub chunk, but on Windows there can also be a
|
||||
// "list" sub chunk.
|
||||
while (0 != strncmp(header.data, "data", 4)) {
|
||||
// We will just ignore the data in these chunks.
|
||||
fseek(fp, header.data_size, SEEK_CUR);
|
||||
// read next sub chunk
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
|
||||
num_channel_ = header.channels;
|
||||
sample_rate_ = header.sample_rate;
|
||||
bits_per_sample_ = header.bit;
|
||||
int num_data = header.data_size / (bits_per_sample_ / 8);
|
||||
data_ = new float[num_data]; // Create 1-dim array
|
||||
num_samples_ = num_data / num_channel_;
|
||||
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample;
|
||||
fread(&sample, 1, sizeof(char), fp);
|
||||
data_[i] = static_cast<float>(sample);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample;
|
||||
fread(&sample, 1, sizeof(int16_t), fp);
|
||||
// std::cout << sample;
|
||||
data_[i] = static_cast<float>(sample);
|
||||
// std::cout << data_[i];
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
int sample;
|
||||
fread(&sample, 1, sizeof(int), fp);
|
||||
data_[i] = static_cast<float>(sample);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
fprintf(stderr, "unsupported quantization bits");
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
fclose(fp);
|
||||
return true;
|
||||
}
|
||||
|
||||
int num_channel() const { return num_channel_; }
|
||||
int sample_rate() const { return sample_rate_; }
|
||||
int bits_per_sample() const { return bits_per_sample_; }
|
||||
int num_samples() const { return num_samples_; }
|
||||
|
||||
~WavReader() {
|
||||
delete[] data_;
|
||||
}
|
||||
|
||||
const float* data() const { return data_; }
|
||||
|
||||
private:
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
int num_samples_; // sample points per channel
|
||||
float* data_;
|
||||
};
|
||||
|
||||
class WavWriter {
|
||||
public:
|
||||
WavWriter(const float* data, int num_samples, int num_channel,
|
||||
int sample_rate, int bits_per_sample)
|
||||
: data_(data),
|
||||
num_samples_(num_samples),
|
||||
num_channel_(num_channel),
|
||||
sample_rate_(sample_rate),
|
||||
bits_per_sample_(bits_per_sample) {}
|
||||
|
||||
void Write(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "w");
|
||||
// init char 'riff' 'WAVE' 'fmt ' 'data'
|
||||
WavHeader header;
|
||||
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
|
||||
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
|
||||
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
|
||||
memcpy(&header, wav_header, sizeof(header));
|
||||
header.channels = num_channel_;
|
||||
header.bit = bits_per_sample_;
|
||||
header.sample_rate = sample_rate_;
|
||||
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.size = sizeof(header) - 8 + header.data_size;
|
||||
header.bytes_per_second =
|
||||
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.block_size = num_channel_ * (bits_per_sample_ / 8);
|
||||
|
||||
fwrite(&header, 1, sizeof(header), fp);
|
||||
|
||||
for (int i = 0; i < num_samples_; ++i) {
|
||||
for (int j = 0; j < num_channel_; ++j) {
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample = static_cast<char>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
int sample = static_cast<int>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fclose(fp);
|
||||
}
|
||||
|
||||
private:
|
||||
const float* data_;
|
||||
int num_samples_; // total float points in data_
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
};
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // FRONTEND_WAV_H_
|
||||
Reference in New Issue
Block a user