v5 model cpp example

This commit is contained in:
Stefan Miletic
2024-07-01 15:32:40 +01:00
parent 0b3d43d432
commit 60ae7abfb7
2 changed files with 388 additions and 125 deletions

View File

@@ -2,13 +2,97 @@
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <cstring> #include <cstring>
#include <limits>
#include <chrono> #include <chrono>
#include <memory>
#include <string>
#include <stdexcept>
#include <iostream>
#include <string>
#include "onnxruntime_cxx_api.h" #include "onnxruntime_cxx_api.h"
#include "wav.h" #include "wav.h"
#include <cstdio>
#include <cstdarg>
#if __cplusplus < 201703L
#include <memory>
#endif
//#define __DEBUG_SPEECH_PROB___
class timestamp_t
{
public:
int start;
int end;
// default + parameterized constructor
timestamp_t(int start = -1, int end = -1)
: start(start), end(end)
{
};
// assignment operator modifies object, therefore non-const
timestamp_t& operator=(const timestamp_t& a)
{
start = a.start;
end = a.end;
return *this;
};
// equality comparison. doesn't modify object. therefore const.
bool operator==(const timestamp_t& a) const
{
return (start == a.start && end == a.end);
};
std::string c_str()
{
//return std::format("timestamp {:08d}, {:08d}", start, end);
return format("{start:%08d,end:%08d}", start, end);
};
private:
std::string format(const char* fmt, ...)
{
char buf[256];
va_list args;
va_start(args, fmt);
const auto r = std::vsnprintf(buf, sizeof buf, fmt, args);
va_end(args);
if (r < 0)
// conversion failed
return {};
const size_t len = r;
if (len < sizeof buf)
// we fit in the buffer
return { buf, len };
#if __cplusplus >= 201703L
// C++17: Create a string and write to its underlying array
std::string s(len, '\0');
va_start(args, fmt);
std::vsnprintf(s.data(), len + 1, fmt, args);
va_end(args);
return s;
#else
// C++11 or C++14: We need to allocate scratch memory
auto vbuf = std::unique_ptr<char[]>(new char[len + 1]);
va_start(args, fmt);
std::vsnprintf(vbuf.get(), len + 1, fmt, args);
va_end(args);
return { vbuf.get(), len };
#endif
};
};
class VadIterator class VadIterator
{ {
private:
// OnnxRuntime resources // OnnxRuntime resources
Ort::Env env; Ort::Env env;
Ort::SessionOptions session_options; Ort::SessionOptions session_options;
@@ -16,47 +100,39 @@ class VadIterator
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
public: private:
void init_engine_threads(int inter_threads, int intra_threads) void init_engine_threads(int inter_threads, int intra_threads)
{ {
// The method should be called in each thread/proc in multi-thread/proc work // The method should be called in each thread/proc in multi-thread/proc work
session_options.SetIntraOpNumThreads(intra_threads); session_options.SetIntraOpNumThreads(intra_threads);
session_options.SetInterOpNumThreads(inter_threads); session_options.SetInterOpNumThreads(inter_threads);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
} };
void init_onnx_model(const std::string &model_path) void init_onnx_model(const std::wstring& model_path)
{ {
// Init threads = 1 for // Init threads = 1 for
init_engine_threads(1, 1); init_engine_threads(1, 1);
// Load model // Load model
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options); session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
} };
void reset_states() void reset_states()
{ {
// Call reset before each audio start // Call reset before each audio start
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float)); std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
triggerd = false; triggered = false;
temp_end = 0; temp_end = 0;
current_sample = 0; current_sample = 0;
}
// Call it in predict func. if you prefer raw bytes input. prev_end = next_start = 0;
void bytes_to_float_tensor(const char *pcm_bytes)
{
std::memcpy(input.data(), pcm_bytes, window_size_samples * sizeof(int16_t));
for (int i = 0; i < window_size_samples; i++)
{
input[i] = static_cast<float>(input[i]) / 32768; // int16_t normalized to float
}
}
speeches.clear();
current_speech = timestamp_t();
};
void predict(const std::vector<float> &data) void predict(const std::vector<float> &data)
{ {
// bytes_to_float_tensor(data);
// Infer // Infer
// Create ort tensors // Create ort tensors
input.assign(data.begin(), data.end()); input.assign(data.begin(), data.end());
@@ -80,81 +156,215 @@ public:
output_node_names.data(), output_node_names.size()); output_node_names.data(), output_node_names.size());
// Output probability & update h,c recursively // Output probability & update h,c recursively
float output = ort_outputs[0].GetTensorMutableData<float>()[0]; float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
float *stateN = ort_outputs[1].GetTensorMutableData<float>(); float *stateN = ort_outputs[1].GetTensorMutableData<float>();
std::memcpy(_state.data(), stateN, size_state * sizeof(float)); std::memcpy(_state.data(), stateN, size_state * sizeof(float));
// Push forward sample index // Push forward sample index
current_sample += window_size_samples; current_sample += window_size_samples;
// Reset temp_end when > threshold // Reset temp_end when > threshold
if ((output >= threshold) && (temp_end != 0)) if ((speech_prob >= threshold))
{ {
temp_end = 0; #ifdef __DEBUG_SPEECH_PROB___
} float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
// 1) Silence printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample- window_size_samples);
if ((output < threshold) && (triggerd == false)) #endif //__DEBUG_SPEECH_PROB___
{ if (temp_end != 0)
// printf("{ silence: %.3f s }\n", 1.0 * current_sample / sample_rate); {
} temp_end = 0;
// 2) Speaking if (next_start < prev_end)
if ((output >= (threshold - 0.15)) && (triggerd == true)) next_start = current_sample - window_size_samples;
{ }
// printf("{ speaking_2: %.3f s }\n", 1.0 * current_sample / sample_rate); if (triggered == false)
{
triggered = true;
current_speech.start = current_sample - window_size_samples;
}
return;
} }
// 3) Start if (
if ((output >= threshold) && (triggerd == false)) (triggered == true)
{ && ((current_sample - current_speech.start) > max_speech_samples)
triggerd = true; ) {
speech_start = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point. if (prev_end > 0) {
printf("{ start: %.3f s }\n", 1.0 * speech_start / sample_rate); current_speech.end = prev_end;
speeches.push_back(current_speech);
current_speech = timestamp_t();
// previously reached silence(< neg_thres) and is still not speech(< thres)
if (next_start < prev_end)
triggered = false;
else{
current_speech.start = next_start;
}
prev_end = 0;
next_start = 0;
temp_end = 0;
}
else{
current_speech.end = current_sample;
speeches.push_back(current_speech);
current_speech = timestamp_t();
prev_end = 0;
next_start = 0;
temp_end = 0;
triggered = false;
}
return;
} }
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold))
{
if (triggered) {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
}
else {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
}
return;
}
// 4) End // 4) End
if ((output < (threshold - 0.15)) && (triggerd == true)) if ((speech_prob < (threshold - 0.15)))
{ {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point.
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
if (triggered == true)
{
if (temp_end == 0)
{
temp_end = current_sample;
}
if (current_sample - temp_end > min_silence_samples_at_max_speech)
prev_end = temp_end;
// a. silence < min_slience_samples, continue speaking
if ((current_sample - temp_end) < min_silence_samples)
{
if (temp_end == 0) }
{ // b. silence >= min_slience_samples, end speaking
temp_end = current_sample; else
{
current_speech.end = temp_end;
if (current_speech.end - current_speech.start > min_speech_samples)
{
speeches.push_back(current_speech);
current_speech = timestamp_t();
prev_end = 0;
next_start = 0;
temp_end = 0;
triggered = false;
}
}
} }
// a. silence < min_slience_samples, continue speaking else {
if ((current_sample - temp_end) < min_silence_samples) // may first windows see end state.
{
// printf("{ speaking_4: %.3f s }\n", 1.0 * current_sample / sample_rate);
// printf("");
}
// b. silence >= min_slience_samples, end speaking
else
{
speech_end = temp_end ? temp_end + speech_pad_samples : current_sample + speech_pad_samples;
temp_end = 0;
triggerd = false;
printf("{ end: %.3f s }\n", 1.0 * speech_end / sample_rate);
} }
return;
}
};
public:
void process(const std::vector<float>& input_wav)
{
reset_states();
audio_length_samples = input_wav.size();
for (int j = 0; j < audio_length_samples; j += window_size_samples)
{
if (j + window_size_samples > audio_length_samples)
break;
std::vector<float> r{ &input_wav[0] + j, &input_wav[0] + j + window_size_samples };
predict(r);
} }
if (current_speech.start >= 0) {
current_speech.end = audio_length_samples;
speeches.push_back(current_speech);
current_speech = timestamp_t();
prev_end = 0;
next_start = 0;
temp_end = 0;
triggered = false;
}
};
void process(const std::vector<float>& input_wav, std::vector<float>& output_wav)
{
process(input_wav);
collect_chunks(input_wav, output_wav);
} }
void collect_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
{
output_wav.clear();
for (int i = 0; i < speeches.size(); i++) {
#ifdef __DEBUG_SPEECH_PROB___
std::cout << speeches[i].c_str() << std::endl;
#endif //#ifdef __DEBUG_SPEECH_PROB___
std::vector<float> slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]);
output_wav.insert(output_wav.end(),slice.begin(),slice.end());
}
};
const std::vector<timestamp_t> get_speech_timestamps() const
{
return speeches;
}
void drop_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
{
output_wav.clear();
int current_start = 0;
for (int i = 0; i < speeches.size(); i++) {
std::vector<float> slice(&input_wav[current_start],&input_wav[speeches[i].start]);
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
current_start = speeches[i].end;
}
std::vector<float> slice(&input_wav[current_start], &input_wav[input_wav.size()]);
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
};
private: private:
// model config // model config
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k. int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
int sample_rate; int sample_rate; //Assign when init support 16000 or 8000
int sr_per_ms; // Assign when init, support 8 or 16 int sr_per_ms; // Assign when init, support 8 or 16
float threshold; float threshold;
int min_silence_samples; // sr_per_ms * #ms int min_silence_samples; // sr_per_ms * #ms
int min_silence_samples_at_max_speech; // sr_per_ms * #98
int min_speech_samples; // sr_per_ms * #ms
float max_speech_samples;
int speech_pad_samples; // usually a int speech_pad_samples; // usually a
int audio_length_samples;
// model states // model states
bool triggerd = false; bool triggered = false;
unsigned int speech_start = 0;
unsigned int speech_end = 0;
unsigned int temp_end = 0; unsigned int temp_end = 0;
unsigned int current_sample = 0; unsigned int current_sample = 0;
// MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
float output; int prev_end;
int next_start = 0;
//Output timestamp
std::vector<timestamp_t> speeches;
timestamp_t current_speech;
// Onnx model // Onnx model
// Inputs // Inputs
@@ -166,79 +376,103 @@ private:
std::vector<float> _state; std::vector<float> _state;
std::vector<int64_t> sr; std::vector<int64_t> sr;
int64_t input_node_dims[2] = {}; int64_t input_node_dims[2] = {};
const int64_t state_node_dims[3] = {2, 1, 128}; const int64_t state_node_dims[3] = {2, 1, 128};
const int64_t sr_node_dims[1] = {1}; const int64_t sr_node_dims[1] = {1};
// Outputs // Outputs
std::vector<Ort::Value> ort_outputs; std::vector<Ort::Value> ort_outputs;
std::vector<const char *> output_node_names = {"output", "stateN"}; std::vector<const char *> output_node_names = {"output", "stateN"};
public: public:
// Construction // Construction
VadIterator(const std::string ModelPath, VadIterator(const std::wstring ModelPath,
int Sample_rate, int frame_size, int Sample_rate = 16000, int windows_frame_size = 32,
float Threshold, int min_silence_duration_ms, int speech_pad_ms) float Threshold = 0.5, int min_silence_duration_ms = 0,
int speech_pad_ms = 32, int min_speech_duration_ms = 32,
float max_speech_duration_s = std::numeric_limits<float>::infinity())
{ {
init_onnx_model(ModelPath); init_onnx_model(ModelPath);
threshold = Threshold;
sample_rate = Sample_rate; sample_rate = Sample_rate;
sr_per_ms = sample_rate / 1000; sr_per_ms = sample_rate / 1000;
threshold = Threshold;
min_silence_samples = sr_per_ms * min_silence_duration_ms; window_size_samples = windows_frame_size * sr_per_ms;
min_speech_samples = sr_per_ms * min_speech_duration_ms;
speech_pad_samples = sr_per_ms * speech_pad_ms; speech_pad_samples = sr_per_ms * speech_pad_ms;
window_size_samples = frame_size * sr_per_ms;
max_speech_samples = (
sample_rate * max_speech_duration_s
- window_size_samples
- 2 * speech_pad_samples
);
min_silence_samples = sr_per_ms * min_silence_duration_ms;
min_silence_samples_at_max_speech = sr_per_ms * 98;
input.resize(window_size_samples); input.resize(window_size_samples);
input_node_dims[0] = 1; input_node_dims[0] = 1;
input_node_dims[1] = window_size_samples; input_node_dims[1] = window_size_samples;
_state.resize(size_state); _state.resize(size_state);
sr.resize(1); sr.resize(1);
sr[0] = sample_rate; sr[0] = sample_rate;
} };
}; };
int main() int main()
{ {
std::vector<timestamp_t> stamps;
// Read wav // Read wav
wav::WavReader wav_reader("./recorder.wav"); wav::WavReader wav_reader("recorder.wav"); //16000,1,32float
std::vector<int16_t> data(wav_reader.num_samples());
std::vector<float> input_wav(wav_reader.num_samples()); std::vector<float> input_wav(wav_reader.num_samples());
std::vector<float> output_wav;
for (int i = 0; i < wav_reader.num_samples(); i++) for (int i = 0; i < wav_reader.num_samples(); i++)
{ {
data[i] = static_cast<int16_t>(*(wav_reader.data() + i)); input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
} }
for (int i = 0; i < wav_reader.num_samples(); i++)
{
input_wav[i] = static_cast<float>(data[i]) / 32768;
}
// ===== Test configs ===== // ===== Test configs =====
std::string path = "../../files/silero_vad.onnx"; std::wstring path = L"silero_vad.onnx";
int test_sr = 16000; VadIterator vad(path);
int test_frame_ms = 32;
float test_threshold = 0.5f;
int test_min_silence_duration_ms = 0;
int test_speech_pad_ms = 0;
int test_window_samples = test_frame_ms * (test_sr/1000);
VadIterator vad( // ==============================================
path, test_sr, test_frame_ms, test_threshold, // ==== = Example 1 of full function =====
test_min_silence_duration_ms, test_speech_pad_ms); // ==============================================
vad.process(input_wav);
for (int j = 0; j < wav_reader.num_samples(); j += test_window_samples) // 1.a get_speech_timestamps
{ stamps = vad.get_speech_timestamps();
std::vector<float> r{&input_wav[0] + j, &input_wav[0] + j + test_window_samples}; for (int i = 0; i < stamps.size(); i++) {
auto start = std::chrono::high_resolution_clock::now();
// Predict and print throughout process time
vad.predict(r);
auto end = std::chrono::high_resolution_clock::now();
auto elapsed_time = std::chrono::duration_cast<std::chrono::nanoseconds>(end-start);
std::cout << "== Elapsed time: " << 1.0*elapsed_time.count()/1000000 << "ms" << " ==" <<std::endl;
std::cout << stamps[i].c_str() << std::endl;
} }
// 1.b collect_chunks output wav
vad.collect_chunks(input_wav, output_wav);
// 1.c drop_chunks output wav
vad.drop_chunks(input_wav, output_wav);
// ==============================================
// ===== Example 2 of simple full function =====
// ==============================================
vad.process(input_wav, output_wav);
stamps = vad.get_speech_timestamps();
for (int i = 0; i < stamps.size(); i++) {
std::cout << stamps[i].c_str() << std::endl;
}
// ==============================================
// ===== Example 3 of full function =====
// ==============================================
for(int i = 0; i<2; i++)
vad.process(input_wav, output_wav);
} }

View File

@@ -16,7 +16,6 @@
#ifndef FRONTEND_WAV_H_ #ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_ #define FRONTEND_WAV_H_
#include <iostream>
#include <assert.h> #include <assert.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
@@ -60,8 +59,7 @@ class WavReader {
WavHeader header; WavHeader header;
fread(&header, 1, sizeof(header), fp); fread(&header, 1, sizeof(header), fp);
if (header.fmt_size < 16) { if (header.fmt_size < 16) {
fprintf(stderr, printf("WaveData: expect PCM format data "
"WaveData: expect PCM format data "
"to have fmt chunk of at least size 16.\n"); "to have fmt chunk of at least size 16.\n");
return false; return false;
} else if (header.fmt_size > 16) { } else if (header.fmt_size > 16) {
@@ -81,6 +79,13 @@ class WavReader {
fread(header.data, 8, sizeof(char), fp); fread(header.data, 8, sizeof(char), fp);
} }
if (header.data_size == 0) {
int offset = ftell(fp);
fseek(fp, 0, SEEK_END);
header.data_size = ftell(fp) - offset;
fseek(fp, offset, SEEK_SET);
}
num_channel_ = header.channels; num_channel_ = header.channels;
sample_rate_ = header.sample_rate; sample_rate_ = header.sample_rate;
bits_per_sample_ = header.bit; bits_per_sample_ = header.bit;
@@ -88,33 +93,57 @@ class WavReader {
data_ = new float[num_data]; // Create 1-dim array data_ = new float[num_data]; // Create 1-dim array
num_samples_ = num_data / num_channel_; num_samples_ = num_data / num_channel_;
for (int i = 0; i < num_data; ++i) { std::cout << "num_channel_ :" << num_channel_ << std::endl;
switch (bits_per_sample_) { std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
std::cout << "num_samples :" << num_data << std::endl;
std::cout << "num_data_size :" << header.data_size << std::endl;
switch (bits_per_sample_) {
case 8: { case 8: {
char sample; char sample;
fread(&sample, 1, sizeof(char), fp); for (int i = 0; i < num_data; ++i) {
data_[i] = static_cast<float>(sample); fread(&sample, 1, sizeof(char), fp);
break; data_[i] = static_cast<float>(sample) / 32768;
}
break;
} }
case 16: { case 16: {
int16_t sample; int16_t sample;
fread(&sample, 1, sizeof(int16_t), fp); for (int i = 0; i < num_data; ++i) {
// std::cout << sample; fread(&sample, 1, sizeof(int16_t), fp);
data_[i] = static_cast<float>(sample); data_[i] = static_cast<float>(sample) / 32768;
// std::cout << data_[i]; }
break; break;
} }
case 32: { case 32:
int sample; {
fread(&sample, 1, sizeof(int), fp); if (header.format == 1) //S32
data_[i] = static_cast<float>(sample); {
break; int sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(int), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
}
else if (header.format == 3) // IEEE-float
{
float sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(float), fp);
data_[i] = static_cast<float>(sample);
}
}
else {
printf("unsupported quantization bits\n");
}
break;
} }
default: default:
fprintf(stderr, "unsupported quantization bits"); printf("unsupported quantization bits\n");
exit(1); break;
}
} }
fclose(fp); fclose(fp);
return true; return true;
} }