mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
Merge pull request #506 from yuguanqin/master
Add java example for wav file & support V5 model
This commit is contained in:
@@ -0,0 +1,37 @@
|
||||
package org.example;
|
||||
|
||||
import ai.onnxruntime.OrtException;
|
||||
import java.io.File;
|
||||
import java.util.List;
|
||||
|
||||
public class App {
|
||||
|
||||
private static final String MODEL_PATH = "/path/silero_vad.onnx";
|
||||
private static final String EXAMPLE_WAV_FILE = "/path/example.wav";
|
||||
private static final int SAMPLE_RATE = 16000;
|
||||
private static final float THRESHOLD = 0.5f;
|
||||
private static final int MIN_SPEECH_DURATION_MS = 250;
|
||||
private static final float MAX_SPEECH_DURATION_SECONDS = Float.POSITIVE_INFINITY;
|
||||
private static final int MIN_SILENCE_DURATION_MS = 100;
|
||||
private static final int SPEECH_PAD_MS = 30;
|
||||
|
||||
public static void main(String[] args) {
|
||||
// Initialize the Voice Activity Detector
|
||||
SileroVadDetector vadDetector;
|
||||
try {
|
||||
vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
|
||||
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
||||
fromWavFile(vadDetector, new File(EXAMPLE_WAV_FILE));
|
||||
} catch (OrtException e) {
|
||||
System.err.println("Error initializing the VAD detector: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public static void fromWavFile(SileroVadDetector vadDetector, File wavFile) {
|
||||
List<SileroSpeechSegment> speechTimeList = vadDetector.getSpeechSegmentList(wavFile);
|
||||
for (SileroSpeechSegment speechSegment : speechTimeList) {
|
||||
System.out.println(String.format("start second: %f, end second: %f",
|
||||
speechSegment.getStartSecond(), speechSegment.getEndSecond()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package org.example;
|
||||
|
||||
|
||||
public class SileroSpeechSegment {
|
||||
private Integer startOffset;
|
||||
private Integer endOffset;
|
||||
private Float startSecond;
|
||||
private Float endSecond;
|
||||
|
||||
public SileroSpeechSegment() {
|
||||
}
|
||||
|
||||
public SileroSpeechSegment(Integer startOffset, Integer endOffset, Float startSecond, Float endSecond) {
|
||||
this.startOffset = startOffset;
|
||||
this.endOffset = endOffset;
|
||||
this.startSecond = startSecond;
|
||||
this.endSecond = endSecond;
|
||||
}
|
||||
|
||||
public Integer getStartOffset() {
|
||||
return startOffset;
|
||||
}
|
||||
|
||||
public Integer getEndOffset() {
|
||||
return endOffset;
|
||||
}
|
||||
|
||||
public Float getStartSecond() {
|
||||
return startSecond;
|
||||
}
|
||||
|
||||
public Float getEndSecond() {
|
||||
return endSecond;
|
||||
}
|
||||
|
||||
public void setStartOffset(Integer startOffset) {
|
||||
this.startOffset = startOffset;
|
||||
}
|
||||
|
||||
public void setEndOffset(Integer endOffset) {
|
||||
this.endOffset = endOffset;
|
||||
}
|
||||
|
||||
public void setStartSecond(Float startSecond) {
|
||||
this.startSecond = startSecond;
|
||||
}
|
||||
|
||||
public void setEndSecond(Float endSecond) {
|
||||
this.endSecond = endSecond;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package org.example;
|
||||
|
||||
|
||||
import ai.onnxruntime.OrtException;
|
||||
|
||||
import javax.sound.sampled.AudioInputStream;
|
||||
import javax.sound.sampled.AudioSystem;
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
public class SileroVadDetector {
|
||||
private final SileroVadOnnxModel model;
|
||||
private final float threshold;
|
||||
private final float negThreshold;
|
||||
private final int samplingRate;
|
||||
private final int windowSizeSample;
|
||||
private final float minSpeechSamples;
|
||||
private final float speechPadSamples;
|
||||
private final float maxSpeechSamples;
|
||||
private final float minSilenceSamples;
|
||||
private final float minSilenceSamplesAtMaxSpeech;
|
||||
private int audioLengthSamples;
|
||||
private static final float THRESHOLD_GAP = 0.15f;
|
||||
private static final Integer SAMPLING_RATE_8K = 8000;
|
||||
private static final Integer SAMPLING_RATE_16K = 16000;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
* @param onnxModelPath the path of silero-vad onnx model
|
||||
* @param threshold threshold for speech start
|
||||
* @param samplingRate audio sampling rate, only available for [8k, 16k]
|
||||
* @param minSpeechDurationMs Minimum speech length in millis, any speech duration that smaller than this value would not be considered as speech
|
||||
* @param maxSpeechDurationSeconds Maximum speech length in millis, recommend to be set as Float.POSITIVE_INFINITY
|
||||
* @param minSilenceDurationMs Minimum silence length in millis, any silence duration that smaller than this value would not be considered as silence
|
||||
* @param speechPadMs Additional pad millis for speech start and end
|
||||
* @throws OrtException
|
||||
*/
|
||||
public SileroVadDetector(String onnxModelPath, float threshold, int samplingRate,
|
||||
int minSpeechDurationMs, float maxSpeechDurationSeconds,
|
||||
int minSilenceDurationMs, int speechPadMs) throws OrtException {
|
||||
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K) {
|
||||
throw new IllegalArgumentException("Sampling rate not support, only available for [8000, 16000]");
|
||||
}
|
||||
this.model = new SileroVadOnnxModel(onnxModelPath);
|
||||
this.samplingRate = samplingRate;
|
||||
this.threshold = threshold;
|
||||
this.negThreshold = threshold - THRESHOLD_GAP;
|
||||
if (samplingRate == SAMPLING_RATE_16K) {
|
||||
this.windowSizeSample = 512;
|
||||
} else {
|
||||
this.windowSizeSample = 256;
|
||||
}
|
||||
this.minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
|
||||
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
|
||||
this.maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - windowSizeSample - 2 * speechPadSamples;
|
||||
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
||||
this.minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
|
||||
this.reset();
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to reset the state
|
||||
*/
|
||||
public void reset() {
|
||||
model.resetStates();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get speech segment list by given wav-format file
|
||||
* @param wavFile wav file
|
||||
* @return list of speech segment
|
||||
*/
|
||||
public List<SileroSpeechSegment> getSpeechSegmentList(File wavFile) {
|
||||
reset();
|
||||
try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(wavFile)){
|
||||
List<Float> speechProbList = new ArrayList<>();
|
||||
this.audioLengthSamples = audioInputStream.available() / 2;
|
||||
byte[] data = new byte[this.windowSizeSample * 2];
|
||||
int numBytesRead = 0;
|
||||
|
||||
while ((numBytesRead = audioInputStream.read(data)) != -1) {
|
||||
if (numBytesRead <= 0) {
|
||||
break;
|
||||
}
|
||||
// Convert the byte array to a float array
|
||||
float[] audioData = new float[data.length / 2];
|
||||
for (int i = 0; i < audioData.length; i++) {
|
||||
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
|
||||
}
|
||||
|
||||
float speechProb = 0;
|
||||
try {
|
||||
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
|
||||
speechProbList.add(speechProb);
|
||||
} catch (OrtException e) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
return calculateProb(speechProbList);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("SileroVadDetector getSpeechTimeList with error", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate speech segement by probability
|
||||
* @param speechProbList speech probability list
|
||||
* @return list of speech segment
|
||||
*/
|
||||
private List<SileroSpeechSegment> calculateProb(List<Float> speechProbList) {
|
||||
List<SileroSpeechSegment> result = new ArrayList<>();
|
||||
boolean triggered = false;
|
||||
int tempEnd = 0, prevEnd = 0, nextStart = 0;
|
||||
SileroSpeechSegment segment = new SileroSpeechSegment();
|
||||
|
||||
for (int i = 0; i < speechProbList.size(); i++) {
|
||||
Float speechProb = speechProbList.get(i);
|
||||
if (speechProb >= threshold && (tempEnd != 0)) {
|
||||
tempEnd = 0;
|
||||
if (nextStart < prevEnd) {
|
||||
nextStart = windowSizeSample * i;
|
||||
}
|
||||
}
|
||||
|
||||
if (speechProb >= threshold && !triggered) {
|
||||
triggered = true;
|
||||
segment.setStartOffset(windowSizeSample * i);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (triggered && (windowSizeSample * i) - segment.getStartOffset() > maxSpeechSamples) {
|
||||
if (prevEnd != 0) {
|
||||
segment.setEndOffset(prevEnd);
|
||||
result.add(segment);
|
||||
segment = new SileroSpeechSegment();
|
||||
if (nextStart < prevEnd) {
|
||||
triggered = false;
|
||||
}else {
|
||||
segment.setStartOffset(nextStart);
|
||||
}
|
||||
prevEnd = 0;
|
||||
nextStart = 0;
|
||||
tempEnd = 0;
|
||||
}else {
|
||||
segment.setEndOffset(windowSizeSample * i);
|
||||
result.add(segment);
|
||||
segment = new SileroSpeechSegment();
|
||||
prevEnd = 0;
|
||||
nextStart = 0;
|
||||
tempEnd = 0;
|
||||
triggered = false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (speechProb < negThreshold && triggered) {
|
||||
if (tempEnd == 0) {
|
||||
tempEnd = windowSizeSample * i;
|
||||
}
|
||||
if (((windowSizeSample * i) - tempEnd) > minSilenceSamplesAtMaxSpeech) {
|
||||
prevEnd = tempEnd;
|
||||
}
|
||||
if ((windowSizeSample * i) - tempEnd < minSilenceSamples) {
|
||||
continue;
|
||||
}else {
|
||||
segment.setEndOffset(tempEnd);
|
||||
if ((segment.getEndOffset() - segment.getStartOffset()) > minSpeechSamples) {
|
||||
result.add(segment);
|
||||
}
|
||||
segment = new SileroSpeechSegment();
|
||||
prevEnd = 0;
|
||||
nextStart = 0;
|
||||
tempEnd = 0;
|
||||
triggered = false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (segment.getStartOffset() != null && (audioLengthSamples - segment.getStartOffset()) > minSpeechSamples) {
|
||||
segment.setEndOffset(audioLengthSamples);
|
||||
result.add(segment);
|
||||
}
|
||||
|
||||
for (int i = 0; i < result.size(); i++) {
|
||||
SileroSpeechSegment item = result.get(i);
|
||||
if (i == 0) {
|
||||
item.setStartOffset((int)(Math.max(0,item.getStartOffset() - speechPadSamples)));
|
||||
}
|
||||
if (i != result.size() - 1) {
|
||||
SileroSpeechSegment nextItem = result.get(i + 1);
|
||||
Integer silenceDuration = nextItem.getStartOffset() - item.getEndOffset();
|
||||
if(silenceDuration < 2 * speechPadSamples){
|
||||
item.setEndOffset(item.getEndOffset() + (silenceDuration / 2 ));
|
||||
nextItem.setStartOffset(Math.max(0, nextItem.getStartOffset() - (silenceDuration / 2)));
|
||||
} else {
|
||||
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
|
||||
nextItem.setStartOffset((int)(Math.max(0,nextItem.getStartOffset() - speechPadSamples)));
|
||||
}
|
||||
}else {
|
||||
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
|
||||
}
|
||||
}
|
||||
|
||||
return mergeListAndCalculateSecond(result, samplingRate);
|
||||
}
|
||||
|
||||
private List<SileroSpeechSegment> mergeListAndCalculateSecond(List<SileroSpeechSegment> original, Integer samplingRate) {
|
||||
List<SileroSpeechSegment> result = new ArrayList<>();
|
||||
if (original == null || original.size() == 0) {
|
||||
return result;
|
||||
}
|
||||
Integer left = original.get(0).getStartOffset();
|
||||
Integer right = original.get(0).getEndOffset();
|
||||
if (original.size() > 1) {
|
||||
original.sort(Comparator.comparingLong(SileroSpeechSegment::getStartOffset));
|
||||
for (int i = 1; i < original.size(); i++) {
|
||||
SileroSpeechSegment segment = original.get(i);
|
||||
|
||||
if (segment.getStartOffset() > right) {
|
||||
result.add(new SileroSpeechSegment(left, right,
|
||||
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||
left = segment.getStartOffset();
|
||||
right = segment.getEndOffset();
|
||||
} else {
|
||||
right = Math.max(right, segment.getEndOffset());
|
||||
}
|
||||
}
|
||||
result.add(new SileroSpeechSegment(left, right,
|
||||
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||
}else {
|
||||
result.add(new SileroSpeechSegment(left, right,
|
||||
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private Float calculateSecondByOffset(Integer offset, Integer samplingRate) {
|
||||
float secondValue = offset * 1.0f / samplingRate;
|
||||
return (float) Math.floor(secondValue * 1000.0f) / 1000.0f;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
package org.example;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SileroVadOnnxModel {
|
||||
// Define private variable OrtSession
|
||||
private final OrtSession session;
|
||||
private float[][][] state;
|
||||
private float[][] context;
|
||||
// Define the last sample rate
|
||||
private int lastSr = 0;
|
||||
// Define the last batch size
|
||||
private int lastBatchSize = 0;
|
||||
// Define a list of supported sample rates
|
||||
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
|
||||
|
||||
// Constructor
|
||||
public SileroVadOnnxModel(String modelPath) throws OrtException {
|
||||
// Get the ONNX runtime environment
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
// Create an ONNX session options object
|
||||
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
||||
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
|
||||
opts.setInterOpNumThreads(1);
|
||||
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
|
||||
opts.setIntraOpNumThreads(1);
|
||||
// Add a CPU device, setting to false disables CPU execution optimization
|
||||
opts.addCPU(true);
|
||||
// Create an ONNX session using the environment, model path, and options
|
||||
session = env.createSession(modelPath, opts);
|
||||
// Reset states
|
||||
resetStates();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset states
|
||||
*/
|
||||
void resetStates() {
|
||||
state = new float[2][1][128];
|
||||
context = new float[0][];
|
||||
lastSr = 0;
|
||||
lastBatchSize = 0;
|
||||
}
|
||||
|
||||
public void close() throws OrtException {
|
||||
session.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Define inner class ValidationResult
|
||||
*/
|
||||
public static class ValidationResult {
|
||||
public final float[][] x;
|
||||
public final int sr;
|
||||
|
||||
// Constructor
|
||||
public ValidationResult(float[][] x, int sr) {
|
||||
this.x = x;
|
||||
this.sr = sr;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Function to validate input data
|
||||
*/
|
||||
private ValidationResult validateInput(float[][] x, int sr) {
|
||||
// Process the input data with dimension 1
|
||||
if (x.length == 1) {
|
||||
x = new float[][]{x[0]};
|
||||
}
|
||||
// Throw an exception when the input data dimension is greater than 2
|
||||
if (x.length > 2) {
|
||||
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
|
||||
}
|
||||
|
||||
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
|
||||
if (sr != 16000 && (sr % 16000 == 0)) {
|
||||
int step = sr / 16000;
|
||||
float[][] reducedX = new float[x.length][];
|
||||
|
||||
for (int i = 0; i < x.length; i++) {
|
||||
float[] current = x[i];
|
||||
float[] newArr = new float[(current.length + step - 1) / step];
|
||||
|
||||
for (int j = 0, index = 0; j < current.length; j += step, index++) {
|
||||
newArr[index] = current[j];
|
||||
}
|
||||
|
||||
reducedX[i] = newArr;
|
||||
}
|
||||
|
||||
x = reducedX;
|
||||
sr = 16000;
|
||||
}
|
||||
|
||||
// If the sample rate is not in the list of supported sample rates, throw an exception
|
||||
if (!SAMPLE_RATES.contains(sr)) {
|
||||
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
|
||||
}
|
||||
|
||||
// If the input audio block is too short, throw an exception
|
||||
if (((float) sr) / x[0].length > 31.25) {
|
||||
throw new IllegalArgumentException("Input audio is too short");
|
||||
}
|
||||
|
||||
// Return the validated result
|
||||
return new ValidationResult(x, sr);
|
||||
}
|
||||
|
||||
private static float[][] concatenate(float[][] a, float[][] b) {
|
||||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("The number of rows in both arrays must be the same.");
|
||||
}
|
||||
|
||||
int rows = a.length;
|
||||
int colsA = a[0].length;
|
||||
int colsB = b[0].length;
|
||||
float[][] result = new float[rows][colsA + colsB];
|
||||
|
||||
for (int i = 0; i < rows; i++) {
|
||||
System.arraycopy(a[i], 0, result[i], 0, colsA);
|
||||
System.arraycopy(b[i], 0, result[i], colsA, colsB);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private static float[][] getLastColumns(float[][] array, int contextSize) {
|
||||
int rows = array.length;
|
||||
int cols = array[0].length;
|
||||
|
||||
if (contextSize > cols) {
|
||||
throw new IllegalArgumentException("contextSize cannot be greater than the number of columns in the array.");
|
||||
}
|
||||
|
||||
float[][] result = new float[rows][contextSize];
|
||||
|
||||
for (int i = 0; i < rows; i++) {
|
||||
System.arraycopy(array[i], cols - contextSize, result[i], 0, contextSize);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to call the ONNX model
|
||||
*/
|
||||
public float[] call(float[][] x, int sr) throws OrtException {
|
||||
ValidationResult result = validateInput(x, sr);
|
||||
x = result.x;
|
||||
sr = result.sr;
|
||||
int numberSamples = 256;
|
||||
if (sr == 16000) {
|
||||
numberSamples = 512;
|
||||
}
|
||||
|
||||
if (x[0].length != numberSamples) {
|
||||
throw new IllegalArgumentException("Provided number of samples is " + x[0].length + " (Supported values: 256 for 8000 sample rate, 512 for 16000)");
|
||||
}
|
||||
|
||||
int batchSize = x.length;
|
||||
|
||||
int contextSize = 32;
|
||||
if (sr == 16000) {
|
||||
contextSize = 64;
|
||||
}
|
||||
|
||||
if (lastBatchSize == 0) {
|
||||
resetStates();
|
||||
}
|
||||
if (lastSr != 0 && lastSr != sr) {
|
||||
resetStates();
|
||||
}
|
||||
if (lastBatchSize != 0 && lastBatchSize != batchSize) {
|
||||
resetStates();
|
||||
}
|
||||
|
||||
if (context.length == 0) {
|
||||
context = new float[batchSize][contextSize];
|
||||
}
|
||||
|
||||
x = concatenate(context, x);
|
||||
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
||||
OnnxTensor inputTensor = null;
|
||||
OnnxTensor stateTensor = null;
|
||||
OnnxTensor srTensor = null;
|
||||
OrtSession.Result ortOutputs = null;
|
||||
|
||||
try {
|
||||
// Create input tensors
|
||||
inputTensor = OnnxTensor.createTensor(env, x);
|
||||
stateTensor = OnnxTensor.createTensor(env, state);
|
||||
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
|
||||
|
||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||
inputs.put("input", inputTensor);
|
||||
inputs.put("sr", srTensor);
|
||||
inputs.put("state", stateTensor);
|
||||
|
||||
// Call the ONNX model for calculation
|
||||
ortOutputs = session.run(inputs);
|
||||
// Get the output results
|
||||
float[][] output = (float[][]) ortOutputs.get(0).getValue();
|
||||
state = (float[][][]) ortOutputs.get(1).getValue();
|
||||
|
||||
context = getLastColumns(x, contextSize);
|
||||
lastSr = sr;
|
||||
lastBatchSize = batchSize;
|
||||
return output[0];
|
||||
} finally {
|
||||
if (inputTensor != null) {
|
||||
inputTensor.close();
|
||||
}
|
||||
if (stateTensor != null) {
|
||||
stateTensor.close();
|
||||
}
|
||||
if (srTensor != null) {
|
||||
srTensor.close();
|
||||
}
|
||||
if (ortOutputs != null) {
|
||||
ortOutputs.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user