diff --git a/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java b/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java index e9fd8b8..e509364 100644 --- a/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java +++ b/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java @@ -9,42 +9,58 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +/** + * Silero VAD ONNX Model Wrapper + * + * @author VvvvvGH + */ public class SlieroVadOnnxModel { - // Define private variable OrtSession + // ONNX runtime session private final OrtSession session; - private float[][][] h; - private float[][][] c; - // Define the last sample rate + // Model state - dimensions: [2, batch_size, 128] + private float[][][] state; + // Context - stores the tail of the previous audio chunk + private float[][] context; + // Last sample rate private int lastSr = 0; - // Define the last batch size + // Last batch size private int lastBatchSize = 0; - // Define a list of supported sample rates + // Supported sample rates private static final List SAMPLE_RATES = Arrays.asList(8000, 16000); // Constructor public SlieroVadOnnxModel(String modelPath) throws OrtException { // Get the ONNX runtime environment OrtEnvironment env = OrtEnvironment.getEnvironment(); - // Create an ONNX session options object + // Create ONNX session options OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); - // Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations + // Set InterOp thread count to 1 (for parallel processing of different graph operations) opts.setInterOpNumThreads(1); - // Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation + // Set IntraOp thread count to 1 (for parallel processing within a single operation) opts.setIntraOpNumThreads(1); - // Add a CPU device, setting to false disables CPU execution optimization + // Enable CPU execution optimization opts.addCPU(true); - // Create an ONNX session using the environment, model path, and options + // Create ONNX session with the environment, model path, and options session = env.createSession(modelPath, opts); // Reset states resetStates(); } /** - * Reset states + * Reset states with default batch size */ void resetStates() { - h = new float[2][1][64]; - c = new float[2][1][64]; + resetStates(1); + } + + /** + * Reset states with specific batch size + * + * @param batchSize Batch size for state initialization + */ + void resetStates(int batchSize) { + state = new float[2][batchSize][128]; + context = new float[0][]; // Empty context lastSr = 0; lastBatchSize = 0; } @@ -54,13 +70,12 @@ public class SlieroVadOnnxModel { } /** - * Define inner class ValidationResult + * Inner class for validation result */ public static class ValidationResult { public final float[][] x; public final int sr; - // Constructor public ValidationResult(float[][] x, int sr) { this.x = x; this.sr = sr; @@ -68,19 +83,23 @@ public class SlieroVadOnnxModel { } /** - * Function to validate input data + * Validate input data + * + * @param x Audio data array + * @param sr Sample rate + * @return Validated input data and sample rate */ private ValidationResult validateInput(float[][] x, int sr) { - // Process the input data with dimension 1 + // Ensure input is at least 2D if (x.length == 1) { x = new float[][]{x[0]}; } - // Throw an exception when the input data dimension is greater than 2 + // Check if input dimension is valid if (x.length > 2) { throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length); } - // Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000 + // Downsample if sample rate is a multiple of 16000 if (sr != 16000 && (sr % 16000 == 0)) { int step = sr / 16000; float[][] reducedX = new float[x.length][]; @@ -100,22 +119,26 @@ public class SlieroVadOnnxModel { sr = 16000; } - // If the sample rate is not in the list of supported sample rates, throw an exception + // Validate sample rate if (!SAMPLE_RATES.contains(sr)) { throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)"); } - // If the input audio block is too short, throw an exception + // Check if audio chunk is too short if (((float) sr) / x[0].length > 31.25) { throw new IllegalArgumentException("Input audio is too short"); } - // Return the validated result return new ValidationResult(x, sr); } /** - * Method to call the ONNX model + * Call the ONNX model for inference + * + * @param x Audio data array + * @param sr Sample rate + * @return Speech probability output + * @throws OrtException If ONNX runtime error occurs */ public float[] call(float[][] x, int sr) throws OrtException { ValidationResult result = validateInput(x, sr); @@ -123,38 +146,62 @@ public class SlieroVadOnnxModel { sr = result.sr; int batchSize = x.length; + int numSamples = sr == 16000 ? 512 : 256; + int contextSize = sr == 16000 ? 64 : 32; - if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) { - resetStates(); + // Reset states only when sample rate or batch size changes + if (lastSr != 0 && lastSr != sr) { + resetStates(batchSize); + } else if (lastBatchSize != 0 && lastBatchSize != batchSize) { + resetStates(batchSize); + } else if (lastBatchSize == 0) { + // First call - state is already initialized, just set batch size + lastBatchSize = batchSize; + } + + // Initialize context if needed + if (context.length == 0) { + context = new float[batchSize][contextSize]; + } + + // Concatenate context and input + float[][] xWithContext = new float[batchSize][contextSize + numSamples]; + for (int i = 0; i < batchSize; i++) { + // Copy context + System.arraycopy(context[i], 0, xWithContext[i], 0, contextSize); + // Copy input + System.arraycopy(x[i], 0, xWithContext[i], contextSize, numSamples); } OrtEnvironment env = OrtEnvironment.getEnvironment(); OnnxTensor inputTensor = null; - OnnxTensor hTensor = null; - OnnxTensor cTensor = null; + OnnxTensor stateTensor = null; OnnxTensor srTensor = null; OrtSession.Result ortOutputs = null; try { // Create input tensors - inputTensor = OnnxTensor.createTensor(env, x); - hTensor = OnnxTensor.createTensor(env, h); - cTensor = OnnxTensor.createTensor(env, c); + inputTensor = OnnxTensor.createTensor(env, xWithContext); + stateTensor = OnnxTensor.createTensor(env, state); srTensor = OnnxTensor.createTensor(env, new long[]{sr}); Map inputs = new HashMap<>(); inputs.put("input", inputTensor); inputs.put("sr", srTensor); - inputs.put("h", hTensor); - inputs.put("c", cTensor); + inputs.put("state", stateTensor); - // Call the ONNX model for calculation + // Run ONNX model inference ortOutputs = session.run(inputs); - // Get the output results + // Get output results float[][] output = (float[][]) ortOutputs.get(0).getValue(); - h = (float[][][]) ortOutputs.get(1).getValue(); - c = (float[][][]) ortOutputs.get(2).getValue(); + state = (float[][][]) ortOutputs.get(1).getValue(); + + // Update context - save the last contextSize samples from input + for (int i = 0; i < batchSize; i++) { + System.arraycopy(xWithContext[i], xWithContext[i].length - contextSize, + context[i], 0, contextSize); + } lastSr = sr; lastBatchSize = batchSize; @@ -163,11 +210,8 @@ public class SlieroVadOnnxModel { if (inputTensor != null) { inputTensor.close(); } - if (hTensor != null) { - hTensor.close(); - } - if (cTensor != null) { - cTensor.close(); + if (stateTensor != null) { + stateTensor.close(); } if (srTensor != null) { srTensor.close();