From cfe63384f0283ef597fed7007997b71d98652e93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20M=C3=BCller?= Date: Sun, 28 Dec 2025 07:15:01 -0800 Subject: [PATCH] Update model plumbing for Rust example The v6.2 models broke the Rust example. Update the logic for driving them to reflect what the reference Python code does. Fixes: #745 Co-Authored-By: Claude --- examples/rust-example/src/silero.rs | 34 ++++++++++++++++++++++++----- examples/rust-example/src/utils.rs | 2 +- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/examples/rust-example/src/silero.rs b/examples/rust-example/src/silero.rs index fce8808..a4d7103 100644 --- a/examples/rust-example/src/silero.rs +++ b/examples/rust-example/src/silero.rs @@ -1,5 +1,5 @@ use crate::utils; -use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr}; +use ndarray::{Array, Array1, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr}; use std::path::Path; #[derive(Debug)] @@ -7,6 +7,8 @@ pub struct Silero { session: ort::Session, sample_rate: ArrayBase, Dim<[usize; 1]>>, state: ArrayBase, Dim>, + context: Array1, + context_size: usize, } impl Silero { @@ -16,16 +18,22 @@ impl Silero { ) -> Result { let session = ort::Session::builder()?.commit_from_file(model_path)?; let state = ArrayD::::zeros([2, 1, 128].as_slice()); - let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap(); + let sample_rate_val: i64 = sample_rate.into(); + let context_size = if sample_rate_val == 16000 { 64 } else { 32 }; + let context = Array1::::zeros(context_size); + let sample_rate = Array::from_shape_vec([1], vec![sample_rate_val]).unwrap(); Ok(Self { session, sample_rate, state, + context, + context_size, }) } pub fn reset(&mut self) { self.state = ArrayD::::zeros([2, 1, 128].as_slice()); + self.context = Array1::::zeros(self.context_size); } pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result { @@ -33,8 +41,14 @@ impl Silero { .iter() .map(|x| (*x as f32) / (i16::MAX as f32)) .collect::>(); - let mut frame = Array2::::from_shape_vec([1, data.len()], data).unwrap(); - frame = frame.slice(s![.., ..480]).to_owned(); + + // Concatenate context with input + let mut input_with_context = Vec::with_capacity(self.context_size + data.len()); + input_with_context.extend_from_slice(self.context.as_slice().unwrap()); + input_with_context.extend_from_slice(&data); + + let frame = Array2::::from_shape_vec([1, input_with_context.len()], input_with_context).unwrap(); + let inps = ort::inputs![ frame, std::mem::take(&mut self.state), @@ -43,12 +57,20 @@ impl Silero { let res = self .session .run(ort::SessionInputs::ValueSlice::<3>(&inps))?; + self.state = res["stateN"].try_extract_tensor().unwrap().to_owned(); - Ok(*res["output"] + + // Update context with last context_size samples from the input + if data.len() >= self.context_size { + self.context = Array1::from_vec(data[data.len() - self.context_size..].to_vec()); + } + + let prob = *res["output"] .try_extract_raw_tensor::() .unwrap() .1 .first() - .unwrap()) + .unwrap(); + Ok(prob) } } diff --git a/examples/rust-example/src/utils.rs b/examples/rust-example/src/utils.rs index 8207920..b37c33a 100644 --- a/examples/rust-example/src/utils.rs +++ b/examples/rust-example/src/utils.rs @@ -36,7 +36,7 @@ pub struct VadParams { impl Default for VadParams { fn default() -> Self { Self { - frame_size: 64, + frame_size: 32, // 32ms for 512 samples at 16kHz threshold: 0.5, min_silence_duration_ms: 0, speech_pad_ms: 64,