From 60ae7abfb70373ab90d92fd9daf8632e1f0d3dca Mon Sep 17 00:00:00 2001 From: Stefan Miletic Date: Mon, 1 Jul 2024 15:32:40 +0100 Subject: [PATCH] v5 model cpp example --- examples/cpp/silero-vad-onnx.cpp | 438 ++++++++++++++++++++++++------- examples/cpp/wav.h | 75 ++++-- 2 files changed, 388 insertions(+), 125 deletions(-) diff --git a/examples/cpp/silero-vad-onnx.cpp b/examples/cpp/silero-vad-onnx.cpp index 2d53469..dd2bf4e 100644 --- a/examples/cpp/silero-vad-onnx.cpp +++ b/examples/cpp/silero-vad-onnx.cpp @@ -2,13 +2,97 @@ #include #include #include +#include #include - +#include +#include +#include +#include +#include #include "onnxruntime_cxx_api.h" #include "wav.h" +#include +#include +#if __cplusplus < 201703L +#include +#endif + +//#define __DEBUG_SPEECH_PROB___ + +class timestamp_t +{ +public: + int start; + int end; + + // default + parameterized constructor + timestamp_t(int start = -1, int end = -1) + : start(start), end(end) + { + }; + + // assignment operator modifies object, therefore non-const + timestamp_t& operator=(const timestamp_t& a) + { + start = a.start; + end = a.end; + return *this; + }; + + // equality comparison. doesn't modify object. therefore const. + bool operator==(const timestamp_t& a) const + { + return (start == a.start && end == a.end); + }; + std::string c_str() + { + //return std::format("timestamp {:08d}, {:08d}", start, end); + return format("{start:%08d,end:%08d}", start, end); + }; +private: + + std::string format(const char* fmt, ...) + { + char buf[256]; + + va_list args; + va_start(args, fmt); + const auto r = std::vsnprintf(buf, sizeof buf, fmt, args); + va_end(args); + + if (r < 0) + // conversion failed + return {}; + + const size_t len = r; + if (len < sizeof buf) + // we fit in the buffer + return { buf, len }; + +#if __cplusplus >= 201703L + // C++17: Create a string and write to its underlying array + std::string s(len, '\0'); + va_start(args, fmt); + std::vsnprintf(s.data(), len + 1, fmt, args); + va_end(args); + + return s; +#else + // C++11 or C++14: We need to allocate scratch memory + auto vbuf = std::unique_ptr(new char[len + 1]); + va_start(args, fmt); + std::vsnprintf(vbuf.get(), len + 1, fmt, args); + va_end(args); + + return { vbuf.get(), len }; +#endif + }; +}; + class VadIterator { +private: // OnnxRuntime resources Ort::Env env; Ort::SessionOptions session_options; @@ -16,47 +100,39 @@ class VadIterator Ort::AllocatorWithDefaultOptions allocator; Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); -public: +private: void init_engine_threads(int inter_threads, int intra_threads) - { + { // The method should be called in each thread/proc in multi-thread/proc work session_options.SetIntraOpNumThreads(intra_threads); session_options.SetInterOpNumThreads(inter_threads); session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); - } + }; - void init_onnx_model(const std::string &model_path) - { + void init_onnx_model(const std::wstring& model_path) + { // Init threads = 1 for init_engine_threads(1, 1); // Load model session = std::make_shared(env, model_path.c_str(), session_options); - } + }; void reset_states() { // Call reset before each audio start std::memset(_state.data(), 0.0f, _state.size() * sizeof(float)); - triggerd = false; + triggered = false; temp_end = 0; current_sample = 0; - } - // Call it in predict func. if you prefer raw bytes input. - void bytes_to_float_tensor(const char *pcm_bytes) - { - std::memcpy(input.data(), pcm_bytes, window_size_samples * sizeof(int16_t)); - for (int i = 0; i < window_size_samples; i++) - { - input[i] = static_cast(input[i]) / 32768; // int16_t normalized to float - } - } + prev_end = next_start = 0; + speeches.clear(); + current_speech = timestamp_t(); + }; void predict(const std::vector &data) { - // bytes_to_float_tensor(data); - // Infer // Create ort tensors input.assign(data.begin(), data.end()); @@ -80,81 +156,215 @@ public: output_node_names.data(), output_node_names.size()); // Output probability & update h,c recursively - float output = ort_outputs[0].GetTensorMutableData()[0]; + float speech_prob = ort_outputs[0].GetTensorMutableData()[0]; float *stateN = ort_outputs[1].GetTensorMutableData(); std::memcpy(_state.data(), stateN, size_state * sizeof(float)); // Push forward sample index current_sample += window_size_samples; - + // Reset temp_end when > threshold - if ((output >= threshold) && (temp_end != 0)) + if ((speech_prob >= threshold)) { - temp_end = 0; - } - // 1) Silence - if ((output < threshold) && (triggerd == false)) - { - // printf("{ silence: %.3f s }\n", 1.0 * current_sample / sample_rate); - } - // 2) Speaking - if ((output >= (threshold - 0.15)) && (triggerd == true)) - { - // printf("{ speaking_2: %.3f s }\n", 1.0 * current_sample / sample_rate); +#ifdef __DEBUG_SPEECH_PROB___ + float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. + printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample- window_size_samples); +#endif //__DEBUG_SPEECH_PROB___ + if (temp_end != 0) + { + temp_end = 0; + if (next_start < prev_end) + next_start = current_sample - window_size_samples; + } + if (triggered == false) + { + triggered = true; + + current_speech.start = current_sample - window_size_samples; + } + return; } - // 3) Start - if ((output >= threshold) && (triggerd == false)) - { - triggerd = true; - speech_start = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point. - printf("{ start: %.3f s }\n", 1.0 * speech_start / sample_rate); + if ( + (triggered == true) + && ((current_sample - current_speech.start) > max_speech_samples) + ) { + if (prev_end > 0) { + current_speech.end = prev_end; + speeches.push_back(current_speech); + current_speech = timestamp_t(); + + // previously reached silence(< neg_thres) and is still not speech(< thres) + if (next_start < prev_end) + triggered = false; + else{ + current_speech.start = next_start; + } + prev_end = 0; + next_start = 0; + temp_end = 0; + + } + else{ + current_speech.end = current_sample; + speeches.push_back(current_speech); + current_speech = timestamp_t(); + prev_end = 0; + next_start = 0; + temp_end = 0; + triggered = false; + } + return; + } + if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) + { + if (triggered) { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. + printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); +#endif //__DEBUG_SPEECH_PROB___ + } + else { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. + printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); +#endif //__DEBUG_SPEECH_PROB___ + } + return; + } + // 4) End - if ((output < (threshold - 0.15)) && (triggerd == true)) + if ((speech_prob < (threshold - 0.15))) { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point. + printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); +#endif //__DEBUG_SPEECH_PROB___ + if (triggered == true) + { + if (temp_end == 0) + { + temp_end = current_sample; + } + if (current_sample - temp_end > min_silence_samples_at_max_speech) + prev_end = temp_end; + // a. silence < min_slience_samples, continue speaking + if ((current_sample - temp_end) < min_silence_samples) + { - if (temp_end == 0) - { - temp_end = current_sample; + } + // b. silence >= min_slience_samples, end speaking + else + { + current_speech.end = temp_end; + if (current_speech.end - current_speech.start > min_speech_samples) + { + speeches.push_back(current_speech); + current_speech = timestamp_t(); + prev_end = 0; + next_start = 0; + temp_end = 0; + triggered = false; + } + } } - // a. silence < min_slience_samples, continue speaking - if ((current_sample - temp_end) < min_silence_samples) - { - // printf("{ speaking_4: %.3f s }\n", 1.0 * current_sample / sample_rate); - // printf(""); - } - // b. silence >= min_slience_samples, end speaking - else - { - speech_end = temp_end ? temp_end + speech_pad_samples : current_sample + speech_pad_samples; - temp_end = 0; - triggerd = false; - printf("{ end: %.3f s }\n", 1.0 * speech_end / sample_rate); + else { + // may first windows see end state. } + return; + } + }; +public: + void process(const std::vector& input_wav) + { + reset_states(); + + audio_length_samples = input_wav.size(); + + for (int j = 0; j < audio_length_samples; j += window_size_samples) + { + if (j + window_size_samples > audio_length_samples) + break; + std::vector r{ &input_wav[0] + j, &input_wav[0] + j + window_size_samples }; + predict(r); } + if (current_speech.start >= 0) { + current_speech.end = audio_length_samples; + speeches.push_back(current_speech); + current_speech = timestamp_t(); + prev_end = 0; + next_start = 0; + temp_end = 0; + triggered = false; + } + }; + void process(const std::vector& input_wav, std::vector& output_wav) + { + process(input_wav); + collect_chunks(input_wav, output_wav); } + void collect_chunks(const std::vector& input_wav, std::vector& output_wav) + { + output_wav.clear(); + for (int i = 0; i < speeches.size(); i++) { +#ifdef __DEBUG_SPEECH_PROB___ + std::cout << speeches[i].c_str() << std::endl; +#endif //#ifdef __DEBUG_SPEECH_PROB___ + std::vector slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]); + output_wav.insert(output_wav.end(),slice.begin(),slice.end()); + } + }; + + const std::vector get_speech_timestamps() const + { + return speeches; + } + + void drop_chunks(const std::vector& input_wav, std::vector& output_wav) + { + output_wav.clear(); + int current_start = 0; + for (int i = 0; i < speeches.size(); i++) { + + std::vector slice(&input_wav[current_start],&input_wav[speeches[i].start]); + output_wav.insert(output_wav.end(), slice.begin(), slice.end()); + current_start = speeches[i].end; + } + + std::vector slice(&input_wav[current_start], &input_wav[input_wav.size()]); + output_wav.insert(output_wav.end(), slice.begin(), slice.end()); + }; + private: // model config int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k. - int sample_rate; - int sr_per_ms; // Assign when init, support 8 or 16 - float threshold; + int sample_rate; //Assign when init support 16000 or 8000 + int sr_per_ms; // Assign when init, support 8 or 16 + float threshold; int min_silence_samples; // sr_per_ms * #ms + int min_silence_samples_at_max_speech; // sr_per_ms * #98 + int min_speech_samples; // sr_per_ms * #ms + float max_speech_samples; int speech_pad_samples; // usually a + int audio_length_samples; // model states - bool triggerd = false; - unsigned int speech_start = 0; - unsigned int speech_end = 0; + bool triggered = false; unsigned int temp_end = 0; unsigned int current_sample = 0; // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes - float output; + int prev_end; + int next_start = 0; + + //Output timestamp + std::vector speeches; + timestamp_t current_speech; + // Onnx model // Inputs @@ -166,79 +376,103 @@ private: std::vector _state; std::vector sr; - int64_t input_node_dims[2] = {}; - const int64_t state_node_dims[3] = {2, 1, 128}; + int64_t input_node_dims[2] = {}; + const int64_t state_node_dims[3] = {2, 1, 128}; const int64_t sr_node_dims[1] = {1}; // Outputs std::vector ort_outputs; std::vector output_node_names = {"output", "stateN"}; - public: // Construction - VadIterator(const std::string ModelPath, - int Sample_rate, int frame_size, - float Threshold, int min_silence_duration_ms, int speech_pad_ms) + VadIterator(const std::wstring ModelPath, + int Sample_rate = 16000, int windows_frame_size = 32, + float Threshold = 0.5, int min_silence_duration_ms = 0, + int speech_pad_ms = 32, int min_speech_duration_ms = 32, + float max_speech_duration_s = std::numeric_limits::infinity()) { init_onnx_model(ModelPath); + threshold = Threshold; sample_rate = Sample_rate; sr_per_ms = sample_rate / 1000; - threshold = Threshold; - min_silence_samples = sr_per_ms * min_silence_duration_ms; + + window_size_samples = windows_frame_size * sr_per_ms; + + min_speech_samples = sr_per_ms * min_speech_duration_ms; speech_pad_samples = sr_per_ms * speech_pad_ms; - window_size_samples = frame_size * sr_per_ms; - + + max_speech_samples = ( + sample_rate * max_speech_duration_s + - window_size_samples + - 2 * speech_pad_samples + ); + + min_silence_samples = sr_per_ms * min_silence_duration_ms; + min_silence_samples_at_max_speech = sr_per_ms * 98; + input.resize(window_size_samples); input_node_dims[0] = 1; input_node_dims[1] = window_size_samples; + _state.resize(size_state); sr.resize(1); sr[0] = sample_rate; - } - + }; }; int main() { + std::vector stamps; // Read wav - wav::WavReader wav_reader("./recorder.wav"); - std::vector data(wav_reader.num_samples()); + wav::WavReader wav_reader("recorder.wav"); //16000,1,32float std::vector input_wav(wav_reader.num_samples()); + std::vector output_wav; for (int i = 0; i < wav_reader.num_samples(); i++) { - data[i] = static_cast(*(wav_reader.data() + i)); + input_wav[i] = static_cast(*(wav_reader.data() + i)); } - for (int i = 0; i < wav_reader.num_samples(); i++) - { - input_wav[i] = static_cast(data[i]) / 32768; - } + // ===== Test configs ===== - std::string path = "../../files/silero_vad.onnx"; - int test_sr = 16000; - int test_frame_ms = 32; - float test_threshold = 0.5f; - int test_min_silence_duration_ms = 0; - int test_speech_pad_ms = 0; - int test_window_samples = test_frame_ms * (test_sr/1000); + std::wstring path = L"silero_vad.onnx"; + VadIterator vad(path); - VadIterator vad( - path, test_sr, test_frame_ms, test_threshold, - test_min_silence_duration_ms, test_speech_pad_ms); + // ============================================== + // ==== = Example 1 of full function ===== + // ============================================== + vad.process(input_wav); - for (int j = 0; j < wav_reader.num_samples(); j += test_window_samples) - { - std::vector r{&input_wav[0] + j, &input_wav[0] + j + test_window_samples}; - auto start = std::chrono::high_resolution_clock::now(); - // Predict and print throughout process time - vad.predict(r); - auto end = std::chrono::high_resolution_clock::now(); - auto elapsed_time = std::chrono::duration_cast(end-start); - std::cout << "== Elapsed time: " << 1.0*elapsed_time.count()/1000000 << "ms" << " ==" < #include #include #include @@ -60,8 +59,7 @@ class WavReader { WavHeader header; fread(&header, 1, sizeof(header), fp); if (header.fmt_size < 16) { - fprintf(stderr, - "WaveData: expect PCM format data " + printf("WaveData: expect PCM format data " "to have fmt chunk of at least size 16.\n"); return false; } else if (header.fmt_size > 16) { @@ -81,6 +79,13 @@ class WavReader { fread(header.data, 8, sizeof(char), fp); } + if (header.data_size == 0) { + int offset = ftell(fp); + fseek(fp, 0, SEEK_END); + header.data_size = ftell(fp) - offset; + fseek(fp, offset, SEEK_SET); + } + num_channel_ = header.channels; sample_rate_ = header.sample_rate; bits_per_sample_ = header.bit; @@ -88,33 +93,57 @@ class WavReader { data_ = new float[num_data]; // Create 1-dim array num_samples_ = num_data / num_channel_; - for (int i = 0; i < num_data; ++i) { - switch (bits_per_sample_) { + std::cout << "num_channel_ :" << num_channel_ << std::endl; + std::cout << "sample_rate_ :" << sample_rate_ << std::endl; + std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl; + std::cout << "num_samples :" << num_data << std::endl; + std::cout << "num_data_size :" << header.data_size << std::endl; + + switch (bits_per_sample_) { case 8: { - char sample; - fread(&sample, 1, sizeof(char), fp); - data_[i] = static_cast(sample); - break; + char sample; + for (int i = 0; i < num_data; ++i) { + fread(&sample, 1, sizeof(char), fp); + data_[i] = static_cast(sample) / 32768; + } + break; } case 16: { - int16_t sample; - fread(&sample, 1, sizeof(int16_t), fp); - // std::cout << sample; - data_[i] = static_cast(sample); - // std::cout << data_[i]; - break; + int16_t sample; + for (int i = 0; i < num_data; ++i) { + fread(&sample, 1, sizeof(int16_t), fp); + data_[i] = static_cast(sample) / 32768; + } + break; } - case 32: { - int sample; - fread(&sample, 1, sizeof(int), fp); - data_[i] = static_cast(sample); - break; + case 32: + { + if (header.format == 1) //S32 + { + int sample; + for (int i = 0; i < num_data; ++i) { + fread(&sample, 1, sizeof(int), fp); + data_[i] = static_cast(sample) / 32768; + } + } + else if (header.format == 3) // IEEE-float + { + float sample; + for (int i = 0; i < num_data; ++i) { + fread(&sample, 1, sizeof(float), fp); + data_[i] = static_cast(sample); + } + } + else { + printf("unsupported quantization bits\n"); + } + break; } default: - fprintf(stderr, "unsupported quantization bits"); - exit(1); - } + printf("unsupported quantization bits\n"); + break; } + fclose(fp); return true; }