Update SlieroVadOnnxModel.java

This commit is contained in:
GH
2025-10-11 16:21:57 +08:00
committed by GitHub
parent 25a778c798
commit b90f8c012f

View File

@@ -9,42 +9,58 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
/**
* Silero VAD ONNX Model Wrapper
*
* @author VvvvvGH
*/
public class SlieroVadOnnxModel { public class SlieroVadOnnxModel {
// Define private variable OrtSession // ONNX runtime session
private final OrtSession session; private final OrtSession session;
private float[][][] h; // Model state - dimensions: [2, batch_size, 128]
private float[][][] c; private float[][][] state;
// Define the last sample rate // Context - stores the tail of the previous audio chunk
private float[][] context;
// Last sample rate
private int lastSr = 0; private int lastSr = 0;
// Define the last batch size // Last batch size
private int lastBatchSize = 0; private int lastBatchSize = 0;
// Define a list of supported sample rates // Supported sample rates
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000); private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
// Constructor // Constructor
public SlieroVadOnnxModel(String modelPath) throws OrtException { public SlieroVadOnnxModel(String modelPath) throws OrtException {
// Get the ONNX runtime environment // Get the ONNX runtime environment
OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtEnvironment env = OrtEnvironment.getEnvironment();
// Create an ONNX session options object // Create ONNX session options
OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations // Set InterOp thread count to 1 (for parallel processing of different graph operations)
opts.setInterOpNumThreads(1); opts.setInterOpNumThreads(1);
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation // Set IntraOp thread count to 1 (for parallel processing within a single operation)
opts.setIntraOpNumThreads(1); opts.setIntraOpNumThreads(1);
// Add a CPU device, setting to false disables CPU execution optimization // Enable CPU execution optimization
opts.addCPU(true); opts.addCPU(true);
// Create an ONNX session using the environment, model path, and options // Create ONNX session with the environment, model path, and options
session = env.createSession(modelPath, opts); session = env.createSession(modelPath, opts);
// Reset states // Reset states
resetStates(); resetStates();
} }
/** /**
* Reset states * Reset states with default batch size
*/ */
void resetStates() { void resetStates() {
h = new float[2][1][64]; resetStates(1);
c = new float[2][1][64]; }
/**
* Reset states with specific batch size
*
* @param batchSize Batch size for state initialization
*/
void resetStates(int batchSize) {
state = new float[2][batchSize][128];
context = new float[0][]; // Empty context
lastSr = 0; lastSr = 0;
lastBatchSize = 0; lastBatchSize = 0;
} }
@@ -54,13 +70,12 @@ public class SlieroVadOnnxModel {
} }
/** /**
* Define inner class ValidationResult * Inner class for validation result
*/ */
public static class ValidationResult { public static class ValidationResult {
public final float[][] x; public final float[][] x;
public final int sr; public final int sr;
// Constructor
public ValidationResult(float[][] x, int sr) { public ValidationResult(float[][] x, int sr) {
this.x = x; this.x = x;
this.sr = sr; this.sr = sr;
@@ -68,19 +83,23 @@ public class SlieroVadOnnxModel {
} }
/** /**
* Function to validate input data * Validate input data
*
* @param x Audio data array
* @param sr Sample rate
* @return Validated input data and sample rate
*/ */
private ValidationResult validateInput(float[][] x, int sr) { private ValidationResult validateInput(float[][] x, int sr) {
// Process the input data with dimension 1 // Ensure input is at least 2D
if (x.length == 1) { if (x.length == 1) {
x = new float[][]{x[0]}; x = new float[][]{x[0]};
} }
// Throw an exception when the input data dimension is greater than 2 // Check if input dimension is valid
if (x.length > 2) { if (x.length > 2) {
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length); throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
} }
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000 // Downsample if sample rate is a multiple of 16000
if (sr != 16000 && (sr % 16000 == 0)) { if (sr != 16000 && (sr % 16000 == 0)) {
int step = sr / 16000; int step = sr / 16000;
float[][] reducedX = new float[x.length][]; float[][] reducedX = new float[x.length][];
@@ -100,22 +119,26 @@ public class SlieroVadOnnxModel {
sr = 16000; sr = 16000;
} }
// If the sample rate is not in the list of supported sample rates, throw an exception // Validate sample rate
if (!SAMPLE_RATES.contains(sr)) { if (!SAMPLE_RATES.contains(sr)) {
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)"); throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
} }
// If the input audio block is too short, throw an exception // Check if audio chunk is too short
if (((float) sr) / x[0].length > 31.25) { if (((float) sr) / x[0].length > 31.25) {
throw new IllegalArgumentException("Input audio is too short"); throw new IllegalArgumentException("Input audio is too short");
} }
// Return the validated result
return new ValidationResult(x, sr); return new ValidationResult(x, sr);
} }
/** /**
* Method to call the ONNX model * Call the ONNX model for inference
*
* @param x Audio data array
* @param sr Sample rate
* @return Speech probability output
* @throws OrtException If ONNX runtime error occurs
*/ */
public float[] call(float[][] x, int sr) throws OrtException { public float[] call(float[][] x, int sr) throws OrtException {
ValidationResult result = validateInput(x, sr); ValidationResult result = validateInput(x, sr);
@@ -123,38 +146,62 @@ public class SlieroVadOnnxModel {
sr = result.sr; sr = result.sr;
int batchSize = x.length; int batchSize = x.length;
int numSamples = sr == 16000 ? 512 : 256;
int contextSize = sr == 16000 ? 64 : 32;
if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) { // Reset states only when sample rate or batch size changes
resetStates(); if (lastSr != 0 && lastSr != sr) {
resetStates(batchSize);
} else if (lastBatchSize != 0 && lastBatchSize != batchSize) {
resetStates(batchSize);
} else if (lastBatchSize == 0) {
// First call - state is already initialized, just set batch size
lastBatchSize = batchSize;
}
// Initialize context if needed
if (context.length == 0) {
context = new float[batchSize][contextSize];
}
// Concatenate context and input
float[][] xWithContext = new float[batchSize][contextSize + numSamples];
for (int i = 0; i < batchSize; i++) {
// Copy context
System.arraycopy(context[i], 0, xWithContext[i], 0, contextSize);
// Copy input
System.arraycopy(x[i], 0, xWithContext[i], contextSize, numSamples);
} }
OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtEnvironment env = OrtEnvironment.getEnvironment();
OnnxTensor inputTensor = null; OnnxTensor inputTensor = null;
OnnxTensor hTensor = null; OnnxTensor stateTensor = null;
OnnxTensor cTensor = null;
OnnxTensor srTensor = null; OnnxTensor srTensor = null;
OrtSession.Result ortOutputs = null; OrtSession.Result ortOutputs = null;
try { try {
// Create input tensors // Create input tensors
inputTensor = OnnxTensor.createTensor(env, x); inputTensor = OnnxTensor.createTensor(env, xWithContext);
hTensor = OnnxTensor.createTensor(env, h); stateTensor = OnnxTensor.createTensor(env, state);
cTensor = OnnxTensor.createTensor(env, c);
srTensor = OnnxTensor.createTensor(env, new long[]{sr}); srTensor = OnnxTensor.createTensor(env, new long[]{sr});
Map<String, OnnxTensor> inputs = new HashMap<>(); Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input", inputTensor); inputs.put("input", inputTensor);
inputs.put("sr", srTensor); inputs.put("sr", srTensor);
inputs.put("h", hTensor); inputs.put("state", stateTensor);
inputs.put("c", cTensor);
// Call the ONNX model for calculation // Run ONNX model inference
ortOutputs = session.run(inputs); ortOutputs = session.run(inputs);
// Get the output results // Get output results
float[][] output = (float[][]) ortOutputs.get(0).getValue(); float[][] output = (float[][]) ortOutputs.get(0).getValue();
h = (float[][][]) ortOutputs.get(1).getValue(); state = (float[][][]) ortOutputs.get(1).getValue();
c = (float[][][]) ortOutputs.get(2).getValue();
// Update context - save the last contextSize samples from input
for (int i = 0; i < batchSize; i++) {
System.arraycopy(xWithContext[i], xWithContext[i].length - contextSize,
context[i], 0, contextSize);
}
lastSr = sr; lastSr = sr;
lastBatchSize = batchSize; lastBatchSize = batchSize;
@@ -163,11 +210,8 @@ public class SlieroVadOnnxModel {
if (inputTensor != null) { if (inputTensor != null) {
inputTensor.close(); inputTensor.close();
} }
if (hTensor != null) { if (stateTensor != null) {
hTensor.close(); stateTensor.close();
}
if (cTensor != null) {
cTensor.close();
} }
if (srTensor != null) { if (srTensor != null) {
srTensor.close(); srTensor.close();