diff --git a/README.md b/README.md
index 5050042..fdb6464 100644
--- a/README.md
+++ b/README.md
@@ -20,9 +20,9 @@ This repository also includes Number Detector and Language classifier [models](h
Real Time Example
-
+
https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4
-
+
diff --git a/cpp/silero_vad_onnx_1.cpp b/cpp/silero_vad_onnx_1.cpp
new file mode 100644
index 0000000..c5f5590
--- /dev/null
+++ b/cpp/silero_vad_onnx_1.cpp
@@ -0,0 +1,290 @@
+#include
+#include
+#include
+#include
+#include
+
+#include "onnxruntime_cxx_api.h"
+#include "wav.h"
+
+class VadModel
+{
+ // OnnxRuntime resources
+ Ort::Env env;
+ Ort::SessionOptions session_options;
+ std::shared_ptr session = nullptr;
+ Ort::AllocatorWithDefaultOptions allocator;
+ Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
+
+public:
+ void init_engine_threads(int inter_threads, int intra_threads)
+ {
+ // The method should be called in each thread/proc in multi-thread/proc work
+ session_options.SetIntraOpNumThreads(intra_threads);
+ session_options.SetInterOpNumThreads(inter_threads);
+ session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
+ }
+
+ void init_onnx_model(const std::string &model_path)
+ {
+ // Init threads = 1 for
+ init_engine_threads(1, 1);
+ // Load model
+ session = std::make_shared(env, model_path.c_str(), session_options);
+ }
+
+ void reset_states()
+ {
+ // Call reset before each audio start
+ std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
+ std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
+ triggerd = false;
+ temp_end = 0;
+ current_sample = 0;
+ }
+
+ // Call it in predict func. if you prefer raw bytes input.
+ void bytes_to_float_tensor(const char *pcm_bytes)
+ {
+ std::memcpy(input.data(), pcm_bytes, window_size_samples * sizeof(int16_t));
+ for (int i = 0; i < window_size_samples; i++)
+ {
+ input[i] = static_cast(input[i]) / 32768; // int16_t normalized to float
+ }
+ }
+
+
+ void predict(const std::vector &data) // const char *data
+ {
+ // bytes_to_float_tensor(data);
+
+ // Infer
+ // Inputs
+ input.assign(data.begin(), data.end());
+ Ort::Value input_ort = Ort::Value::CreateTensor(
+ memory_info, input.data(), input.size(), input_node_dims, 2);
+ // std::cout << "input size:" << input.size() << std::endl;
+ Ort::Value sr_ort = Ort::Value::CreateTensor(
+ memory_info, sr.data(), sr.size(), sr_node_dims, 1);
+ Ort::Value h_ort = Ort::Value::CreateTensor(
+ memory_info, _h.data(), _h.size(), hc_node_dims, 3);
+ Ort::Value c_ort = Ort::Value::CreateTensor(
+ memory_info, _c.data(), _c.size(), hc_node_dims, 3);
+
+ ort_inputs.clear(); // clear inputs
+ ort_inputs.emplace_back(std::move(input_ort));
+ ort_inputs.emplace_back(std::move(sr_ort));
+ ort_inputs.emplace_back(std::move(h_ort));
+ ort_inputs.emplace_back(std::move(c_ort));
+
+ // Infer
+ ort_outputs = session->Run(
+ Ort::RunOptions{nullptr},
+ input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
+ output_node_names.data(), output_node_names.size());
+
+ // out put Probability & update h,c recursively
+ float output = ort_outputs[0].GetTensorMutableData()[0];
+ float *hn = ort_outputs[1].GetTensorMutableData();
+ std::memcpy(_h.data(), hn, size_hc * sizeof(float));
+ float *cn = ort_outputs[2].GetTensorMutableData();
+ std::memcpy(_c.data(), cn, size_hc * sizeof(float));
+
+ // Push forward sample index
+ current_sample += window_size_samples;
+
+ // 1) Reset temp_end when > threshold
+ if ((output >= threshold) && (temp_end != 0))
+ {
+ temp_end = 0;
+ }
+ // 2) Trigger and start sentence
+ if ((output >= threshold) && (triggerd == false))
+ {
+ triggerd = true;
+ speech_start = current_sample - speech_pad_samples;
+ printf("{ start: %.3f s }\n", 1.0 * current_sample / sample_rate);
+ }
+ // 3) Speaking
+ if ((output >= (threshold - 0.15)) && (triggerd == true))
+ {
+ printf("{ speaking: %.3f s }\n", 1.0 * current_sample / sample_rate);
+ }
+ // 4) End
+ if ((output < (threshold - 0.15)) && (triggerd == true))
+ {
+
+ if (temp_end != 0)
+ {
+ temp_end = current_sample;
+ }
+ // a. silence < min_slience_samples, continue speaking
+ if ((current_sample - temp_end) < min_silence_samples)
+ {
+ printf("{ speaking: %.3f s }\n", 1.0 * current_sample / sample_rate);
+ }
+ // b. silence >= min_slience_samples, end speaking
+ else
+ {
+ speech_end = temp_end + speech_pad_samples;
+ temp_end = 0;
+ triggerd = false;
+ printf("{ end: %.3f s }\n", 1.0 * current_sample / sample_rate);
+ }
+ }
+ // 5) Silence
+ if ((output < threshold) && (triggerd == false))
+ {
+ printf("{ silence: %.3f s }\n", 1.0 * current_sample / sample_rate);
+ }
+
+ }
+
+ // Print input output shape of the model
+ void GetInputOutputInfo(
+ const std::shared_ptr &session,
+ std::vector *in_names, std::vector *out_names)
+ {
+ Ort::AllocatorWithDefaultOptions allocator;
+ // Input info
+ int num_nodes = session->GetInputCount();
+ in_names->resize(num_nodes);
+ for (int i = 0; i < num_nodes; ++i)
+ {
+ char *name = session->GetInputName(i, allocator);
+ Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
+ auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
+ ONNXTensorElementDataType type = tensor_info.GetElementType();
+ std::vector node_dims = tensor_info.GetShape();
+ std::stringstream shape;
+ for (auto j : node_dims)
+ {
+ shape << j;
+ shape << " ";
+ }
+ std::cout << "\tInput " << i << " : name=" << name << " type=" << type
+ << " dims=" << shape.str() << std::endl;
+ (*in_names)[i] = name;
+ }
+ // Output info
+ num_nodes = session->GetOutputCount();
+ out_names->resize(num_nodes);
+ for (int i = 0; i < num_nodes; ++i)
+ {
+ char *name = session->GetOutputName(i, allocator);
+ Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
+ auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
+ ONNXTensorElementDataType type = tensor_info.GetElementType();
+ std::vector node_dims = tensor_info.GetShape();
+ std::stringstream shape;
+ for (auto j : node_dims)
+ {
+ shape << j;
+ shape << " ";
+ }
+ std::cout << "\tOutput " << i << " : name=" << name << " type=" << type
+ << " dims=" << shape.str() << std::endl;
+ ;
+ (*out_names)[i] = name;
+ }
+ }
+
+private:
+ // model config
+ int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
+ int sample_rate;
+ int sr_per_ms; // Assign when init, support 8 or 16
+ float threshold = 0.5;
+ int min_silence_samples; // sr_per_ms * #ms
+ int speech_pad_samples = 0; // Can be used in offline infer to get as much speech as possible
+
+ // model states
+ bool triggerd = false;
+ unsigned int speech_start = 0;
+ unsigned int speech_end = 0;
+ unsigned int temp_end = 0;
+ unsigned int current_sample = 0;
+ // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
+ float output;
+
+ // Onnx model
+ // Inputs
+ std::vector ort_inputs;
+
+ std::vector input_node_names = {"input", "sr", "h", "c"};
+ std::vector input;
+ std::vector sr;
+ unsigned int size_hc = 2 * 1 * 64; // It's FIXED.
+ std::vector _h;
+ std::vector _c;
+
+ int64_t input_node_dims[2] = {};
+ const int64_t sr_node_dims[1] = {1};
+ const int64_t hc_node_dims[3] = {2, 1, 64};
+
+ // Outputs
+ std::vector ort_outputs;
+ std::vector output_node_names = {"output", "hn", "cn"};
+
+
+public:
+ // Construct init
+ VadModel(const std::string ModelPath,
+ int sample_rate, int frame_size,
+ float threshold, int min_silence_duration_ms, int speech_pad_ms)
+ {
+ init_onnx_model(ModelPath);
+ sr_per_ms = sample_rate / 1000;
+ min_silence_samples = sr_per_ms * min_silence_duration_ms;
+ speech_pad_samples = sr_per_ms * speech_pad_ms;
+ window_size_samples = frame_size * sr_per_ms; // Input 64ms/frame * 8ms = 512 samples/frame
+ input.resize(window_size_samples);
+ input_node_dims[0] = 1;
+ input_node_dims[1] = window_size_samples;
+ // std::cout << "== Input size" << input.size() << std::endl;
+ _h.resize(size_hc);
+ _c.resize(size_hc);
+ sr.resize(1);
+ }
+
+};
+
+int main()
+{
+
+ // Read wav
+ wav::WavReader wav_reader("silero-vad-master/test_audios/test0_for_vad.wav");
+
+ std::vector data(wav_reader.num_samples());
+ std::vector input_wav(wav_reader.num_samples());
+
+ for (int i = 0; i < wav_reader.num_samples(); i++)
+ {
+ data[i] = static_cast(*(wav_reader.data() + i));
+ }
+
+ for (int i = 0; i < wav_reader.num_samples(); i++)
+ {
+ input_wav[i] = static_cast(data[i]) / 32768;
+ }
+
+ std::string path = "silero-vad-master/files/silero_vad.onnx";
+ int test_sr = 8000;
+ int test_frame_ms = 64;
+ int test_window_samples = test_frame_ms * (test_sr/1000);
+ VadModel vad(path, test_sr, test_frame_ms);
+ // std::cout << "== 3" << std::endl;
+ // std::cout << vad.window_size_samples1() << std::endl;
+
+ for (int j = 0; j < wav_reader.num_samples(); j += test_window_samples)
+ {
+ std::vector r{&input_wav[0] + j, &input_wav[0] + j + test_window_samples};
+ auto start = std::chrono::high_resolution_clock::now();
+ // Predict and print throughout process time
+ vad.predict(r);
+ auto end = std::chrono::high_resolution_clock::now();
+ auto elapsed_time = std::chrono::duration_cast(end-start);
+ std::cout << "== Elapsed time: " << elapsed_time.count() << "ns" << " ==" <
+#include
+#include
+#include
+#include
+
+#include
+
+// #include "utils/log.h"
+
+namespace wav {
+
+struct WavHeader {
+ char riff[4]; // "riff"
+ unsigned int size;
+ char wav[4]; // "WAVE"
+ char fmt[4]; // "fmt "
+ unsigned int fmt_size;
+ uint16_t format;
+ uint16_t channels;
+ unsigned int sample_rate;
+ unsigned int bytes_per_second;
+ uint16_t block_size;
+ uint16_t bit;
+ char data[4]; // "data"
+ unsigned int data_size;
+};
+
+class WavReader {
+ public:
+ WavReader() : data_(nullptr) {}
+ explicit WavReader(const std::string& filename) { Open(filename); }
+
+ bool Open(const std::string& filename) {
+ FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
+ if (NULL == fp) {
+ std::cout << "Error in read " << filename;
+ return false;
+ }
+
+ WavHeader header;
+ fread(&header, 1, sizeof(header), fp);
+ if (header.fmt_size < 16) {
+ fprintf(stderr,
+ "WaveData: expect PCM format data "
+ "to have fmt chunk of at least size 16.\n");
+ return false;
+ } else if (header.fmt_size > 16) {
+ int offset = 44 - 8 + header.fmt_size - 16;
+ fseek(fp, offset, SEEK_SET);
+ fread(header.data, 8, sizeof(char), fp);
+ }
+ // check "riff" "WAVE" "fmt " "data"
+
+ // Skip any sub-chunks between "fmt" and "data". Usually there will
+ // be a single "fact" sub chunk, but on Windows there can also be a
+ // "list" sub chunk.
+ while (0 != strncmp(header.data, "data", 4)) {
+ // We will just ignore the data in these chunks.
+ fseek(fp, header.data_size, SEEK_CUR);
+ // read next sub chunk
+ fread(header.data, 8, sizeof(char), fp);
+ }
+
+ num_channel_ = header.channels;
+ sample_rate_ = header.sample_rate;
+ bits_per_sample_ = header.bit;
+ int num_data = header.data_size / (bits_per_sample_ / 8);
+ data_ = new float[num_data]; // Create 1-dim array
+ num_samples_ = num_data / num_channel_;
+
+ for (int i = 0; i < num_data; ++i) {
+ switch (bits_per_sample_) {
+ case 8: {
+ char sample;
+ fread(&sample, 1, sizeof(char), fp);
+ data_[i] = static_cast(sample);
+ break;
+ }
+ case 16: {
+ int16_t sample;
+ fread(&sample, 1, sizeof(int16_t), fp);
+ // std::cout << sample;
+ data_[i] = static_cast(sample);
+ // std::cout << data_[i];
+ break;
+ }
+ case 32: {
+ int sample;
+ fread(&sample, 1, sizeof(int), fp);
+ data_[i] = static_cast(sample);
+ break;
+ }
+ default:
+ fprintf(stderr, "unsupported quantization bits");
+ exit(1);
+ }
+ }
+ fclose(fp);
+ return true;
+ }
+
+ int num_channel() const { return num_channel_; }
+ int sample_rate() const { return sample_rate_; }
+ int bits_per_sample() const { return bits_per_sample_; }
+ int num_samples() const { return num_samples_; }
+
+ ~WavReader() {
+ delete[] data_;
+ }
+
+ const float* data() const { return data_; }
+
+ private:
+ int num_channel_;
+ int sample_rate_;
+ int bits_per_sample_;
+ int num_samples_; // sample points per channel
+ float* data_;
+};
+
+class WavWriter {
+ public:
+ WavWriter(const float* data, int num_samples, int num_channel,
+ int sample_rate, int bits_per_sample)
+ : data_(data),
+ num_samples_(num_samples),
+ num_channel_(num_channel),
+ sample_rate_(sample_rate),
+ bits_per_sample_(bits_per_sample) {}
+
+ void Write(const std::string& filename) {
+ FILE* fp = fopen(filename.c_str(), "w");
+ // init char 'riff' 'WAVE' 'fmt ' 'data'
+ WavHeader header;
+ char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
+ 0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
+ 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
+ memcpy(&header, wav_header, sizeof(header));
+ header.channels = num_channel_;
+ header.bit = bits_per_sample_;
+ header.sample_rate = sample_rate_;
+ header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
+ header.size = sizeof(header) - 8 + header.data_size;
+ header.bytes_per_second =
+ sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
+ header.block_size = num_channel_ * (bits_per_sample_ / 8);
+
+ fwrite(&header, 1, sizeof(header), fp);
+
+ for (int i = 0; i < num_samples_; ++i) {
+ for (int j = 0; j < num_channel_; ++j) {
+ switch (bits_per_sample_) {
+ case 8: {
+ char sample = static_cast(data_[i * num_channel_ + j]);
+ fwrite(&sample, 1, sizeof(sample), fp);
+ break;
+ }
+ case 16: {
+ int16_t sample = static_cast(data_[i * num_channel_ + j]);
+ fwrite(&sample, 1, sizeof(sample), fp);
+ break;
+ }
+ case 32: {
+ int sample = static_cast(data_[i * num_channel_ + j]);
+ fwrite(&sample, 1, sizeof(sample), fp);
+ break;
+ }
+ }
+ }
+ }
+ fclose(fp);
+ }
+
+ private:
+ const float* data_;
+ int num_samples_; // total float points in data_
+ int num_channel_;
+ int sample_rate_;
+ int bits_per_sample_;
+};
+
+} // namespace wenet
+
+#endif // FRONTEND_WAV_H_
diff --git a/runtime/cpp/README.md b/runtime/cpp/README.md
new file mode 100644
index 0000000..9cce823
--- /dev/null
+++ b/runtime/cpp/README.md
@@ -0,0 +1,50 @@
+# Stream example in C++
+
+Here's a simple example of the vad model in c++ onnxruntime.
+
+
+
+## Requirements
+
+Code are tested in the environments bellow, feel free to try others.
+
+- WSL2 + Debian-bullseye (docker)
+- gcc 12.2.0
+- onnxruntime-linux-x64-1.12.1
+
+
+
+## Usage
+
+1. Install gcc 12.2.0, or just pull the docker image with `docker pull gcc:12.2.0-bullseye`
+
+2. Install onnxruntime-linux-x64-1.12.1
+
+ - Dowload lib onnxruntime:
+
+ `wget https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz`
+
+ - Unzip. Assume the path is `/root/onnxruntime-linux-x64-1.12.1`
+
+3. Modify wav path & Test configs in main function
+
+ `wav::WavReader wav_reader("${path_to_your_wav_file}");`
+
+ test sample rate, frame per ms, threshold...
+
+4. Build with gcc and run
+
+ ```bash
+ # Build
+ g++ silero-vad-onnx.cpp -I /root/onnxruntime-linux-x64-1.12.1/include/ -L /root/onnxruntime-linux-x64-1.12.1/lib/ -lonnxruntime -Wl,-rpath,/root/onnxruntime-linux-x64-1.12.1/lib/ -o test
+
+ # Run
+ ./test
+ ```
+
+ build:
+
+ `g++ silero-vad-onnx.cpp -I /root/onnxruntime-linux-x64-1.12.1/include/ -L /root/onnxruntime-linux-x64-1.12.1/lib/ -lonnxruntime -Wl,-rpath,/root/onnxruntime-linux-x64-1.12.1/lib/ -o test`
+
+ `./test`
+
diff --git a/runtime/cpp/silero-vad-onnx.cpp b/runtime/cpp/silero-vad-onnx.cpp
new file mode 100644
index 0000000..59846e1
--- /dev/null
+++ b/runtime/cpp/silero-vad-onnx.cpp
@@ -0,0 +1,253 @@
+#include
+#include
+#include
+#include
+#include
+
+#include "onnxruntime_cxx_api.h"
+#include "wav.h"
+
+class VadIterator
+{
+ // OnnxRuntime resources
+ Ort::Env env;
+ Ort::SessionOptions session_options;
+ std::shared_ptr session = nullptr;
+ Ort::AllocatorWithDefaultOptions allocator;
+ Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
+
+public:
+ void init_engine_threads(int inter_threads, int intra_threads)
+ {
+ // The method should be called in each thread/proc in multi-thread/proc work
+ session_options.SetIntraOpNumThreads(intra_threads);
+ session_options.SetInterOpNumThreads(inter_threads);
+ session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
+ }
+
+ void init_onnx_model(const std::string &model_path)
+ {
+ // Init threads = 1 for
+ init_engine_threads(1, 1);
+ // Load model
+ session = std::make_shared(env, model_path.c_str(), session_options);
+ }
+
+ void reset_states()
+ {
+ // Call reset before each audio start
+ std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
+ std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
+ triggerd = false;
+ temp_end = 0;
+ current_sample = 0;
+ }
+
+ // Call it in predict func. if you prefer raw bytes input.
+ void bytes_to_float_tensor(const char *pcm_bytes)
+ {
+ std::memcpy(input.data(), pcm_bytes, window_size_samples * sizeof(int16_t));
+ for (int i = 0; i < window_size_samples; i++)
+ {
+ input[i] = static_cast(input[i]) / 32768; // int16_t normalized to float
+ }
+ }
+
+
+ void predict(const std::vector &data)
+ {
+ // bytes_to_float_tensor(data);
+
+ // Infer
+ // Create ort tensors
+ input.assign(data.begin(), data.end());
+ Ort::Value input_ort = Ort::Value::CreateTensor(
+ memory_info, input.data(), input.size(), input_node_dims, 2);
+ Ort::Value sr_ort = Ort::Value::CreateTensor(
+ memory_info, sr.data(), sr.size(), sr_node_dims, 1);
+ Ort::Value h_ort = Ort::Value::CreateTensor(
+ memory_info, _h.data(), _h.size(), hc_node_dims, 3);
+ Ort::Value c_ort = Ort::Value::CreateTensor(
+ memory_info, _c.data(), _c.size(), hc_node_dims, 3);
+
+ // Clear and add inputs
+ ort_inputs.clear();
+ ort_inputs.emplace_back(std::move(input_ort));
+ ort_inputs.emplace_back(std::move(sr_ort));
+ ort_inputs.emplace_back(std::move(h_ort));
+ ort_inputs.emplace_back(std::move(c_ort));
+
+ // Infer
+ ort_outputs = session->Run(
+ Ort::RunOptions{nullptr},
+ input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
+ output_node_names.data(), output_node_names.size());
+
+ // Output probability & update h,c recursively
+ float output = ort_outputs[0].GetTensorMutableData()[0];
+ float *hn = ort_outputs[1].GetTensorMutableData();
+ std::memcpy(_h.data(), hn, size_hc * sizeof(float));
+ float *cn = ort_outputs[2].GetTensorMutableData();
+ std::memcpy(_c.data(), cn, size_hc * sizeof(float));
+
+ // Push forward sample index
+ current_sample += window_size_samples;
+
+ // Reset temp_end when > threshold
+ if ((output >= threshold) && (temp_end != 0))
+ {
+ temp_end = 0;
+ }
+ // 1) Silence
+ if ((output < threshold) && (triggerd == false))
+ {
+ // printf("{ silence: %.3f s }\n", 1.0 * current_sample / sample_rate);
+ }
+ // 2) Speaking
+ if ((output >= (threshold - 0.15)) && (triggerd == true))
+ {
+ // printf("{ speaking_2: %.3f s }\n", 1.0 * current_sample / sample_rate);
+ }
+
+ // 3) Start
+ if ((output >= threshold) && (triggerd == false))
+ {
+ triggerd = true;
+ speech_start = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point.
+ printf("{ start: %.3f s }\n", 1.0 * speech_start / sample_rate);
+ }
+
+ // 4) End
+ if ((output < (threshold - 0.15)) && (triggerd == true))
+ {
+
+ if (temp_end != 0)
+ {
+ temp_end = current_sample;
+ }
+ // a. silence < min_slience_samples, continue speaking
+ if ((current_sample - temp_end) < min_silence_samples)
+ {
+ // printf("{ speaking_4: %.3f s }\n", 1.0 * current_sample / sample_rate);
+ // printf("");
+ }
+ // b. silence >= min_slience_samples, end speaking
+ else
+ {
+ speech_end = current_sample + speech_pad_samples;
+ temp_end = 0;
+ triggerd = false;
+ printf("{ end: %.3f s }\n", 1.0 * speech_end / sample_rate);
+ }
+ }
+
+
+ }
+
+private:
+ // model config
+ int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
+ int sample_rate;
+ int sr_per_ms; // Assign when init, support 8 or 16
+ float threshold;
+ int min_silence_samples; // sr_per_ms * #ms
+ int speech_pad_samples; // usually a
+
+ // model states
+ bool triggerd = false;
+ unsigned int speech_start = 0;
+ unsigned int speech_end = 0;
+ unsigned int temp_end = 0;
+ unsigned int current_sample = 0;
+ // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
+ float output;
+
+ // Onnx model
+ // Inputs
+ std::vector ort_inputs;
+
+ std::vector input_node_names = {"input", "sr", "h", "c"};
+ std::vector input;
+ std::vector sr;
+ unsigned int size_hc = 2 * 1 * 64; // It's FIXED.
+ std::vector _h;
+ std::vector _c;
+
+ int64_t input_node_dims[2] = {};
+ const int64_t sr_node_dims[1] = {1};
+ const int64_t hc_node_dims[3] = {2, 1, 64};
+
+ // Outputs
+ std::vector ort_outputs;
+ std::vector output_node_names = {"output", "hn", "cn"};
+
+
+public:
+ // Construction
+ VadIterator(const std::string ModelPath,
+ int Sample_rate, int frame_size,
+ float Threshold, int min_silence_duration_ms, int speech_pad_ms)
+ {
+ init_onnx_model(ModelPath);
+ sample_rate = Sample_rate;
+ sr_per_ms = sample_rate / 1000;
+ threshold = Threshold;
+ min_silence_samples = sr_per_ms * min_silence_duration_ms;
+ speech_pad_samples = sr_per_ms * speech_pad_ms;
+ window_size_samples = frame_size * sr_per_ms;
+
+ input.resize(window_size_samples);
+ input_node_dims[0] = 1;
+ input_node_dims[1] = window_size_samples;
+ // std::cout << "== Input size" << input.size() << std::endl;
+ _h.resize(size_hc);
+ _c.resize(size_hc);
+ sr.resize(1);
+ }
+
+};
+
+int main()
+{
+
+ // Read wav
+ wav::WavReader wav_reader("./test_for_vad.wav");
+ std::vector data(wav_reader.num_samples());
+ std::vector input_wav(wav_reader.num_samples());
+
+ for (int i = 0; i < wav_reader.num_samples(); i++)
+ {
+ data[i] = static_cast(*(wav_reader.data() + i));
+ }
+
+ for (int i = 0; i < wav_reader.num_samples(); i++)
+ {
+ input_wav[i] = static_cast(data[i]) / 32768;
+ }
+
+ // ===== Test configs =====
+ std::string path = "../files/silero_vad.onnx";
+ int test_sr = 8000;
+ int test_frame_ms = 64;
+ float test_threshold = 0.5f;
+ int test_min_silence_duration_ms = 0;
+ int test_speech_pad_ms = 0;
+ int test_window_samples = test_frame_ms * (test_sr/1000);
+
+ VadIterator vad(
+ path, test_sr, test_frame_ms, test_threshold,
+ test_min_silence_duration_ms, test_speech_pad_ms);
+
+ for (int j = 0; j < wav_reader.num_samples(); j += test_window_samples)
+ {
+ // std::cout << "== 4" << std::endl;
+ std::vector r{&input_wav[0] + j, &input_wav[0] + j + test_window_samples};
+ auto start = std::chrono::high_resolution_clock::now();
+ // Predict and print throughout process time
+ vad.predict(r);
+ auto end = std::chrono::high_resolution_clock::now();
+ auto elapsed_time = std::chrono::duration_cast(end-start);
+ // std::cout << "== Elapsed time: " << 1.0*elapsed_time.count()/1000000 << "ms" << " ==" <
+#include
+#include
+#include
+#include
+
+#include
+
+// #include "utils/log.h"
+
+namespace wav {
+
+struct WavHeader {
+ char riff[4]; // "riff"
+ unsigned int size;
+ char wav[4]; // "WAVE"
+ char fmt[4]; // "fmt "
+ unsigned int fmt_size;
+ uint16_t format;
+ uint16_t channels;
+ unsigned int sample_rate;
+ unsigned int bytes_per_second;
+ uint16_t block_size;
+ uint16_t bit;
+ char data[4]; // "data"
+ unsigned int data_size;
+};
+
+class WavReader {
+ public:
+ WavReader() : data_(nullptr) {}
+ explicit WavReader(const std::string& filename) { Open(filename); }
+
+ bool Open(const std::string& filename) {
+ FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
+ if (NULL == fp) {
+ std::cout << "Error in read " << filename;
+ return false;
+ }
+
+ WavHeader header;
+ fread(&header, 1, sizeof(header), fp);
+ if (header.fmt_size < 16) {
+ fprintf(stderr,
+ "WaveData: expect PCM format data "
+ "to have fmt chunk of at least size 16.\n");
+ return false;
+ } else if (header.fmt_size > 16) {
+ int offset = 44 - 8 + header.fmt_size - 16;
+ fseek(fp, offset, SEEK_SET);
+ fread(header.data, 8, sizeof(char), fp);
+ }
+ // check "riff" "WAVE" "fmt " "data"
+
+ // Skip any sub-chunks between "fmt" and "data". Usually there will
+ // be a single "fact" sub chunk, but on Windows there can also be a
+ // "list" sub chunk.
+ while (0 != strncmp(header.data, "data", 4)) {
+ // We will just ignore the data in these chunks.
+ fseek(fp, header.data_size, SEEK_CUR);
+ // read next sub chunk
+ fread(header.data, 8, sizeof(char), fp);
+ }
+
+ num_channel_ = header.channels;
+ sample_rate_ = header.sample_rate;
+ bits_per_sample_ = header.bit;
+ int num_data = header.data_size / (bits_per_sample_ / 8);
+ data_ = new float[num_data]; // Create 1-dim array
+ num_samples_ = num_data / num_channel_;
+
+ for (int i = 0; i < num_data; ++i) {
+ switch (bits_per_sample_) {
+ case 8: {
+ char sample;
+ fread(&sample, 1, sizeof(char), fp);
+ data_[i] = static_cast(sample);
+ break;
+ }
+ case 16: {
+ int16_t sample;
+ fread(&sample, 1, sizeof(int16_t), fp);
+ // std::cout << sample;
+ data_[i] = static_cast(sample);
+ // std::cout << data_[i];
+ break;
+ }
+ case 32: {
+ int sample;
+ fread(&sample, 1, sizeof(int), fp);
+ data_[i] = static_cast(sample);
+ break;
+ }
+ default:
+ fprintf(stderr, "unsupported quantization bits");
+ exit(1);
+ }
+ }
+ fclose(fp);
+ return true;
+ }
+
+ int num_channel() const { return num_channel_; }
+ int sample_rate() const { return sample_rate_; }
+ int bits_per_sample() const { return bits_per_sample_; }
+ int num_samples() const { return num_samples_; }
+
+ ~WavReader() {
+ delete[] data_;
+ }
+
+ const float* data() const { return data_; }
+
+ private:
+ int num_channel_;
+ int sample_rate_;
+ int bits_per_sample_;
+ int num_samples_; // sample points per channel
+ float* data_;
+};
+
+class WavWriter {
+ public:
+ WavWriter(const float* data, int num_samples, int num_channel,
+ int sample_rate, int bits_per_sample)
+ : data_(data),
+ num_samples_(num_samples),
+ num_channel_(num_channel),
+ sample_rate_(sample_rate),
+ bits_per_sample_(bits_per_sample) {}
+
+ void Write(const std::string& filename) {
+ FILE* fp = fopen(filename.c_str(), "w");
+ // init char 'riff' 'WAVE' 'fmt ' 'data'
+ WavHeader header;
+ char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
+ 0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
+ 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
+ memcpy(&header, wav_header, sizeof(header));
+ header.channels = num_channel_;
+ header.bit = bits_per_sample_;
+ header.sample_rate = sample_rate_;
+ header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
+ header.size = sizeof(header) - 8 + header.data_size;
+ header.bytes_per_second =
+ sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
+ header.block_size = num_channel_ * (bits_per_sample_ / 8);
+
+ fwrite(&header, 1, sizeof(header), fp);
+
+ for (int i = 0; i < num_samples_; ++i) {
+ for (int j = 0; j < num_channel_; ++j) {
+ switch (bits_per_sample_) {
+ case 8: {
+ char sample = static_cast(data_[i * num_channel_ + j]);
+ fwrite(&sample, 1, sizeof(sample), fp);
+ break;
+ }
+ case 16: {
+ int16_t sample = static_cast(data_[i * num_channel_ + j]);
+ fwrite(&sample, 1, sizeof(sample), fp);
+ break;
+ }
+ case 32: {
+ int sample = static_cast(data_[i * num_channel_ + j]);
+ fwrite(&sample, 1, sizeof(sample), fp);
+ break;
+ }
+ }
+ }
+ }
+ fclose(fp);
+ }
+
+ private:
+ const float* data_;
+ int num_samples_; // total float points in data_
+ int num_channel_;
+ int sample_rate_;
+ int bits_per_sample_;
+};
+
+} // namespace wenet
+
+#endif // FRONTEND_WAV_H_