Merge pull request #746 from d-e-s-o/topic/fix-rust

Fix `rust-example`
This commit is contained in:
Alexander Veysov
2025-12-29 09:34:47 +03:00
committed by GitHub
5 changed files with 32 additions and 21 deletions

View File

@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 4
[[package]] [[package]]
name = "adler" name = "adler"
@@ -206,16 +206,6 @@ version = "0.2.155"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]]
name = "libloading"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
dependencies = [
"cfg-if",
"windows-targets",
]
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.4.14" version = "0.4.14"
@@ -301,7 +291,6 @@ checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14"
dependencies = [ dependencies = [
"half", "half",
"js-sys", "js-sys",
"libloading",
"ndarray", "ndarray",
"ort-sys", "ort-sys",
"thiserror", "thiserror",

View File

@@ -4,6 +4,6 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
ort = { version = "2.0.0-rc.2", features = ["load-dynamic", "ndarray"] } ort = { version = "2.0.0-rc.2", features = ["ndarray"] }
ndarray = "0.15" ndarray = "0.15"
hound = "3" hound = "3"

View File

@@ -4,7 +4,7 @@ mod vad_iter;
fn main() { fn main() {
let model_path = std::env::var("SILERO_MODEL_PATH") let model_path = std::env::var("SILERO_MODEL_PATH")
.unwrap_or_else(|_| String::from("../../files/silero_vad.onnx")); .unwrap_or_else(|_| String::from("../../src/silero_vad/data/silero_vad.onnx"));
let audio_path = std::env::args() let audio_path = std::env::args()
.nth(1) .nth(1)
.unwrap_or_else(|| String::from("recorder.wav")); .unwrap_or_else(|| String::from("recorder.wav"));

View File

@@ -1,5 +1,5 @@
use crate::utils; use crate::utils;
use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr}; use ndarray::{Array, Array1, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
use std::path::Path; use std::path::Path;
#[derive(Debug)] #[derive(Debug)]
@@ -7,6 +7,8 @@ pub struct Silero {
session: ort::Session, session: ort::Session,
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>, sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>, state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
context: Array1<f32>,
context_size: usize,
} }
impl Silero { impl Silero {
@@ -16,16 +18,22 @@ impl Silero {
) -> Result<Self, ort::Error> { ) -> Result<Self, ort::Error> {
let session = ort::Session::builder()?.commit_from_file(model_path)?; let session = ort::Session::builder()?.commit_from_file(model_path)?;
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice()); let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap(); let sample_rate_val: i64 = sample_rate.into();
let context_size = if sample_rate_val == 16000 { 64 } else { 32 };
let context = Array1::<f32>::zeros(context_size);
let sample_rate = Array::from_shape_vec([1], vec![sample_rate_val]).unwrap();
Ok(Self { Ok(Self {
session, session,
sample_rate, sample_rate,
state, state,
context,
context_size,
}) })
} }
pub fn reset(&mut self) { pub fn reset(&mut self) {
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice()); self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
self.context = Array1::<f32>::zeros(self.context_size);
} }
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> { pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
@@ -33,8 +41,14 @@ impl Silero {
.iter() .iter()
.map(|x| (*x as f32) / (i16::MAX as f32)) .map(|x| (*x as f32) / (i16::MAX as f32))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
frame = frame.slice(s![.., ..480]).to_owned(); // Concatenate context with input
let mut input_with_context = Vec::with_capacity(self.context_size + data.len());
input_with_context.extend_from_slice(self.context.as_slice().unwrap());
input_with_context.extend_from_slice(&data);
let frame = Array2::<f32>::from_shape_vec([1, input_with_context.len()], input_with_context).unwrap();
let inps = ort::inputs![ let inps = ort::inputs![
frame, frame,
std::mem::take(&mut self.state), std::mem::take(&mut self.state),
@@ -43,12 +57,20 @@ impl Silero {
let res = self let res = self
.session .session
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?; .run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned(); self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
Ok(*res["output"]
// Update context with last context_size samples from the input
if data.len() >= self.context_size {
self.context = Array1::from_vec(data[data.len() - self.context_size..].to_vec());
}
let prob = *res["output"]
.try_extract_raw_tensor::<f32>() .try_extract_raw_tensor::<f32>()
.unwrap() .unwrap()
.1 .1
.first() .first()
.unwrap()) .unwrap();
Ok(prob)
} }
} }

View File

@@ -36,7 +36,7 @@ pub struct VadParams {
impl Default for VadParams { impl Default for VadParams {
fn default() -> Self { fn default() -> Self {
Self { Self {
frame_size: 64, frame_size: 32, // 32ms for 512 samples at 16kHz
threshold: 0.5, threshold: 0.5,
min_silence_duration_ms: 0, min_silence_duration_ms: 0,
speech_pad_ms: 64, speech_pad_ms: 64,