mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 01:49:22 +08:00
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f0c9ead54 | ||
|
|
9623ce72da | ||
|
|
b6dd0599fc | ||
|
|
d8f88c9157 | ||
|
|
b15a216b47 | ||
|
|
2389039408 | ||
|
|
df22fcaec8 | ||
|
|
81e8a48e25 | ||
|
|
a14a23faa7 | ||
|
|
a30b5843c1 | ||
|
|
a66c890188 | ||
|
|
77c91a91fa | ||
|
|
33093c6f1b | ||
|
|
dc0b62e1e4 | ||
|
|
64fb49e1c8 | ||
|
|
55ba6e2825 | ||
|
|
b90f8c012f | ||
|
|
25a778c798 | ||
|
|
3d860e6ace | ||
|
|
f5ea01bfda | ||
|
|
dd651a54a5 | ||
|
|
f1175c902f | ||
|
|
7819fd911b |
@@ -1,6 +1,6 @@
|
|||||||
[](mailto:hello@silero.ai) [](https://t.me/silero_speech) [](https://github.com/snakers4/silero-vad/blob/master/LICENSE) [](https://pypi.org/project/silero-vad/)
|
[](mailto:hello@silero.ai) [](https://t.me/silero_speech) [](https://github.com/snakers4/silero-vad/blob/master/LICENSE) [](https://pypi.org/project/silero-vad/)
|
||||||
|
|
||||||
[](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) [](https://github.com/snakers4/silero-vad/actions/workflows/test.yml)
|
[](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) [](https://github.com/snakers4/silero-vad/actions/workflows/test.yml) [](https://pypi.org/project/silero-vad/) [](https://pypi.org/project/silero-vad)
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|||||||
@@ -1,30 +1,31 @@
|
|||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<groupId>org.example</groupId>
|
<groupId>org.example</groupId>
|
||||||
<artifactId>java-example</artifactId>
|
<artifactId>java-example</artifactId>
|
||||||
<version>1.0-SNAPSHOT</version>
|
<version>1.0-SNAPSHOT</version>
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
<name>sliero-vad-example</name>
|
<name>sliero-vad-example</name>
|
||||||
<url>http://maven.apache.org</url>
|
<url>http://maven.apache.org</url>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>junit</groupId>
|
<groupId>junit</groupId>
|
||||||
<artifactId>junit</artifactId>
|
<artifactId>junit</artifactId>
|
||||||
<version>3.8.1</version>
|
<version>3.8.1</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<!-- https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime -->
|
||||||
<groupId>com.microsoft.onnxruntime</groupId>
|
<dependency>
|
||||||
<artifactId>onnxruntime</artifactId>
|
<groupId>com.microsoft.onnxruntime</groupId>
|
||||||
<version>1.16.0-rc1</version>
|
<artifactId>onnxruntime</artifactId>
|
||||||
</dependency>
|
<version>1.23.1</version>
|
||||||
</dependencies>
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@@ -2,68 +2,263 @@ package org.example;
|
|||||||
|
|
||||||
import ai.onnxruntime.OrtException;
|
import ai.onnxruntime.OrtException;
|
||||||
import javax.sound.sampled.*;
|
import javax.sound.sampled.*;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Silero VAD Java Example
|
||||||
|
* Voice Activity Detection using ONNX model
|
||||||
|
*
|
||||||
|
* @author VvvvvGH
|
||||||
|
*/
|
||||||
public class App {
|
public class App {
|
||||||
|
|
||||||
private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx";
|
// ONNX model path - using the model file from the project
|
||||||
|
private static final String MODEL_PATH = "../../src/silero_vad/data/silero_vad.onnx";
|
||||||
|
// Test audio file path
|
||||||
|
private static final String AUDIO_FILE_PATH = "../../en_example.wav";
|
||||||
|
// Sampling rate
|
||||||
private static final int SAMPLE_RATE = 16000;
|
private static final int SAMPLE_RATE = 16000;
|
||||||
private static final float START_THRESHOLD = 0.6f;
|
// Speech threshold (consistent with Python default)
|
||||||
private static final float END_THRESHOLD = 0.45f;
|
private static final float THRESHOLD = 0.5f;
|
||||||
private static final int MIN_SILENCE_DURATION_MS = 600;
|
// Negative threshold (used to determine speech end)
|
||||||
private static final int SPEECH_PAD_MS = 500;
|
private static final float NEG_THRESHOLD = 0.35f; // threshold - 0.15
|
||||||
private static final int WINDOW_SIZE_SAMPLES = 2048;
|
// Minimum speech duration (milliseconds)
|
||||||
|
private static final int MIN_SPEECH_DURATION_MS = 250;
|
||||||
|
// Minimum silence duration (milliseconds)
|
||||||
|
private static final int MIN_SILENCE_DURATION_MS = 100;
|
||||||
|
// Speech padding (milliseconds)
|
||||||
|
private static final int SPEECH_PAD_MS = 30;
|
||||||
|
// Window size (samples) - 512 samples for 16kHz
|
||||||
|
private static final int WINDOW_SIZE_SAMPLES = 512;
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
// Initialize the Voice Activity Detector
|
System.out.println("=".repeat(60));
|
||||||
SlieroVadDetector vadDetector;
|
System.out.println("Silero VAD Java ONNX Example");
|
||||||
|
System.out.println("=".repeat(60));
|
||||||
|
|
||||||
|
// Load ONNX model
|
||||||
|
SlieroVadOnnxModel model;
|
||||||
try {
|
try {
|
||||||
vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
System.out.println("Loading ONNX model: " + MODEL_PATH);
|
||||||
|
model = new SlieroVadOnnxModel(MODEL_PATH);
|
||||||
|
System.out.println("Model loaded successfully!");
|
||||||
} catch (OrtException e) {
|
} catch (OrtException e) {
|
||||||
System.err.println("Error initializing the VAD detector: " + e.getMessage());
|
System.err.println("Failed to load model: " + e.getMessage());
|
||||||
|
e.printStackTrace();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set audio format
|
// Read WAV file
|
||||||
AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false);
|
float[] audioData;
|
||||||
DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
|
|
||||||
|
|
||||||
// Get the target data line and open it with the specified format
|
|
||||||
TargetDataLine targetDataLine;
|
|
||||||
try {
|
try {
|
||||||
targetDataLine = (TargetDataLine) AudioSystem.getLine(info);
|
System.out.println("\nReading audio file: " + AUDIO_FILE_PATH);
|
||||||
targetDataLine.open(format);
|
audioData = readWavFileAsFloatArray(AUDIO_FILE_PATH);
|
||||||
targetDataLine.start();
|
System.out.println("Audio file read successfully, samples: " + audioData.length);
|
||||||
} catch (LineUnavailableException e) {
|
System.out.println("Audio duration: " + String.format("%.2f", (audioData.length / (float) SAMPLE_RATE)) + " seconds");
|
||||||
System.err.println("Error opening target data line: " + e.getMessage());
|
} catch (Exception e) {
|
||||||
|
System.err.println("Failed to read audio file: " + e.getMessage());
|
||||||
|
e.printStackTrace();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Main loop to continuously read data and apply Voice Activity Detection
|
// Get speech timestamps (batch mode, consistent with Python's get_speech_timestamps)
|
||||||
while (targetDataLine.isOpen()) {
|
System.out.println("\nDetecting speech segments...");
|
||||||
byte[] data = new byte[WINDOW_SIZE_SAMPLES];
|
List<Map<String, Integer>> speechTimestamps;
|
||||||
|
try {
|
||||||
int numBytesRead = targetDataLine.read(data, 0, data.length);
|
speechTimestamps = getSpeechTimestamps(
|
||||||
if (numBytesRead <= 0) {
|
audioData,
|
||||||
System.err.println("Error reading data from target data line.");
|
model,
|
||||||
continue;
|
THRESHOLD,
|
||||||
}
|
SAMPLE_RATE,
|
||||||
|
MIN_SPEECH_DURATION_MS,
|
||||||
// Apply the Voice Activity Detector to the data and get the result
|
MIN_SILENCE_DURATION_MS,
|
||||||
Map<String, Double> detectResult;
|
SPEECH_PAD_MS,
|
||||||
try {
|
NEG_THRESHOLD
|
||||||
detectResult = vadDetector.apply(data, true);
|
);
|
||||||
} catch (Exception e) {
|
} catch (OrtException e) {
|
||||||
System.err.println("Error applying VAD detector: " + e.getMessage());
|
System.err.println("Failed to detect speech timestamps: " + e.getMessage());
|
||||||
continue;
|
e.printStackTrace();
|
||||||
}
|
return;
|
||||||
|
|
||||||
if (!detectResult.isEmpty()) {
|
|
||||||
System.out.println(detectResult);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the target data line to release audio resources
|
// Output detection results
|
||||||
targetDataLine.close();
|
System.out.println("\nDetected speech timestamps (in samples):");
|
||||||
|
for (Map<String, Integer> timestamp : speechTimestamps) {
|
||||||
|
System.out.println(timestamp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output summary
|
||||||
|
System.out.println("\n" + "=".repeat(60));
|
||||||
|
System.out.println("Detection completed!");
|
||||||
|
System.out.println("Total detected " + speechTimestamps.size() + " speech segments");
|
||||||
|
System.out.println("=".repeat(60));
|
||||||
|
|
||||||
|
// Close model
|
||||||
|
try {
|
||||||
|
model.close();
|
||||||
|
} catch (OrtException e) {
|
||||||
|
System.err.println("Error closing model: " + e.getMessage());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get speech timestamps
|
||||||
|
* Implements the same logic as Python's get_speech_timestamps
|
||||||
|
*
|
||||||
|
* @param audio Audio data (float array)
|
||||||
|
* @param model ONNX model
|
||||||
|
* @param threshold Speech threshold
|
||||||
|
* @param samplingRate Sampling rate
|
||||||
|
* @param minSpeechDurationMs Minimum speech duration (milliseconds)
|
||||||
|
* @param minSilenceDurationMs Minimum silence duration (milliseconds)
|
||||||
|
* @param speechPadMs Speech padding (milliseconds)
|
||||||
|
* @param negThreshold Negative threshold (used to determine speech end)
|
||||||
|
* @return List of speech timestamps
|
||||||
|
*/
|
||||||
|
private static List<Map<String, Integer>> getSpeechTimestamps(
|
||||||
|
float[] audio,
|
||||||
|
SlieroVadOnnxModel model,
|
||||||
|
float threshold,
|
||||||
|
int samplingRate,
|
||||||
|
int minSpeechDurationMs,
|
||||||
|
int minSilenceDurationMs,
|
||||||
|
int speechPadMs,
|
||||||
|
float negThreshold) throws OrtException {
|
||||||
|
|
||||||
|
// Reset model states
|
||||||
|
model.resetStates();
|
||||||
|
|
||||||
|
// Calculate parameters
|
||||||
|
int minSpeechSamples = samplingRate * minSpeechDurationMs / 1000;
|
||||||
|
int speechPadSamples = samplingRate * speechPadMs / 1000;
|
||||||
|
int minSilenceSamples = samplingRate * minSilenceDurationMs / 1000;
|
||||||
|
int windowSizeSamples = samplingRate == 16000 ? 512 : 256;
|
||||||
|
int audioLengthSamples = audio.length;
|
||||||
|
|
||||||
|
// Calculate speech probabilities for all audio chunks
|
||||||
|
List<Float> speechProbs = new ArrayList<>();
|
||||||
|
for (int currentStart = 0; currentStart < audioLengthSamples; currentStart += windowSizeSamples) {
|
||||||
|
float[] chunk = new float[windowSizeSamples];
|
||||||
|
int chunkLength = Math.min(windowSizeSamples, audioLengthSamples - currentStart);
|
||||||
|
System.arraycopy(audio, currentStart, chunk, 0, chunkLength);
|
||||||
|
|
||||||
|
// Pad with zeros if chunk is shorter than window size
|
||||||
|
if (chunkLength < windowSizeSamples) {
|
||||||
|
for (int i = chunkLength; i < windowSizeSamples; i++) {
|
||||||
|
chunk[i] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float speechProb = model.call(new float[][]{chunk}, samplingRate)[0];
|
||||||
|
speechProbs.add(speechProb);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect speech segments using the same algorithm as Python
|
||||||
|
boolean triggered = false;
|
||||||
|
List<Map<String, Integer>> speeches = new ArrayList<>();
|
||||||
|
Map<String, Integer> currentSpeech = null;
|
||||||
|
int tempEnd = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < speechProbs.size(); i++) {
|
||||||
|
float speechProb = speechProbs.get(i);
|
||||||
|
|
||||||
|
// Reset temporary end if speech probability exceeds threshold
|
||||||
|
if (speechProb >= threshold && tempEnd != 0) {
|
||||||
|
tempEnd = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect speech start
|
||||||
|
if (speechProb >= threshold && !triggered) {
|
||||||
|
triggered = true;
|
||||||
|
currentSpeech = new HashMap<>();
|
||||||
|
currentSpeech.put("start", windowSizeSamples * i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect speech end
|
||||||
|
if (speechProb < negThreshold && triggered) {
|
||||||
|
if (tempEnd == 0) {
|
||||||
|
tempEnd = windowSizeSamples * i;
|
||||||
|
}
|
||||||
|
if (windowSizeSamples * i - tempEnd < minSilenceSamples) {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
currentSpeech.put("end", tempEnd);
|
||||||
|
if (currentSpeech.get("end") - currentSpeech.get("start") > minSpeechSamples) {
|
||||||
|
speeches.add(currentSpeech);
|
||||||
|
}
|
||||||
|
currentSpeech = null;
|
||||||
|
tempEnd = 0;
|
||||||
|
triggered = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the last speech segment
|
||||||
|
if (currentSpeech != null &&
|
||||||
|
(audioLengthSamples - currentSpeech.get("start")) > minSpeechSamples) {
|
||||||
|
currentSpeech.put("end", audioLengthSamples);
|
||||||
|
speeches.add(currentSpeech);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add speech padding - same logic as Python
|
||||||
|
for (int i = 0; i < speeches.size(); i++) {
|
||||||
|
Map<String, Integer> speech = speeches.get(i);
|
||||||
|
if (i == 0) {
|
||||||
|
speech.put("start", Math.max(0, speech.get("start") - speechPadSamples));
|
||||||
|
}
|
||||||
|
if (i != speeches.size() - 1) {
|
||||||
|
int silenceDuration = speeches.get(i + 1).get("start") - speech.get("end");
|
||||||
|
if (silenceDuration < 2 * speechPadSamples) {
|
||||||
|
speech.put("end", speech.get("end") + silenceDuration / 2);
|
||||||
|
speeches.get(i + 1).put("start",
|
||||||
|
Math.max(0, speeches.get(i + 1).get("start") - silenceDuration / 2));
|
||||||
|
} else {
|
||||||
|
speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples));
|
||||||
|
speeches.get(i + 1).put("start",
|
||||||
|
Math.max(0, speeches.get(i + 1).get("start") - speechPadSamples));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return speeches;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read WAV file and return as float array
|
||||||
|
*
|
||||||
|
* @param filePath WAV file path
|
||||||
|
* @return Audio data as float array (normalized to -1.0 to 1.0)
|
||||||
|
*/
|
||||||
|
private static float[] readWavFileAsFloatArray(String filePath)
|
||||||
|
throws UnsupportedAudioFileException, IOException {
|
||||||
|
File audioFile = new File(filePath);
|
||||||
|
AudioInputStream audioStream = AudioSystem.getAudioInputStream(audioFile);
|
||||||
|
|
||||||
|
// Get audio format information
|
||||||
|
AudioFormat format = audioStream.getFormat();
|
||||||
|
System.out.println("Audio format: " + format);
|
||||||
|
|
||||||
|
// Read all audio data
|
||||||
|
byte[] audioBytes = audioStream.readAllBytes();
|
||||||
|
audioStream.close();
|
||||||
|
|
||||||
|
// Convert to float array
|
||||||
|
float[] audioData = new float[audioBytes.length / 2];
|
||||||
|
for (int i = 0; i < audioData.length; i++) {
|
||||||
|
// 16-bit PCM: two bytes per sample (little-endian)
|
||||||
|
short sample = (short) ((audioBytes[i * 2] & 0xff) | (audioBytes[i * 2 + 1] << 8));
|
||||||
|
audioData[i] = sample / 32768.0f; // Normalize to -1.0 to 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return audioData;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,25 +8,30 @@ import java.util.Collections;
|
|||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Silero VAD Detector
|
||||||
|
* Real-time voice activity detection
|
||||||
|
*
|
||||||
|
* @author VvvvvGH
|
||||||
|
*/
|
||||||
public class SlieroVadDetector {
|
public class SlieroVadDetector {
|
||||||
// OnnxModel model used for speech processing
|
// ONNX model for speech processing
|
||||||
private final SlieroVadOnnxModel model;
|
private final SlieroVadOnnxModel model;
|
||||||
// Threshold for speech start
|
// Speech start threshold
|
||||||
private final float startThreshold;
|
private final float startThreshold;
|
||||||
// Threshold for speech end
|
// Speech end threshold
|
||||||
private final float endThreshold;
|
private final float endThreshold;
|
||||||
// Sampling rate
|
// Sampling rate
|
||||||
private final int samplingRate;
|
private final int samplingRate;
|
||||||
// Minimum number of silence samples to determine the end threshold of speech
|
// Minimum silence samples to determine speech end
|
||||||
private final float minSilenceSamples;
|
private final float minSilenceSamples;
|
||||||
// Additional number of samples for speech start or end to calculate speech start or end time
|
// Speech padding samples for calculating speech boundaries
|
||||||
private final float speechPadSamples;
|
private final float speechPadSamples;
|
||||||
// Whether in the triggered state (i.e. whether speech is being detected)
|
// Triggered state (whether speech is being detected)
|
||||||
private boolean triggered;
|
private boolean triggered;
|
||||||
// Temporarily stored number of speech end samples
|
// Temporary speech end sample position
|
||||||
private int tempEnd;
|
private int tempEnd;
|
||||||
// Number of samples currently being processed
|
// Current sample position
|
||||||
private int currentSample;
|
private int currentSample;
|
||||||
|
|
||||||
|
|
||||||
@@ -36,23 +41,25 @@ public class SlieroVadDetector {
|
|||||||
int samplingRate,
|
int samplingRate,
|
||||||
int minSilenceDurationMs,
|
int minSilenceDurationMs,
|
||||||
int speechPadMs) throws OrtException {
|
int speechPadMs) throws OrtException {
|
||||||
// Check if the sampling rate is 8000 or 16000, if not, throw an exception
|
// Validate sampling rate
|
||||||
if (samplingRate != 8000 && samplingRate != 16000) {
|
if (samplingRate != 8000 && samplingRate != 16000) {
|
||||||
throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]");
|
throw new IllegalArgumentException("Does not support sampling rates other than [8000, 16000]");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the parameters
|
// Initialize parameters
|
||||||
this.model = new SlieroVadOnnxModel(modelPath);
|
this.model = new SlieroVadOnnxModel(modelPath);
|
||||||
this.startThreshold = startThreshold;
|
this.startThreshold = startThreshold;
|
||||||
this.endThreshold = endThreshold;
|
this.endThreshold = endThreshold;
|
||||||
this.samplingRate = samplingRate;
|
this.samplingRate = samplingRate;
|
||||||
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
||||||
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
|
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
|
||||||
// Reset the state
|
// Reset state
|
||||||
reset();
|
reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Method to reset the state, including the model state, trigger state, temporary end time, and current sample count
|
/**
|
||||||
|
* Reset detector state
|
||||||
|
*/
|
||||||
public void reset() {
|
public void reset() {
|
||||||
model.resetStates();
|
model.resetStates();
|
||||||
triggered = false;
|
triggered = false;
|
||||||
@@ -60,21 +67,27 @@ public class SlieroVadDetector {
|
|||||||
currentSample = 0;
|
currentSample = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply method for processing the audio array, returning possible speech start or end times
|
/**
|
||||||
|
* Process audio data and detect speech events
|
||||||
|
*
|
||||||
|
* @param data Audio data as byte array
|
||||||
|
* @param returnSeconds Whether to return timestamps in seconds
|
||||||
|
* @return Speech event (start or end) or empty map if no event
|
||||||
|
*/
|
||||||
public Map<String, Double> apply(byte[] data, boolean returnSeconds) {
|
public Map<String, Double> apply(byte[] data, boolean returnSeconds) {
|
||||||
|
|
||||||
// Convert the byte array to a float array
|
// Convert byte array to float array
|
||||||
float[] audioData = new float[data.length / 2];
|
float[] audioData = new float[data.length / 2];
|
||||||
for (int i = 0; i < audioData.length; i++) {
|
for (int i = 0; i < audioData.length; i++) {
|
||||||
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
|
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the length of the audio array as the window size
|
// Get window size from audio data length
|
||||||
int windowSizeSamples = audioData.length;
|
int windowSizeSamples = audioData.length;
|
||||||
// Update the current sample count
|
// Update current sample position
|
||||||
currentSample += windowSizeSamples;
|
currentSample += windowSizeSamples;
|
||||||
|
|
||||||
// Call the model to get the prediction probability of speech
|
// Get speech probability from model
|
||||||
float speechProb = 0;
|
float speechProb = 0;
|
||||||
try {
|
try {
|
||||||
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
|
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
|
||||||
@@ -82,19 +95,18 @@ public class SlieroVadDetector {
|
|||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time
|
// Reset temporary end if speech probability exceeds threshold
|
||||||
// This indicates that the speech duration has exceeded expectations and needs to recalculate the end time
|
|
||||||
if (speechProb >= startThreshold && tempEnd != 0) {
|
if (speechProb >= startThreshold && tempEnd != 0) {
|
||||||
tempEnd = 0;
|
tempEnd = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time
|
// Detect speech start
|
||||||
if (speechProb >= startThreshold && !triggered) {
|
if (speechProb >= startThreshold && !triggered) {
|
||||||
triggered = true;
|
triggered = true;
|
||||||
int speechStart = (int) (currentSample - speechPadSamples);
|
int speechStart = (int) (currentSample - speechPadSamples);
|
||||||
speechStart = Math.max(speechStart, 0);
|
speechStart = Math.max(speechStart, 0);
|
||||||
Map<String, Double> result = new HashMap<>();
|
Map<String, Double> result = new HashMap<>();
|
||||||
// Decide whether to return the result in seconds or sample count based on the returnSeconds parameter
|
// Return in seconds or samples based on returnSeconds parameter
|
||||||
if (returnSeconds) {
|
if (returnSeconds) {
|
||||||
double speechStartSeconds = speechStart / (double) samplingRate;
|
double speechStartSeconds = speechStart / (double) samplingRate;
|
||||||
double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
|
double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
|
||||||
@@ -106,18 +118,17 @@ public class SlieroVadDetector {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time
|
// Detect speech end
|
||||||
if (speechProb < endThreshold && triggered) {
|
if (speechProb < endThreshold && triggered) {
|
||||||
// Initialize or update the temporary end time
|
// Initialize or update temporary end position
|
||||||
if (tempEnd == 0) {
|
if (tempEnd == 0) {
|
||||||
tempEnd = currentSample;
|
tempEnd = currentSample;
|
||||||
}
|
}
|
||||||
// If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null
|
// Wait for minimum silence duration before confirming speech end
|
||||||
// This indicates that it is not yet possible to determine whether the speech has ended
|
|
||||||
if (currentSample - tempEnd < minSilenceSamples) {
|
if (currentSample - tempEnd < minSilenceSamples) {
|
||||||
return Collections.emptyMap();
|
return Collections.emptyMap();
|
||||||
} else {
|
} else {
|
||||||
// Calculate the speech end time, reset the trigger state and temporary end time
|
// Calculate speech end time and reset state
|
||||||
int speechEnd = (int) (tempEnd + speechPadSamples);
|
int speechEnd = (int) (tempEnd + speechPadSamples);
|
||||||
tempEnd = 0;
|
tempEnd = 0;
|
||||||
triggered = false;
|
triggered = false;
|
||||||
@@ -134,7 +145,7 @@ public class SlieroVadDetector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the above conditions are not met, return null by default
|
// No speech event detected
|
||||||
return Collections.emptyMap();
|
return Collections.emptyMap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,42 +9,58 @@ import java.util.HashMap;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Silero VAD ONNX Model Wrapper
|
||||||
|
*
|
||||||
|
* @author VvvvvGH
|
||||||
|
*/
|
||||||
public class SlieroVadOnnxModel {
|
public class SlieroVadOnnxModel {
|
||||||
// Define private variable OrtSession
|
// ONNX runtime session
|
||||||
private final OrtSession session;
|
private final OrtSession session;
|
||||||
private float[][][] h;
|
// Model state - dimensions: [2, batch_size, 128]
|
||||||
private float[][][] c;
|
private float[][][] state;
|
||||||
// Define the last sample rate
|
// Context - stores the tail of the previous audio chunk
|
||||||
|
private float[][] context;
|
||||||
|
// Last sample rate
|
||||||
private int lastSr = 0;
|
private int lastSr = 0;
|
||||||
// Define the last batch size
|
// Last batch size
|
||||||
private int lastBatchSize = 0;
|
private int lastBatchSize = 0;
|
||||||
// Define a list of supported sample rates
|
// Supported sample rates
|
||||||
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
|
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
|
||||||
|
|
||||||
// Constructor
|
// Constructor
|
||||||
public SlieroVadOnnxModel(String modelPath) throws OrtException {
|
public SlieroVadOnnxModel(String modelPath) throws OrtException {
|
||||||
// Get the ONNX runtime environment
|
// Get the ONNX runtime environment
|
||||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||||
// Create an ONNX session options object
|
// Create ONNX session options
|
||||||
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
||||||
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
|
// Set InterOp thread count to 1 (for parallel processing of different graph operations)
|
||||||
opts.setInterOpNumThreads(1);
|
opts.setInterOpNumThreads(1);
|
||||||
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
|
// Set IntraOp thread count to 1 (for parallel processing within a single operation)
|
||||||
opts.setIntraOpNumThreads(1);
|
opts.setIntraOpNumThreads(1);
|
||||||
// Add a CPU device, setting to false disables CPU execution optimization
|
// Enable CPU execution optimization
|
||||||
opts.addCPU(true);
|
opts.addCPU(true);
|
||||||
// Create an ONNX session using the environment, model path, and options
|
// Create ONNX session with the environment, model path, and options
|
||||||
session = env.createSession(modelPath, opts);
|
session = env.createSession(modelPath, opts);
|
||||||
// Reset states
|
// Reset states
|
||||||
resetStates();
|
resetStates();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reset states
|
* Reset states with default batch size
|
||||||
*/
|
*/
|
||||||
void resetStates() {
|
void resetStates() {
|
||||||
h = new float[2][1][64];
|
resetStates(1);
|
||||||
c = new float[2][1][64];
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset states with specific batch size
|
||||||
|
*
|
||||||
|
* @param batchSize Batch size for state initialization
|
||||||
|
*/
|
||||||
|
void resetStates(int batchSize) {
|
||||||
|
state = new float[2][batchSize][128];
|
||||||
|
context = new float[0][]; // Empty context
|
||||||
lastSr = 0;
|
lastSr = 0;
|
||||||
lastBatchSize = 0;
|
lastBatchSize = 0;
|
||||||
}
|
}
|
||||||
@@ -54,13 +70,12 @@ public class SlieroVadOnnxModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Define inner class ValidationResult
|
* Inner class for validation result
|
||||||
*/
|
*/
|
||||||
public static class ValidationResult {
|
public static class ValidationResult {
|
||||||
public final float[][] x;
|
public final float[][] x;
|
||||||
public final int sr;
|
public final int sr;
|
||||||
|
|
||||||
// Constructor
|
|
||||||
public ValidationResult(float[][] x, int sr) {
|
public ValidationResult(float[][] x, int sr) {
|
||||||
this.x = x;
|
this.x = x;
|
||||||
this.sr = sr;
|
this.sr = sr;
|
||||||
@@ -68,19 +83,23 @@ public class SlieroVadOnnxModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Function to validate input data
|
* Validate input data
|
||||||
|
*
|
||||||
|
* @param x Audio data array
|
||||||
|
* @param sr Sample rate
|
||||||
|
* @return Validated input data and sample rate
|
||||||
*/
|
*/
|
||||||
private ValidationResult validateInput(float[][] x, int sr) {
|
private ValidationResult validateInput(float[][] x, int sr) {
|
||||||
// Process the input data with dimension 1
|
// Ensure input is at least 2D
|
||||||
if (x.length == 1) {
|
if (x.length == 1) {
|
||||||
x = new float[][]{x[0]};
|
x = new float[][]{x[0]};
|
||||||
}
|
}
|
||||||
// Throw an exception when the input data dimension is greater than 2
|
// Check if input dimension is valid
|
||||||
if (x.length > 2) {
|
if (x.length > 2) {
|
||||||
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
|
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
|
// Downsample if sample rate is a multiple of 16000
|
||||||
if (sr != 16000 && (sr % 16000 == 0)) {
|
if (sr != 16000 && (sr % 16000 == 0)) {
|
||||||
int step = sr / 16000;
|
int step = sr / 16000;
|
||||||
float[][] reducedX = new float[x.length][];
|
float[][] reducedX = new float[x.length][];
|
||||||
@@ -100,22 +119,26 @@ public class SlieroVadOnnxModel {
|
|||||||
sr = 16000;
|
sr = 16000;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the sample rate is not in the list of supported sample rates, throw an exception
|
// Validate sample rate
|
||||||
if (!SAMPLE_RATES.contains(sr)) {
|
if (!SAMPLE_RATES.contains(sr)) {
|
||||||
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
|
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the input audio block is too short, throw an exception
|
// Check if audio chunk is too short
|
||||||
if (((float) sr) / x[0].length > 31.25) {
|
if (((float) sr) / x[0].length > 31.25) {
|
||||||
throw new IllegalArgumentException("Input audio is too short");
|
throw new IllegalArgumentException("Input audio is too short");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the validated result
|
|
||||||
return new ValidationResult(x, sr);
|
return new ValidationResult(x, sr);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Method to call the ONNX model
|
* Call the ONNX model for inference
|
||||||
|
*
|
||||||
|
* @param x Audio data array
|
||||||
|
* @param sr Sample rate
|
||||||
|
* @return Speech probability output
|
||||||
|
* @throws OrtException If ONNX runtime error occurs
|
||||||
*/
|
*/
|
||||||
public float[] call(float[][] x, int sr) throws OrtException {
|
public float[] call(float[][] x, int sr) throws OrtException {
|
||||||
ValidationResult result = validateInput(x, sr);
|
ValidationResult result = validateInput(x, sr);
|
||||||
@@ -123,38 +146,62 @@ public class SlieroVadOnnxModel {
|
|||||||
sr = result.sr;
|
sr = result.sr;
|
||||||
|
|
||||||
int batchSize = x.length;
|
int batchSize = x.length;
|
||||||
|
int numSamples = sr == 16000 ? 512 : 256;
|
||||||
|
int contextSize = sr == 16000 ? 64 : 32;
|
||||||
|
|
||||||
if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) {
|
// Reset states only when sample rate or batch size changes
|
||||||
resetStates();
|
if (lastSr != 0 && lastSr != sr) {
|
||||||
|
resetStates(batchSize);
|
||||||
|
} else if (lastBatchSize != 0 && lastBatchSize != batchSize) {
|
||||||
|
resetStates(batchSize);
|
||||||
|
} else if (lastBatchSize == 0) {
|
||||||
|
// First call - state is already initialized, just set batch size
|
||||||
|
lastBatchSize = batchSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize context if needed
|
||||||
|
if (context.length == 0) {
|
||||||
|
context = new float[batchSize][contextSize];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concatenate context and input
|
||||||
|
float[][] xWithContext = new float[batchSize][contextSize + numSamples];
|
||||||
|
for (int i = 0; i < batchSize; i++) {
|
||||||
|
// Copy context
|
||||||
|
System.arraycopy(context[i], 0, xWithContext[i], 0, contextSize);
|
||||||
|
// Copy input
|
||||||
|
System.arraycopy(x[i], 0, xWithContext[i], contextSize, numSamples);
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||||
|
|
||||||
OnnxTensor inputTensor = null;
|
OnnxTensor inputTensor = null;
|
||||||
OnnxTensor hTensor = null;
|
OnnxTensor stateTensor = null;
|
||||||
OnnxTensor cTensor = null;
|
|
||||||
OnnxTensor srTensor = null;
|
OnnxTensor srTensor = null;
|
||||||
OrtSession.Result ortOutputs = null;
|
OrtSession.Result ortOutputs = null;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Create input tensors
|
// Create input tensors
|
||||||
inputTensor = OnnxTensor.createTensor(env, x);
|
inputTensor = OnnxTensor.createTensor(env, xWithContext);
|
||||||
hTensor = OnnxTensor.createTensor(env, h);
|
stateTensor = OnnxTensor.createTensor(env, state);
|
||||||
cTensor = OnnxTensor.createTensor(env, c);
|
|
||||||
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
|
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
|
||||||
|
|
||||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||||
inputs.put("input", inputTensor);
|
inputs.put("input", inputTensor);
|
||||||
inputs.put("sr", srTensor);
|
inputs.put("sr", srTensor);
|
||||||
inputs.put("h", hTensor);
|
inputs.put("state", stateTensor);
|
||||||
inputs.put("c", cTensor);
|
|
||||||
|
|
||||||
// Call the ONNX model for calculation
|
// Run ONNX model inference
|
||||||
ortOutputs = session.run(inputs);
|
ortOutputs = session.run(inputs);
|
||||||
// Get the output results
|
// Get output results
|
||||||
float[][] output = (float[][]) ortOutputs.get(0).getValue();
|
float[][] output = (float[][]) ortOutputs.get(0).getValue();
|
||||||
h = (float[][][]) ortOutputs.get(1).getValue();
|
state = (float[][][]) ortOutputs.get(1).getValue();
|
||||||
c = (float[][][]) ortOutputs.get(2).getValue();
|
|
||||||
|
// Update context - save the last contextSize samples from input
|
||||||
|
for (int i = 0; i < batchSize; i++) {
|
||||||
|
System.arraycopy(xWithContext[i], xWithContext[i].length - contextSize,
|
||||||
|
context[i], 0, contextSize);
|
||||||
|
}
|
||||||
|
|
||||||
lastSr = sr;
|
lastSr = sr;
|
||||||
lastBatchSize = batchSize;
|
lastBatchSize = batchSize;
|
||||||
@@ -163,11 +210,8 @@ public class SlieroVadOnnxModel {
|
|||||||
if (inputTensor != null) {
|
if (inputTensor != null) {
|
||||||
inputTensor.close();
|
inputTensor.close();
|
||||||
}
|
}
|
||||||
if (hTensor != null) {
|
if (stateTensor != null) {
|
||||||
hTensor.close();
|
stateTensor.close();
|
||||||
}
|
|
||||||
if (cTensor != null) {
|
|
||||||
cTensor.close();
|
|
||||||
}
|
}
|
||||||
if (srTensor != null) {
|
if (srTensor != null) {
|
||||||
srTensor.close();
|
srTensor.close();
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ requires = ["hatchling"]
|
|||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
[project]
|
[project]
|
||||||
name = "silero-vad"
|
name = "silero-vad"
|
||||||
version = "6.0.0"
|
version = "6.1.0"
|
||||||
authors = [
|
authors = [
|
||||||
{name="Silero Team", email="hello@silero.ai"},
|
{name="Silero Team", email="hello@silero.ai"},
|
||||||
]
|
]
|
||||||
@@ -28,6 +28,7 @@ classifiers = [
|
|||||||
"Topic :: Scientific/Engineering",
|
"Topic :: Scientific/Engineering",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"packaging",
|
||||||
"torch>=1.12.0",
|
"torch>=1.12.0",
|
||||||
"torchaudio>=0.12.0",
|
"torchaudio>=0.12.0",
|
||||||
"onnxruntime>=1.16.1",
|
"onnxruntime>=1.16.1",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
import warnings
|
import warnings
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
languages = ['ru', 'en', 'de', 'es']
|
languages = ['ru', 'en', 'de', 'es']
|
||||||
|
|
||||||
@@ -134,40 +135,60 @@ class Validator():
|
|||||||
return outs
|
return outs
|
||||||
|
|
||||||
|
|
||||||
def read_audio(path: str,
|
def read_audio(path: str, sampling_rate: int = 16000) -> torch.Tensor:
|
||||||
sampling_rate: int = 16000):
|
ta_ver = version.parse(torchaudio.__version__)
|
||||||
list_backends = torchaudio.list_audio_backends()
|
if ta_ver < version.parse("2.9"):
|
||||||
|
try:
|
||||||
|
effects = [['channels', '1'],['rate', str(sampling_rate)]]
|
||||||
|
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
|
||||||
|
except:
|
||||||
|
wav, sr = torchaudio.load(path)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
wav, sr = torchaudio.load(path)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
from torchcodec.decoders import AudioDecoder
|
||||||
|
samples = AudioDecoder(path).get_all_samples()
|
||||||
|
wav = samples.data
|
||||||
|
sr = samples.sample_rate
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"torchaudio version {torchaudio.__version__} requires torchcodec for audio I/O. "
|
||||||
|
+ "Install torchcodec or pin torchaudio < 2.9"
|
||||||
|
)
|
||||||
|
|
||||||
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
|
if wav.ndim > 1 and wav.size(0) > 1:
|
||||||
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
|
wav = wav.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
try:
|
if sr != sampling_rate:
|
||||||
effects = [
|
wav = torchaudio.transforms.Resample(sr, sampling_rate)(wav)
|
||||||
['channels', '1'],
|
|
||||||
['rate', str(sampling_rate)]
|
|
||||||
]
|
|
||||||
|
|
||||||
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
|
|
||||||
except:
|
|
||||||
wav, sr = torchaudio.load(path)
|
|
||||||
|
|
||||||
if wav.size(0) > 1:
|
|
||||||
wav = wav.mean(dim=0, keepdim=True)
|
|
||||||
|
|
||||||
if sr != sampling_rate:
|
|
||||||
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
|
||||||
new_freq=sampling_rate)
|
|
||||||
wav = transform(wav)
|
|
||||||
sr = sampling_rate
|
|
||||||
|
|
||||||
assert sr == sampling_rate
|
|
||||||
return wav.squeeze(0)
|
return wav.squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
def save_audio(path: str,
|
def save_audio(path: str, tensor: torch.Tensor, sampling_rate: int = 16000):
|
||||||
tensor: torch.Tensor,
|
tensor = tensor.detach().cpu()
|
||||||
sampling_rate: int = 16000):
|
if tensor.ndim == 1:
|
||||||
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16)
|
tensor = tensor.unsqueeze(0)
|
||||||
|
|
||||||
|
ta_ver = version.parse(torchaudio.__version__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
torchaudio.save(path, tensor, sampling_rate, bits_per_sample=16)
|
||||||
|
except Exception:
|
||||||
|
if ta_ver >= version.parse("2.9"):
|
||||||
|
try:
|
||||||
|
from torchcodec.encoders import AudioEncoder
|
||||||
|
encoder = AudioEncoder(tensor, sample_rate=16000)
|
||||||
|
encoder.to_file(path)
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"torchaudio version {torchaudio.__version__} requires torchcodec for saving. "
|
||||||
|
+ "Install torchcodec or pin torchaudio < 2.9"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def init_jit_model(model_path: str,
|
def init_jit_model(model_path: str,
|
||||||
@@ -227,7 +248,7 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
|
|
||||||
max_speech_duration_s: int (default - inf)
|
max_speech_duration_s: int (default - inf)
|
||||||
Maximum duration of speech chunks in seconds
|
Maximum duration of speech chunks in seconds
|
||||||
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting.
|
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent aggressive cutting.
|
||||||
Otherwise, they will be split aggressively just before max_speech_duration_s.
|
Otherwise, they will be split aggressively just before max_speech_duration_s.
|
||||||
|
|
||||||
min_silence_duration_ms: int (default - 100 milliseconds)
|
min_silence_duration_ms: int (default - 100 milliseconds)
|
||||||
@@ -289,7 +310,6 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
|
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
|
||||||
|
|
||||||
window_size_samples = 512 if sampling_rate == 16000 else 256
|
window_size_samples = 512 if sampling_rate == 16000 else 256
|
||||||
hop_size_samples = int(window_size_samples)
|
|
||||||
|
|
||||||
model.reset_states()
|
model.reset_states()
|
||||||
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
||||||
@@ -301,17 +321,14 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
audio_length_samples = len(audio)
|
audio_length_samples = len(audio)
|
||||||
|
|
||||||
speech_probs = []
|
speech_probs = []
|
||||||
for current_start_sample in range(0, audio_length_samples, hop_size_samples):
|
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
||||||
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
|
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
|
||||||
if len(chunk) < window_size_samples:
|
if len(chunk) < window_size_samples:
|
||||||
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||||
try:
|
speech_prob = model(chunk, sampling_rate).item()
|
||||||
speech_prob = model(chunk, sampling_rate).item()
|
|
||||||
except Exception as e:
|
|
||||||
import ipdb; ipdb.set_trace()
|
|
||||||
speech_probs.append(speech_prob)
|
speech_probs.append(speech_prob)
|
||||||
# caculate progress and seng it to callback function
|
# calculate progress and send it to callback function
|
||||||
progress = current_start_sample + hop_size_samples
|
progress = current_start_sample + window_size_samples
|
||||||
if progress > audio_length_samples:
|
if progress > audio_length_samples:
|
||||||
progress = audio_length_samples
|
progress = audio_length_samples
|
||||||
progress_percent = (progress / audio_length_samples) * 100
|
progress_percent = (progress / audio_length_samples) * 100
|
||||||
@@ -329,53 +346,70 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
possible_ends = []
|
possible_ends = []
|
||||||
|
|
||||||
for i, speech_prob in enumerate(speech_probs):
|
for i, speech_prob in enumerate(speech_probs):
|
||||||
if (speech_prob >= threshold) and temp_end:
|
cur_sample = window_size_samples * i
|
||||||
if temp_end != 0:
|
|
||||||
sil_dur = (hop_size_samples * i) - temp_end
|
|
||||||
if sil_dur > min_silence_samples_at_max_speech:
|
|
||||||
possible_ends.append((temp_end, sil_dur))
|
|
||||||
temp_end = 0
|
|
||||||
if next_start < prev_end:
|
|
||||||
next_start = hop_size_samples * i
|
|
||||||
|
|
||||||
|
# If speech returns after a temp_end, record candidate silence if long enough and clear temp_end
|
||||||
|
if (speech_prob >= threshold) and temp_end:
|
||||||
|
sil_dur = cur_sample - temp_end
|
||||||
|
if sil_dur > min_silence_samples_at_max_speech:
|
||||||
|
possible_ends.append((temp_end, sil_dur))
|
||||||
|
temp_end = 0
|
||||||
|
if next_start < prev_end:
|
||||||
|
next_start = cur_sample
|
||||||
|
|
||||||
|
# Start of speech
|
||||||
if (speech_prob >= threshold) and not triggered:
|
if (speech_prob >= threshold) and not triggered:
|
||||||
triggered = True
|
triggered = True
|
||||||
current_speech['start'] = hop_size_samples * i
|
current_speech['start'] = cur_sample
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if triggered and (hop_size_samples * i) - current_speech['start'] > max_speech_samples:
|
# Max speech length reached: decide where to cut
|
||||||
if possible_ends:
|
if triggered and (cur_sample - current_speech['start'] > max_speech_samples):
|
||||||
if use_max_poss_sil_at_max_speech:
|
if use_max_poss_sil_at_max_speech and possible_ends:
|
||||||
prev_end, dur = max(possible_ends, key=lambda x: x[1]) # use the longest possible silence segment in the current speech chunk
|
prev_end, dur = max(possible_ends, key=lambda x: x[1]) # use the longest possible silence segment in the current speech chunk
|
||||||
else:
|
|
||||||
prev_end, dur = possible_ends[-1] # use the last possible silence segement
|
|
||||||
current_speech['end'] = prev_end
|
current_speech['end'] = prev_end
|
||||||
speeches.append(current_speech)
|
speeches.append(current_speech)
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
next_start = prev_end + dur
|
next_start = prev_end + dur
|
||||||
if next_start < prev_end + hop_size_samples * i: # previously reached silence (< neg_thres) and is still not speech (< thres)
|
|
||||||
#triggered = False
|
if next_start < prev_end + cur_sample: # previously reached silence (< neg_thres) and is still not speech (< thres)
|
||||||
current_speech['start'] = next_start
|
current_speech['start'] = next_start
|
||||||
else:
|
else:
|
||||||
triggered = False
|
triggered = False
|
||||||
#current_speech['start'] = next_start
|
|
||||||
prev_end = next_start = temp_end = 0
|
prev_end = next_start = temp_end = 0
|
||||||
possible_ends = []
|
possible_ends = []
|
||||||
else:
|
else:
|
||||||
current_speech['end'] = hop_size_samples * i
|
# Legacy max-speech cut (use_max_poss_sil_at_max_speech=False): prefer last valid silence (prev_end) if available
|
||||||
speeches.append(current_speech)
|
if prev_end:
|
||||||
current_speech = {}
|
current_speech['end'] = prev_end
|
||||||
prev_end = next_start = temp_end = 0
|
speeches.append(current_speech)
|
||||||
triggered = False
|
current_speech = {}
|
||||||
possible_ends = []
|
if next_start < prev_end:
|
||||||
continue
|
triggered = False
|
||||||
|
else:
|
||||||
|
current_speech['start'] = next_start
|
||||||
|
prev_end = next_start = temp_end = 0
|
||||||
|
possible_ends = []
|
||||||
|
else:
|
||||||
|
# No prev_end -> fallback to cutting at current sample
|
||||||
|
current_speech['end'] = cur_sample
|
||||||
|
speeches.append(current_speech)
|
||||||
|
current_speech = {}
|
||||||
|
prev_end = next_start = temp_end = 0
|
||||||
|
triggered = False
|
||||||
|
possible_ends = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Silence detection while in speech
|
||||||
if (speech_prob < neg_threshold) and triggered:
|
if (speech_prob < neg_threshold) and triggered:
|
||||||
if not temp_end:
|
if not temp_end:
|
||||||
temp_end = hop_size_samples * i
|
temp_end = cur_sample
|
||||||
# if ((hop_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
|
sil_dur_now = cur_sample - temp_end
|
||||||
# prev_end = temp_end
|
|
||||||
if (hop_size_samples * i) - temp_end < min_silence_samples:
|
if not use_max_poss_sil_at_max_speech and sil_dur_now > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
|
||||||
|
prev_end = temp_end
|
||||||
|
|
||||||
|
if sil_dur_now < min_silence_samples:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
current_speech['end'] = temp_end
|
current_speech['end'] = temp_end
|
||||||
@@ -416,7 +450,7 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
speech_dict['end'] *= step
|
speech_dict['end'] *= step
|
||||||
|
|
||||||
if visualize_probs:
|
if visualize_probs:
|
||||||
make_visualization(speech_probs, hop_size_samples / sampling_rate)
|
make_visualization(speech_probs, window_size_samples / sampling_rate)
|
||||||
|
|
||||||
return speeches
|
return speeches
|
||||||
|
|
||||||
@@ -607,6 +641,8 @@ def drop_chunks(tss: List[dict],
|
|||||||
chunks.append((wav[cur_start: i['start']]))
|
chunks.append((wav[cur_start: i['start']]))
|
||||||
cur_start = i['end']
|
cur_start = i['end']
|
||||||
|
|
||||||
|
chunks.append(wav[cur_start:])
|
||||||
|
|
||||||
return torch.cat(chunks)
|
return torch.cat(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -118,8 +118,6 @@ class SileroVadDataset(Dataset):
|
|||||||
|
|
||||||
assert len(gt) == len(wav) / self.num_samples
|
assert len(gt) == len(wav) / self.num_samples
|
||||||
|
|
||||||
mask[gt == 0]
|
|
||||||
|
|
||||||
return wav, gt, mask
|
return wav, gt, mask
|
||||||
|
|
||||||
def get_ground_truth_annotated(self, annotation, audio_length_samples):
|
def get_ground_truth_annotated(self, annotation, audio_length_samples):
|
||||||
@@ -240,6 +238,7 @@ def train(config,
|
|||||||
|
|
||||||
loss = criterion(stacked, targets)
|
loss = criterion(stacked, targets)
|
||||||
loss = (loss * masks).mean()
|
loss = (loss * masks).mean()
|
||||||
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
losses.update(loss.item(), masks.numel())
|
losses.update(loss.item(), masks.numel())
|
||||||
|
|||||||
Reference in New Issue
Block a user