use crate::utils; use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr}; use std::path::Path; #[derive(Debug)] pub struct Silero { session: ort::Session, sample_rate: ArrayBase, Dim<[usize; 1]>>, state: ArrayBase, Dim>, } impl Silero { pub fn new( sample_rate: utils::SampleRate, model_path: impl AsRef, ) -> Result { let session = ort::Session::builder()?.commit_from_file(model_path)?; let state = ArrayD::::zeros([2, 1, 128].as_slice()); let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap(); Ok(Self { session, sample_rate, state, }) } pub fn reset(&mut self) { self.state = ArrayD::::zeros([2, 1, 128].as_slice()); } pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result { let data = audio_frame .iter() .map(|x| (*x as f32) / (i16::MAX as f32)) .collect::>(); let mut frame = Array2::::from_shape_vec([1, data.len()], data).unwrap(); frame = frame.slice(s![.., ..480]).to_owned(); let inps = ort::inputs![ frame, std::mem::take(&mut self.state), self.sample_rate.clone(), ]?; let res = self .session .run(ort::SessionInputs::ValueSlice::<3>(&inps))?; self.state = res["stateN"].try_extract_tensor().unwrap().to_owned(); Ok(*res["output"] .try_extract_raw_tensor::() .unwrap() .1 .first() .unwrap()) } }