diff --git a/examples/java-example/pom.xml b/examples/java-example/pom.xml
new file mode 100644
index 0000000..32ba720
--- /dev/null
+++ b/examples/java-example/pom.xml
@@ -0,0 +1,30 @@
+
+ 4.0.0
+
+ org.example
+ java-example
+ 1.0-SNAPSHOT
+ jar
+
+ sliero-vad-example
+ http://maven.apache.org
+
+
+ UTF-8
+
+
+
+
+ junit
+ junit
+ 3.8.1
+ test
+
+
+ com.microsoft.onnxruntime
+ onnxruntime
+ 1.16.0-rc1
+
+
+
diff --git a/examples/java-example/src/main/java/org/example/App.java b/examples/java-example/src/main/java/org/example/App.java
new file mode 100644
index 0000000..7b58f17
--- /dev/null
+++ b/examples/java-example/src/main/java/org/example/App.java
@@ -0,0 +1,69 @@
+package org.example;
+
+import ai.onnxruntime.OrtException;
+import javax.sound.sampled.*;
+import java.util.Map;
+
+public class App {
+
+ private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx";
+ private static final int SAMPLE_RATE = 16000;
+ private static final float START_THRESHOLD = 0.6f;
+ private static final float END_THRESHOLD = 0.45f;
+ private static final int MIN_SILENCE_DURATION_MS = 600;
+ private static final int SPEECH_PAD_MS = 500;
+ private static final int WINDOW_SIZE_SAMPLES = 2048;
+
+ public static void main(String[] args) {
+ // Initialize the Voice Activity Detector
+ SlieroVadDetector vadDetector;
+ try {
+ vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
+ } catch (OrtException e) {
+ System.err.println("Error initializing the VAD detector: " + e.getMessage());
+ return;
+ }
+
+ // Set audio format
+ AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false);
+ DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
+
+ // Get the target data line and open it with the specified format
+ TargetDataLine targetDataLine;
+ try {
+ targetDataLine = (TargetDataLine) AudioSystem.getLine(info);
+ targetDataLine.open(format);
+ targetDataLine.start();
+ } catch (LineUnavailableException e) {
+ System.err.println("Error opening target data line: " + e.getMessage());
+ return;
+ }
+
+ // Main loop to continuously read data and apply Voice Activity Detection
+ while (targetDataLine.isOpen()) {
+ byte[] data = new byte[WINDOW_SIZE_SAMPLES];
+
+ int numBytesRead = targetDataLine.read(data, 0, data.length);
+ if (numBytesRead <= 0) {
+ System.err.println("Error reading data from target data line.");
+ continue;
+ }
+
+ // Apply the Voice Activity Detector to the data and get the result
+ Map detectResult;
+ try {
+ detectResult = vadDetector.apply(data, true);
+ } catch (Exception e) {
+ System.err.println("Error applying VAD detector: " + e.getMessage());
+ continue;
+ }
+
+ if (!detectResult.isEmpty()) {
+ System.out.println(detectResult);
+ }
+ }
+
+ // Close the target data line to release audio resources
+ targetDataLine.close();
+ }
+}
diff --git a/examples/java-example/src/main/java/org/example/SlieroVadDetector.java b/examples/java-example/src/main/java/org/example/SlieroVadDetector.java
new file mode 100644
index 0000000..dd2eecd
--- /dev/null
+++ b/examples/java-example/src/main/java/org/example/SlieroVadDetector.java
@@ -0,0 +1,145 @@
+package org.example;
+
+import ai.onnxruntime.OrtException;
+
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+
+public class SlieroVadDetector {
+ // OnnxModel model used for speech processing
+ private final SlieroVadOnnxModel model;
+ // Threshold for speech start
+ private final float startThreshold;
+ // Threshold for speech end
+ private final float endThreshold;
+ // Sampling rate
+ private final int samplingRate;
+ // Minimum number of silence samples to determine the end threshold of speech
+ private final float minSilenceSamples;
+ // Additional number of samples for speech start or end to calculate speech start or end time
+ private final float speechPadSamples;
+ // Whether in the triggered state (i.e. whether speech is being detected)
+ private boolean triggered;
+ // Temporarily stored number of speech end samples
+ private int tempEnd;
+ // Number of samples currently being processed
+ private int currentSample;
+
+
+ public SlieroVadDetector(String modelPath,
+ float startThreshold,
+ float endThreshold,
+ int samplingRate,
+ int minSilenceDurationMs,
+ int speechPadMs) throws OrtException {
+ // Check if the sampling rate is 8000 or 16000, if not, throw an exception
+ if (samplingRate != 8000 && samplingRate != 16000) {
+ throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]");
+ }
+
+ // Initialize the parameters
+ this.model = new SlieroVadOnnxModel(modelPath);
+ this.startThreshold = startThreshold;
+ this.endThreshold = endThreshold;
+ this.samplingRate = samplingRate;
+ this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
+ this.speechPadSamples = samplingRate * speechPadMs / 1000f;
+ // Reset the state
+ reset();
+ }
+
+ // Method to reset the state, including the model state, trigger state, temporary end time, and current sample count
+ public void reset() {
+ model.resetStates();
+ triggered = false;
+ tempEnd = 0;
+ currentSample = 0;
+ }
+
+ // apply method for processing the audio array, returning possible speech start or end times
+ public Map apply(byte[] data, boolean returnSeconds) {
+
+ // Convert the byte array to a float array
+ float[] audioData = new float[data.length / 2];
+ for (int i = 0; i < audioData.length; i++) {
+ audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
+ }
+
+ // Get the length of the audio array as the window size
+ int windowSizeSamples = audioData.length;
+ // Update the current sample count
+ currentSample += windowSizeSamples;
+
+ // Call the model to get the prediction probability of speech
+ float speechProb = 0;
+ try {
+ speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+
+ // If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time
+ // This indicates that the speech duration has exceeded expectations and needs to recalculate the end time
+ if (speechProb >= startThreshold && tempEnd != 0) {
+ tempEnd = 0;
+ }
+
+ // If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time
+ if (speechProb >= startThreshold && !triggered) {
+ triggered = true;
+ int speechStart = (int) (currentSample - speechPadSamples);
+ speechStart = Math.max(speechStart, 0);
+ Map result = new HashMap<>();
+ // Decide whether to return the result in seconds or sample count based on the returnSeconds parameter
+ if (returnSeconds) {
+ double speechStartSeconds = speechStart / (double) samplingRate;
+ double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
+ result.put("start", roundedSpeechStart);
+ } else {
+ result.put("start", (double) speechStart);
+ }
+
+ return result;
+ }
+
+ // If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time
+ if (speechProb < endThreshold && triggered) {
+ // Initialize or update the temporary end time
+ if (tempEnd == 0) {
+ tempEnd = currentSample;
+ }
+ // If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null
+ // This indicates that it is not yet possible to determine whether the speech has ended
+ if (currentSample - tempEnd < minSilenceSamples) {
+ return Collections.emptyMap();
+ } else {
+ // Calculate the speech end time, reset the trigger state and temporary end time
+ int speechEnd = (int) (tempEnd + speechPadSamples);
+ tempEnd = 0;
+ triggered = false;
+ Map result = new HashMap<>();
+
+ if (returnSeconds) {
+ double speechEndSeconds = speechEnd / (double) samplingRate;
+ double roundedSpeechEnd = BigDecimal.valueOf(speechEndSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
+ result.put("end", roundedSpeechEnd);
+ } else {
+ result.put("end", (double) speechEnd);
+ }
+ return result;
+ }
+ }
+
+ // If the above conditions are not met, return null by default
+ return Collections.emptyMap();
+ }
+
+ public void close() throws OrtException {
+ reset();
+ model.close();
+ }
+}
diff --git a/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java b/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java
new file mode 100644
index 0000000..e9fd8b8
--- /dev/null
+++ b/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java
@@ -0,0 +1,180 @@
+package org.example;
+
+import ai.onnxruntime.OnnxTensor;
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class SlieroVadOnnxModel {
+ // Define private variable OrtSession
+ private final OrtSession session;
+ private float[][][] h;
+ private float[][][] c;
+ // Define the last sample rate
+ private int lastSr = 0;
+ // Define the last batch size
+ private int lastBatchSize = 0;
+ // Define a list of supported sample rates
+ private static final List SAMPLE_RATES = Arrays.asList(8000, 16000);
+
+ // Constructor
+ public SlieroVadOnnxModel(String modelPath) throws OrtException {
+ // Get the ONNX runtime environment
+ OrtEnvironment env = OrtEnvironment.getEnvironment();
+ // Create an ONNX session options object
+ OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
+ // Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
+ opts.setInterOpNumThreads(1);
+ // Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
+ opts.setIntraOpNumThreads(1);
+ // Add a CPU device, setting to false disables CPU execution optimization
+ opts.addCPU(true);
+ // Create an ONNX session using the environment, model path, and options
+ session = env.createSession(modelPath, opts);
+ // Reset states
+ resetStates();
+ }
+
+ /**
+ * Reset states
+ */
+ void resetStates() {
+ h = new float[2][1][64];
+ c = new float[2][1][64];
+ lastSr = 0;
+ lastBatchSize = 0;
+ }
+
+ public void close() throws OrtException {
+ session.close();
+ }
+
+ /**
+ * Define inner class ValidationResult
+ */
+ public static class ValidationResult {
+ public final float[][] x;
+ public final int sr;
+
+ // Constructor
+ public ValidationResult(float[][] x, int sr) {
+ this.x = x;
+ this.sr = sr;
+ }
+ }
+
+ /**
+ * Function to validate input data
+ */
+ private ValidationResult validateInput(float[][] x, int sr) {
+ // Process the input data with dimension 1
+ if (x.length == 1) {
+ x = new float[][]{x[0]};
+ }
+ // Throw an exception when the input data dimension is greater than 2
+ if (x.length > 2) {
+ throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
+ }
+
+ // Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
+ if (sr != 16000 && (sr % 16000 == 0)) {
+ int step = sr / 16000;
+ float[][] reducedX = new float[x.length][];
+
+ for (int i = 0; i < x.length; i++) {
+ float[] current = x[i];
+ float[] newArr = new float[(current.length + step - 1) / step];
+
+ for (int j = 0, index = 0; j < current.length; j += step, index++) {
+ newArr[index] = current[j];
+ }
+
+ reducedX[i] = newArr;
+ }
+
+ x = reducedX;
+ sr = 16000;
+ }
+
+ // If the sample rate is not in the list of supported sample rates, throw an exception
+ if (!SAMPLE_RATES.contains(sr)) {
+ throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
+ }
+
+ // If the input audio block is too short, throw an exception
+ if (((float) sr) / x[0].length > 31.25) {
+ throw new IllegalArgumentException("Input audio is too short");
+ }
+
+ // Return the validated result
+ return new ValidationResult(x, sr);
+ }
+
+ /**
+ * Method to call the ONNX model
+ */
+ public float[] call(float[][] x, int sr) throws OrtException {
+ ValidationResult result = validateInput(x, sr);
+ x = result.x;
+ sr = result.sr;
+
+ int batchSize = x.length;
+
+ if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) {
+ resetStates();
+ }
+
+ OrtEnvironment env = OrtEnvironment.getEnvironment();
+
+ OnnxTensor inputTensor = null;
+ OnnxTensor hTensor = null;
+ OnnxTensor cTensor = null;
+ OnnxTensor srTensor = null;
+ OrtSession.Result ortOutputs = null;
+
+ try {
+ // Create input tensors
+ inputTensor = OnnxTensor.createTensor(env, x);
+ hTensor = OnnxTensor.createTensor(env, h);
+ cTensor = OnnxTensor.createTensor(env, c);
+ srTensor = OnnxTensor.createTensor(env, new long[]{sr});
+
+ Map inputs = new HashMap<>();
+ inputs.put("input", inputTensor);
+ inputs.put("sr", srTensor);
+ inputs.put("h", hTensor);
+ inputs.put("c", cTensor);
+
+ // Call the ONNX model for calculation
+ ortOutputs = session.run(inputs);
+ // Get the output results
+ float[][] output = (float[][]) ortOutputs.get(0).getValue();
+ h = (float[][][]) ortOutputs.get(1).getValue();
+ c = (float[][][]) ortOutputs.get(2).getValue();
+
+ lastSr = sr;
+ lastBatchSize = batchSize;
+ return output[0];
+ } finally {
+ if (inputTensor != null) {
+ inputTensor.close();
+ }
+ if (hTensor != null) {
+ hTensor.close();
+ }
+ if (cTensor != null) {
+ cTensor.close();
+ }
+ if (srTensor != null) {
+ srTensor.close();
+ }
+ if (ortOutputs != null) {
+ ortOutputs.close();
+ }
+ }
+ }
+}