mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
fix: rust example
This commit is contained in:
@@ -1,13 +1,12 @@
|
|||||||
use crate::utils;
|
use crate::utils;
|
||||||
use ndarray::{Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Silero {
|
pub struct Silero {
|
||||||
session: ort::Session,
|
session: ort::Session,
|
||||||
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
||||||
h: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
||||||
c: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Silero {
|
impl Silero {
|
||||||
@@ -16,20 +15,17 @@ impl Silero {
|
|||||||
model_path: impl AsRef<Path>,
|
model_path: impl AsRef<Path>,
|
||||||
) -> Result<Self, ort::Error> {
|
) -> Result<Self, ort::Error> {
|
||||||
let session = ort::Session::builder()?.commit_from_file(model_path)?;
|
let session = ort::Session::builder()?.commit_from_file(model_path)?;
|
||||||
let h = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||||
let c = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
|
||||||
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
|
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
session,
|
session,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
h,
|
state,
|
||||||
c,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset(&mut self) {
|
pub fn reset(&mut self) {
|
||||||
self.h = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||||
self.c = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
|
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
|
||||||
@@ -37,18 +33,17 @@ impl Silero {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|x| (*x as f32) / (i16::MAX as f32))
|
.map(|x| (*x as f32) / (i16::MAX as f32))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
|
let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
|
||||||
|
frame = frame.slice(s![.., ..480]).to_owned();
|
||||||
let inps = ort::inputs![
|
let inps = ort::inputs![
|
||||||
frame,
|
frame,
|
||||||
|
std::mem::take(&mut self.state),
|
||||||
self.sample_rate.clone(),
|
self.sample_rate.clone(),
|
||||||
std::mem::take(&mut self.h),
|
|
||||||
std::mem::take(&mut self.c)
|
|
||||||
]?;
|
]?;
|
||||||
let res = self
|
let res = self
|
||||||
.session
|
.session
|
||||||
.run(ort::SessionInputs::ValueSlice::<4>(&inps))?;
|
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
|
||||||
self.h = res["hn"].try_extract_tensor().unwrap().to_owned();
|
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
|
||||||
self.c = res["cn"].try_extract_tensor().unwrap().to_owned();
|
|
||||||
Ok(*res["output"]
|
Ok(*res["output"]
|
||||||
.try_extract_raw_tensor::<f32>()
|
.try_extract_raw_tensor::<f32>()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ impl VadIter {
|
|||||||
pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> {
|
pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> {
|
||||||
self.reset_states();
|
self.reset_states();
|
||||||
for audio_frame in samples.chunks_exact(self.params.frame_size_samples) {
|
for audio_frame in samples.chunks_exact(self.params.frame_size_samples) {
|
||||||
let speech_prob = self.silero.calc_level(audio_frame)?;
|
let speech_prob: f32 = self.silero.calc_level(audio_frame)?;
|
||||||
self.state.update(&self.params, speech_prob);
|
self.state.update(&self.params, speech_prob);
|
||||||
}
|
}
|
||||||
self.state.check_for_last_speech(samples.len());
|
self.state.check_for_last_speech(samples.len());
|
||||||
|
|||||||
Reference in New Issue
Block a user