mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 09:29:22 +08:00
Update model plumbing for Rust example
The v6.2 models broke the Rust example. Update the logic for driving them to reflect what the reference Python code does. Fixes: #745 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use crate::utils;
|
||||
use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
||||
use ndarray::{Array, Array1, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -7,6 +7,8 @@ pub struct Silero {
|
||||
session: ort::Session,
|
||||
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
||||
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
||||
context: Array1<f32>,
|
||||
context_size: usize,
|
||||
}
|
||||
|
||||
impl Silero {
|
||||
@@ -16,16 +18,22 @@ impl Silero {
|
||||
) -> Result<Self, ort::Error> {
|
||||
let session = ort::Session::builder()?.commit_from_file(model_path)?;
|
||||
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
|
||||
let sample_rate_val: i64 = sample_rate.into();
|
||||
let context_size = if sample_rate_val == 16000 { 64 } else { 32 };
|
||||
let context = Array1::<f32>::zeros(context_size);
|
||||
let sample_rate = Array::from_shape_vec([1], vec![sample_rate_val]).unwrap();
|
||||
Ok(Self {
|
||||
session,
|
||||
sample_rate,
|
||||
state,
|
||||
context,
|
||||
context_size,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||
self.context = Array1::<f32>::zeros(self.context_size);
|
||||
}
|
||||
|
||||
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
|
||||
@@ -33,8 +41,14 @@ impl Silero {
|
||||
.iter()
|
||||
.map(|x| (*x as f32) / (i16::MAX as f32))
|
||||
.collect::<Vec<_>>();
|
||||
let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
|
||||
frame = frame.slice(s![.., ..480]).to_owned();
|
||||
|
||||
// Concatenate context with input
|
||||
let mut input_with_context = Vec::with_capacity(self.context_size + data.len());
|
||||
input_with_context.extend_from_slice(self.context.as_slice().unwrap());
|
||||
input_with_context.extend_from_slice(&data);
|
||||
|
||||
let frame = Array2::<f32>::from_shape_vec([1, input_with_context.len()], input_with_context).unwrap();
|
||||
|
||||
let inps = ort::inputs![
|
||||
frame,
|
||||
std::mem::take(&mut self.state),
|
||||
@@ -43,12 +57,20 @@ impl Silero {
|
||||
let res = self
|
||||
.session
|
||||
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
|
||||
|
||||
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
|
||||
Ok(*res["output"]
|
||||
|
||||
// Update context with last context_size samples from the input
|
||||
if data.len() >= self.context_size {
|
||||
self.context = Array1::from_vec(data[data.len() - self.context_size..].to_vec());
|
||||
}
|
||||
|
||||
let prob = *res["output"]
|
||||
.try_extract_raw_tensor::<f32>()
|
||||
.unwrap()
|
||||
.1
|
||||
.first()
|
||||
.unwrap())
|
||||
.unwrap();
|
||||
Ok(prob)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ pub struct VadParams {
|
||||
impl Default for VadParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
frame_size: 64,
|
||||
frame_size: 32, // 32ms for 512 samples at 16kHz
|
||||
threshold: 0.5,
|
||||
min_silence_duration_ms: 0,
|
||||
speech_pad_ms: 64,
|
||||
|
||||
Reference in New Issue
Block a user