Merge pull request #506 from yuguanqin/master

Add java example for wav file & support V5 model
This commit is contained in:
Alexander Veysov
2024-07-18 07:34:40 +03:00
committed by GitHub
4 changed files with 566 additions and 0 deletions

View File

@@ -0,0 +1,37 @@
package org.example;
import ai.onnxruntime.OrtException;
import java.io.File;
import java.util.List;
public class App {
private static final String MODEL_PATH = "/path/silero_vad.onnx";
private static final String EXAMPLE_WAV_FILE = "/path/example.wav";
private static final int SAMPLE_RATE = 16000;
private static final float THRESHOLD = 0.5f;
private static final int MIN_SPEECH_DURATION_MS = 250;
private static final float MAX_SPEECH_DURATION_SECONDS = Float.POSITIVE_INFINITY;
private static final int MIN_SILENCE_DURATION_MS = 100;
private static final int SPEECH_PAD_MS = 30;
public static void main(String[] args) {
// Initialize the Voice Activity Detector
SileroVadDetector vadDetector;
try {
vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
fromWavFile(vadDetector, new File(EXAMPLE_WAV_FILE));
} catch (OrtException e) {
System.err.println("Error initializing the VAD detector: " + e.getMessage());
}
}
public static void fromWavFile(SileroVadDetector vadDetector, File wavFile) {
List<SileroSpeechSegment> speechTimeList = vadDetector.getSpeechSegmentList(wavFile);
for (SileroSpeechSegment speechSegment : speechTimeList) {
System.out.println(String.format("start second: %f, end second: %f",
speechSegment.getStartSecond(), speechSegment.getEndSecond()));
}
}
}

View File

@@ -0,0 +1,51 @@
package org.example;
public class SileroSpeechSegment {
private Integer startOffset;
private Integer endOffset;
private Float startSecond;
private Float endSecond;
public SileroSpeechSegment() {
}
public SileroSpeechSegment(Integer startOffset, Integer endOffset, Float startSecond, Float endSecond) {
this.startOffset = startOffset;
this.endOffset = endOffset;
this.startSecond = startSecond;
this.endSecond = endSecond;
}
public Integer getStartOffset() {
return startOffset;
}
public Integer getEndOffset() {
return endOffset;
}
public Float getStartSecond() {
return startSecond;
}
public Float getEndSecond() {
return endSecond;
}
public void setStartOffset(Integer startOffset) {
this.startOffset = startOffset;
}
public void setEndOffset(Integer endOffset) {
this.endOffset = endOffset;
}
public void setStartSecond(Float startSecond) {
this.startSecond = startSecond;
}
public void setEndSecond(Float endSecond) {
this.endSecond = endSecond;
}
}

View File

@@ -0,0 +1,244 @@
package org.example;
import ai.onnxruntime.OrtException;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.File;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
public class SileroVadDetector {
private final SileroVadOnnxModel model;
private final float threshold;
private final float negThreshold;
private final int samplingRate;
private final int windowSizeSample;
private final float minSpeechSamples;
private final float speechPadSamples;
private final float maxSpeechSamples;
private final float minSilenceSamples;
private final float minSilenceSamplesAtMaxSpeech;
private int audioLengthSamples;
private static final float THRESHOLD_GAP = 0.15f;
private static final Integer SAMPLING_RATE_8K = 8000;
private static final Integer SAMPLING_RATE_16K = 16000;
/**
* Constructor
* @param onnxModelPath the path of silero-vad onnx model
* @param threshold threshold for speech start
* @param samplingRate audio sampling rate, only available for [8k, 16k]
* @param minSpeechDurationMs Minimum speech length in millis, any speech duration that smaller than this value would not be considered as speech
* @param maxSpeechDurationSeconds Maximum speech length in millis, recommend to be set as Float.POSITIVE_INFINITY
* @param minSilenceDurationMs Minimum silence length in millis, any silence duration that smaller than this value would not be considered as silence
* @param speechPadMs Additional pad millis for speech start and end
* @throws OrtException
*/
public SileroVadDetector(String onnxModelPath, float threshold, int samplingRate,
int minSpeechDurationMs, float maxSpeechDurationSeconds,
int minSilenceDurationMs, int speechPadMs) throws OrtException {
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K) {
throw new IllegalArgumentException("Sampling rate not support, only available for [8000, 16000]");
}
this.model = new SileroVadOnnxModel(onnxModelPath);
this.samplingRate = samplingRate;
this.threshold = threshold;
this.negThreshold = threshold - THRESHOLD_GAP;
if (samplingRate == SAMPLING_RATE_16K) {
this.windowSizeSample = 512;
} else {
this.windowSizeSample = 256;
}
this.minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
this.maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - windowSizeSample - 2 * speechPadSamples;
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
this.minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
this.reset();
}
/**
* Method to reset the state
*/
public void reset() {
model.resetStates();
}
/**
* Get speech segment list by given wav-format file
* @param wavFile wav file
* @return list of speech segment
*/
public List<SileroSpeechSegment> getSpeechSegmentList(File wavFile) {
reset();
try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(wavFile)){
List<Float> speechProbList = new ArrayList<>();
this.audioLengthSamples = audioInputStream.available() / 2;
byte[] data = new byte[this.windowSizeSample * 2];
int numBytesRead = 0;
while ((numBytesRead = audioInputStream.read(data)) != -1) {
if (numBytesRead <= 0) {
break;
}
// Convert the byte array to a float array
float[] audioData = new float[data.length / 2];
for (int i = 0; i < audioData.length; i++) {
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
}
float speechProb = 0;
try {
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
speechProbList.add(speechProb);
} catch (OrtException e) {
throw e;
}
}
return calculateProb(speechProbList);
} catch (Exception e) {
throw new RuntimeException("SileroVadDetector getSpeechTimeList with error", e);
}
}
/**
* Calculate speech segement by probability
* @param speechProbList speech probability list
* @return list of speech segment
*/
private List<SileroSpeechSegment> calculateProb(List<Float> speechProbList) {
List<SileroSpeechSegment> result = new ArrayList<>();
boolean triggered = false;
int tempEnd = 0, prevEnd = 0, nextStart = 0;
SileroSpeechSegment segment = new SileroSpeechSegment();
for (int i = 0; i < speechProbList.size(); i++) {
Float speechProb = speechProbList.get(i);
if (speechProb >= threshold && (tempEnd != 0)) {
tempEnd = 0;
if (nextStart < prevEnd) {
nextStart = windowSizeSample * i;
}
}
if (speechProb >= threshold && !triggered) {
triggered = true;
segment.setStartOffset(windowSizeSample * i);
continue;
}
if (triggered && (windowSizeSample * i) - segment.getStartOffset() > maxSpeechSamples) {
if (prevEnd != 0) {
segment.setEndOffset(prevEnd);
result.add(segment);
segment = new SileroSpeechSegment();
if (nextStart < prevEnd) {
triggered = false;
}else {
segment.setStartOffset(nextStart);
}
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
}else {
segment.setEndOffset(windowSizeSample * i);
result.add(segment);
segment = new SileroSpeechSegment();
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
triggered = false;
continue;
}
}
if (speechProb < negThreshold && triggered) {
if (tempEnd == 0) {
tempEnd = windowSizeSample * i;
}
if (((windowSizeSample * i) - tempEnd) > minSilenceSamplesAtMaxSpeech) {
prevEnd = tempEnd;
}
if ((windowSizeSample * i) - tempEnd < minSilenceSamples) {
continue;
}else {
segment.setEndOffset(tempEnd);
if ((segment.getEndOffset() - segment.getStartOffset()) > minSpeechSamples) {
result.add(segment);
}
segment = new SileroSpeechSegment();
prevEnd = 0;
nextStart = 0;
tempEnd = 0;
triggered = false;
continue;
}
}
}
if (segment.getStartOffset() != null && (audioLengthSamples - segment.getStartOffset()) > minSpeechSamples) {
segment.setEndOffset(audioLengthSamples);
result.add(segment);
}
for (int i = 0; i < result.size(); i++) {
SileroSpeechSegment item = result.get(i);
if (i == 0) {
item.setStartOffset((int)(Math.max(0,item.getStartOffset() - speechPadSamples)));
}
if (i != result.size() - 1) {
SileroSpeechSegment nextItem = result.get(i + 1);
Integer silenceDuration = nextItem.getStartOffset() - item.getEndOffset();
if(silenceDuration < 2 * speechPadSamples){
item.setEndOffset(item.getEndOffset() + (silenceDuration / 2 ));
nextItem.setStartOffset(Math.max(0, nextItem.getStartOffset() - (silenceDuration / 2)));
} else {
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
nextItem.setStartOffset((int)(Math.max(0,nextItem.getStartOffset() - speechPadSamples)));
}
}else {
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
}
}
return mergeListAndCalculateSecond(result, samplingRate);
}
private List<SileroSpeechSegment> mergeListAndCalculateSecond(List<SileroSpeechSegment> original, Integer samplingRate) {
List<SileroSpeechSegment> result = new ArrayList<>();
if (original == null || original.size() == 0) {
return result;
}
Integer left = original.get(0).getStartOffset();
Integer right = original.get(0).getEndOffset();
if (original.size() > 1) {
original.sort(Comparator.comparingLong(SileroSpeechSegment::getStartOffset));
for (int i = 1; i < original.size(); i++) {
SileroSpeechSegment segment = original.get(i);
if (segment.getStartOffset() > right) {
result.add(new SileroSpeechSegment(left, right,
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
left = segment.getStartOffset();
right = segment.getEndOffset();
} else {
right = Math.max(right, segment.getEndOffset());
}
}
result.add(new SileroSpeechSegment(left, right,
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
}else {
result.add(new SileroSpeechSegment(left, right,
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
}
return result;
}
private Float calculateSecondByOffset(Integer offset, Integer samplingRate) {
float secondValue = offset * 1.0f / samplingRate;
return (float) Math.floor(secondValue * 1000.0f) / 1000.0f;
}
}

View File

@@ -0,0 +1,234 @@
package org.example;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class SileroVadOnnxModel {
// Define private variable OrtSession
private final OrtSession session;
private float[][][] state;
private float[][] context;
// Define the last sample rate
private int lastSr = 0;
// Define the last batch size
private int lastBatchSize = 0;
// Define a list of supported sample rates
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
// Constructor
public SileroVadOnnxModel(String modelPath) throws OrtException {
// Get the ONNX runtime environment
OrtEnvironment env = OrtEnvironment.getEnvironment();
// Create an ONNX session options object
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
opts.setInterOpNumThreads(1);
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
opts.setIntraOpNumThreads(1);
// Add a CPU device, setting to false disables CPU execution optimization
opts.addCPU(true);
// Create an ONNX session using the environment, model path, and options
session = env.createSession(modelPath, opts);
// Reset states
resetStates();
}
/**
* Reset states
*/
void resetStates() {
state = new float[2][1][128];
context = new float[0][];
lastSr = 0;
lastBatchSize = 0;
}
public void close() throws OrtException {
session.close();
}
/**
* Define inner class ValidationResult
*/
public static class ValidationResult {
public final float[][] x;
public final int sr;
// Constructor
public ValidationResult(float[][] x, int sr) {
this.x = x;
this.sr = sr;
}
}
/**
* Function to validate input data
*/
private ValidationResult validateInput(float[][] x, int sr) {
// Process the input data with dimension 1
if (x.length == 1) {
x = new float[][]{x[0]};
}
// Throw an exception when the input data dimension is greater than 2
if (x.length > 2) {
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
}
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
if (sr != 16000 && (sr % 16000 == 0)) {
int step = sr / 16000;
float[][] reducedX = new float[x.length][];
for (int i = 0; i < x.length; i++) {
float[] current = x[i];
float[] newArr = new float[(current.length + step - 1) / step];
for (int j = 0, index = 0; j < current.length; j += step, index++) {
newArr[index] = current[j];
}
reducedX[i] = newArr;
}
x = reducedX;
sr = 16000;
}
// If the sample rate is not in the list of supported sample rates, throw an exception
if (!SAMPLE_RATES.contains(sr)) {
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
}
// If the input audio block is too short, throw an exception
if (((float) sr) / x[0].length > 31.25) {
throw new IllegalArgumentException("Input audio is too short");
}
// Return the validated result
return new ValidationResult(x, sr);
}
private static float[][] concatenate(float[][] a, float[][] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("The number of rows in both arrays must be the same.");
}
int rows = a.length;
int colsA = a[0].length;
int colsB = b[0].length;
float[][] result = new float[rows][colsA + colsB];
for (int i = 0; i < rows; i++) {
System.arraycopy(a[i], 0, result[i], 0, colsA);
System.arraycopy(b[i], 0, result[i], colsA, colsB);
}
return result;
}
private static float[][] getLastColumns(float[][] array, int contextSize) {
int rows = array.length;
int cols = array[0].length;
if (contextSize > cols) {
throw new IllegalArgumentException("contextSize cannot be greater than the number of columns in the array.");
}
float[][] result = new float[rows][contextSize];
for (int i = 0; i < rows; i++) {
System.arraycopy(array[i], cols - contextSize, result[i], 0, contextSize);
}
return result;
}
/**
* Method to call the ONNX model
*/
public float[] call(float[][] x, int sr) throws OrtException {
ValidationResult result = validateInput(x, sr);
x = result.x;
sr = result.sr;
int numberSamples = 256;
if (sr == 16000) {
numberSamples = 512;
}
if (x[0].length != numberSamples) {
throw new IllegalArgumentException("Provided number of samples is " + x[0].length + " (Supported values: 256 for 8000 sample rate, 512 for 16000)");
}
int batchSize = x.length;
int contextSize = 32;
if (sr == 16000) {
contextSize = 64;
}
if (lastBatchSize == 0) {
resetStates();
}
if (lastSr != 0 && lastSr != sr) {
resetStates();
}
if (lastBatchSize != 0 && lastBatchSize != batchSize) {
resetStates();
}
if (context.length == 0) {
context = new float[batchSize][contextSize];
}
x = concatenate(context, x);
OrtEnvironment env = OrtEnvironment.getEnvironment();
OnnxTensor inputTensor = null;
OnnxTensor stateTensor = null;
OnnxTensor srTensor = null;
OrtSession.Result ortOutputs = null;
try {
// Create input tensors
inputTensor = OnnxTensor.createTensor(env, x);
stateTensor = OnnxTensor.createTensor(env, state);
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input", inputTensor);
inputs.put("sr", srTensor);
inputs.put("state", stateTensor);
// Call the ONNX model for calculation
ortOutputs = session.run(inputs);
// Get the output results
float[][] output = (float[][]) ortOutputs.get(0).getValue();
state = (float[][][]) ortOutputs.get(1).getValue();
context = getLastColumns(x, contextSize);
lastSr = sr;
lastBatchSize = batchSize;
return output[0];
} finally {
if (inputTensor != null) {
inputTensor.close();
}
if (stateTensor != null) {
stateTensor.close();
}
if (srTensor != null) {
srTensor.close();
}
if (ortOutputs != null) {
ortOutputs.close();
}
}
}
}