mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 01:49:22 +08:00
add java onnx example
This commit is contained in:
69
examples/java-example/src/main/java/org/example/App.java
Normal file
69
examples/java-example/src/main/java/org/example/App.java
Normal file
@@ -0,0 +1,69 @@
|
||||
package org.example;
|
||||
|
||||
import ai.onnxruntime.OrtException;
|
||||
import javax.sound.sampled.*;
|
||||
import java.util.Map;
|
||||
|
||||
public class App {
|
||||
|
||||
private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx";
|
||||
private static final int SAMPLE_RATE = 16000;
|
||||
private static final float START_THRESHOLD = 0.6f;
|
||||
private static final float END_THRESHOLD = 0.45f;
|
||||
private static final int MIN_SILENCE_DURATION_MS = 600;
|
||||
private static final int SPEECH_PAD_MS = 500;
|
||||
private static final int WINDOW_SIZE_SAMPLES = 2048;
|
||||
|
||||
public static void main(String[] args) {
|
||||
// Initialize the Voice Activity Detector
|
||||
SlieroVadDetector vadDetector;
|
||||
try {
|
||||
vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
||||
} catch (OrtException e) {
|
||||
System.err.println("Error initializing the VAD detector: " + e.getMessage());
|
||||
return;
|
||||
}
|
||||
|
||||
// Set audio format
|
||||
AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false);
|
||||
DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
|
||||
|
||||
// Get the target data line and open it with the specified format
|
||||
TargetDataLine targetDataLine;
|
||||
try {
|
||||
targetDataLine = (TargetDataLine) AudioSystem.getLine(info);
|
||||
targetDataLine.open(format);
|
||||
targetDataLine.start();
|
||||
} catch (LineUnavailableException e) {
|
||||
System.err.println("Error opening target data line: " + e.getMessage());
|
||||
return;
|
||||
}
|
||||
|
||||
// Main loop to continuously read data and apply Voice Activity Detection
|
||||
while (targetDataLine.isOpen()) {
|
||||
byte[] data = new byte[WINDOW_SIZE_SAMPLES];
|
||||
|
||||
int numBytesRead = targetDataLine.read(data, 0, data.length);
|
||||
if (numBytesRead <= 0) {
|
||||
System.err.println("Error reading data from target data line.");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Apply the Voice Activity Detector to the data and get the result
|
||||
Map<String, Double> detectResult;
|
||||
try {
|
||||
detectResult = vadDetector.apply(data, true);
|
||||
} catch (Exception e) {
|
||||
System.err.println("Error applying VAD detector: " + e.getMessage());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!detectResult.isEmpty()) {
|
||||
System.out.println(detectResult);
|
||||
}
|
||||
}
|
||||
|
||||
// Close the target data line to release audio resources
|
||||
targetDataLine.close();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package org.example;
|
||||
|
||||
import ai.onnxruntime.OrtException;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.RoundingMode;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
public class SlieroVadDetector {
|
||||
// OnnxModel model used for speech processing
|
||||
private final SlieroVadOnnxModel model;
|
||||
// Threshold for speech start
|
||||
private final float startThreshold;
|
||||
// Threshold for speech end
|
||||
private final float endThreshold;
|
||||
// Sampling rate
|
||||
private final int samplingRate;
|
||||
// Minimum number of silence samples to determine the end threshold of speech
|
||||
private final float minSilenceSamples;
|
||||
// Additional number of samples for speech start or end to calculate speech start or end time
|
||||
private final float speechPadSamples;
|
||||
// Whether in the triggered state (i.e. whether speech is being detected)
|
||||
private boolean triggered;
|
||||
// Temporarily stored number of speech end samples
|
||||
private int tempEnd;
|
||||
// Number of samples currently being processed
|
||||
private int currentSample;
|
||||
|
||||
|
||||
public SlieroVadDetector(String modelPath,
|
||||
float startThreshold,
|
||||
float endThreshold,
|
||||
int samplingRate,
|
||||
int minSilenceDurationMs,
|
||||
int speechPadMs) throws OrtException {
|
||||
// Check if the sampling rate is 8000 or 16000, if not, throw an exception
|
||||
if (samplingRate != 8000 && samplingRate != 16000) {
|
||||
throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]");
|
||||
}
|
||||
|
||||
// Initialize the parameters
|
||||
this.model = new SlieroVadOnnxModel(modelPath);
|
||||
this.startThreshold = startThreshold;
|
||||
this.endThreshold = endThreshold;
|
||||
this.samplingRate = samplingRate;
|
||||
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
||||
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
|
||||
// Reset the state
|
||||
reset();
|
||||
}
|
||||
|
||||
// Method to reset the state, including the model state, trigger state, temporary end time, and current sample count
|
||||
public void reset() {
|
||||
model.resetStates();
|
||||
triggered = false;
|
||||
tempEnd = 0;
|
||||
currentSample = 0;
|
||||
}
|
||||
|
||||
// apply method for processing the audio array, returning possible speech start or end times
|
||||
public Map<String, Double> apply(byte[] data, boolean returnSeconds) {
|
||||
|
||||
// Convert the byte array to a float array
|
||||
float[] audioData = new float[data.length / 2];
|
||||
for (int i = 0; i < audioData.length; i++) {
|
||||
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
|
||||
}
|
||||
|
||||
// Get the length of the audio array as the window size
|
||||
int windowSizeSamples = audioData.length;
|
||||
// Update the current sample count
|
||||
currentSample += windowSizeSamples;
|
||||
|
||||
// Call the model to get the prediction probability of speech
|
||||
float speechProb = 0;
|
||||
try {
|
||||
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
// If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time
|
||||
// This indicates that the speech duration has exceeded expectations and needs to recalculate the end time
|
||||
if (speechProb >= startThreshold && tempEnd != 0) {
|
||||
tempEnd = 0;
|
||||
}
|
||||
|
||||
// If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time
|
||||
if (speechProb >= startThreshold && !triggered) {
|
||||
triggered = true;
|
||||
int speechStart = (int) (currentSample - speechPadSamples);
|
||||
speechStart = Math.max(speechStart, 0);
|
||||
Map<String, Double> result = new HashMap<>();
|
||||
// Decide whether to return the result in seconds or sample count based on the returnSeconds parameter
|
||||
if (returnSeconds) {
|
||||
double speechStartSeconds = speechStart / (double) samplingRate;
|
||||
double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
|
||||
result.put("start", roundedSpeechStart);
|
||||
} else {
|
||||
result.put("start", (double) speechStart);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time
|
||||
if (speechProb < endThreshold && triggered) {
|
||||
// Initialize or update the temporary end time
|
||||
if (tempEnd == 0) {
|
||||
tempEnd = currentSample;
|
||||
}
|
||||
// If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null
|
||||
// This indicates that it is not yet possible to determine whether the speech has ended
|
||||
if (currentSample - tempEnd < minSilenceSamples) {
|
||||
return Collections.emptyMap();
|
||||
} else {
|
||||
// Calculate the speech end time, reset the trigger state and temporary end time
|
||||
int speechEnd = (int) (tempEnd + speechPadSamples);
|
||||
tempEnd = 0;
|
||||
triggered = false;
|
||||
Map<String, Double> result = new HashMap<>();
|
||||
|
||||
if (returnSeconds) {
|
||||
double speechEndSeconds = speechEnd / (double) samplingRate;
|
||||
double roundedSpeechEnd = BigDecimal.valueOf(speechEndSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
|
||||
result.put("end", roundedSpeechEnd);
|
||||
} else {
|
||||
result.put("end", (double) speechEnd);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// If the above conditions are not met, return null by default
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
public void close() throws OrtException {
|
||||
reset();
|
||||
model.close();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
package org.example;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SlieroVadOnnxModel {
|
||||
// Define private variable OrtSession
|
||||
private final OrtSession session;
|
||||
private float[][][] h;
|
||||
private float[][][] c;
|
||||
// Define the last sample rate
|
||||
private int lastSr = 0;
|
||||
// Define the last batch size
|
||||
private int lastBatchSize = 0;
|
||||
// Define a list of supported sample rates
|
||||
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
|
||||
|
||||
// Constructor
|
||||
public SlieroVadOnnxModel(String modelPath) throws OrtException {
|
||||
// Get the ONNX runtime environment
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
// Create an ONNX session options object
|
||||
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
||||
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
|
||||
opts.setInterOpNumThreads(1);
|
||||
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
|
||||
opts.setIntraOpNumThreads(1);
|
||||
// Add a CPU device, setting to false disables CPU execution optimization
|
||||
opts.addCPU(true);
|
||||
// Create an ONNX session using the environment, model path, and options
|
||||
session = env.createSession(modelPath, opts);
|
||||
// Reset states
|
||||
resetStates();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset states
|
||||
*/
|
||||
void resetStates() {
|
||||
h = new float[2][1][64];
|
||||
c = new float[2][1][64];
|
||||
lastSr = 0;
|
||||
lastBatchSize = 0;
|
||||
}
|
||||
|
||||
public void close() throws OrtException {
|
||||
session.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Define inner class ValidationResult
|
||||
*/
|
||||
public static class ValidationResult {
|
||||
public final float[][] x;
|
||||
public final int sr;
|
||||
|
||||
// Constructor
|
||||
public ValidationResult(float[][] x, int sr) {
|
||||
this.x = x;
|
||||
this.sr = sr;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Function to validate input data
|
||||
*/
|
||||
private ValidationResult validateInput(float[][] x, int sr) {
|
||||
// Process the input data with dimension 1
|
||||
if (x.length == 1) {
|
||||
x = new float[][]{x[0]};
|
||||
}
|
||||
// Throw an exception when the input data dimension is greater than 2
|
||||
if (x.length > 2) {
|
||||
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
|
||||
}
|
||||
|
||||
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
|
||||
if (sr != 16000 && (sr % 16000 == 0)) {
|
||||
int step = sr / 16000;
|
||||
float[][] reducedX = new float[x.length][];
|
||||
|
||||
for (int i = 0; i < x.length; i++) {
|
||||
float[] current = x[i];
|
||||
float[] newArr = new float[(current.length + step - 1) / step];
|
||||
|
||||
for (int j = 0, index = 0; j < current.length; j += step, index++) {
|
||||
newArr[index] = current[j];
|
||||
}
|
||||
|
||||
reducedX[i] = newArr;
|
||||
}
|
||||
|
||||
x = reducedX;
|
||||
sr = 16000;
|
||||
}
|
||||
|
||||
// If the sample rate is not in the list of supported sample rates, throw an exception
|
||||
if (!SAMPLE_RATES.contains(sr)) {
|
||||
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
|
||||
}
|
||||
|
||||
// If the input audio block is too short, throw an exception
|
||||
if (((float) sr) / x[0].length > 31.25) {
|
||||
throw new IllegalArgumentException("Input audio is too short");
|
||||
}
|
||||
|
||||
// Return the validated result
|
||||
return new ValidationResult(x, sr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to call the ONNX model
|
||||
*/
|
||||
public float[] call(float[][] x, int sr) throws OrtException {
|
||||
ValidationResult result = validateInput(x, sr);
|
||||
x = result.x;
|
||||
sr = result.sr;
|
||||
|
||||
int batchSize = x.length;
|
||||
|
||||
if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) {
|
||||
resetStates();
|
||||
}
|
||||
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
||||
OnnxTensor inputTensor = null;
|
||||
OnnxTensor hTensor = null;
|
||||
OnnxTensor cTensor = null;
|
||||
OnnxTensor srTensor = null;
|
||||
OrtSession.Result ortOutputs = null;
|
||||
|
||||
try {
|
||||
// Create input tensors
|
||||
inputTensor = OnnxTensor.createTensor(env, x);
|
||||
hTensor = OnnxTensor.createTensor(env, h);
|
||||
cTensor = OnnxTensor.createTensor(env, c);
|
||||
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
|
||||
|
||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||
inputs.put("input", inputTensor);
|
||||
inputs.put("sr", srTensor);
|
||||
inputs.put("h", hTensor);
|
||||
inputs.put("c", cTensor);
|
||||
|
||||
// Call the ONNX model for calculation
|
||||
ortOutputs = session.run(inputs);
|
||||
// Get the output results
|
||||
float[][] output = (float[][]) ortOutputs.get(0).getValue();
|
||||
h = (float[][][]) ortOutputs.get(1).getValue();
|
||||
c = (float[][][]) ortOutputs.get(2).getValue();
|
||||
|
||||
lastSr = sr;
|
||||
lastBatchSize = batchSize;
|
||||
return output[0];
|
||||
} finally {
|
||||
if (inputTensor != null) {
|
||||
inputTensor.close();
|
||||
}
|
||||
if (hTensor != null) {
|
||||
hTensor.close();
|
||||
}
|
||||
if (cTensor != null) {
|
||||
cTensor.close();
|
||||
}
|
||||
if (srTensor != null) {
|
||||
srTensor.close();
|
||||
}
|
||||
if (ortOutputs != null) {
|
||||
ortOutputs.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user