From 7198087152d6487d7cfb80394b7753c60b24a3d5 Mon Sep 17 00:00:00 2001 From: yuGAN6 Date: Sun, 11 Dec 2022 13:06:21 +0800 Subject: [PATCH] Move into examples --- examples/cpp/README.md | 43 ++++++ examples/cpp/silero-vad-onnx.cpp | 253 +++++++++++++++++++++++++++++++ examples/cpp/wav.h | 205 +++++++++++++++++++++++++ 3 files changed, 501 insertions(+) create mode 100644 examples/cpp/README.md create mode 100644 examples/cpp/silero-vad-onnx.cpp create mode 100644 examples/cpp/wav.h diff --git a/examples/cpp/README.md b/examples/cpp/README.md new file mode 100644 index 0000000..93a6791 --- /dev/null +++ b/examples/cpp/README.md @@ -0,0 +1,43 @@ +# Stream example in C++ + +Here's a simple example of the vad model in c++ onnxruntime. + + + +## Requirements + +Code are tested in the environments bellow, feel free to try others. + +- WSL2 + Debian-bullseye (docker) +- gcc 12.2.0 +- onnxruntime-linux-x64-1.12.1 + + + +## Usage + +1. Install gcc 12.2.0, or just pull the docker image with `docker pull gcc:12.2.0-bullseye` + +2. Install onnxruntime-linux-x64-1.12.1 + + - Download lib onnxruntime: + + `wget https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz` + + - Unzip. Assume the path is `/root/onnxruntime-linux-x64-1.12.1` + +3. Modify wav path & Test configs in main function + + `wav::WavReader wav_reader("${path_to_your_wav_file}");` + + test sample rate, frame per ms, threshold... + +4. Build with gcc and run + + ```bash + # Build + g++ silero-vad-onnx.cpp -I /root/onnxruntime-linux-x64-1.12.1/include/ -L /root/onnxruntime-linux-x64-1.12.1/lib/ -lonnxruntime -Wl,-rpath,/root/onnxruntime-linux-x64-1.12.1/lib/ -o test + + # Run + ./test + ``` \ No newline at end of file diff --git a/examples/cpp/silero-vad-onnx.cpp b/examples/cpp/silero-vad-onnx.cpp new file mode 100644 index 0000000..59846e1 --- /dev/null +++ b/examples/cpp/silero-vad-onnx.cpp @@ -0,0 +1,253 @@ +#include +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" +#include "wav.h" + +class VadIterator +{ + // OnnxRuntime resources + Ort::Env env; + Ort::SessionOptions session_options; + std::shared_ptr session = nullptr; + Ort::AllocatorWithDefaultOptions allocator; + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); + +public: + void init_engine_threads(int inter_threads, int intra_threads) + { + // The method should be called in each thread/proc in multi-thread/proc work + session_options.SetIntraOpNumThreads(intra_threads); + session_options.SetInterOpNumThreads(inter_threads); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + } + + void init_onnx_model(const std::string &model_path) + { + // Init threads = 1 for + init_engine_threads(1, 1); + // Load model + session = std::make_shared(env, model_path.c_str(), session_options); + } + + void reset_states() + { + // Call reset before each audio start + std::memset(_h.data(), 0.0f, _h.size() * sizeof(float)); + std::memset(_c.data(), 0.0f, _c.size() * sizeof(float)); + triggerd = false; + temp_end = 0; + current_sample = 0; + } + + // Call it in predict func. if you prefer raw bytes input. + void bytes_to_float_tensor(const char *pcm_bytes) + { + std::memcpy(input.data(), pcm_bytes, window_size_samples * sizeof(int16_t)); + for (int i = 0; i < window_size_samples; i++) + { + input[i] = static_cast(input[i]) / 32768; // int16_t normalized to float + } + } + + + void predict(const std::vector &data) + { + // bytes_to_float_tensor(data); + + // Infer + // Create ort tensors + input.assign(data.begin(), data.end()); + Ort::Value input_ort = Ort::Value::CreateTensor( + memory_info, input.data(), input.size(), input_node_dims, 2); + Ort::Value sr_ort = Ort::Value::CreateTensor( + memory_info, sr.data(), sr.size(), sr_node_dims, 1); + Ort::Value h_ort = Ort::Value::CreateTensor( + memory_info, _h.data(), _h.size(), hc_node_dims, 3); + Ort::Value c_ort = Ort::Value::CreateTensor( + memory_info, _c.data(), _c.size(), hc_node_dims, 3); + + // Clear and add inputs + ort_inputs.clear(); + ort_inputs.emplace_back(std::move(input_ort)); + ort_inputs.emplace_back(std::move(sr_ort)); + ort_inputs.emplace_back(std::move(h_ort)); + ort_inputs.emplace_back(std::move(c_ort)); + + // Infer + ort_outputs = session->Run( + Ort::RunOptions{nullptr}, + input_node_names.data(), ort_inputs.data(), ort_inputs.size(), + output_node_names.data(), output_node_names.size()); + + // Output probability & update h,c recursively + float output = ort_outputs[0].GetTensorMutableData()[0]; + float *hn = ort_outputs[1].GetTensorMutableData(); + std::memcpy(_h.data(), hn, size_hc * sizeof(float)); + float *cn = ort_outputs[2].GetTensorMutableData(); + std::memcpy(_c.data(), cn, size_hc * sizeof(float)); + + // Push forward sample index + current_sample += window_size_samples; + + // Reset temp_end when > threshold + if ((output >= threshold) && (temp_end != 0)) + { + temp_end = 0; + } + // 1) Silence + if ((output < threshold) && (triggerd == false)) + { + // printf("{ silence: %.3f s }\n", 1.0 * current_sample / sample_rate); + } + // 2) Speaking + if ((output >= (threshold - 0.15)) && (triggerd == true)) + { + // printf("{ speaking_2: %.3f s }\n", 1.0 * current_sample / sample_rate); + } + + // 3) Start + if ((output >= threshold) && (triggerd == false)) + { + triggerd = true; + speech_start = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point. + printf("{ start: %.3f s }\n", 1.0 * speech_start / sample_rate); + } + + // 4) End + if ((output < (threshold - 0.15)) && (triggerd == true)) + { + + if (temp_end != 0) + { + temp_end = current_sample; + } + // a. silence < min_slience_samples, continue speaking + if ((current_sample - temp_end) < min_silence_samples) + { + // printf("{ speaking_4: %.3f s }\n", 1.0 * current_sample / sample_rate); + // printf(""); + } + // b. silence >= min_slience_samples, end speaking + else + { + speech_end = current_sample + speech_pad_samples; + temp_end = 0; + triggerd = false; + printf("{ end: %.3f s }\n", 1.0 * speech_end / sample_rate); + } + } + + + } + +private: + // model config + int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k. + int sample_rate; + int sr_per_ms; // Assign when init, support 8 or 16 + float threshold; + int min_silence_samples; // sr_per_ms * #ms + int speech_pad_samples; // usually a + + // model states + bool triggerd = false; + unsigned int speech_start = 0; + unsigned int speech_end = 0; + unsigned int temp_end = 0; + unsigned int current_sample = 0; + // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes + float output; + + // Onnx model + // Inputs + std::vector ort_inputs; + + std::vector input_node_names = {"input", "sr", "h", "c"}; + std::vector input; + std::vector sr; + unsigned int size_hc = 2 * 1 * 64; // It's FIXED. + std::vector _h; + std::vector _c; + + int64_t input_node_dims[2] = {}; + const int64_t sr_node_dims[1] = {1}; + const int64_t hc_node_dims[3] = {2, 1, 64}; + + // Outputs + std::vector ort_outputs; + std::vector output_node_names = {"output", "hn", "cn"}; + + +public: + // Construction + VadIterator(const std::string ModelPath, + int Sample_rate, int frame_size, + float Threshold, int min_silence_duration_ms, int speech_pad_ms) + { + init_onnx_model(ModelPath); + sample_rate = Sample_rate; + sr_per_ms = sample_rate / 1000; + threshold = Threshold; + min_silence_samples = sr_per_ms * min_silence_duration_ms; + speech_pad_samples = sr_per_ms * speech_pad_ms; + window_size_samples = frame_size * sr_per_ms; + + input.resize(window_size_samples); + input_node_dims[0] = 1; + input_node_dims[1] = window_size_samples; + // std::cout << "== Input size" << input.size() << std::endl; + _h.resize(size_hc); + _c.resize(size_hc); + sr.resize(1); + } + +}; + +int main() +{ + + // Read wav + wav::WavReader wav_reader("./test_for_vad.wav"); + std::vector data(wav_reader.num_samples()); + std::vector input_wav(wav_reader.num_samples()); + + for (int i = 0; i < wav_reader.num_samples(); i++) + { + data[i] = static_cast(*(wav_reader.data() + i)); + } + + for (int i = 0; i < wav_reader.num_samples(); i++) + { + input_wav[i] = static_cast(data[i]) / 32768; + } + + // ===== Test configs ===== + std::string path = "../files/silero_vad.onnx"; + int test_sr = 8000; + int test_frame_ms = 64; + float test_threshold = 0.5f; + int test_min_silence_duration_ms = 0; + int test_speech_pad_ms = 0; + int test_window_samples = test_frame_ms * (test_sr/1000); + + VadIterator vad( + path, test_sr, test_frame_ms, test_threshold, + test_min_silence_duration_ms, test_speech_pad_ms); + + for (int j = 0; j < wav_reader.num_samples(); j += test_window_samples) + { + // std::cout << "== 4" << std::endl; + std::vector r{&input_wav[0] + j, &input_wav[0] + j + test_window_samples}; + auto start = std::chrono::high_resolution_clock::now(); + // Predict and print throughout process time + vad.predict(r); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_time = std::chrono::duration_cast(end-start); + // std::cout << "== Elapsed time: " << 1.0*elapsed_time.count()/1000000 << "ms" << " ==" < +#include +#include +#include +#include + +#include + +// #include "utils/log.h" + +namespace wav { + +struct WavHeader { + char riff[4]; // "riff" + unsigned int size; + char wav[4]; // "WAVE" + char fmt[4]; // "fmt " + unsigned int fmt_size; + uint16_t format; + uint16_t channels; + unsigned int sample_rate; + unsigned int bytes_per_second; + uint16_t block_size; + uint16_t bit; + char data[4]; // "data" + unsigned int data_size; +}; + +class WavReader { + public: + WavReader() : data_(nullptr) {} + explicit WavReader(const std::string& filename) { Open(filename); } + + bool Open(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "rb"); //文件读取 + if (NULL == fp) { + std::cout << "Error in read " << filename; + return false; + } + + WavHeader header; + fread(&header, 1, sizeof(header), fp); + if (header.fmt_size < 16) { + fprintf(stderr, + "WaveData: expect PCM format data " + "to have fmt chunk of at least size 16.\n"); + return false; + } else if (header.fmt_size > 16) { + int offset = 44 - 8 + header.fmt_size - 16; + fseek(fp, offset, SEEK_SET); + fread(header.data, 8, sizeof(char), fp); + } + // check "riff" "WAVE" "fmt " "data" + + // Skip any sub-chunks between "fmt" and "data". Usually there will + // be a single "fact" sub chunk, but on Windows there can also be a + // "list" sub chunk. + while (0 != strncmp(header.data, "data", 4)) { + // We will just ignore the data in these chunks. + fseek(fp, header.data_size, SEEK_CUR); + // read next sub chunk + fread(header.data, 8, sizeof(char), fp); + } + + num_channel_ = header.channels; + sample_rate_ = header.sample_rate; + bits_per_sample_ = header.bit; + int num_data = header.data_size / (bits_per_sample_ / 8); + data_ = new float[num_data]; // Create 1-dim array + num_samples_ = num_data / num_channel_; + + for (int i = 0; i < num_data; ++i) { + switch (bits_per_sample_) { + case 8: { + char sample; + fread(&sample, 1, sizeof(char), fp); + data_[i] = static_cast(sample); + break; + } + case 16: { + int16_t sample; + fread(&sample, 1, sizeof(int16_t), fp); + // std::cout << sample; + data_[i] = static_cast(sample); + // std::cout << data_[i]; + break; + } + case 32: { + int sample; + fread(&sample, 1, sizeof(int), fp); + data_[i] = static_cast(sample); + break; + } + default: + fprintf(stderr, "unsupported quantization bits"); + exit(1); + } + } + fclose(fp); + return true; + } + + int num_channel() const { return num_channel_; } + int sample_rate() const { return sample_rate_; } + int bits_per_sample() const { return bits_per_sample_; } + int num_samples() const { return num_samples_; } + + ~WavReader() { + delete[] data_; + } + + const float* data() const { return data_; } + + private: + int num_channel_; + int sample_rate_; + int bits_per_sample_; + int num_samples_; // sample points per channel + float* data_; +}; + +class WavWriter { + public: + WavWriter(const float* data, int num_samples, int num_channel, + int sample_rate, int bits_per_sample) + : data_(data), + num_samples_(num_samples), + num_channel_(num_channel), + sample_rate_(sample_rate), + bits_per_sample_(bits_per_sample) {} + + void Write(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "w"); + // init char 'riff' 'WAVE' 'fmt ' 'data' + WavHeader header; + char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, + 0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00, + 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00}; + memcpy(&header, wav_header, sizeof(header)); + header.channels = num_channel_; + header.bit = bits_per_sample_; + header.sample_rate = sample_rate_; + header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8); + header.size = sizeof(header) - 8 + header.data_size; + header.bytes_per_second = + sample_rate_ * num_channel_ * (bits_per_sample_ / 8); + header.block_size = num_channel_ * (bits_per_sample_ / 8); + + fwrite(&header, 1, sizeof(header), fp); + + for (int i = 0; i < num_samples_; ++i) { + for (int j = 0; j < num_channel_; ++j) { + switch (bits_per_sample_) { + case 8: { + char sample = static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 16: { + int16_t sample = static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 32: { + int sample = static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + } + } + } + fclose(fp); + } + + private: + const float* data_; + int num_samples_; // total float points in data_ + int num_channel_; + int sample_rate_; + int bits_per_sample_; +}; + +} // namespace wenet + +#endif // FRONTEND_WAV_H_