Update SlieroVadDetector.java

This commit is contained in:
GH
2025-10-11 16:21:45 +08:00
committed by GitHub
parent 3d860e6ace
commit 25a778c798

View File

@@ -8,25 +8,30 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
/**
* Silero VAD Detector
* Real-time voice activity detection
*
* @author VvvvvGH
*/
public class SlieroVadDetector { public class SlieroVadDetector {
// OnnxModel model used for speech processing // ONNX model for speech processing
private final SlieroVadOnnxModel model; private final SlieroVadOnnxModel model;
// Threshold for speech start // Speech start threshold
private final float startThreshold; private final float startThreshold;
// Threshold for speech end // Speech end threshold
private final float endThreshold; private final float endThreshold;
// Sampling rate // Sampling rate
private final int samplingRate; private final int samplingRate;
// Minimum number of silence samples to determine the end threshold of speech // Minimum silence samples to determine speech end
private final float minSilenceSamples; private final float minSilenceSamples;
// Additional number of samples for speech start or end to calculate speech start or end time // Speech padding samples for calculating speech boundaries
private final float speechPadSamples; private final float speechPadSamples;
// Whether in the triggered state (i.e. whether speech is being detected) // Triggered state (whether speech is being detected)
private boolean triggered; private boolean triggered;
// Temporarily stored number of speech end samples // Temporary speech end sample position
private int tempEnd; private int tempEnd;
// Number of samples currently being processed // Current sample position
private int currentSample; private int currentSample;
@@ -36,23 +41,25 @@ public class SlieroVadDetector {
int samplingRate, int samplingRate,
int minSilenceDurationMs, int minSilenceDurationMs,
int speechPadMs) throws OrtException { int speechPadMs) throws OrtException {
// Check if the sampling rate is 8000 or 16000, if not, throw an exception // Validate sampling rate
if (samplingRate != 8000 && samplingRate != 16000) { if (samplingRate != 8000 && samplingRate != 16000) {
throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]"); throw new IllegalArgumentException("Does not support sampling rates other than [8000, 16000]");
} }
// Initialize the parameters // Initialize parameters
this.model = new SlieroVadOnnxModel(modelPath); this.model = new SlieroVadOnnxModel(modelPath);
this.startThreshold = startThreshold; this.startThreshold = startThreshold;
this.endThreshold = endThreshold; this.endThreshold = endThreshold;
this.samplingRate = samplingRate; this.samplingRate = samplingRate;
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f; this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
this.speechPadSamples = samplingRate * speechPadMs / 1000f; this.speechPadSamples = samplingRate * speechPadMs / 1000f;
// Reset the state // Reset state
reset(); reset();
} }
// Method to reset the state, including the model state, trigger state, temporary end time, and current sample count /**
* Reset detector state
*/
public void reset() { public void reset() {
model.resetStates(); model.resetStates();
triggered = false; triggered = false;
@@ -60,21 +67,27 @@ public class SlieroVadDetector {
currentSample = 0; currentSample = 0;
} }
// apply method for processing the audio array, returning possible speech start or end times /**
* Process audio data and detect speech events
*
* @param data Audio data as byte array
* @param returnSeconds Whether to return timestamps in seconds
* @return Speech event (start or end) or empty map if no event
*/
public Map<String, Double> apply(byte[] data, boolean returnSeconds) { public Map<String, Double> apply(byte[] data, boolean returnSeconds) {
// Convert the byte array to a float array // Convert byte array to float array
float[] audioData = new float[data.length / 2]; float[] audioData = new float[data.length / 2];
for (int i = 0; i < audioData.length; i++) { for (int i = 0; i < audioData.length; i++) {
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f; audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
} }
// Get the length of the audio array as the window size // Get window size from audio data length
int windowSizeSamples = audioData.length; int windowSizeSamples = audioData.length;
// Update the current sample count // Update current sample position
currentSample += windowSizeSamples; currentSample += windowSizeSamples;
// Call the model to get the prediction probability of speech // Get speech probability from model
float speechProb = 0; float speechProb = 0;
try { try {
speechProb = model.call(new float[][]{audioData}, samplingRate)[0]; speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
@@ -82,19 +95,18 @@ public class SlieroVadDetector {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
// If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time // Reset temporary end if speech probability exceeds threshold
// This indicates that the speech duration has exceeded expectations and needs to recalculate the end time
if (speechProb >= startThreshold && tempEnd != 0) { if (speechProb >= startThreshold && tempEnd != 0) {
tempEnd = 0; tempEnd = 0;
} }
// If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time // Detect speech start
if (speechProb >= startThreshold && !triggered) { if (speechProb >= startThreshold && !triggered) {
triggered = true; triggered = true;
int speechStart = (int) (currentSample - speechPadSamples); int speechStart = (int) (currentSample - speechPadSamples);
speechStart = Math.max(speechStart, 0); speechStart = Math.max(speechStart, 0);
Map<String, Double> result = new HashMap<>(); Map<String, Double> result = new HashMap<>();
// Decide whether to return the result in seconds or sample count based on the returnSeconds parameter // Return in seconds or samples based on returnSeconds parameter
if (returnSeconds) { if (returnSeconds) {
double speechStartSeconds = speechStart / (double) samplingRate; double speechStartSeconds = speechStart / (double) samplingRate;
double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue(); double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
@@ -106,18 +118,17 @@ public class SlieroVadDetector {
return result; return result;
} }
// If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time // Detect speech end
if (speechProb < endThreshold && triggered) { if (speechProb < endThreshold && triggered) {
// Initialize or update the temporary end time // Initialize or update temporary end position
if (tempEnd == 0) { if (tempEnd == 0) {
tempEnd = currentSample; tempEnd = currentSample;
} }
// If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null // Wait for minimum silence duration before confirming speech end
// This indicates that it is not yet possible to determine whether the speech has ended
if (currentSample - tempEnd < minSilenceSamples) { if (currentSample - tempEnd < minSilenceSamples) {
return Collections.emptyMap(); return Collections.emptyMap();
} else { } else {
// Calculate the speech end time, reset the trigger state and temporary end time // Calculate speech end time and reset state
int speechEnd = (int) (tempEnd + speechPadSamples); int speechEnd = (int) (tempEnd + speechPadSamples);
tempEnd = 0; tempEnd = 0;
triggered = false; triggered = false;
@@ -134,7 +145,7 @@ public class SlieroVadDetector {
} }
} }
// If the above conditions are not met, return null by default // No speech event detected
return Collections.emptyMap(); return Collections.emptyMap();
} }