Update ort dependency to 2.0.0-rc.10

Update the ort dependency from 2.0.0-rc.2 to 2.0.0-rc.10 and adapt the code
to work with the new API. This includes:
- Updating ndarray to 0.16 to match ort's requirements
- Using Session and Value from their new module locations
- Adapting to the new Value::from_array() and try_extract_tensor() APIs
- Converting SessionInputs from Value references

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Daniel Müller
2025-12-29 19:29:22 -08:00
parent 4725c40105
commit c5542cd4a8
3 changed files with 349 additions and 288 deletions

View File

@@ -1,10 +1,13 @@
use crate::utils;
use ndarray::{Array, Array1, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
use ort::session::Session;
use ort::value::Value;
use std::mem::take;
use std::path::Path;
#[derive(Debug)]
pub struct Silero {
session: ort::Session,
session: Session,
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
context: Array1<f32>,
@@ -16,7 +19,7 @@ impl Silero {
sample_rate: utils::SampleRate,
model_path: impl AsRef<Path>,
) -> Result<Self, ort::Error> {
let session = ort::Session::builder()?.commit_from_file(model_path)?;
let session = Session::builder()?.commit_from_file(model_path)?;
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
let sample_rate_val: i64 = sample_rate.into();
let context_size = if sample_rate_val == 16000 { 64 } else { 32 };
@@ -47,18 +50,23 @@ impl Silero {
input_with_context.extend_from_slice(self.context.as_slice().unwrap());
input_with_context.extend_from_slice(&data);
let frame = Array2::<f32>::from_shape_vec([1, input_with_context.len()], input_with_context).unwrap();
let frame =
Array2::<f32>::from_shape_vec([1, input_with_context.len()], input_with_context)
.unwrap();
let inps = ort::inputs![
frame,
std::mem::take(&mut self.state),
self.sample_rate.clone(),
]?;
let res = self
.session
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
let frame_value = Value::from_array(frame)?;
let state_value = Value::from_array(take(&mut self.state))?;
let sr_value = Value::from_array(self.sample_rate.clone())?;
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
let res = self.session.run([
(&frame_value).into(),
(&state_value).into(),
(&sr_value).into(),
])?;
let (shape, state_data) = res["stateN"].try_extract_tensor::<f32>()?;
let shape_usize: Vec<usize> = shape.as_ref().iter().map(|&d| d as usize).collect();
self.state = ArrayD::from_shape_vec(shape_usize.as_slice(), state_data.to_vec()).unwrap();
// Update context with last context_size samples from the input
if data.len() >= self.context_size {
@@ -66,7 +74,7 @@ impl Silero {
}
let prob = *res["output"]
.try_extract_raw_tensor::<f32>()
.try_extract_tensor::<f32>()
.unwrap()
.1
.first()