From 7440bc46894815ee8b490ccc72f69f95c96398bd Mon Sep 17 00:00:00 2001 From: Ojuro Yokoyama Date: Mon, 17 Feb 2025 16:02:24 +0900 Subject: [PATCH] Update silero-vad-onnx.cpp I fixed bug of silero-vad-onnx.cpp --- examples/cpp/silero-vad-onnx.cpp | 491 ++++++++++++------------------- 1 file changed, 190 insertions(+), 301 deletions(-) diff --git a/examples/cpp/silero-vad-onnx.cpp b/examples/cpp/silero-vad-onnx.cpp index dd2bf4e..380d76d 100644 --- a/examples/cpp/silero-vad-onnx.cpp +++ b/examples/cpp/silero-vad-onnx.cpp @@ -1,211 +1,227 @@ +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif + #include #include #include #include #include #include +#include #include #include #include -#include -#include -#include "onnxruntime_cxx_api.h" -#include "wav.h" #include #include +#include // for std::rint #if __cplusplus < 201703L #include #endif //#define __DEBUG_SPEECH_PROB___ -class timestamp_t -{ +#include "onnxruntime_cxx_api.h" +#include "wav.h" // For reading WAV files + +// timestamp_t class: stores the start and end (in samples) of a speech segment. +class timestamp_t { public: int start; int end; - // default + parameterized constructor timestamp_t(int start = -1, int end = -1) - : start(start), end(end) - { - }; + : start(start), end(end) { } - // assignment operator modifies object, therefore non-const - timestamp_t& operator=(const timestamp_t& a) - { + timestamp_t& operator=(const timestamp_t& a) { start = a.start; end = a.end; return *this; - }; + } - // equality comparison. doesn't modify object. therefore const. - bool operator==(const timestamp_t& a) const - { + bool operator==(const timestamp_t& a) const { return (start == a.start && end == a.end); - }; - std::string c_str() - { - //return std::format("timestamp {:08d}, {:08d}", start, end); - return format("{start:%08d,end:%08d}", start, end); - }; + } + + // Returns a formatted string of the timestamp. + std::string c_str() const { + return format("{start:%08d, end:%08d}", start, end); + } private: - - std::string format(const char* fmt, ...) - { + // Helper function for formatting. + std::string format(const char* fmt, ...) const { char buf[256]; - va_list args; va_start(args, fmt); - const auto r = std::vsnprintf(buf, sizeof buf, fmt, args); + const auto r = std::vsnprintf(buf, sizeof(buf), fmt, args); va_end(args); - if (r < 0) - // conversion failed return {}; - const size_t len = r; - if (len < sizeof buf) - // we fit in the buffer - return { buf, len }; - + if (len < sizeof(buf)) + return std::string(buf, len); #if __cplusplus >= 201703L - // C++17: Create a string and write to its underlying array std::string s(len, '\0'); va_start(args, fmt); std::vsnprintf(s.data(), len + 1, fmt, args); va_end(args); - return s; #else - // C++11 or C++14: We need to allocate scratch memory auto vbuf = std::unique_ptr(new char[len + 1]); va_start(args, fmt); std::vsnprintf(vbuf.get(), len + 1, fmt, args); va_end(args); - - return { vbuf.get(), len }; + return std::string(vbuf.get(), len); #endif - }; + } }; - -class VadIterator -{ +// VadIterator class: uses ONNX Runtime to detect speech segments. +class VadIterator { private: - // OnnxRuntime resources + // ONNX Runtime resources Ort::Env env; Ort::SessionOptions session_options; std::shared_ptr session = nullptr; Ort::AllocatorWithDefaultOptions allocator; Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); -private: - void init_engine_threads(int inter_threads, int intra_threads) - { - // The method should be called in each thread/proc in multi-thread/proc work + // ----- Context-related additions ----- + const int context_samples = 64; // For 16kHz, 64 samples are added as context. + std::vector _context; // Holds the last 64 samples from the previous chunk (initialized to zero). + + // Original window size (e.g., 32ms corresponds to 512 samples) + int window_size_samples; + // Effective window size = window_size_samples + context_samples + int effective_window_size; + + // Additional declaration: samples per millisecond + int sr_per_ms; + + // ONNX Runtime input/output buffers + std::vector ort_inputs; + std::vector input_node_names = { "input", "state", "sr" }; + std::vector input; + unsigned int size_state = 2 * 1 * 128; + std::vector _state; + std::vector sr; + int64_t input_node_dims[2] = {}; + const int64_t state_node_dims[3] = { 2, 1, 128 }; + const int64_t sr_node_dims[1] = { 1 }; + std::vector ort_outputs; + std::vector output_node_names = { "output", "stateN" }; + + // Model configuration parameters + int sample_rate; + float threshold; + int min_silence_samples; + int min_silence_samples_at_max_speech; + int min_speech_samples; + float max_speech_samples; + int speech_pad_samples; + int audio_length_samples; + + // State management + bool triggered = false; + unsigned int temp_end = 0; + unsigned int current_sample = 0; + int prev_end; + int next_start = 0; + std::vector speeches; + timestamp_t current_speech; + + // Loads the ONNX model. + void init_onnx_model(const std::wstring& model_path) { + init_engine_threads(1, 1); + session = std::make_shared(env, model_path.c_str(), session_options); + } + + // Initializes threading settings. + void init_engine_threads(int inter_threads, int intra_threads) { session_options.SetIntraOpNumThreads(intra_threads); session_options.SetInterOpNumThreads(inter_threads); session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); - }; + } - void init_onnx_model(const std::wstring& model_path) - { - // Init threads = 1 for - init_engine_threads(1, 1); - // Load model - session = std::make_shared(env, model_path.c_str(), session_options); - }; - - void reset_states() - { - // Call reset before each audio start - std::memset(_state.data(), 0.0f, _state.size() * sizeof(float)); + // Resets internal state (_state, _context, etc.) + void reset_states() { + std::memset(_state.data(), 0, _state.size() * sizeof(float)); triggered = false; temp_end = 0; current_sample = 0; - prev_end = next_start = 0; - speeches.clear(); current_speech = timestamp_t(); - }; + std::fill(_context.begin(), _context.end(), 0.0f); + } - void predict(const std::vector &data) - { - // Infer - // Create ort tensors - input.assign(data.begin(), data.end()); + // Inference: runs inference on one chunk of input data. + // data_chunk is expected to have window_size_samples samples. + void predict(const std::vector& data_chunk) { + // Build new input: first context_samples from _context, followed by the current chunk (window_size_samples). + std::vector new_data(effective_window_size, 0.0f); + std::copy(_context.begin(), _context.end(), new_data.begin()); + std::copy(data_chunk.begin(), data_chunk.end(), new_data.begin() + context_samples); + input = new_data; + + // Create input tensor (input_node_dims[1] is already set to effective_window_size). Ort::Value input_ort = Ort::Value::CreateTensor( memory_info, input.data(), input.size(), input_node_dims, 2); Ort::Value state_ort = Ort::Value::CreateTensor( memory_info, _state.data(), _state.size(), state_node_dims, 3); Ort::Value sr_ort = Ort::Value::CreateTensor( memory_info, sr.data(), sr.size(), sr_node_dims, 1); - - // Clear and add inputs ort_inputs.clear(); ort_inputs.emplace_back(std::move(input_ort)); ort_inputs.emplace_back(std::move(state_ort)); ort_inputs.emplace_back(std::move(sr_ort)); - // Infer + // Run inference. ort_outputs = session->Run( - Ort::RunOptions{nullptr}, + Ort::RunOptions{ nullptr }, input_node_names.data(), ort_inputs.data(), ort_inputs.size(), output_node_names.data(), output_node_names.size()); - // Output probability & update h,c recursively float speech_prob = ort_outputs[0].GetTensorMutableData()[0]; - float *stateN = ort_outputs[1].GetTensorMutableData(); + float* stateN = ort_outputs[1].GetTensorMutableData(); std::memcpy(_state.data(), stateN, size_state * sizeof(float)); + current_sample += static_cast(window_size_samples); // Advance by the original window size. - // Push forward sample index - current_sample += window_size_samples; - - // Reset temp_end when > threshold - if ((speech_prob >= threshold)) - { + // If speech is detected (probability >= threshold) + if (speech_prob >= threshold) { #ifdef __DEBUG_SPEECH_PROB___ - float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample- window_size_samples); -#endif //__DEBUG_SPEECH_PROB___ - if (temp_end != 0) - { + float speech = current_sample - window_size_samples; + printf("{ start: %.3f s (%.3f) %08d}\n", 1.0f * speech / sample_rate, speech_prob, current_sample - window_size_samples); +#endif + if (temp_end != 0) { temp_end = 0; if (next_start < prev_end) next_start = current_sample - window_size_samples; } - if (triggered == false) - { + if (!triggered) { triggered = true; - current_speech.start = current_sample - window_size_samples; } + // Update context: copy the last context_samples from new_data. + std::copy(new_data.end() - context_samples, new_data.end(), _context.begin()); return; } - if ( - (triggered == true) - && ((current_sample - current_speech.start) > max_speech_samples) - ) { + // If the speech segment becomes too long. + if (triggered && ((current_sample - current_speech.start) > max_speech_samples)) { if (prev_end > 0) { current_speech.end = prev_end; speeches.push_back(current_speech); current_speech = timestamp_t(); - - // previously reached silence(< neg_thres) and is still not speech(< thres) if (next_start < prev_end) triggered = false; - else{ + else current_speech.start = next_start; - } prev_end = 0; next_start = 0; temp_end = 0; - } - else{ + else { current_speech.end = current_sample; speeches.push_back(current_speech); current_speech = timestamp_t(); @@ -214,53 +230,29 @@ private: temp_end = 0; triggered = false; } + std::copy(new_data.end() - context_samples, new_data.end(), _context.begin()); return; - } - if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) - { + + if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) { + // When the speech probability temporarily drops but is still in speech, update context without changing state. + std::copy(new_data.end() - context_samples, new_data.end(), _context.begin()); + return; + } + + if (speech_prob < (threshold - 0.15)) { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = current_sample - window_size_samples - speech_pad_samples; + printf("{ end: %.3f s (%.3f) %08d}\n", 1.0f * speech / sample_rate, speech_prob, current_sample - window_size_samples); +#endif if (triggered) { -#ifdef __DEBUG_SPEECH_PROB___ - float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); -#endif //__DEBUG_SPEECH_PROB___ - } - else { -#ifdef __DEBUG_SPEECH_PROB___ - float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); -#endif //__DEBUG_SPEECH_PROB___ - } - return; - } - - - // 4) End - if ((speech_prob < (threshold - 0.15))) - { -#ifdef __DEBUG_SPEECH_PROB___ - float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point. - printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); -#endif //__DEBUG_SPEECH_PROB___ - if (triggered == true) - { if (temp_end == 0) - { temp_end = current_sample; - } if (current_sample - temp_end > min_silence_samples_at_max_speech) prev_end = temp_end; - // a. silence < min_slience_samples, continue speaking - if ((current_sample - temp_end) < min_silence_samples) - { - - } - // b. silence >= min_slience_samples, end speaking - else - { + if ((current_sample - temp_end) >= min_silence_samples) { current_speech.end = temp_end; - if (current_speech.end - current_speech.start > min_speech_samples) - { + if (current_speech.end - current_speech.start > min_speech_samples) { speeches.push_back(current_speech); current_speech = timestamp_t(); prev_end = 0; @@ -270,27 +262,23 @@ private: } } } - else { - // may first windows see end state. - } + std::copy(new_data.end() - context_samples, new_data.end(), _context.begin()); return; } - }; + } + public: - void process(const std::vector& input_wav) - { + // Process the entire audio input. + void process(const std::vector& input_wav) { reset_states(); - - audio_length_samples = input_wav.size(); - - for (int j = 0; j < audio_length_samples; j += window_size_samples) - { - if (j + window_size_samples > audio_length_samples) + audio_length_samples = static_cast(input_wav.size()); + // Process audio in chunks of window_size_samples (e.g., 512 samples) + for (size_t j = 0; j < static_cast(audio_length_samples); j += static_cast(window_size_samples)) { + if (j + static_cast(window_size_samples) > static_cast(audio_length_samples)) break; - std::vector r{ &input_wav[0] + j, &input_wav[0] + j + window_size_samples }; - predict(r); + std::vector chunk(&input_wav[j], &input_wav[j] + window_size_samples); + predict(chunk); } - if (current_speech.start >= 0) { current_speech.end = audio_length_samples; speeches.push_back(current_speech); @@ -300,179 +288,80 @@ public: temp_end = 0; triggered = false; } - }; - - void process(const std::vector& input_wav, std::vector& output_wav) - { - process(input_wav); - collect_chunks(input_wav, output_wav); } - void collect_chunks(const std::vector& input_wav, std::vector& output_wav) - { - output_wav.clear(); - for (int i = 0; i < speeches.size(); i++) { -#ifdef __DEBUG_SPEECH_PROB___ - std::cout << speeches[i].c_str() << std::endl; -#endif //#ifdef __DEBUG_SPEECH_PROB___ - std::vector slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]); - output_wav.insert(output_wav.end(),slice.begin(),slice.end()); - } - }; - - const std::vector get_speech_timestamps() const - { + // Returns the detected speech timestamps. + const std::vector get_speech_timestamps() const { return speeches; } - void drop_chunks(const std::vector& input_wav, std::vector& output_wav) - { - output_wav.clear(); - int current_start = 0; - for (int i = 0; i < speeches.size(); i++) { - - std::vector slice(&input_wav[current_start],&input_wav[speeches[i].start]); - output_wav.insert(output_wav.end(), slice.begin(), slice.end()); - current_start = speeches[i].end; - } - - std::vector slice(&input_wav[current_start], &input_wav[input_wav.size()]); - output_wav.insert(output_wav.end(), slice.begin(), slice.end()); - }; - -private: - // model config - int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k. - int sample_rate; //Assign when init support 16000 or 8000 - int sr_per_ms; // Assign when init, support 8 or 16 - float threshold; - int min_silence_samples; // sr_per_ms * #ms - int min_silence_samples_at_max_speech; // sr_per_ms * #98 - int min_speech_samples; // sr_per_ms * #ms - float max_speech_samples; - int speech_pad_samples; // usually a - int audio_length_samples; - - // model states - bool triggered = false; - unsigned int temp_end = 0; - unsigned int current_sample = 0; - // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes - int prev_end; - int next_start = 0; - - //Output timestamp - std::vector speeches; - timestamp_t current_speech; - - - // Onnx model - // Inputs - std::vector ort_inputs; - - std::vector input_node_names = {"input", "state", "sr"}; - std::vector input; - unsigned int size_state = 2 * 1 * 128; // It's FIXED. - std::vector _state; - std::vector sr; - - int64_t input_node_dims[2] = {}; - const int64_t state_node_dims[3] = {2, 1, 128}; - const int64_t sr_node_dims[1] = {1}; - - // Outputs - std::vector ort_outputs; - std::vector output_node_names = {"output", "stateN"}; + // Public method to reset the internal state. + void reset() { + reset_states(); + } public: - // Construction + // Constructor: sets model path, sample rate, window size (ms), and other parameters. + // The parameters are set to match the Python version. VadIterator(const std::wstring ModelPath, int Sample_rate = 16000, int windows_frame_size = 32, - float Threshold = 0.5, int min_silence_duration_ms = 0, - int speech_pad_ms = 32, int min_speech_duration_ms = 32, + float Threshold = 0.5, int min_silence_duration_ms = 100, + int speech_pad_ms = 30, int min_speech_duration_ms = 250, float max_speech_duration_s = std::numeric_limits::infinity()) + : sample_rate(Sample_rate), threshold(Threshold), speech_pad_samples(speech_pad_ms), prev_end(0) { - init_onnx_model(ModelPath); - threshold = Threshold; - sample_rate = Sample_rate; - sr_per_ms = sample_rate / 1000; - - window_size_samples = windows_frame_size * sr_per_ms; - - min_speech_samples = sr_per_ms * min_speech_duration_ms; - speech_pad_samples = sr_per_ms * speech_pad_ms; - - max_speech_samples = ( - sample_rate * max_speech_duration_s - - window_size_samples - - 2 * speech_pad_samples - ); - - min_silence_samples = sr_per_ms * min_silence_duration_ms; - min_silence_samples_at_max_speech = sr_per_ms * 98; - - input.resize(window_size_samples); + sr_per_ms = sample_rate / 1000; // e.g., 16000 / 1000 = 16 + window_size_samples = windows_frame_size * sr_per_ms; // e.g., 32ms * 16 = 512 samples + effective_window_size = window_size_samples + context_samples; // e.g., 512 + 64 = 576 samples input_node_dims[0] = 1; - input_node_dims[1] = window_size_samples; - + input_node_dims[1] = effective_window_size; _state.resize(size_state); sr.resize(1); sr[0] = sample_rate; - }; + _context.assign(context_samples, 0.0f); + min_speech_samples = sr_per_ms * min_speech_duration_ms; + max_speech_samples = (sample_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples); + min_silence_samples = sr_per_ms * min_silence_duration_ms; + min_silence_samples_at_max_speech = sr_per_ms * 98; + init_onnx_model(ModelPath); + } }; -int main() -{ - std::vector stamps; - - // Read wav - wav::WavReader wav_reader("recorder.wav"); //16000,1,32float - std::vector input_wav(wav_reader.num_samples()); - std::vector output_wav; - - for (int i = 0; i < wav_reader.num_samples(); i++) - { +int main() { + // Read the WAV file (expects 16000 Hz, mono, PCM). + wav::WavReader wav_reader("audio/recorder.wav"); // File located in the "audio" folder. + int numSamples = wav_reader.num_samples(); + std::vector input_wav(static_cast(numSamples)); + for (size_t i = 0; i < static_cast(numSamples); i++) { input_wav[i] = static_cast(*(wav_reader.data() + i)); } + // Set the ONNX model path (file located in the "model" folder). + std::wstring model_path = L"model/silero_vad.onnx"; + // Initialize the VadIterator. + VadIterator vad(model_path); - // ===== Test configs ===== - std::wstring path = L"silero_vad.onnx"; - VadIterator vad(path); - - // ============================================== - // ==== = Example 1 of full function ===== - // ============================================== + // Process the audio. vad.process(input_wav); - // 1.a get_speech_timestamps - stamps = vad.get_speech_timestamps(); - for (int i = 0; i < stamps.size(); i++) { + // Retrieve the speech timestamps (in samples). + std::vector stamps = vad.get_speech_timestamps(); - std::cout << stamps[i].c_str() << std::endl; + // Convert timestamps to seconds and round to one decimal place (for 16000 Hz). + const float sample_rate_float = 16000.0f; + for (size_t i = 0; i < stamps.size(); i++) { + float start_sec = std::rint((stamps[i].start / sample_rate_float) * 10.0f) / 10.0f; + float end_sec = std::rint((stamps[i].end / sample_rate_float) * 10.0f) / 10.0f; + std::cout << "Speech detected from " + << std::fixed << std::setprecision(1) << start_sec + << " s to " + << std::fixed << std::setprecision(1) << end_sec + << " s" << std::endl; } - // 1.b collect_chunks output wav - vad.collect_chunks(input_wav, output_wav); + // Optionally, reset the internal state. + vad.reset(); - // 1.c drop_chunks output wav - vad.drop_chunks(input_wav, output_wav); - - // ============================================== - // ===== Example 2 of simple full function ===== - // ============================================== - vad.process(input_wav, output_wav); - - stamps = vad.get_speech_timestamps(); - for (int i = 0; i < stamps.size(); i++) { - - std::cout << stamps[i].c_str() << std::endl; - } - - // ============================================== - // ===== Example 3 of full function ===== - // ============================================== - for(int i = 0; i<2; i++) - vad.process(input_wav, output_wav); + return 0; }