From 0b3d43d43263b6bd85b35d954cd8e7d24ab4b69e Mon Sep 17 00:00:00 2001 From: Stefan Miletic Date: Mon, 1 Jul 2024 15:04:48 +0100 Subject: [PATCH] cpp example v5 model --- examples/cpp/silero-vad-onnx.cpp | 464 ++++++++----------------------- examples/cpp/wav.h | 75 ++--- 2 files changed, 134 insertions(+), 405 deletions(-) diff --git a/examples/cpp/silero-vad-onnx.cpp b/examples/cpp/silero-vad-onnx.cpp index eb92296..2d53469 100644 --- a/examples/cpp/silero-vad-onnx.cpp +++ b/examples/cpp/silero-vad-onnx.cpp @@ -2,97 +2,13 @@ #include #include #include -#include #include -#include -#include -#include -#include -#include + #include "onnxruntime_cxx_api.h" #include "wav.h" -#include -#include -#if __cplusplus < 201703L -#include -#endif - -//#define __DEBUG_SPEECH_PROB___ - -class timestamp_t -{ -public: - int start; - int end; - - // default + parameterized constructor - timestamp_t(int start = -1, int end = -1) - : start(start), end(end) - { - }; - - // assignment operator modifies object, therefore non-const - timestamp_t& operator=(const timestamp_t& a) - { - start = a.start; - end = a.end; - return *this; - }; - - // equality comparison. doesn't modify object. therefore const. - bool operator==(const timestamp_t& a) const - { - return (start == a.start && end == a.end); - }; - std::string c_str() - { - //return std::format("timestamp {:08d}, {:08d}", start, end); - return format("{start:%08d,end:%08d}", start, end); - }; -private: - - std::string format(const char* fmt, ...) - { - char buf[256]; - - va_list args; - va_start(args, fmt); - const auto r = std::vsnprintf(buf, sizeof buf, fmt, args); - va_end(args); - - if (r < 0) - // conversion failed - return {}; - - const size_t len = r; - if (len < sizeof buf) - // we fit in the buffer - return { buf, len }; - -#if __cplusplus >= 201703L - // C++17: Create a string and write to its underlying array - std::string s(len, '\0'); - va_start(args, fmt); - std::vsnprintf(s.data(), len + 1, fmt, args); - va_end(args); - - return s; -#else - // C++11 or C++14: We need to allocate scratch memory - auto vbuf = std::unique_ptr(new char[len + 1]); - va_start(args, fmt); - std::vsnprintf(vbuf.get(), len + 1, fmt, args); - va_end(args); - - return { vbuf.get(), len }; -#endif - }; -}; - class VadIterator { -private: // OnnxRuntime resources Ort::Env env; Ort::SessionOptions session_options; @@ -100,58 +16,62 @@ private: Ort::AllocatorWithDefaultOptions allocator; Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); -private: +public: void init_engine_threads(int inter_threads, int intra_threads) - { + { // The method should be called in each thread/proc in multi-thread/proc work session_options.SetIntraOpNumThreads(intra_threads); session_options.SetInterOpNumThreads(inter_threads); session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); - }; + } - void init_onnx_model(const std::wstring& model_path) - { + void init_onnx_model(const std::string &model_path) + { // Init threads = 1 for init_engine_threads(1, 1); // Load model session = std::make_shared(env, model_path.c_str(), session_options); - }; + } void reset_states() { // Call reset before each audio start - std::memset(_h.data(), 0.0f, _h.size() * sizeof(float)); - std::memset(_c.data(), 0.0f, _c.size() * sizeof(float)); - triggered = false; + std::memset(_state.data(), 0.0f, _state.size() * sizeof(float)); + triggerd = false; temp_end = 0; current_sample = 0; + } - prev_end = next_start = 0; + // Call it in predict func. if you prefer raw bytes input. + void bytes_to_float_tensor(const char *pcm_bytes) + { + std::memcpy(input.data(), pcm_bytes, window_size_samples * sizeof(int16_t)); + for (int i = 0; i < window_size_samples; i++) + { + input[i] = static_cast(input[i]) / 32768; // int16_t normalized to float + } + } - speeches.clear(); - current_speech = timestamp_t(); - }; void predict(const std::vector &data) { + // bytes_to_float_tensor(data); + // Infer // Create ort tensors input.assign(data.begin(), data.end()); Ort::Value input_ort = Ort::Value::CreateTensor( memory_info, input.data(), input.size(), input_node_dims, 2); + Ort::Value state_ort = Ort::Value::CreateTensor( + memory_info, _state.data(), _state.size(), state_node_dims, 3); Ort::Value sr_ort = Ort::Value::CreateTensor( memory_info, sr.data(), sr.size(), sr_node_dims, 1); - Ort::Value h_ort = Ort::Value::CreateTensor( - memory_info, _h.data(), _h.size(), hc_node_dims, 3); - Ort::Value c_ort = Ort::Value::CreateTensor( - memory_info, _c.data(), _c.size(), hc_node_dims, 3); // Clear and add inputs ort_inputs.clear(); ort_inputs.emplace_back(std::move(input_ort)); + ort_inputs.emplace_back(std::move(state_ort)); ort_inputs.emplace_back(std::move(sr_ort)); - ort_inputs.emplace_back(std::move(h_ort)); - ort_inputs.emplace_back(std::move(c_ort)); // Infer ort_outputs = session->Run( @@ -160,327 +80,165 @@ private: output_node_names.data(), output_node_names.size()); // Output probability & update h,c recursively - float speech_prob = ort_outputs[0].GetTensorMutableData()[0]; - float *hn = ort_outputs[1].GetTensorMutableData(); - std::memcpy(_h.data(), hn, size_hc * sizeof(float)); - float *cn = ort_outputs[2].GetTensorMutableData(); - std::memcpy(_c.data(), cn, size_hc * sizeof(float)); + float output = ort_outputs[0].GetTensorMutableData()[0]; + float *stateN = ort_outputs[1].GetTensorMutableData(); + std::memcpy(_state.data(), stateN, size_state * sizeof(float)); // Push forward sample index current_sample += window_size_samples; - + // Reset temp_end when > threshold - if ((speech_prob >= threshold)) + if ((output >= threshold) && (temp_end != 0)) { -#ifdef __DEBUG_SPEECH_PROB___ - float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample- window_size_samples); -#endif //__DEBUG_SPEECH_PROB___ - if (temp_end != 0) - { - temp_end = 0; - if (next_start < prev_end) - next_start = current_sample - window_size_samples; - } - if (triggered == false) - { - triggered = true; - - current_speech.start = current_sample - window_size_samples; - } - return; + temp_end = 0; } - - if ( - (triggered == true) - && ((current_sample - current_speech.start) > max_speech_samples) - ) { - if (prev_end > 0) { - current_speech.end = prev_end; - speeches.push_back(current_speech); - current_speech = timestamp_t(); - - // previously reached silence(< neg_thres) and is still not speech(< thres) - if (next_start < prev_end) - triggered = false; - else{ - current_speech.start = next_start; - } - prev_end = 0; - next_start = 0; - temp_end = 0; - - } - else{ - current_speech.end = current_sample; - speeches.push_back(current_speech); - current_speech = timestamp_t(); - prev_end = 0; - next_start = 0; - temp_end = 0; - triggered = false; - } - return; - - } - if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) + // 1) Silence + if ((output < threshold) && (triggerd == false)) { - if (triggered) { -#ifdef __DEBUG_SPEECH_PROB___ - float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); -#endif //__DEBUG_SPEECH_PROB___ - } - else { -#ifdef __DEBUG_SPEECH_PROB___ - float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); -#endif //__DEBUG_SPEECH_PROB___ - } - return; + // printf("{ silence: %.3f s }\n", 1.0 * current_sample / sample_rate); + } + // 2) Speaking + if ((output >= (threshold - 0.15)) && (triggerd == true)) + { + // printf("{ speaking_2: %.3f s }\n", 1.0 * current_sample / sample_rate); } + // 3) Start + if ((output >= threshold) && (triggerd == false)) + { + triggerd = true; + speech_start = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point. + printf("{ start: %.3f s }\n", 1.0 * speech_start / sample_rate); + } // 4) End - if ((speech_prob < (threshold - 0.15))) + if ((output < (threshold - 0.15)) && (triggerd == true)) { -#ifdef __DEBUG_SPEECH_PROB___ - float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point. - printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); -#endif //__DEBUG_SPEECH_PROB___ - if (triggered == true) + + if (temp_end == 0) { - if (temp_end == 0) - { - temp_end = current_sample; - } - if (current_sample - temp_end > min_silence_samples_at_max_speech) - prev_end = temp_end; - // a. silence < min_slience_samples, continue speaking - if ((current_sample - temp_end) < min_silence_samples) - { - - } - // b. silence >= min_slience_samples, end speaking - else - { - current_speech.end = temp_end; - if (current_speech.end - current_speech.start > min_speech_samples) - { - speeches.push_back(current_speech); - current_speech = timestamp_t(); - prev_end = 0; - next_start = 0; - temp_end = 0; - triggered = false; - } - } + temp_end = current_sample; } - else { - // may first windows see end state. + // a. silence < min_slience_samples, continue speaking + if ((current_sample - temp_end) < min_silence_samples) + { + // printf("{ speaking_4: %.3f s }\n", 1.0 * current_sample / sample_rate); + // printf(""); + } + // b. silence >= min_slience_samples, end speaking + else + { + speech_end = temp_end ? temp_end + speech_pad_samples : current_sample + speech_pad_samples; + temp_end = 0; + triggerd = false; + printf("{ end: %.3f s }\n", 1.0 * speech_end / sample_rate); } - return; - } - }; -public: - void process(const std::vector& input_wav) - { - reset_states(); - - audio_length_samples = input_wav.size(); - - for (int j = 0; j < audio_length_samples; j += window_size_samples) - { - if (j + window_size_samples > audio_length_samples) - break; - std::vector r{ &input_wav[0] + j, &input_wav[0] + j + window_size_samples }; - predict(r); } - if (current_speech.start >= 0) { - current_speech.end = audio_length_samples; - speeches.push_back(current_speech); - current_speech = timestamp_t(); - prev_end = 0; - next_start = 0; - temp_end = 0; - triggered = false; - } - }; - void process(const std::vector& input_wav, std::vector& output_wav) - { - process(input_wav); - collect_chunks(input_wav, output_wav); } - void collect_chunks(const std::vector& input_wav, std::vector& output_wav) - { - output_wav.clear(); - for (int i = 0; i < speeches.size(); i++) { -#ifdef __DEBUG_SPEECH_PROB___ - std::cout << speeches[i].c_str() << std::endl; -#endif //#ifdef __DEBUG_SPEECH_PROB___ - std::vector slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]); - output_wav.insert(output_wav.end(),slice.begin(),slice.end()); - } - }; - - const std::vector get_speech_timestamps() const - { - return speeches; - } - - void drop_chunks(const std::vector& input_wav, std::vector& output_wav) - { - output_wav.clear(); - int current_start = 0; - for (int i = 0; i < speeches.size(); i++) { - - std::vector slice(&input_wav[current_start],&input_wav[speeches[i].start]); - output_wav.insert(output_wav.end(), slice.begin(), slice.end()); - current_start = speeches[i].end; - } - - std::vector slice(&input_wav[current_start], &input_wav[input_wav.size()]); - output_wav.insert(output_wav.end(), slice.begin(), slice.end()); - }; - private: // model config int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k. - int sample_rate; //Assign when init support 16000 or 8000 - int sr_per_ms; // Assign when init, support 8 or 16 - float threshold; + int sample_rate; + int sr_per_ms; // Assign when init, support 8 or 16 + float threshold; int min_silence_samples; // sr_per_ms * #ms - int min_silence_samples_at_max_speech; // sr_per_ms * #98 - int min_speech_samples; // sr_per_ms * #ms - float max_speech_samples; int speech_pad_samples; // usually a - int audio_length_samples; // model states - bool triggered = false; + bool triggerd = false; + unsigned int speech_start = 0; + unsigned int speech_end = 0; unsigned int temp_end = 0; unsigned int current_sample = 0; // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes - int prev_end; - int next_start = 0; - - //Output timestamp - std::vector speeches; - timestamp_t current_speech; - + float output; // Onnx model // Inputs std::vector ort_inputs; - std::vector input_node_names = {"input", "sr", "h", "c"}; + std::vector input_node_names = {"input", "state", "sr"}; std::vector input; + unsigned int size_state = 2 * 1 * 128; // It's FIXED. + std::vector _state; std::vector sr; - unsigned int size_hc = 2 * 1 * 64; // It's FIXED. - std::vector _h; - std::vector _c; int64_t input_node_dims[2] = {}; + const int64_t state_node_dims[3] = {2, 1, 128}; const int64_t sr_node_dims[1] = {1}; - const int64_t hc_node_dims[3] = {2, 1, 64}; // Outputs std::vector ort_outputs; - std::vector output_node_names = {"output", "hn", "cn"}; + std::vector output_node_names = {"output", "stateN"}; + public: // Construction - VadIterator(const std::wstring ModelPath, - int Sample_rate = 16000, int windows_frame_size = 64, - float Threshold = 0.5, int min_silence_duration_ms = 0, - int speech_pad_ms = 64, int min_speech_duration_ms = 64, - float max_speech_duration_s = std::numeric_limits::infinity()) + VadIterator(const std::string ModelPath, + int Sample_rate, int frame_size, + float Threshold, int min_silence_duration_ms, int speech_pad_ms) { init_onnx_model(ModelPath); - threshold = Threshold; sample_rate = Sample_rate; sr_per_ms = sample_rate / 1000; - - window_size_samples = windows_frame_size * sr_per_ms; - - min_speech_samples = sr_per_ms * min_speech_duration_ms; - speech_pad_samples = sr_per_ms * speech_pad_ms; - - max_speech_samples = ( - sample_rate * max_speech_duration_s - - window_size_samples - - 2 * speech_pad_samples - ); - + threshold = Threshold; min_silence_samples = sr_per_ms * min_silence_duration_ms; - min_silence_samples_at_max_speech = sr_per_ms * 98; - + speech_pad_samples = sr_per_ms * speech_pad_ms; + window_size_samples = frame_size * sr_per_ms; + input.resize(window_size_samples); input_node_dims[0] = 1; input_node_dims[1] = window_size_samples; - - _h.resize(size_hc); - _c.resize(size_hc); + _state.resize(size_state); sr.resize(1); sr[0] = sample_rate; - }; + } + }; int main() { - std::vector stamps; // Read wav - wav::WavReader wav_reader("recorder.wav"); //16000,1,32float + wav::WavReader wav_reader("./recorder.wav"); + std::vector data(wav_reader.num_samples()); std::vector input_wav(wav_reader.num_samples()); - std::vector output_wav; for (int i = 0; i < wav_reader.num_samples(); i++) { - input_wav[i] = static_cast(*(wav_reader.data() + i)); + data[i] = static_cast(*(wav_reader.data() + i)); } - + for (int i = 0; i < wav_reader.num_samples(); i++) + { + input_wav[i] = static_cast(data[i]) / 32768; + } // ===== Test configs ===== - std::wstring path = L"silero_vad.onnx"; - VadIterator vad(path); + std::string path = "../../files/silero_vad.onnx"; + int test_sr = 16000; + int test_frame_ms = 32; + float test_threshold = 0.5f; + int test_min_silence_duration_ms = 0; + int test_speech_pad_ms = 0; + int test_window_samples = test_frame_ms * (test_sr/1000); - // ============================================== - // ==== = Example 1 of full function ===== - // ============================================== - vad.process(input_wav); + VadIterator vad( + path, test_sr, test_frame_ms, test_threshold, + test_min_silence_duration_ms, test_speech_pad_ms); - // 1.a get_speech_timestamps - stamps = vad.get_speech_timestamps(); - for (int i = 0; i < stamps.size(); i++) { + for (int j = 0; j < wav_reader.num_samples(); j += test_window_samples) + { + std::vector r{&input_wav[0] + j, &input_wav[0] + j + test_window_samples}; + auto start = std::chrono::high_resolution_clock::now(); + // Predict and print throughout process time + vad.predict(r); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_time = std::chrono::duration_cast(end-start); + std::cout << "== Elapsed time: " << 1.0*elapsed_time.count()/1000000 << "ms" << " ==" < #include #include #include @@ -59,7 +60,8 @@ class WavReader { WavHeader header; fread(&header, 1, sizeof(header), fp); if (header.fmt_size < 16) { - printf("WaveData: expect PCM format data " + fprintf(stderr, + "WaveData: expect PCM format data " "to have fmt chunk of at least size 16.\n"); return false; } else if (header.fmt_size > 16) { @@ -79,13 +81,6 @@ class WavReader { fread(header.data, 8, sizeof(char), fp); } - if (header.data_size == 0) { - int offset = ftell(fp); - fseek(fp, 0, SEEK_END); - header.data_size = ftell(fp) - offset; - fseek(fp, offset, SEEK_SET); - } - num_channel_ = header.channels; sample_rate_ = header.sample_rate; bits_per_sample_ = header.bit; @@ -93,57 +88,33 @@ class WavReader { data_ = new float[num_data]; // Create 1-dim array num_samples_ = num_data / num_channel_; - std::cout << "num_channel_ :" << num_channel_ << std::endl; - std::cout << "sample_rate_ :" << sample_rate_ << std::endl; - std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl; - std::cout << "num_samples :" << num_data << std::endl; - std::cout << "num_data_size :" << header.data_size << std::endl; - - switch (bits_per_sample_) { + for (int i = 0; i < num_data; ++i) { + switch (bits_per_sample_) { case 8: { - char sample; - for (int i = 0; i < num_data; ++i) { - fread(&sample, 1, sizeof(char), fp); - data_[i] = static_cast(sample) / 32768; - } - break; + char sample; + fread(&sample, 1, sizeof(char), fp); + data_[i] = static_cast(sample); + break; } case 16: { - int16_t sample; - for (int i = 0; i < num_data; ++i) { - fread(&sample, 1, sizeof(int16_t), fp); - data_[i] = static_cast(sample) / 32768; - } - break; + int16_t sample; + fread(&sample, 1, sizeof(int16_t), fp); + // std::cout << sample; + data_[i] = static_cast(sample); + // std::cout << data_[i]; + break; } - case 32: - { - if (header.format == 1) //S32 - { - int sample; - for (int i = 0; i < num_data; ++i) { - fread(&sample, 1, sizeof(int), fp); - data_[i] = static_cast(sample) / 32768; - } - } - else if (header.format == 3) // IEEE-float - { - float sample; - for (int i = 0; i < num_data; ++i) { - fread(&sample, 1, sizeof(float), fp); - data_[i] = static_cast(sample); - } - } - else { - printf("unsupported quantization bits\n"); - } - break; + case 32: { + int sample; + fread(&sample, 1, sizeof(int), fp); + data_[i] = static_cast(sample); + break; } default: - printf("unsupported quantization bits\n"); - break; + fprintf(stderr, "unsupported quantization bits"); + exit(1); + } } - fclose(fp); return true; }