mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 09:59:20 +08:00
Update ort dependency to 2.0.0-rc.10
Update the ort dependency from 2.0.0-rc.2 to 2.0.0-rc.10 and adapt the code to work with the new API. This includes: - Updating ndarray to 0.16 to match ort's requirements - Using Session and Value from their new module locations - Adapting to the new Value::from_array() and try_extract_tensor() APIs - Converting SessionInputs from Value references Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
use crate::utils;
|
||||
use ndarray::{Array, Array1, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
||||
use ort::session::Session;
|
||||
use ort::value::Value;
|
||||
use std::mem::take;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Silero {
|
||||
session: ort::Session,
|
||||
session: Session,
|
||||
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
||||
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
||||
context: Array1<f32>,
|
||||
@@ -16,7 +19,7 @@ impl Silero {
|
||||
sample_rate: utils::SampleRate,
|
||||
model_path: impl AsRef<Path>,
|
||||
) -> Result<Self, ort::Error> {
|
||||
let session = ort::Session::builder()?.commit_from_file(model_path)?;
|
||||
let session = Session::builder()?.commit_from_file(model_path)?;
|
||||
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||
let sample_rate_val: i64 = sample_rate.into();
|
||||
let context_size = if sample_rate_val == 16000 { 64 } else { 32 };
|
||||
@@ -47,18 +50,23 @@ impl Silero {
|
||||
input_with_context.extend_from_slice(self.context.as_slice().unwrap());
|
||||
input_with_context.extend_from_slice(&data);
|
||||
|
||||
let frame = Array2::<f32>::from_shape_vec([1, input_with_context.len()], input_with_context).unwrap();
|
||||
let frame =
|
||||
Array2::<f32>::from_shape_vec([1, input_with_context.len()], input_with_context)
|
||||
.unwrap();
|
||||
|
||||
let inps = ort::inputs![
|
||||
frame,
|
||||
std::mem::take(&mut self.state),
|
||||
self.sample_rate.clone(),
|
||||
]?;
|
||||
let res = self
|
||||
.session
|
||||
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
|
||||
let frame_value = Value::from_array(frame)?;
|
||||
let state_value = Value::from_array(take(&mut self.state))?;
|
||||
let sr_value = Value::from_array(self.sample_rate.clone())?;
|
||||
|
||||
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
|
||||
let res = self.session.run([
|
||||
(&frame_value).into(),
|
||||
(&state_value).into(),
|
||||
(&sr_value).into(),
|
||||
])?;
|
||||
|
||||
let (shape, state_data) = res["stateN"].try_extract_tensor::<f32>()?;
|
||||
let shape_usize: Vec<usize> = shape.as_ref().iter().map(|&d| d as usize).collect();
|
||||
self.state = ArrayD::from_shape_vec(shape_usize.as_slice(), state_data.to_vec()).unwrap();
|
||||
|
||||
// Update context with last context_size samples from the input
|
||||
if data.len() >= self.context_size {
|
||||
@@ -66,7 +74,7 @@ impl Silero {
|
||||
}
|
||||
|
||||
let prob = *res["output"]
|
||||
.try_extract_raw_tensor::<f32>()
|
||||
.try_extract_tensor::<f32>()
|
||||
.unwrap()
|
||||
.1
|
||||
.first()
|
||||
|
||||
Reference in New Issue
Block a user