mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
@@ -120,8 +120,7 @@ private:
|
|||||||
void reset_states()
|
void reset_states()
|
||||||
{
|
{
|
||||||
// Call reset before each audio start
|
// Call reset before each audio start
|
||||||
std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
|
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
|
||||||
std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
|
|
||||||
triggered = false;
|
triggered = false;
|
||||||
temp_end = 0;
|
temp_end = 0;
|
||||||
current_sample = 0;
|
current_sample = 0;
|
||||||
@@ -139,19 +138,16 @@ private:
|
|||||||
input.assign(data.begin(), data.end());
|
input.assign(data.begin(), data.end());
|
||||||
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
|
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
|
||||||
memory_info, input.data(), input.size(), input_node_dims, 2);
|
memory_info, input.data(), input.size(), input_node_dims, 2);
|
||||||
|
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
|
||||||
|
memory_info, _state.data(), _state.size(), state_node_dims, 3);
|
||||||
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
|
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
|
||||||
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
|
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
|
||||||
Ort::Value h_ort = Ort::Value::CreateTensor<float>(
|
|
||||||
memory_info, _h.data(), _h.size(), hc_node_dims, 3);
|
|
||||||
Ort::Value c_ort = Ort::Value::CreateTensor<float>(
|
|
||||||
memory_info, _c.data(), _c.size(), hc_node_dims, 3);
|
|
||||||
|
|
||||||
// Clear and add inputs
|
// Clear and add inputs
|
||||||
ort_inputs.clear();
|
ort_inputs.clear();
|
||||||
ort_inputs.emplace_back(std::move(input_ort));
|
ort_inputs.emplace_back(std::move(input_ort));
|
||||||
|
ort_inputs.emplace_back(std::move(state_ort));
|
||||||
ort_inputs.emplace_back(std::move(sr_ort));
|
ort_inputs.emplace_back(std::move(sr_ort));
|
||||||
ort_inputs.emplace_back(std::move(h_ort));
|
|
||||||
ort_inputs.emplace_back(std::move(c_ort));
|
|
||||||
|
|
||||||
// Infer
|
// Infer
|
||||||
ort_outputs = session->Run(
|
ort_outputs = session->Run(
|
||||||
@@ -161,10 +157,8 @@ private:
|
|||||||
|
|
||||||
// Output probability & update h,c recursively
|
// Output probability & update h,c recursively
|
||||||
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
|
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
|
||||||
float *hn = ort_outputs[1].GetTensorMutableData<float>();
|
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
|
||||||
std::memcpy(_h.data(), hn, size_hc * sizeof(float));
|
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
|
||||||
float *cn = ort_outputs[2].GetTensorMutableData<float>();
|
|
||||||
std::memcpy(_c.data(), cn, size_hc * sizeof(float));
|
|
||||||
|
|
||||||
// Push forward sample index
|
// Push forward sample index
|
||||||
current_sample += window_size_samples;
|
current_sample += window_size_samples;
|
||||||
@@ -376,27 +370,26 @@ private:
|
|||||||
// Inputs
|
// Inputs
|
||||||
std::vector<Ort::Value> ort_inputs;
|
std::vector<Ort::Value> ort_inputs;
|
||||||
|
|
||||||
std::vector<const char *> input_node_names = {"input", "sr", "h", "c"};
|
std::vector<const char *> input_node_names = {"input", "state", "sr"};
|
||||||
std::vector<float> input;
|
std::vector<float> input;
|
||||||
|
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
|
||||||
|
std::vector<float> _state;
|
||||||
std::vector<int64_t> sr;
|
std::vector<int64_t> sr;
|
||||||
unsigned int size_hc = 2 * 1 * 64; // It's FIXED.
|
|
||||||
std::vector<float> _h;
|
|
||||||
std::vector<float> _c;
|
|
||||||
|
|
||||||
int64_t input_node_dims[2] = {};
|
int64_t input_node_dims[2] = {};
|
||||||
|
const int64_t state_node_dims[3] = {2, 1, 128};
|
||||||
const int64_t sr_node_dims[1] = {1};
|
const int64_t sr_node_dims[1] = {1};
|
||||||
const int64_t hc_node_dims[3] = {2, 1, 64};
|
|
||||||
|
|
||||||
// Outputs
|
// Outputs
|
||||||
std::vector<Ort::Value> ort_outputs;
|
std::vector<Ort::Value> ort_outputs;
|
||||||
std::vector<const char *> output_node_names = {"output", "hn", "cn"};
|
std::vector<const char *> output_node_names = {"output", "stateN"};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Construction
|
// Construction
|
||||||
VadIterator(const std::wstring ModelPath,
|
VadIterator(const std::wstring ModelPath,
|
||||||
int Sample_rate = 16000, int windows_frame_size = 64,
|
int Sample_rate = 16000, int windows_frame_size = 32,
|
||||||
float Threshold = 0.5, int min_silence_duration_ms = 0,
|
float Threshold = 0.5, int min_silence_duration_ms = 0,
|
||||||
int speech_pad_ms = 64, int min_speech_duration_ms = 64,
|
int speech_pad_ms = 32, int min_speech_duration_ms = 32,
|
||||||
float max_speech_duration_s = std::numeric_limits<float>::infinity())
|
float max_speech_duration_s = std::numeric_limits<float>::infinity())
|
||||||
{
|
{
|
||||||
init_onnx_model(ModelPath);
|
init_onnx_model(ModelPath);
|
||||||
@@ -422,8 +415,7 @@ public:
|
|||||||
input_node_dims[0] = 1;
|
input_node_dims[0] = 1;
|
||||||
input_node_dims[1] = window_size_samples;
|
input_node_dims[1] = window_size_samples;
|
||||||
|
|
||||||
_h.resize(size_hc);
|
_state.resize(size_state);
|
||||||
_c.resize(size_hc);
|
|
||||||
sr.resize(1);
|
sr.resize(1);
|
||||||
sr[0] = sample_rate;
|
sr[0] = sample_rate;
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user