Update SlieroVadOnnxModel.java

This commit is contained in:
GH
2025-10-11 16:21:57 +08:00
committed by GitHub
parent 25a778c798
commit b90f8c012f

View File

@@ -9,42 +9,58 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Silero VAD ONNX Model Wrapper
*
* @author VvvvvGH
*/
public class SlieroVadOnnxModel {
// Define private variable OrtSession
// ONNX runtime session
private final OrtSession session;
private float[][][] h;
private float[][][] c;
// Define the last sample rate
// Model state - dimensions: [2, batch_size, 128]
private float[][][] state;
// Context - stores the tail of the previous audio chunk
private float[][] context;
// Last sample rate
private int lastSr = 0;
// Define the last batch size
// Last batch size
private int lastBatchSize = 0;
// Define a list of supported sample rates
// Supported sample rates
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
// Constructor
public SlieroVadOnnxModel(String modelPath) throws OrtException {
// Get the ONNX runtime environment
OrtEnvironment env = OrtEnvironment.getEnvironment();
// Create an ONNX session options object
// Create ONNX session options
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
// Set InterOp thread count to 1 (for parallel processing of different graph operations)
opts.setInterOpNumThreads(1);
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
// Set IntraOp thread count to 1 (for parallel processing within a single operation)
opts.setIntraOpNumThreads(1);
// Add a CPU device, setting to false disables CPU execution optimization
// Enable CPU execution optimization
opts.addCPU(true);
// Create an ONNX session using the environment, model path, and options
// Create ONNX session with the environment, model path, and options
session = env.createSession(modelPath, opts);
// Reset states
resetStates();
}
/**
* Reset states
* Reset states with default batch size
*/
void resetStates() {
h = new float[2][1][64];
c = new float[2][1][64];
resetStates(1);
}
/**
* Reset states with specific batch size
*
* @param batchSize Batch size for state initialization
*/
void resetStates(int batchSize) {
state = new float[2][batchSize][128];
context = new float[0][]; // Empty context
lastSr = 0;
lastBatchSize = 0;
}
@@ -54,13 +70,12 @@ public class SlieroVadOnnxModel {
}
/**
* Define inner class ValidationResult
* Inner class for validation result
*/
public static class ValidationResult {
public final float[][] x;
public final int sr;
// Constructor
public ValidationResult(float[][] x, int sr) {
this.x = x;
this.sr = sr;
@@ -68,19 +83,23 @@ public class SlieroVadOnnxModel {
}
/**
* Function to validate input data
* Validate input data
*
* @param x Audio data array
* @param sr Sample rate
* @return Validated input data and sample rate
*/
private ValidationResult validateInput(float[][] x, int sr) {
// Process the input data with dimension 1
// Ensure input is at least 2D
if (x.length == 1) {
x = new float[][]{x[0]};
}
// Throw an exception when the input data dimension is greater than 2
// Check if input dimension is valid
if (x.length > 2) {
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
}
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
// Downsample if sample rate is a multiple of 16000
if (sr != 16000 && (sr % 16000 == 0)) {
int step = sr / 16000;
float[][] reducedX = new float[x.length][];
@@ -100,22 +119,26 @@ public class SlieroVadOnnxModel {
sr = 16000;
}
// If the sample rate is not in the list of supported sample rates, throw an exception
// Validate sample rate
if (!SAMPLE_RATES.contains(sr)) {
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
}
// If the input audio block is too short, throw an exception
// Check if audio chunk is too short
if (((float) sr) / x[0].length > 31.25) {
throw new IllegalArgumentException("Input audio is too short");
}
// Return the validated result
return new ValidationResult(x, sr);
}
/**
* Method to call the ONNX model
* Call the ONNX model for inference
*
* @param x Audio data array
* @param sr Sample rate
* @return Speech probability output
* @throws OrtException If ONNX runtime error occurs
*/
public float[] call(float[][] x, int sr) throws OrtException {
ValidationResult result = validateInput(x, sr);
@@ -123,38 +146,62 @@ public class SlieroVadOnnxModel {
sr = result.sr;
int batchSize = x.length;
int numSamples = sr == 16000 ? 512 : 256;
int contextSize = sr == 16000 ? 64 : 32;
if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) {
resetStates();
// Reset states only when sample rate or batch size changes
if (lastSr != 0 && lastSr != sr) {
resetStates(batchSize);
} else if (lastBatchSize != 0 && lastBatchSize != batchSize) {
resetStates(batchSize);
} else if (lastBatchSize == 0) {
// First call - state is already initialized, just set batch size
lastBatchSize = batchSize;
}
// Initialize context if needed
if (context.length == 0) {
context = new float[batchSize][contextSize];
}
// Concatenate context and input
float[][] xWithContext = new float[batchSize][contextSize + numSamples];
for (int i = 0; i < batchSize; i++) {
// Copy context
System.arraycopy(context[i], 0, xWithContext[i], 0, contextSize);
// Copy input
System.arraycopy(x[i], 0, xWithContext[i], contextSize, numSamples);
}
OrtEnvironment env = OrtEnvironment.getEnvironment();
OnnxTensor inputTensor = null;
OnnxTensor hTensor = null;
OnnxTensor cTensor = null;
OnnxTensor stateTensor = null;
OnnxTensor srTensor = null;
OrtSession.Result ortOutputs = null;
try {
// Create input tensors
inputTensor = OnnxTensor.createTensor(env, x);
hTensor = OnnxTensor.createTensor(env, h);
cTensor = OnnxTensor.createTensor(env, c);
inputTensor = OnnxTensor.createTensor(env, xWithContext);
stateTensor = OnnxTensor.createTensor(env, state);
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input", inputTensor);
inputs.put("sr", srTensor);
inputs.put("h", hTensor);
inputs.put("c", cTensor);
inputs.put("state", stateTensor);
// Call the ONNX model for calculation
// Run ONNX model inference
ortOutputs = session.run(inputs);
// Get the output results
// Get output results
float[][] output = (float[][]) ortOutputs.get(0).getValue();
h = (float[][][]) ortOutputs.get(1).getValue();
c = (float[][][]) ortOutputs.get(2).getValue();
state = (float[][][]) ortOutputs.get(1).getValue();
// Update context - save the last contextSize samples from input
for (int i = 0; i < batchSize; i++) {
System.arraycopy(xWithContext[i], xWithContext[i].length - contextSize,
context[i], 0, contextSize);
}
lastSr = sr;
lastBatchSize = batchSize;
@@ -163,11 +210,8 @@ public class SlieroVadOnnxModel {
if (inputTensor != null) {
inputTensor.close();
}
if (hTensor != null) {
hTensor.close();
}
if (cTensor != null) {
cTensor.close();
if (stateTensor != null) {
stateTensor.close();
}
if (srTensor != null) {
srTensor.close();