diff --git a/examples/rust-example/src/silero.rs b/examples/rust-example/src/silero.rs index 22f154c..fce8808 100644 --- a/examples/rust-example/src/silero.rs +++ b/examples/rust-example/src/silero.rs @@ -1,13 +1,12 @@ use crate::utils; -use ndarray::{Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr}; +use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr}; use std::path::Path; #[derive(Debug)] pub struct Silero { session: ort::Session, sample_rate: ArrayBase, Dim<[usize; 1]>>, - h: ArrayBase, Dim>, - c: ArrayBase, Dim>, + state: ArrayBase, Dim>, } impl Silero { @@ -16,20 +15,17 @@ impl Silero { model_path: impl AsRef, ) -> Result { let session = ort::Session::builder()?.commit_from_file(model_path)?; - let h = ArrayD::::zeros([2, 1, 64].as_slice()); - let c = ArrayD::::zeros([2, 1, 64].as_slice()); + let state = ArrayD::::zeros([2, 1, 128].as_slice()); let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap(); Ok(Self { session, sample_rate, - h, - c, + state, }) } pub fn reset(&mut self) { - self.h = ArrayD::::zeros([2, 1, 64].as_slice()); - self.c = ArrayD::::zeros([2, 1, 64].as_slice()); + self.state = ArrayD::::zeros([2, 1, 128].as_slice()); } pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result { @@ -37,18 +33,17 @@ impl Silero { .iter() .map(|x| (*x as f32) / (i16::MAX as f32)) .collect::>(); - let frame = Array2::::from_shape_vec([1, data.len()], data).unwrap(); + let mut frame = Array2::::from_shape_vec([1, data.len()], data).unwrap(); + frame = frame.slice(s![.., ..480]).to_owned(); let inps = ort::inputs![ frame, + std::mem::take(&mut self.state), self.sample_rate.clone(), - std::mem::take(&mut self.h), - std::mem::take(&mut self.c) ]?; let res = self .session - .run(ort::SessionInputs::ValueSlice::<4>(&inps))?; - self.h = res["hn"].try_extract_tensor().unwrap().to_owned(); - self.c = res["cn"].try_extract_tensor().unwrap().to_owned(); + .run(ort::SessionInputs::ValueSlice::<3>(&inps))?; + self.state = res["stateN"].try_extract_tensor().unwrap().to_owned(); Ok(*res["output"] .try_extract_raw_tensor::() .unwrap() diff --git a/examples/rust-example/src/vad_iter.rs b/examples/rust-example/src/vad_iter.rs index 1b8ae6d..3bcb00b 100644 --- a/examples/rust-example/src/vad_iter.rs +++ b/examples/rust-example/src/vad_iter.rs @@ -20,7 +20,7 @@ impl VadIter { pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> { self.reset_states(); for audio_frame in samples.chunks_exact(self.params.frame_size_samples) { - let speech_prob = self.silero.calc_level(audio_frame)?; + let speech_prob: f32 = self.silero.calc_level(audio_frame)?; self.state.update(&self.params, speech_prob); } self.state.check_for_last_speech(samples.len());