mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
Update SlieroVadOnnxModel.java
This commit is contained in:
@@ -9,42 +9,58 @@ import java.util.HashMap;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Silero VAD ONNX Model Wrapper
|
||||||
|
*
|
||||||
|
* @author VvvvvGH
|
||||||
|
*/
|
||||||
public class SlieroVadOnnxModel {
|
public class SlieroVadOnnxModel {
|
||||||
// Define private variable OrtSession
|
// ONNX runtime session
|
||||||
private final OrtSession session;
|
private final OrtSession session;
|
||||||
private float[][][] h;
|
// Model state - dimensions: [2, batch_size, 128]
|
||||||
private float[][][] c;
|
private float[][][] state;
|
||||||
// Define the last sample rate
|
// Context - stores the tail of the previous audio chunk
|
||||||
|
private float[][] context;
|
||||||
|
// Last sample rate
|
||||||
private int lastSr = 0;
|
private int lastSr = 0;
|
||||||
// Define the last batch size
|
// Last batch size
|
||||||
private int lastBatchSize = 0;
|
private int lastBatchSize = 0;
|
||||||
// Define a list of supported sample rates
|
// Supported sample rates
|
||||||
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
|
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
|
||||||
|
|
||||||
// Constructor
|
// Constructor
|
||||||
public SlieroVadOnnxModel(String modelPath) throws OrtException {
|
public SlieroVadOnnxModel(String modelPath) throws OrtException {
|
||||||
// Get the ONNX runtime environment
|
// Get the ONNX runtime environment
|
||||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||||
// Create an ONNX session options object
|
// Create ONNX session options
|
||||||
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
||||||
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
|
// Set InterOp thread count to 1 (for parallel processing of different graph operations)
|
||||||
opts.setInterOpNumThreads(1);
|
opts.setInterOpNumThreads(1);
|
||||||
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
|
// Set IntraOp thread count to 1 (for parallel processing within a single operation)
|
||||||
opts.setIntraOpNumThreads(1);
|
opts.setIntraOpNumThreads(1);
|
||||||
// Add a CPU device, setting to false disables CPU execution optimization
|
// Enable CPU execution optimization
|
||||||
opts.addCPU(true);
|
opts.addCPU(true);
|
||||||
// Create an ONNX session using the environment, model path, and options
|
// Create ONNX session with the environment, model path, and options
|
||||||
session = env.createSession(modelPath, opts);
|
session = env.createSession(modelPath, opts);
|
||||||
// Reset states
|
// Reset states
|
||||||
resetStates();
|
resetStates();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reset states
|
* Reset states with default batch size
|
||||||
*/
|
*/
|
||||||
void resetStates() {
|
void resetStates() {
|
||||||
h = new float[2][1][64];
|
resetStates(1);
|
||||||
c = new float[2][1][64];
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset states with specific batch size
|
||||||
|
*
|
||||||
|
* @param batchSize Batch size for state initialization
|
||||||
|
*/
|
||||||
|
void resetStates(int batchSize) {
|
||||||
|
state = new float[2][batchSize][128];
|
||||||
|
context = new float[0][]; // Empty context
|
||||||
lastSr = 0;
|
lastSr = 0;
|
||||||
lastBatchSize = 0;
|
lastBatchSize = 0;
|
||||||
}
|
}
|
||||||
@@ -54,13 +70,12 @@ public class SlieroVadOnnxModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Define inner class ValidationResult
|
* Inner class for validation result
|
||||||
*/
|
*/
|
||||||
public static class ValidationResult {
|
public static class ValidationResult {
|
||||||
public final float[][] x;
|
public final float[][] x;
|
||||||
public final int sr;
|
public final int sr;
|
||||||
|
|
||||||
// Constructor
|
|
||||||
public ValidationResult(float[][] x, int sr) {
|
public ValidationResult(float[][] x, int sr) {
|
||||||
this.x = x;
|
this.x = x;
|
||||||
this.sr = sr;
|
this.sr = sr;
|
||||||
@@ -68,19 +83,23 @@ public class SlieroVadOnnxModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Function to validate input data
|
* Validate input data
|
||||||
|
*
|
||||||
|
* @param x Audio data array
|
||||||
|
* @param sr Sample rate
|
||||||
|
* @return Validated input data and sample rate
|
||||||
*/
|
*/
|
||||||
private ValidationResult validateInput(float[][] x, int sr) {
|
private ValidationResult validateInput(float[][] x, int sr) {
|
||||||
// Process the input data with dimension 1
|
// Ensure input is at least 2D
|
||||||
if (x.length == 1) {
|
if (x.length == 1) {
|
||||||
x = new float[][]{x[0]};
|
x = new float[][]{x[0]};
|
||||||
}
|
}
|
||||||
// Throw an exception when the input data dimension is greater than 2
|
// Check if input dimension is valid
|
||||||
if (x.length > 2) {
|
if (x.length > 2) {
|
||||||
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
|
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
|
// Downsample if sample rate is a multiple of 16000
|
||||||
if (sr != 16000 && (sr % 16000 == 0)) {
|
if (sr != 16000 && (sr % 16000 == 0)) {
|
||||||
int step = sr / 16000;
|
int step = sr / 16000;
|
||||||
float[][] reducedX = new float[x.length][];
|
float[][] reducedX = new float[x.length][];
|
||||||
@@ -100,22 +119,26 @@ public class SlieroVadOnnxModel {
|
|||||||
sr = 16000;
|
sr = 16000;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the sample rate is not in the list of supported sample rates, throw an exception
|
// Validate sample rate
|
||||||
if (!SAMPLE_RATES.contains(sr)) {
|
if (!SAMPLE_RATES.contains(sr)) {
|
||||||
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
|
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the input audio block is too short, throw an exception
|
// Check if audio chunk is too short
|
||||||
if (((float) sr) / x[0].length > 31.25) {
|
if (((float) sr) / x[0].length > 31.25) {
|
||||||
throw new IllegalArgumentException("Input audio is too short");
|
throw new IllegalArgumentException("Input audio is too short");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the validated result
|
|
||||||
return new ValidationResult(x, sr);
|
return new ValidationResult(x, sr);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Method to call the ONNX model
|
* Call the ONNX model for inference
|
||||||
|
*
|
||||||
|
* @param x Audio data array
|
||||||
|
* @param sr Sample rate
|
||||||
|
* @return Speech probability output
|
||||||
|
* @throws OrtException If ONNX runtime error occurs
|
||||||
*/
|
*/
|
||||||
public float[] call(float[][] x, int sr) throws OrtException {
|
public float[] call(float[][] x, int sr) throws OrtException {
|
||||||
ValidationResult result = validateInput(x, sr);
|
ValidationResult result = validateInput(x, sr);
|
||||||
@@ -123,38 +146,62 @@ public class SlieroVadOnnxModel {
|
|||||||
sr = result.sr;
|
sr = result.sr;
|
||||||
|
|
||||||
int batchSize = x.length;
|
int batchSize = x.length;
|
||||||
|
int numSamples = sr == 16000 ? 512 : 256;
|
||||||
|
int contextSize = sr == 16000 ? 64 : 32;
|
||||||
|
|
||||||
if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) {
|
// Reset states only when sample rate or batch size changes
|
||||||
resetStates();
|
if (lastSr != 0 && lastSr != sr) {
|
||||||
|
resetStates(batchSize);
|
||||||
|
} else if (lastBatchSize != 0 && lastBatchSize != batchSize) {
|
||||||
|
resetStates(batchSize);
|
||||||
|
} else if (lastBatchSize == 0) {
|
||||||
|
// First call - state is already initialized, just set batch size
|
||||||
|
lastBatchSize = batchSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize context if needed
|
||||||
|
if (context.length == 0) {
|
||||||
|
context = new float[batchSize][contextSize];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concatenate context and input
|
||||||
|
float[][] xWithContext = new float[batchSize][contextSize + numSamples];
|
||||||
|
for (int i = 0; i < batchSize; i++) {
|
||||||
|
// Copy context
|
||||||
|
System.arraycopy(context[i], 0, xWithContext[i], 0, contextSize);
|
||||||
|
// Copy input
|
||||||
|
System.arraycopy(x[i], 0, xWithContext[i], contextSize, numSamples);
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||||
|
|
||||||
OnnxTensor inputTensor = null;
|
OnnxTensor inputTensor = null;
|
||||||
OnnxTensor hTensor = null;
|
OnnxTensor stateTensor = null;
|
||||||
OnnxTensor cTensor = null;
|
|
||||||
OnnxTensor srTensor = null;
|
OnnxTensor srTensor = null;
|
||||||
OrtSession.Result ortOutputs = null;
|
OrtSession.Result ortOutputs = null;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Create input tensors
|
// Create input tensors
|
||||||
inputTensor = OnnxTensor.createTensor(env, x);
|
inputTensor = OnnxTensor.createTensor(env, xWithContext);
|
||||||
hTensor = OnnxTensor.createTensor(env, h);
|
stateTensor = OnnxTensor.createTensor(env, state);
|
||||||
cTensor = OnnxTensor.createTensor(env, c);
|
|
||||||
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
|
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
|
||||||
|
|
||||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||||
inputs.put("input", inputTensor);
|
inputs.put("input", inputTensor);
|
||||||
inputs.put("sr", srTensor);
|
inputs.put("sr", srTensor);
|
||||||
inputs.put("h", hTensor);
|
inputs.put("state", stateTensor);
|
||||||
inputs.put("c", cTensor);
|
|
||||||
|
|
||||||
// Call the ONNX model for calculation
|
// Run ONNX model inference
|
||||||
ortOutputs = session.run(inputs);
|
ortOutputs = session.run(inputs);
|
||||||
// Get the output results
|
// Get output results
|
||||||
float[][] output = (float[][]) ortOutputs.get(0).getValue();
|
float[][] output = (float[][]) ortOutputs.get(0).getValue();
|
||||||
h = (float[][][]) ortOutputs.get(1).getValue();
|
state = (float[][][]) ortOutputs.get(1).getValue();
|
||||||
c = (float[][][]) ortOutputs.get(2).getValue();
|
|
||||||
|
// Update context - save the last contextSize samples from input
|
||||||
|
for (int i = 0; i < batchSize; i++) {
|
||||||
|
System.arraycopy(xWithContext[i], xWithContext[i].length - contextSize,
|
||||||
|
context[i], 0, contextSize);
|
||||||
|
}
|
||||||
|
|
||||||
lastSr = sr;
|
lastSr = sr;
|
||||||
lastBatchSize = batchSize;
|
lastBatchSize = batchSize;
|
||||||
@@ -163,11 +210,8 @@ public class SlieroVadOnnxModel {
|
|||||||
if (inputTensor != null) {
|
if (inputTensor != null) {
|
||||||
inputTensor.close();
|
inputTensor.close();
|
||||||
}
|
}
|
||||||
if (hTensor != null) {
|
if (stateTensor != null) {
|
||||||
hTensor.close();
|
stateTensor.close();
|
||||||
}
|
|
||||||
if (cTensor != null) {
|
|
||||||
cTensor.close();
|
|
||||||
}
|
}
|
||||||
if (srTensor != null) {
|
if (srTensor != null) {
|
||||||
srTensor.close();
|
srTensor.close();
|
||||||
|
|||||||
Reference in New Issue
Block a user