From 3780baf49f38b53c234ec90478171596c8a0c43f Mon Sep 17 00:00:00 2001 From: VvvvvGH Date: Wed, 18 Oct 2023 13:57:18 +0800 Subject: [PATCH] add java onnx example --- examples/java-example/pom.xml | 30 +++ .../src/main/java/org/example/App.java | 69 +++++++ .../java/org/example/SlieroVadDetector.java | 145 ++++++++++++++ .../java/org/example/SlieroVadOnnxModel.java | 180 ++++++++++++++++++ 4 files changed, 424 insertions(+) create mode 100644 examples/java-example/pom.xml create mode 100644 examples/java-example/src/main/java/org/example/App.java create mode 100644 examples/java-example/src/main/java/org/example/SlieroVadDetector.java create mode 100644 examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java diff --git a/examples/java-example/pom.xml b/examples/java-example/pom.xml new file mode 100644 index 0000000..32ba720 --- /dev/null +++ b/examples/java-example/pom.xml @@ -0,0 +1,30 @@ + + 4.0.0 + + org.example + java-example + 1.0-SNAPSHOT + jar + + sliero-vad-example + http://maven.apache.org + + + UTF-8 + + + + + junit + junit + 3.8.1 + test + + + com.microsoft.onnxruntime + onnxruntime + 1.16.0-rc1 + + + diff --git a/examples/java-example/src/main/java/org/example/App.java b/examples/java-example/src/main/java/org/example/App.java new file mode 100644 index 0000000..7b58f17 --- /dev/null +++ b/examples/java-example/src/main/java/org/example/App.java @@ -0,0 +1,69 @@ +package org.example; + +import ai.onnxruntime.OrtException; +import javax.sound.sampled.*; +import java.util.Map; + +public class App { + + private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx"; + private static final int SAMPLE_RATE = 16000; + private static final float START_THRESHOLD = 0.6f; + private static final float END_THRESHOLD = 0.45f; + private static final int MIN_SILENCE_DURATION_MS = 600; + private static final int SPEECH_PAD_MS = 500; + private static final int WINDOW_SIZE_SAMPLES = 2048; + + public static void main(String[] args) { + // Initialize the Voice Activity Detector + SlieroVadDetector vadDetector; + try { + vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS); + } catch (OrtException e) { + System.err.println("Error initializing the VAD detector: " + e.getMessage()); + return; + } + + // Set audio format + AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false); + DataLine.Info info = new DataLine.Info(TargetDataLine.class, format); + + // Get the target data line and open it with the specified format + TargetDataLine targetDataLine; + try { + targetDataLine = (TargetDataLine) AudioSystem.getLine(info); + targetDataLine.open(format); + targetDataLine.start(); + } catch (LineUnavailableException e) { + System.err.println("Error opening target data line: " + e.getMessage()); + return; + } + + // Main loop to continuously read data and apply Voice Activity Detection + while (targetDataLine.isOpen()) { + byte[] data = new byte[WINDOW_SIZE_SAMPLES]; + + int numBytesRead = targetDataLine.read(data, 0, data.length); + if (numBytesRead <= 0) { + System.err.println("Error reading data from target data line."); + continue; + } + + // Apply the Voice Activity Detector to the data and get the result + Map detectResult; + try { + detectResult = vadDetector.apply(data, true); + } catch (Exception e) { + System.err.println("Error applying VAD detector: " + e.getMessage()); + continue; + } + + if (!detectResult.isEmpty()) { + System.out.println(detectResult); + } + } + + // Close the target data line to release audio resources + targetDataLine.close(); + } +} diff --git a/examples/java-example/src/main/java/org/example/SlieroVadDetector.java b/examples/java-example/src/main/java/org/example/SlieroVadDetector.java new file mode 100644 index 0000000..dd2eecd --- /dev/null +++ b/examples/java-example/src/main/java/org/example/SlieroVadDetector.java @@ -0,0 +1,145 @@ +package org.example; + +import ai.onnxruntime.OrtException; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + + +public class SlieroVadDetector { + // OnnxModel model used for speech processing + private final SlieroVadOnnxModel model; + // Threshold for speech start + private final float startThreshold; + // Threshold for speech end + private final float endThreshold; + // Sampling rate + private final int samplingRate; + // Minimum number of silence samples to determine the end threshold of speech + private final float minSilenceSamples; + // Additional number of samples for speech start or end to calculate speech start or end time + private final float speechPadSamples; + // Whether in the triggered state (i.e. whether speech is being detected) + private boolean triggered; + // Temporarily stored number of speech end samples + private int tempEnd; + // Number of samples currently being processed + private int currentSample; + + + public SlieroVadDetector(String modelPath, + float startThreshold, + float endThreshold, + int samplingRate, + int minSilenceDurationMs, + int speechPadMs) throws OrtException { + // Check if the sampling rate is 8000 or 16000, if not, throw an exception + if (samplingRate != 8000 && samplingRate != 16000) { + throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]"); + } + + // Initialize the parameters + this.model = new SlieroVadOnnxModel(modelPath); + this.startThreshold = startThreshold; + this.endThreshold = endThreshold; + this.samplingRate = samplingRate; + this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f; + this.speechPadSamples = samplingRate * speechPadMs / 1000f; + // Reset the state + reset(); + } + + // Method to reset the state, including the model state, trigger state, temporary end time, and current sample count + public void reset() { + model.resetStates(); + triggered = false; + tempEnd = 0; + currentSample = 0; + } + + // apply method for processing the audio array, returning possible speech start or end times + public Map apply(byte[] data, boolean returnSeconds) { + + // Convert the byte array to a float array + float[] audioData = new float[data.length / 2]; + for (int i = 0; i < audioData.length; i++) { + audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f; + } + + // Get the length of the audio array as the window size + int windowSizeSamples = audioData.length; + // Update the current sample count + currentSample += windowSizeSamples; + + // Call the model to get the prediction probability of speech + float speechProb = 0; + try { + speechProb = model.call(new float[][]{audioData}, samplingRate)[0]; + } catch (OrtException e) { + throw new RuntimeException(e); + } + + // If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time + // This indicates that the speech duration has exceeded expectations and needs to recalculate the end time + if (speechProb >= startThreshold && tempEnd != 0) { + tempEnd = 0; + } + + // If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time + if (speechProb >= startThreshold && !triggered) { + triggered = true; + int speechStart = (int) (currentSample - speechPadSamples); + speechStart = Math.max(speechStart, 0); + Map result = new HashMap<>(); + // Decide whether to return the result in seconds or sample count based on the returnSeconds parameter + if (returnSeconds) { + double speechStartSeconds = speechStart / (double) samplingRate; + double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue(); + result.put("start", roundedSpeechStart); + } else { + result.put("start", (double) speechStart); + } + + return result; + } + + // If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time + if (speechProb < endThreshold && triggered) { + // Initialize or update the temporary end time + if (tempEnd == 0) { + tempEnd = currentSample; + } + // If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null + // This indicates that it is not yet possible to determine whether the speech has ended + if (currentSample - tempEnd < minSilenceSamples) { + return Collections.emptyMap(); + } else { + // Calculate the speech end time, reset the trigger state and temporary end time + int speechEnd = (int) (tempEnd + speechPadSamples); + tempEnd = 0; + triggered = false; + Map result = new HashMap<>(); + + if (returnSeconds) { + double speechEndSeconds = speechEnd / (double) samplingRate; + double roundedSpeechEnd = BigDecimal.valueOf(speechEndSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue(); + result.put("end", roundedSpeechEnd); + } else { + result.put("end", (double) speechEnd); + } + return result; + } + } + + // If the above conditions are not met, return null by default + return Collections.emptyMap(); + } + + public void close() throws OrtException { + reset(); + model.close(); + } +} diff --git a/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java b/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java new file mode 100644 index 0000000..e9fd8b8 --- /dev/null +++ b/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java @@ -0,0 +1,180 @@ +package org.example; + +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class SlieroVadOnnxModel { + // Define private variable OrtSession + private final OrtSession session; + private float[][][] h; + private float[][][] c; + // Define the last sample rate + private int lastSr = 0; + // Define the last batch size + private int lastBatchSize = 0; + // Define a list of supported sample rates + private static final List SAMPLE_RATES = Arrays.asList(8000, 16000); + + // Constructor + public SlieroVadOnnxModel(String modelPath) throws OrtException { + // Get the ONNX runtime environment + OrtEnvironment env = OrtEnvironment.getEnvironment(); + // Create an ONNX session options object + OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); + // Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations + opts.setInterOpNumThreads(1); + // Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation + opts.setIntraOpNumThreads(1); + // Add a CPU device, setting to false disables CPU execution optimization + opts.addCPU(true); + // Create an ONNX session using the environment, model path, and options + session = env.createSession(modelPath, opts); + // Reset states + resetStates(); + } + + /** + * Reset states + */ + void resetStates() { + h = new float[2][1][64]; + c = new float[2][1][64]; + lastSr = 0; + lastBatchSize = 0; + } + + public void close() throws OrtException { + session.close(); + } + + /** + * Define inner class ValidationResult + */ + public static class ValidationResult { + public final float[][] x; + public final int sr; + + // Constructor + public ValidationResult(float[][] x, int sr) { + this.x = x; + this.sr = sr; + } + } + + /** + * Function to validate input data + */ + private ValidationResult validateInput(float[][] x, int sr) { + // Process the input data with dimension 1 + if (x.length == 1) { + x = new float[][]{x[0]}; + } + // Throw an exception when the input data dimension is greater than 2 + if (x.length > 2) { + throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length); + } + + // Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000 + if (sr != 16000 && (sr % 16000 == 0)) { + int step = sr / 16000; + float[][] reducedX = new float[x.length][]; + + for (int i = 0; i < x.length; i++) { + float[] current = x[i]; + float[] newArr = new float[(current.length + step - 1) / step]; + + for (int j = 0, index = 0; j < current.length; j += step, index++) { + newArr[index] = current[j]; + } + + reducedX[i] = newArr; + } + + x = reducedX; + sr = 16000; + } + + // If the sample rate is not in the list of supported sample rates, throw an exception + if (!SAMPLE_RATES.contains(sr)) { + throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)"); + } + + // If the input audio block is too short, throw an exception + if (((float) sr) / x[0].length > 31.25) { + throw new IllegalArgumentException("Input audio is too short"); + } + + // Return the validated result + return new ValidationResult(x, sr); + } + + /** + * Method to call the ONNX model + */ + public float[] call(float[][] x, int sr) throws OrtException { + ValidationResult result = validateInput(x, sr); + x = result.x; + sr = result.sr; + + int batchSize = x.length; + + if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) { + resetStates(); + } + + OrtEnvironment env = OrtEnvironment.getEnvironment(); + + OnnxTensor inputTensor = null; + OnnxTensor hTensor = null; + OnnxTensor cTensor = null; + OnnxTensor srTensor = null; + OrtSession.Result ortOutputs = null; + + try { + // Create input tensors + inputTensor = OnnxTensor.createTensor(env, x); + hTensor = OnnxTensor.createTensor(env, h); + cTensor = OnnxTensor.createTensor(env, c); + srTensor = OnnxTensor.createTensor(env, new long[]{sr}); + + Map inputs = new HashMap<>(); + inputs.put("input", inputTensor); + inputs.put("sr", srTensor); + inputs.put("h", hTensor); + inputs.put("c", cTensor); + + // Call the ONNX model for calculation + ortOutputs = session.run(inputs); + // Get the output results + float[][] output = (float[][]) ortOutputs.get(0).getValue(); + h = (float[][][]) ortOutputs.get(1).getValue(); + c = (float[][][]) ortOutputs.get(2).getValue(); + + lastSr = sr; + lastBatchSize = batchSize; + return output[0]; + } finally { + if (inputTensor != null) { + inputTensor.close(); + } + if (hTensor != null) { + hTensor.close(); + } + if (cTensor != null) { + cTensor.close(); + } + if (srTensor != null) { + srTensor.close(); + } + if (ortOutputs != null) { + ortOutputs.close(); + } + } + } +}