From f5ea01bfdacc8a096f355dd70a805f80775808e2 Mon Sep 17 00:00:00 2001 From: GH Date: Sat, 11 Oct 2025 16:21:03 +0800 Subject: [PATCH 1/4] Update pom.xml --- examples/java-example/pom.xml | 49 ++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/examples/java-example/pom.xml b/examples/java-example/pom.xml index 32ba720..88dc906 100644 --- a/examples/java-example/pom.xml +++ b/examples/java-example/pom.xml @@ -1,30 +1,31 @@ - 4.0.0 + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + 4.0.0 - org.example - java-example - 1.0-SNAPSHOT - jar + org.example + java-example + 1.0-SNAPSHOT + jar - sliero-vad-example - http://maven.apache.org + sliero-vad-example + http://maven.apache.org - - UTF-8 - + + UTF-8 + - - - junit - junit - 3.8.1 - test - - - com.microsoft.onnxruntime - onnxruntime - 1.16.0-rc1 - - + + + junit + junit + 3.8.1 + test + + + + com.microsoft.onnxruntime + onnxruntime + 1.23.1 + + From 3d860e6acef3f2405687c4a2e32a867bdb87e82a Mon Sep 17 00:00:00 2001 From: GH Date: Sat, 11 Oct 2025 16:21:32 +0800 Subject: [PATCH 2/4] Update App.java --- .../src/main/java/org/example/App.java | 285 +++++++++++++++--- 1 file changed, 240 insertions(+), 45 deletions(-) diff --git a/examples/java-example/src/main/java/org/example/App.java b/examples/java-example/src/main/java/org/example/App.java index 7b58f17..56aca29 100644 --- a/examples/java-example/src/main/java/org/example/App.java +++ b/examples/java-example/src/main/java/org/example/App.java @@ -2,68 +2,263 @@ package org.example; import ai.onnxruntime.OrtException; import javax.sound.sampled.*; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; +/** + * Silero VAD Java Example + * Voice Activity Detection using ONNX model + * + * @author VvvvvGH + */ public class App { - private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx"; + // ONNX model path - using the model file from the project + private static final String MODEL_PATH = "../../src/silero_vad/data/silero_vad.onnx"; + // Test audio file path + private static final String AUDIO_FILE_PATH = "../../en_example.wav"; + // Sampling rate private static final int SAMPLE_RATE = 16000; - private static final float START_THRESHOLD = 0.6f; - private static final float END_THRESHOLD = 0.45f; - private static final int MIN_SILENCE_DURATION_MS = 600; - private static final int SPEECH_PAD_MS = 500; - private static final int WINDOW_SIZE_SAMPLES = 2048; + // Speech threshold (consistent with Python default) + private static final float THRESHOLD = 0.5f; + // Negative threshold (used to determine speech end) + private static final float NEG_THRESHOLD = 0.35f; // threshold - 0.15 + // Minimum speech duration (milliseconds) + private static final int MIN_SPEECH_DURATION_MS = 250; + // Minimum silence duration (milliseconds) + private static final int MIN_SILENCE_DURATION_MS = 100; + // Speech padding (milliseconds) + private static final int SPEECH_PAD_MS = 30; + // Window size (samples) - 512 samples for 16kHz + private static final int WINDOW_SIZE_SAMPLES = 512; public static void main(String[] args) { - // Initialize the Voice Activity Detector - SlieroVadDetector vadDetector; + System.out.println("=".repeat(60)); + System.out.println("Silero VAD Java ONNX Example"); + System.out.println("=".repeat(60)); + + // Load ONNX model + SlieroVadOnnxModel model; try { - vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS); + System.out.println("Loading ONNX model: " + MODEL_PATH); + model = new SlieroVadOnnxModel(MODEL_PATH); + System.out.println("Model loaded successfully!"); } catch (OrtException e) { - System.err.println("Error initializing the VAD detector: " + e.getMessage()); + System.err.println("Failed to load model: " + e.getMessage()); + e.printStackTrace(); return; } - // Set audio format - AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false); - DataLine.Info info = new DataLine.Info(TargetDataLine.class, format); - - // Get the target data line and open it with the specified format - TargetDataLine targetDataLine; + // Read WAV file + float[] audioData; try { - targetDataLine = (TargetDataLine) AudioSystem.getLine(info); - targetDataLine.open(format); - targetDataLine.start(); - } catch (LineUnavailableException e) { - System.err.println("Error opening target data line: " + e.getMessage()); + System.out.println("\nReading audio file: " + AUDIO_FILE_PATH); + audioData = readWavFileAsFloatArray(AUDIO_FILE_PATH); + System.out.println("Audio file read successfully, samples: " + audioData.length); + System.out.println("Audio duration: " + String.format("%.2f", (audioData.length / (float) SAMPLE_RATE)) + " seconds"); + } catch (Exception e) { + System.err.println("Failed to read audio file: " + e.getMessage()); + e.printStackTrace(); return; } - // Main loop to continuously read data and apply Voice Activity Detection - while (targetDataLine.isOpen()) { - byte[] data = new byte[WINDOW_SIZE_SAMPLES]; - - int numBytesRead = targetDataLine.read(data, 0, data.length); - if (numBytesRead <= 0) { - System.err.println("Error reading data from target data line."); - continue; - } - - // Apply the Voice Activity Detector to the data and get the result - Map detectResult; - try { - detectResult = vadDetector.apply(data, true); - } catch (Exception e) { - System.err.println("Error applying VAD detector: " + e.getMessage()); - continue; - } - - if (!detectResult.isEmpty()) { - System.out.println(detectResult); - } + // Get speech timestamps (batch mode, consistent with Python's get_speech_timestamps) + System.out.println("\nDetecting speech segments..."); + List> speechTimestamps; + try { + speechTimestamps = getSpeechTimestamps( + audioData, + model, + THRESHOLD, + SAMPLE_RATE, + MIN_SPEECH_DURATION_MS, + MIN_SILENCE_DURATION_MS, + SPEECH_PAD_MS, + NEG_THRESHOLD + ); + } catch (OrtException e) { + System.err.println("Failed to detect speech timestamps: " + e.getMessage()); + e.printStackTrace(); + return; } - // Close the target data line to release audio resources - targetDataLine.close(); + // Output detection results + System.out.println("\nDetected speech timestamps (in samples):"); + for (Map timestamp : speechTimestamps) { + System.out.println(timestamp); + } + + // Output summary + System.out.println("\n" + "=".repeat(60)); + System.out.println("Detection completed!"); + System.out.println("Total detected " + speechTimestamps.size() + " speech segments"); + System.out.println("=".repeat(60)); + + // Close model + try { + model.close(); + } catch (OrtException e) { + System.err.println("Error closing model: " + e.getMessage()); + } } + + /** + * Get speech timestamps + * Implements the same logic as Python's get_speech_timestamps + * + * @param audio Audio data (float array) + * @param model ONNX model + * @param threshold Speech threshold + * @param samplingRate Sampling rate + * @param minSpeechDurationMs Minimum speech duration (milliseconds) + * @param minSilenceDurationMs Minimum silence duration (milliseconds) + * @param speechPadMs Speech padding (milliseconds) + * @param negThreshold Negative threshold (used to determine speech end) + * @return List of speech timestamps + */ + private static List> getSpeechTimestamps( + float[] audio, + SlieroVadOnnxModel model, + float threshold, + int samplingRate, + int minSpeechDurationMs, + int minSilenceDurationMs, + int speechPadMs, + float negThreshold) throws OrtException { + + // Reset model states + model.resetStates(); + + // Calculate parameters + int minSpeechSamples = samplingRate * minSpeechDurationMs / 1000; + int speechPadSamples = samplingRate * speechPadMs / 1000; + int minSilenceSamples = samplingRate * minSilenceDurationMs / 1000; + int windowSizeSamples = samplingRate == 16000 ? 512 : 256; + int audioLengthSamples = audio.length; + + // Calculate speech probabilities for all audio chunks + List speechProbs = new ArrayList<>(); + for (int currentStart = 0; currentStart < audioLengthSamples; currentStart += windowSizeSamples) { + float[] chunk = new float[windowSizeSamples]; + int chunkLength = Math.min(windowSizeSamples, audioLengthSamples - currentStart); + System.arraycopy(audio, currentStart, chunk, 0, chunkLength); + + // Pad with zeros if chunk is shorter than window size + if (chunkLength < windowSizeSamples) { + for (int i = chunkLength; i < windowSizeSamples; i++) { + chunk[i] = 0.0f; + } + } + + float speechProb = model.call(new float[][]{chunk}, samplingRate)[0]; + speechProbs.add(speechProb); + } + + // Detect speech segments using the same algorithm as Python + boolean triggered = false; + List> speeches = new ArrayList<>(); + Map currentSpeech = null; + int tempEnd = 0; + + for (int i = 0; i < speechProbs.size(); i++) { + float speechProb = speechProbs.get(i); + + // Reset temporary end if speech probability exceeds threshold + if (speechProb >= threshold && tempEnd != 0) { + tempEnd = 0; + } + + // Detect speech start + if (speechProb >= threshold && !triggered) { + triggered = true; + currentSpeech = new HashMap<>(); + currentSpeech.put("start", windowSizeSamples * i); + continue; + } + + // Detect speech end + if (speechProb < negThreshold && triggered) { + if (tempEnd == 0) { + tempEnd = windowSizeSamples * i; + } + if (windowSizeSamples * i - tempEnd < minSilenceSamples) { + continue; + } else { + currentSpeech.put("end", tempEnd); + if (currentSpeech.get("end") - currentSpeech.get("start") > minSpeechSamples) { + speeches.add(currentSpeech); + } + currentSpeech = null; + tempEnd = 0; + triggered = false; + } + } + } + + // Handle the last speech segment + if (currentSpeech != null && + (audioLengthSamples - currentSpeech.get("start")) > minSpeechSamples) { + currentSpeech.put("end", audioLengthSamples); + speeches.add(currentSpeech); + } + + // Add speech padding - same logic as Python + for (int i = 0; i < speeches.size(); i++) { + Map speech = speeches.get(i); + if (i == 0) { + speech.put("start", Math.max(0, speech.get("start") - speechPadSamples)); + } + if (i != speeches.size() - 1) { + int silenceDuration = speeches.get(i + 1).get("start") - speech.get("end"); + if (silenceDuration < 2 * speechPadSamples) { + speech.put("end", speech.get("end") + silenceDuration / 2); + speeches.get(i + 1).put("start", + Math.max(0, speeches.get(i + 1).get("start") - silenceDuration / 2)); + } else { + speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples)); + speeches.get(i + 1).put("start", + Math.max(0, speeches.get(i + 1).get("start") - speechPadSamples)); + } + } else { + speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples)); + } + } + + return speeches; + } + + /** + * Read WAV file and return as float array + * + * @param filePath WAV file path + * @return Audio data as float array (normalized to -1.0 to 1.0) + */ + private static float[] readWavFileAsFloatArray(String filePath) + throws UnsupportedAudioFileException, IOException { + File audioFile = new File(filePath); + AudioInputStream audioStream = AudioSystem.getAudioInputStream(audioFile); + + // Get audio format information + AudioFormat format = audioStream.getFormat(); + System.out.println("Audio format: " + format); + + // Read all audio data + byte[] audioBytes = audioStream.readAllBytes(); + audioStream.close(); + + // Convert to float array + float[] audioData = new float[audioBytes.length / 2]; + for (int i = 0; i < audioData.length; i++) { + // 16-bit PCM: two bytes per sample (little-endian) + short sample = (short) ((audioBytes[i * 2] & 0xff) | (audioBytes[i * 2 + 1] << 8)); + audioData[i] = sample / 32768.0f; // Normalize to -1.0 to 1.0 + } + + return audioData; + } + } From 25a778c79897e734342a430f1715233925a928a3 Mon Sep 17 00:00:00 2001 From: GH Date: Sat, 11 Oct 2025 16:21:45 +0800 Subject: [PATCH 3/4] Update SlieroVadDetector.java --- .../java/org/example/SlieroVadDetector.java | 69 +++++++++++-------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/examples/java-example/src/main/java/org/example/SlieroVadDetector.java b/examples/java-example/src/main/java/org/example/SlieroVadDetector.java index dd2eecd..e03113c 100644 --- a/examples/java-example/src/main/java/org/example/SlieroVadDetector.java +++ b/examples/java-example/src/main/java/org/example/SlieroVadDetector.java @@ -8,25 +8,30 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; - +/** + * Silero VAD Detector + * Real-time voice activity detection + * + * @author VvvvvGH + */ public class SlieroVadDetector { - // OnnxModel model used for speech processing + // ONNX model for speech processing private final SlieroVadOnnxModel model; - // Threshold for speech start + // Speech start threshold private final float startThreshold; - // Threshold for speech end + // Speech end threshold private final float endThreshold; // Sampling rate private final int samplingRate; - // Minimum number of silence samples to determine the end threshold of speech + // Minimum silence samples to determine speech end private final float minSilenceSamples; - // Additional number of samples for speech start or end to calculate speech start or end time + // Speech padding samples for calculating speech boundaries private final float speechPadSamples; - // Whether in the triggered state (i.e. whether speech is being detected) + // Triggered state (whether speech is being detected) private boolean triggered; - // Temporarily stored number of speech end samples + // Temporary speech end sample position private int tempEnd; - // Number of samples currently being processed + // Current sample position private int currentSample; @@ -36,23 +41,25 @@ public class SlieroVadDetector { int samplingRate, int minSilenceDurationMs, int speechPadMs) throws OrtException { - // Check if the sampling rate is 8000 or 16000, if not, throw an exception + // Validate sampling rate if (samplingRate != 8000 && samplingRate != 16000) { - throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]"); + throw new IllegalArgumentException("Does not support sampling rates other than [8000, 16000]"); } - // Initialize the parameters + // Initialize parameters this.model = new SlieroVadOnnxModel(modelPath); this.startThreshold = startThreshold; this.endThreshold = endThreshold; this.samplingRate = samplingRate; this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f; this.speechPadSamples = samplingRate * speechPadMs / 1000f; - // Reset the state + // Reset state reset(); } - // Method to reset the state, including the model state, trigger state, temporary end time, and current sample count + /** + * Reset detector state + */ public void reset() { model.resetStates(); triggered = false; @@ -60,21 +67,27 @@ public class SlieroVadDetector { currentSample = 0; } - // apply method for processing the audio array, returning possible speech start or end times + /** + * Process audio data and detect speech events + * + * @param data Audio data as byte array + * @param returnSeconds Whether to return timestamps in seconds + * @return Speech event (start or end) or empty map if no event + */ public Map apply(byte[] data, boolean returnSeconds) { - // Convert the byte array to a float array + // Convert byte array to float array float[] audioData = new float[data.length / 2]; for (int i = 0; i < audioData.length; i++) { audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f; } - // Get the length of the audio array as the window size + // Get window size from audio data length int windowSizeSamples = audioData.length; - // Update the current sample count + // Update current sample position currentSample += windowSizeSamples; - // Call the model to get the prediction probability of speech + // Get speech probability from model float speechProb = 0; try { speechProb = model.call(new float[][]{audioData}, samplingRate)[0]; @@ -82,19 +95,18 @@ public class SlieroVadDetector { throw new RuntimeException(e); } - // If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time - // This indicates that the speech duration has exceeded expectations and needs to recalculate the end time + // Reset temporary end if speech probability exceeds threshold if (speechProb >= startThreshold && tempEnd != 0) { tempEnd = 0; } - // If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time + // Detect speech start if (speechProb >= startThreshold && !triggered) { triggered = true; int speechStart = (int) (currentSample - speechPadSamples); speechStart = Math.max(speechStart, 0); Map result = new HashMap<>(); - // Decide whether to return the result in seconds or sample count based on the returnSeconds parameter + // Return in seconds or samples based on returnSeconds parameter if (returnSeconds) { double speechStartSeconds = speechStart / (double) samplingRate; double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue(); @@ -106,18 +118,17 @@ public class SlieroVadDetector { return result; } - // If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time + // Detect speech end if (speechProb < endThreshold && triggered) { - // Initialize or update the temporary end time + // Initialize or update temporary end position if (tempEnd == 0) { tempEnd = currentSample; } - // If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null - // This indicates that it is not yet possible to determine whether the speech has ended + // Wait for minimum silence duration before confirming speech end if (currentSample - tempEnd < minSilenceSamples) { return Collections.emptyMap(); } else { - // Calculate the speech end time, reset the trigger state and temporary end time + // Calculate speech end time and reset state int speechEnd = (int) (tempEnd + speechPadSamples); tempEnd = 0; triggered = false; @@ -134,7 +145,7 @@ public class SlieroVadDetector { } } - // If the above conditions are not met, return null by default + // No speech event detected return Collections.emptyMap(); } From b90f8c012f8d3132e66e1ba52bf892fb3b0f3755 Mon Sep 17 00:00:00 2001 From: GH Date: Sat, 11 Oct 2025 16:21:57 +0800 Subject: [PATCH 4/4] Update SlieroVadOnnxModel.java --- .../java/org/example/SlieroVadOnnxModel.java | 128 ++++++++++++------ 1 file changed, 86 insertions(+), 42 deletions(-) diff --git a/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java b/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java index e9fd8b8..e509364 100644 --- a/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java +++ b/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java @@ -9,42 +9,58 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +/** + * Silero VAD ONNX Model Wrapper + * + * @author VvvvvGH + */ public class SlieroVadOnnxModel { - // Define private variable OrtSession + // ONNX runtime session private final OrtSession session; - private float[][][] h; - private float[][][] c; - // Define the last sample rate + // Model state - dimensions: [2, batch_size, 128] + private float[][][] state; + // Context - stores the tail of the previous audio chunk + private float[][] context; + // Last sample rate private int lastSr = 0; - // Define the last batch size + // Last batch size private int lastBatchSize = 0; - // Define a list of supported sample rates + // Supported sample rates private static final List SAMPLE_RATES = Arrays.asList(8000, 16000); // Constructor public SlieroVadOnnxModel(String modelPath) throws OrtException { // Get the ONNX runtime environment OrtEnvironment env = OrtEnvironment.getEnvironment(); - // Create an ONNX session options object + // Create ONNX session options OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); - // Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations + // Set InterOp thread count to 1 (for parallel processing of different graph operations) opts.setInterOpNumThreads(1); - // Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation + // Set IntraOp thread count to 1 (for parallel processing within a single operation) opts.setIntraOpNumThreads(1); - // Add a CPU device, setting to false disables CPU execution optimization + // Enable CPU execution optimization opts.addCPU(true); - // Create an ONNX session using the environment, model path, and options + // Create ONNX session with the environment, model path, and options session = env.createSession(modelPath, opts); // Reset states resetStates(); } /** - * Reset states + * Reset states with default batch size */ void resetStates() { - h = new float[2][1][64]; - c = new float[2][1][64]; + resetStates(1); + } + + /** + * Reset states with specific batch size + * + * @param batchSize Batch size for state initialization + */ + void resetStates(int batchSize) { + state = new float[2][batchSize][128]; + context = new float[0][]; // Empty context lastSr = 0; lastBatchSize = 0; } @@ -54,13 +70,12 @@ public class SlieroVadOnnxModel { } /** - * Define inner class ValidationResult + * Inner class for validation result */ public static class ValidationResult { public final float[][] x; public final int sr; - // Constructor public ValidationResult(float[][] x, int sr) { this.x = x; this.sr = sr; @@ -68,19 +83,23 @@ public class SlieroVadOnnxModel { } /** - * Function to validate input data + * Validate input data + * + * @param x Audio data array + * @param sr Sample rate + * @return Validated input data and sample rate */ private ValidationResult validateInput(float[][] x, int sr) { - // Process the input data with dimension 1 + // Ensure input is at least 2D if (x.length == 1) { x = new float[][]{x[0]}; } - // Throw an exception when the input data dimension is greater than 2 + // Check if input dimension is valid if (x.length > 2) { throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length); } - // Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000 + // Downsample if sample rate is a multiple of 16000 if (sr != 16000 && (sr % 16000 == 0)) { int step = sr / 16000; float[][] reducedX = new float[x.length][]; @@ -100,22 +119,26 @@ public class SlieroVadOnnxModel { sr = 16000; } - // If the sample rate is not in the list of supported sample rates, throw an exception + // Validate sample rate if (!SAMPLE_RATES.contains(sr)) { throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)"); } - // If the input audio block is too short, throw an exception + // Check if audio chunk is too short if (((float) sr) / x[0].length > 31.25) { throw new IllegalArgumentException("Input audio is too short"); } - // Return the validated result return new ValidationResult(x, sr); } /** - * Method to call the ONNX model + * Call the ONNX model for inference + * + * @param x Audio data array + * @param sr Sample rate + * @return Speech probability output + * @throws OrtException If ONNX runtime error occurs */ public float[] call(float[][] x, int sr) throws OrtException { ValidationResult result = validateInput(x, sr); @@ -123,38 +146,62 @@ public class SlieroVadOnnxModel { sr = result.sr; int batchSize = x.length; + int numSamples = sr == 16000 ? 512 : 256; + int contextSize = sr == 16000 ? 64 : 32; - if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) { - resetStates(); + // Reset states only when sample rate or batch size changes + if (lastSr != 0 && lastSr != sr) { + resetStates(batchSize); + } else if (lastBatchSize != 0 && lastBatchSize != batchSize) { + resetStates(batchSize); + } else if (lastBatchSize == 0) { + // First call - state is already initialized, just set batch size + lastBatchSize = batchSize; + } + + // Initialize context if needed + if (context.length == 0) { + context = new float[batchSize][contextSize]; + } + + // Concatenate context and input + float[][] xWithContext = new float[batchSize][contextSize + numSamples]; + for (int i = 0; i < batchSize; i++) { + // Copy context + System.arraycopy(context[i], 0, xWithContext[i], 0, contextSize); + // Copy input + System.arraycopy(x[i], 0, xWithContext[i], contextSize, numSamples); } OrtEnvironment env = OrtEnvironment.getEnvironment(); OnnxTensor inputTensor = null; - OnnxTensor hTensor = null; - OnnxTensor cTensor = null; + OnnxTensor stateTensor = null; OnnxTensor srTensor = null; OrtSession.Result ortOutputs = null; try { // Create input tensors - inputTensor = OnnxTensor.createTensor(env, x); - hTensor = OnnxTensor.createTensor(env, h); - cTensor = OnnxTensor.createTensor(env, c); + inputTensor = OnnxTensor.createTensor(env, xWithContext); + stateTensor = OnnxTensor.createTensor(env, state); srTensor = OnnxTensor.createTensor(env, new long[]{sr}); Map inputs = new HashMap<>(); inputs.put("input", inputTensor); inputs.put("sr", srTensor); - inputs.put("h", hTensor); - inputs.put("c", cTensor); + inputs.put("state", stateTensor); - // Call the ONNX model for calculation + // Run ONNX model inference ortOutputs = session.run(inputs); - // Get the output results + // Get output results float[][] output = (float[][]) ortOutputs.get(0).getValue(); - h = (float[][][]) ortOutputs.get(1).getValue(); - c = (float[][][]) ortOutputs.get(2).getValue(); + state = (float[][][]) ortOutputs.get(1).getValue(); + + // Update context - save the last contextSize samples from input + for (int i = 0; i < batchSize; i++) { + System.arraycopy(xWithContext[i], xWithContext[i].length - contextSize, + context[i], 0, contextSize); + } lastSr = sr; lastBatchSize = batchSize; @@ -163,11 +210,8 @@ public class SlieroVadOnnxModel { if (inputTensor != null) { inputTensor.close(); } - if (hTensor != null) { - hTensor.close(); - } - if (cTensor != null) { - cTensor.close(); + if (stateTensor != null) { + stateTensor.close(); } if (srTensor != null) { srTensor.close();