mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
Compare commits
38 Commits
adamnsandl
...
adamnsandl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64b863d2ff | ||
|
|
8a3600665b | ||
|
|
9c2c90aa1c | ||
|
|
1d48167271 | ||
|
|
d0139d94d9 | ||
|
|
46f94b7d60 | ||
|
|
3de3ee3abe | ||
|
|
e680ea6633 | ||
|
|
199de226e5 | ||
|
|
4109b107c1 | ||
|
|
36854a90db | ||
|
|
827e86e685 | ||
|
|
e706ec6fee | ||
|
|
88df0ce1dd | ||
|
|
d18b91e037 | ||
|
|
1e3f343767 | ||
|
|
6a8ee81ee0 | ||
|
|
cb25c0c047 | ||
|
|
7af8628a27 | ||
|
|
3682cb189c | ||
|
|
57c0b51f9b | ||
|
|
dd0b143803 | ||
|
|
181cdf92b6 | ||
|
|
a7bd2dd38f | ||
|
|
df7de797a5 | ||
|
|
87ed11b508 | ||
|
|
84768cefdf | ||
|
|
6de3660f25 | ||
|
|
d9a6941852 | ||
|
|
dfdc9a484e | ||
|
|
f2e3a23d96 | ||
|
|
2b97f61160 | ||
|
|
657dac8736 | ||
|
|
412a478e29 | ||
|
|
9adf6d2192 | ||
|
|
8a2a73c14f | ||
|
|
3e0305559d | ||
|
|
f0d880d79c |
65
README.md
65
README.md
@@ -25,6 +25,55 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
|
|||||||
</details>
|
</details>
|
||||||
|
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
|
<h2 align="center">Fast start</h2>
|
||||||
|
<br/>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Dependencies</summary>
|
||||||
|
|
||||||
|
System requirements to run python examples:
|
||||||
|
- `python 3.8+`
|
||||||
|
- 1G+ RAM
|
||||||
|
- not too outdated cpu
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
- `torch>=1.12.0`
|
||||||
|
- `torchaudio>=0.12.0` (for I/O functionalities only)
|
||||||
|
- `onnxruntime>=1.16.1` (for ONNX model usage)
|
||||||
|
|
||||||
|
Silero VAD uses torchaudio library for audio file I/O functionalities, which are torchaudio.info, torchaudio.load, and torchaudio.save, so a proper audio backend is required:
|
||||||
|
|
||||||
|
- Option №1 - [**FFmpeg**](https://www.ffmpeg.org/) backend. `conda install -c conda-forge 'ffmpeg<7'`
|
||||||
|
- Option №2 - [**sox_io**](https://pypi.org/project/sox/) backend. `apt-get install sox`, TorchAudio is tested on libsox 14.4.2.
|
||||||
|
- Option №3 - [**soundfile**](https://pypi.org/project/soundfile/) backend. `pip install soundfile`
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
**Using pip**:
|
||||||
|
`pip install silero-vad`
|
||||||
|
|
||||||
|
```python3
|
||||||
|
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
|
||||||
|
model = load_silero_vad()
|
||||||
|
wav = read_audio('path_to_audio_file')
|
||||||
|
speech_timestamps = get_speech_timestamps(wav, model)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Using torch.hub**:
|
||||||
|
```python3
|
||||||
|
import torch
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
||||||
|
(get_speech_timestamps, _, read_audio, _, _) = utils
|
||||||
|
|
||||||
|
wav = read_audio('path_to_audio_file')
|
||||||
|
speech_timestamps = get_speech_timestamps(wav, model)
|
||||||
|
```
|
||||||
|
|
||||||
|
<br/>
|
||||||
|
|
||||||
<h2 align="center">Key Features</h2>
|
<h2 align="center">Key Features</h2>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
@@ -57,21 +106,7 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
|
|||||||
Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
|
Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
|
||||||
|
|
||||||
<br/>
|
<br/>
|
||||||
<h2 align="center">Fast start</h2>
|
|
||||||
<br/>
|
|
||||||
|
|
||||||
```python3
|
|
||||||
import torch
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
|
|
||||||
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
|
||||||
(get_speech_timestamps, _, read_audio, _, _) = utils
|
|
||||||
|
|
||||||
wav = read_audio('path_to_audio_file')
|
|
||||||
speech_timestamps = get_speech_timestamps(wav, model)
|
|
||||||
```
|
|
||||||
|
|
||||||
<br/>
|
|
||||||
<h2 align="center">Typical Use Cases</h2>
|
<h2 align="center">Typical Use Cases</h2>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
@@ -106,7 +141,7 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for releva
|
|||||||
@misc{Silero VAD,
|
@misc{Silero VAD,
|
||||||
author = {Silero Team},
|
author = {Silero Team},
|
||||||
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
|
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
|
||||||
year = {2021},
|
year = {2024},
|
||||||
publisher = {GitHub},
|
publisher = {GitHub},
|
||||||
journal = {GitHub repository},
|
journal = {GitHub repository},
|
||||||
howpublished = {\url{https://github.com/snakers4/silero-vad}},
|
howpublished = {\url{https://github.com/snakers4/silero-vad}},
|
||||||
|
|||||||
35
examples/csharp/Program.cs
Normal file
35
examples/csharp/Program.cs
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
using System.Text;
|
||||||
|
|
||||||
|
namespace VadDotNet;
|
||||||
|
|
||||||
|
|
||||||
|
class Program
|
||||||
|
{
|
||||||
|
private const string MODEL_PATH = "./resources/silero_vad.onnx";
|
||||||
|
private const string EXAMPLE_WAV_FILE = "./resources/example.wav";
|
||||||
|
private const int SAMPLE_RATE = 16000;
|
||||||
|
private const float THRESHOLD = 0.5f;
|
||||||
|
private const int MIN_SPEECH_DURATION_MS = 250;
|
||||||
|
private const float MAX_SPEECH_DURATION_SECONDS = float.PositiveInfinity;
|
||||||
|
private const int MIN_SILENCE_DURATION_MS = 100;
|
||||||
|
private const int SPEECH_PAD_MS = 30;
|
||||||
|
|
||||||
|
public static void Main(string[] args)
|
||||||
|
{
|
||||||
|
|
||||||
|
var vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
|
||||||
|
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
||||||
|
List<SileroSpeechSegment> speechTimeList = vadDetector.GetSpeechSegmentList(new FileInfo(EXAMPLE_WAV_FILE));
|
||||||
|
//Console.WriteLine(speechTimeList.ToJson());
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
foreach (var speechSegment in speechTimeList)
|
||||||
|
{
|
||||||
|
sb.Append($"start second: {speechSegment.StartSecond}, end second: {speechSegment.EndSecond}\n");
|
||||||
|
|
||||||
|
}
|
||||||
|
Console.WriteLine(sb.ToString());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
21
examples/csharp/SileroSpeechSegment.cs
Normal file
21
examples/csharp/SileroSpeechSegment.cs
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
namespace VadDotNet;
|
||||||
|
|
||||||
|
public class SileroSpeechSegment
|
||||||
|
{
|
||||||
|
public int? StartOffset { get; set; }
|
||||||
|
public int? EndOffset { get; set; }
|
||||||
|
public float? StartSecond { get; set; }
|
||||||
|
public float? EndSecond { get; set; }
|
||||||
|
|
||||||
|
public SileroSpeechSegment()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
public SileroSpeechSegment(int startOffset, int? endOffset, float? startSecond, float? endSecond)
|
||||||
|
{
|
||||||
|
StartOffset = startOffset;
|
||||||
|
EndOffset = endOffset;
|
||||||
|
StartSecond = startSecond;
|
||||||
|
EndSecond = endSecond;
|
||||||
|
}
|
||||||
|
}
|
||||||
250
examples/csharp/SileroVadDetector.cs
Normal file
250
examples/csharp/SileroVadDetector.cs
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
using NAudio.Wave;
|
||||||
|
using VADdotnet;
|
||||||
|
|
||||||
|
namespace VadDotNet;
|
||||||
|
|
||||||
|
public class SileroVadDetector
|
||||||
|
{
|
||||||
|
private readonly SileroVadOnnxModel _model;
|
||||||
|
private readonly float _threshold;
|
||||||
|
private readonly float _negThreshold;
|
||||||
|
private readonly int _samplingRate;
|
||||||
|
private readonly int _windowSizeSample;
|
||||||
|
private readonly float _minSpeechSamples;
|
||||||
|
private readonly float _speechPadSamples;
|
||||||
|
private readonly float _maxSpeechSamples;
|
||||||
|
private readonly float _minSilenceSamples;
|
||||||
|
private readonly float _minSilenceSamplesAtMaxSpeech;
|
||||||
|
private int _audioLengthSamples;
|
||||||
|
private const float THRESHOLD_GAP = 0.15f;
|
||||||
|
// ReSharper disable once InconsistentNaming
|
||||||
|
private const int SAMPLING_RATE_8K = 8000;
|
||||||
|
// ReSharper disable once InconsistentNaming
|
||||||
|
private const int SAMPLING_RATE_16K = 16000;
|
||||||
|
|
||||||
|
public SileroVadDetector(string onnxModelPath, float threshold, int samplingRate,
|
||||||
|
int minSpeechDurationMs, float maxSpeechDurationSeconds,
|
||||||
|
int minSilenceDurationMs, int speechPadMs)
|
||||||
|
{
|
||||||
|
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("Sampling rate not support, only available for [8000, 16000]");
|
||||||
|
}
|
||||||
|
|
||||||
|
this._model = new SileroVadOnnxModel(onnxModelPath);
|
||||||
|
this._samplingRate = samplingRate;
|
||||||
|
this._threshold = threshold;
|
||||||
|
this._negThreshold = threshold - THRESHOLD_GAP;
|
||||||
|
this._windowSizeSample = samplingRate == SAMPLING_RATE_16K ? 512 : 256;
|
||||||
|
this._minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
|
||||||
|
this._speechPadSamples = samplingRate * speechPadMs / 1000f;
|
||||||
|
this._maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - _windowSizeSample - 2 * _speechPadSamples;
|
||||||
|
this._minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
||||||
|
this._minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
|
||||||
|
this.Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Reset()
|
||||||
|
{
|
||||||
|
_model.ResetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SileroSpeechSegment> GetSpeechSegmentList(FileInfo wavFile)
|
||||||
|
{
|
||||||
|
Reset();
|
||||||
|
|
||||||
|
using (var audioFile = new AudioFileReader(wavFile.FullName))
|
||||||
|
{
|
||||||
|
List<float> speechProbList = new List<float>();
|
||||||
|
this._audioLengthSamples = (int)(audioFile.Length / 2);
|
||||||
|
float[] buffer = new float[this._windowSizeSample];
|
||||||
|
|
||||||
|
while (audioFile.Read(buffer, 0, buffer.Length) > 0)
|
||||||
|
{
|
||||||
|
float speechProb = _model.Call(new[] { buffer }, _samplingRate)[0];
|
||||||
|
speechProbList.Add(speechProb);
|
||||||
|
}
|
||||||
|
|
||||||
|
return CalculateProb(speechProbList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SileroSpeechSegment> CalculateProb(List<float> speechProbList)
|
||||||
|
{
|
||||||
|
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
|
||||||
|
bool triggered = false;
|
||||||
|
int tempEnd = 0, prevEnd = 0, nextStart = 0;
|
||||||
|
SileroSpeechSegment segment = new SileroSpeechSegment();
|
||||||
|
|
||||||
|
for (int i = 0; i < speechProbList.Count; i++)
|
||||||
|
{
|
||||||
|
float speechProb = speechProbList[i];
|
||||||
|
if (speechProb >= _threshold && (tempEnd != 0))
|
||||||
|
{
|
||||||
|
tempEnd = 0;
|
||||||
|
if (nextStart < prevEnd)
|
||||||
|
{
|
||||||
|
nextStart = _windowSizeSample * i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speechProb >= _threshold && !triggered)
|
||||||
|
{
|
||||||
|
triggered = true;
|
||||||
|
segment.StartOffset = _windowSizeSample * i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (triggered && (_windowSizeSample * i) - segment.StartOffset > _maxSpeechSamples)
|
||||||
|
{
|
||||||
|
if (prevEnd != 0)
|
||||||
|
{
|
||||||
|
segment.EndOffset = prevEnd;
|
||||||
|
result.Add(segment);
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
if (nextStart < prevEnd)
|
||||||
|
{
|
||||||
|
triggered = false;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
segment.StartOffset = nextStart;
|
||||||
|
}
|
||||||
|
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
segment.EndOffset = _windowSizeSample * i;
|
||||||
|
result.Add(segment);
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
triggered = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speechProb < _negThreshold && triggered)
|
||||||
|
{
|
||||||
|
if (tempEnd == 0)
|
||||||
|
{
|
||||||
|
tempEnd = _windowSizeSample * i;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (((_windowSizeSample * i) - tempEnd) > _minSilenceSamplesAtMaxSpeech)
|
||||||
|
{
|
||||||
|
prevEnd = tempEnd;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((_windowSizeSample * i) - tempEnd < _minSilenceSamples)
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
segment.EndOffset = tempEnd;
|
||||||
|
if ((segment.EndOffset - segment.StartOffset) > _minSpeechSamples)
|
||||||
|
{
|
||||||
|
result.Add(segment);
|
||||||
|
}
|
||||||
|
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
triggered = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (segment.StartOffset != null && (_audioLengthSamples - segment.StartOffset) > _minSpeechSamples)
|
||||||
|
{
|
||||||
|
segment.EndOffset = _audioLengthSamples;
|
||||||
|
result.Add(segment);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < result.Count; i++)
|
||||||
|
{
|
||||||
|
SileroSpeechSegment item = result[i];
|
||||||
|
if (i == 0)
|
||||||
|
{
|
||||||
|
item.StartOffset = (int)Math.Max(0, item.StartOffset.Value - _speechPadSamples);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i != result.Count - 1)
|
||||||
|
{
|
||||||
|
SileroSpeechSegment nextItem = result[i + 1];
|
||||||
|
int silenceDuration = nextItem.StartOffset.Value - item.EndOffset.Value;
|
||||||
|
if (silenceDuration < 2 * _speechPadSamples)
|
||||||
|
{
|
||||||
|
item.EndOffset = item.EndOffset + (silenceDuration / 2);
|
||||||
|
nextItem.StartOffset = Math.Max(0, nextItem.StartOffset.Value - (silenceDuration / 2));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
|
||||||
|
nextItem.StartOffset = (int)Math.Max(0, nextItem.StartOffset.Value - _speechPadSamples);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return MergeListAndCalculateSecond(result, _samplingRate);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SileroSpeechSegment> MergeListAndCalculateSecond(List<SileroSpeechSegment> original, int samplingRate)
|
||||||
|
{
|
||||||
|
List<SileroSpeechSegment> result = new List<SileroSpeechSegment>();
|
||||||
|
if (original == null || original.Count == 0)
|
||||||
|
{
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
int left = original[0].StartOffset.Value;
|
||||||
|
int right = original[0].EndOffset.Value;
|
||||||
|
if (original.Count > 1)
|
||||||
|
{
|
||||||
|
original.Sort((a, b) => a.StartOffset.Value.CompareTo(b.StartOffset.Value));
|
||||||
|
for (int i = 1; i < original.Count; i++)
|
||||||
|
{
|
||||||
|
SileroSpeechSegment segment = original[i];
|
||||||
|
|
||||||
|
if (segment.StartOffset > right)
|
||||||
|
{
|
||||||
|
result.Add(new SileroSpeechSegment(left, right,
|
||||||
|
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
|
||||||
|
left = segment.StartOffset.Value;
|
||||||
|
right = segment.EndOffset.Value;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
right = Math.Max(right, segment.EndOffset.Value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Add(new SileroSpeechSegment(left, right,
|
||||||
|
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
result.Add(new SileroSpeechSegment(left, right,
|
||||||
|
CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private float CalculateSecondByOffset(int offset, int samplingRate)
|
||||||
|
{
|
||||||
|
float secondValue = offset * 1.0f / samplingRate;
|
||||||
|
return (float)Math.Floor(secondValue * 1000.0f) / 1000.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
220
examples/csharp/SileroVadOnnxModel.cs
Normal file
220
examples/csharp/SileroVadOnnxModel.cs
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
using Microsoft.ML.OnnxRuntime;
|
||||||
|
using Microsoft.ML.OnnxRuntime.Tensors;
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Linq;
|
||||||
|
|
||||||
|
namespace VADdotnet;
|
||||||
|
|
||||||
|
|
||||||
|
public class SileroVadOnnxModel : IDisposable
|
||||||
|
{
|
||||||
|
private readonly InferenceSession session;
|
||||||
|
private float[][][] state;
|
||||||
|
private float[][] context;
|
||||||
|
private int lastSr = 0;
|
||||||
|
private int lastBatchSize = 0;
|
||||||
|
private static readonly List<int> SAMPLE_RATES = new List<int> { 8000, 16000 };
|
||||||
|
|
||||||
|
public SileroVadOnnxModel(string modelPath)
|
||||||
|
{
|
||||||
|
var sessionOptions = new SessionOptions();
|
||||||
|
sessionOptions.InterOpNumThreads = 1;
|
||||||
|
sessionOptions.IntraOpNumThreads = 1;
|
||||||
|
sessionOptions.EnableCpuMemArena = true;
|
||||||
|
|
||||||
|
session = new InferenceSession(modelPath, sessionOptions);
|
||||||
|
ResetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void ResetStates()
|
||||||
|
{
|
||||||
|
state = new float[2][][];
|
||||||
|
state[0] = new float[1][];
|
||||||
|
state[1] = new float[1][];
|
||||||
|
state[0][0] = new float[128];
|
||||||
|
state[1][0] = new float[128];
|
||||||
|
context = Array.Empty<float[]>();
|
||||||
|
lastSr = 0;
|
||||||
|
lastBatchSize = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Dispose()
|
||||||
|
{
|
||||||
|
session?.Dispose();
|
||||||
|
}
|
||||||
|
|
||||||
|
public class ValidationResult
|
||||||
|
{
|
||||||
|
public float[][] X { get; }
|
||||||
|
public int Sr { get; }
|
||||||
|
|
||||||
|
public ValidationResult(float[][] x, int sr)
|
||||||
|
{
|
||||||
|
X = x;
|
||||||
|
Sr = sr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ValidationResult ValidateInput(float[][] x, int sr)
|
||||||
|
{
|
||||||
|
if (x.Length == 1)
|
||||||
|
{
|
||||||
|
x = new float[][] { x[0] };
|
||||||
|
}
|
||||||
|
if (x.Length > 2)
|
||||||
|
{
|
||||||
|
throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sr != 16000 && (sr % 16000 == 0))
|
||||||
|
{
|
||||||
|
int step = sr / 16000;
|
||||||
|
float[][] reducedX = new float[x.Length][];
|
||||||
|
|
||||||
|
for (int i = 0; i < x.Length; i++)
|
||||||
|
{
|
||||||
|
float[] current = x[i];
|
||||||
|
float[] newArr = new float[(current.Length + step - 1) / step];
|
||||||
|
|
||||||
|
for (int j = 0, index = 0; j < current.Length; j += step, index++)
|
||||||
|
{
|
||||||
|
newArr[index] = current[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
reducedX[i] = newArr;
|
||||||
|
}
|
||||||
|
|
||||||
|
x = reducedX;
|
||||||
|
sr = 16000;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!SAMPLE_RATES.Contains(sr))
|
||||||
|
{
|
||||||
|
throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (((float)sr) / x[0].Length > 31.25)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("Input audio is too short");
|
||||||
|
}
|
||||||
|
|
||||||
|
return new ValidationResult(x, sr);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[][] Concatenate(float[][] a, float[][] b)
|
||||||
|
{
|
||||||
|
if (a.Length != b.Length)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("The number of rows in both arrays must be the same.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int rows = a.Length;
|
||||||
|
int colsA = a[0].Length;
|
||||||
|
int colsB = b[0].Length;
|
||||||
|
float[][] result = new float[rows][];
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++)
|
||||||
|
{
|
||||||
|
result[i] = new float[colsA + colsB];
|
||||||
|
Array.Copy(a[i], 0, result[i], 0, colsA);
|
||||||
|
Array.Copy(b[i], 0, result[i], colsA, colsB);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[][] GetLastColumns(float[][] array, int contextSize)
|
||||||
|
{
|
||||||
|
int rows = array.Length;
|
||||||
|
int cols = array[0].Length;
|
||||||
|
|
||||||
|
if (contextSize > cols)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("contextSize cannot be greater than the number of columns in the array.");
|
||||||
|
}
|
||||||
|
|
||||||
|
float[][] result = new float[rows][];
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++)
|
||||||
|
{
|
||||||
|
result[i] = new float[contextSize];
|
||||||
|
Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float[] Call(float[][] x, int sr)
|
||||||
|
{
|
||||||
|
var result = ValidateInput(x, sr);
|
||||||
|
x = result.X;
|
||||||
|
sr = result.Sr;
|
||||||
|
int numberSamples = sr == 16000 ? 512 : 256;
|
||||||
|
|
||||||
|
if (x[0].Length != numberSamples)
|
||||||
|
{
|
||||||
|
throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)");
|
||||||
|
}
|
||||||
|
|
||||||
|
int batchSize = x.Length;
|
||||||
|
int contextSize = sr == 16000 ? 64 : 32;
|
||||||
|
|
||||||
|
if (lastBatchSize == 0)
|
||||||
|
{
|
||||||
|
ResetStates();
|
||||||
|
}
|
||||||
|
if (lastSr != 0 && lastSr != sr)
|
||||||
|
{
|
||||||
|
ResetStates();
|
||||||
|
}
|
||||||
|
if (lastBatchSize != 0 && lastBatchSize != batchSize)
|
||||||
|
{
|
||||||
|
ResetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (context.Length == 0)
|
||||||
|
{
|
||||||
|
context = new float[batchSize][];
|
||||||
|
for (int i = 0; i < batchSize; i++)
|
||||||
|
{
|
||||||
|
context[i] = new float[contextSize];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
x = Concatenate(context, x);
|
||||||
|
|
||||||
|
var inputs = new List<NamedOnnxValue>
|
||||||
|
{
|
||||||
|
NamedOnnxValue.CreateFromTensor("input", new DenseTensor<float>(x.SelectMany(a => a).ToArray(), new[] { x.Length, x[0].Length })),
|
||||||
|
NamedOnnxValue.CreateFromTensor("sr", new DenseTensor<long>(new[] { (long)sr }, new[] { 1 })),
|
||||||
|
NamedOnnxValue.CreateFromTensor("state", new DenseTensor<float>(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), new[] { state.Length, state[0].Length, state[0][0].Length }))
|
||||||
|
};
|
||||||
|
|
||||||
|
using (var outputs = session.Run(inputs))
|
||||||
|
{
|
||||||
|
var output = outputs.First(o => o.Name == "output").AsTensor<float>();
|
||||||
|
var newState = outputs.First(o => o.Name == "stateN").AsTensor<float>();
|
||||||
|
|
||||||
|
context = GetLastColumns(x, contextSize);
|
||||||
|
lastSr = sr;
|
||||||
|
lastBatchSize = batchSize;
|
||||||
|
|
||||||
|
state = new float[newState.Dimensions[0]][][];
|
||||||
|
for (int i = 0; i < newState.Dimensions[0]; i++)
|
||||||
|
{
|
||||||
|
state[i] = new float[newState.Dimensions[1]][];
|
||||||
|
for (int j = 0; j < newState.Dimensions[1]; j++)
|
||||||
|
{
|
||||||
|
state[i][j] = new float[newState.Dimensions[2]];
|
||||||
|
for (int k = 0; k < newState.Dimensions[2]; k++)
|
||||||
|
{
|
||||||
|
state[i][j][k] = newState[i, j, k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return output.ToArray();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
25
examples/csharp/VadDotNet.csproj
Normal file
25
examples/csharp/VadDotNet.csproj
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
<Project Sdk="Microsoft.NET.Sdk">
|
||||||
|
|
||||||
|
<PropertyGroup>
|
||||||
|
<OutputType>Exe</OutputType>
|
||||||
|
<TargetFramework>net8.0</TargetFramework>
|
||||||
|
<ImplicitUsings>enable</ImplicitUsings>
|
||||||
|
<Nullable>enable</Nullable>
|
||||||
|
</PropertyGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.18.1" />
|
||||||
|
<PackageReference Include="NAudio" Version="2.2.1" />
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<Folder Include="resources\" />
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<Content Include="resources\**">
|
||||||
|
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||||
|
</Content>
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
</Project>
|
||||||
1
examples/csharp/resources/put_model_here.txt
Normal file
1
examples/csharp/resources/put_model_here.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
place onnx model file and example.wav file in this folder
|
||||||
@@ -11,11 +11,11 @@ import (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
sd, err := speech.NewDetector(speech.DetectorConfig{
|
sd, err := speech.NewDetector(speech.DetectorConfig{
|
||||||
ModelPath: "../../files/silero_vad.onnx",
|
ModelPath: "../../src/silero_vad/data/silero_vad.onnx",
|
||||||
SampleRate: 16000,
|
SampleRate: 16000,
|
||||||
Threshold: 0.5,
|
Threshold: 0.5,
|
||||||
MinSilenceDurationMs: 0,
|
MinSilenceDurationMs: 100,
|
||||||
SpeechPadMs: 0,
|
SpeechPadMs: 30,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create speech detector: %s", err)
|
log.Fatalf("failed to create speech detector: %s", err)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ go 1.21.4
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-audio/wav v1.1.0
|
github.com/go-audio/wav v1.1.0
|
||||||
github.com/streamer45/silero-vad-go v0.2.0
|
github.com/streamer45/silero-vad-go v0.2.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/streamer45/silero-vad-go v0.2.0 h1:bbRTa6cQuc7VI88y0qicx375UyWoxE6wlVOF+mUg0+g=
|
github.com/streamer45/silero-vad-go v0.2.0 h1:bbRTa6cQuc7VI88y0qicx375UyWoxE6wlVOF+mUg0+g=
|
||||||
github.com/streamer45/silero-vad-go v0.2.0/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
|
github.com/streamer45/silero-vad-go v0.2.0/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
|
||||||
|
github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk=
|
||||||
|
github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
|
||||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
package org.example;
|
||||||
|
|
||||||
|
import ai.onnxruntime.OrtException;
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class App {
|
||||||
|
|
||||||
|
private static final String MODEL_PATH = "/path/silero_vad.onnx";
|
||||||
|
private static final String EXAMPLE_WAV_FILE = "/path/example.wav";
|
||||||
|
private static final int SAMPLE_RATE = 16000;
|
||||||
|
private static final float THRESHOLD = 0.5f;
|
||||||
|
private static final int MIN_SPEECH_DURATION_MS = 250;
|
||||||
|
private static final float MAX_SPEECH_DURATION_SECONDS = Float.POSITIVE_INFINITY;
|
||||||
|
private static final int MIN_SILENCE_DURATION_MS = 100;
|
||||||
|
private static final int SPEECH_PAD_MS = 30;
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
// Initialize the Voice Activity Detector
|
||||||
|
SileroVadDetector vadDetector;
|
||||||
|
try {
|
||||||
|
vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
|
||||||
|
MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
|
||||||
|
fromWavFile(vadDetector, new File(EXAMPLE_WAV_FILE));
|
||||||
|
} catch (OrtException e) {
|
||||||
|
System.err.println("Error initializing the VAD detector: " + e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void fromWavFile(SileroVadDetector vadDetector, File wavFile) {
|
||||||
|
List<SileroSpeechSegment> speechTimeList = vadDetector.getSpeechSegmentList(wavFile);
|
||||||
|
for (SileroSpeechSegment speechSegment : speechTimeList) {
|
||||||
|
System.out.println(String.format("start second: %f, end second: %f",
|
||||||
|
speechSegment.getStartSecond(), speechSegment.getEndSecond()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package org.example;
|
||||||
|
|
||||||
|
|
||||||
|
public class SileroSpeechSegment {
|
||||||
|
private Integer startOffset;
|
||||||
|
private Integer endOffset;
|
||||||
|
private Float startSecond;
|
||||||
|
private Float endSecond;
|
||||||
|
|
||||||
|
public SileroSpeechSegment() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public SileroSpeechSegment(Integer startOffset, Integer endOffset, Float startSecond, Float endSecond) {
|
||||||
|
this.startOffset = startOffset;
|
||||||
|
this.endOffset = endOffset;
|
||||||
|
this.startSecond = startSecond;
|
||||||
|
this.endSecond = endSecond;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getStartOffset() {
|
||||||
|
return startOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEndOffset() {
|
||||||
|
return endOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Float getStartSecond() {
|
||||||
|
return startSecond;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Float getEndSecond() {
|
||||||
|
return endSecond;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setStartOffset(Integer startOffset) {
|
||||||
|
this.startOffset = startOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setEndOffset(Integer endOffset) {
|
||||||
|
this.endOffset = endOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setStartSecond(Float startSecond) {
|
||||||
|
this.startSecond = startSecond;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setEndSecond(Float endSecond) {
|
||||||
|
this.endSecond = endSecond;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,244 @@
|
|||||||
|
package org.example;
|
||||||
|
|
||||||
|
|
||||||
|
import ai.onnxruntime.OrtException;
|
||||||
|
|
||||||
|
import javax.sound.sampled.AudioInputStream;
|
||||||
|
import javax.sound.sampled.AudioSystem;
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class SileroVadDetector {
|
||||||
|
private final SileroVadOnnxModel model;
|
||||||
|
private final float threshold;
|
||||||
|
private final float negThreshold;
|
||||||
|
private final int samplingRate;
|
||||||
|
private final int windowSizeSample;
|
||||||
|
private final float minSpeechSamples;
|
||||||
|
private final float speechPadSamples;
|
||||||
|
private final float maxSpeechSamples;
|
||||||
|
private final float minSilenceSamples;
|
||||||
|
private final float minSilenceSamplesAtMaxSpeech;
|
||||||
|
private int audioLengthSamples;
|
||||||
|
private static final float THRESHOLD_GAP = 0.15f;
|
||||||
|
private static final Integer SAMPLING_RATE_8K = 8000;
|
||||||
|
private static final Integer SAMPLING_RATE_16K = 16000;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor
|
||||||
|
* @param onnxModelPath the path of silero-vad onnx model
|
||||||
|
* @param threshold threshold for speech start
|
||||||
|
* @param samplingRate audio sampling rate, only available for [8k, 16k]
|
||||||
|
* @param minSpeechDurationMs Minimum speech length in millis, any speech duration that smaller than this value would not be considered as speech
|
||||||
|
* @param maxSpeechDurationSeconds Maximum speech length in millis, recommend to be set as Float.POSITIVE_INFINITY
|
||||||
|
* @param minSilenceDurationMs Minimum silence length in millis, any silence duration that smaller than this value would not be considered as silence
|
||||||
|
* @param speechPadMs Additional pad millis for speech start and end
|
||||||
|
* @throws OrtException
|
||||||
|
*/
|
||||||
|
public SileroVadDetector(String onnxModelPath, float threshold, int samplingRate,
|
||||||
|
int minSpeechDurationMs, float maxSpeechDurationSeconds,
|
||||||
|
int minSilenceDurationMs, int speechPadMs) throws OrtException {
|
||||||
|
if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K) {
|
||||||
|
throw new IllegalArgumentException("Sampling rate not support, only available for [8000, 16000]");
|
||||||
|
}
|
||||||
|
this.model = new SileroVadOnnxModel(onnxModelPath);
|
||||||
|
this.samplingRate = samplingRate;
|
||||||
|
this.threshold = threshold;
|
||||||
|
this.negThreshold = threshold - THRESHOLD_GAP;
|
||||||
|
if (samplingRate == SAMPLING_RATE_16K) {
|
||||||
|
this.windowSizeSample = 512;
|
||||||
|
} else {
|
||||||
|
this.windowSizeSample = 256;
|
||||||
|
}
|
||||||
|
this.minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
|
||||||
|
this.speechPadSamples = samplingRate * speechPadMs / 1000f;
|
||||||
|
this.maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - windowSizeSample - 2 * speechPadSamples;
|
||||||
|
this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
|
||||||
|
this.minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
|
||||||
|
this.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Method to reset the state
|
||||||
|
*/
|
||||||
|
public void reset() {
|
||||||
|
model.resetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get speech segment list by given wav-format file
|
||||||
|
* @param wavFile wav file
|
||||||
|
* @return list of speech segment
|
||||||
|
*/
|
||||||
|
public List<SileroSpeechSegment> getSpeechSegmentList(File wavFile) {
|
||||||
|
reset();
|
||||||
|
try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(wavFile)){
|
||||||
|
List<Float> speechProbList = new ArrayList<>();
|
||||||
|
this.audioLengthSamples = audioInputStream.available() / 2;
|
||||||
|
byte[] data = new byte[this.windowSizeSample * 2];
|
||||||
|
int numBytesRead = 0;
|
||||||
|
|
||||||
|
while ((numBytesRead = audioInputStream.read(data)) != -1) {
|
||||||
|
if (numBytesRead <= 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// Convert the byte array to a float array
|
||||||
|
float[] audioData = new float[data.length / 2];
|
||||||
|
for (int i = 0; i < audioData.length; i++) {
|
||||||
|
audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
float speechProb = 0;
|
||||||
|
try {
|
||||||
|
speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
|
||||||
|
speechProbList.add(speechProb);
|
||||||
|
} catch (OrtException e) {
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return calculateProb(speechProbList);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("SileroVadDetector getSpeechTimeList with error", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate speech segement by probability
|
||||||
|
* @param speechProbList speech probability list
|
||||||
|
* @return list of speech segment
|
||||||
|
*/
|
||||||
|
private List<SileroSpeechSegment> calculateProb(List<Float> speechProbList) {
|
||||||
|
List<SileroSpeechSegment> result = new ArrayList<>();
|
||||||
|
boolean triggered = false;
|
||||||
|
int tempEnd = 0, prevEnd = 0, nextStart = 0;
|
||||||
|
SileroSpeechSegment segment = new SileroSpeechSegment();
|
||||||
|
|
||||||
|
for (int i = 0; i < speechProbList.size(); i++) {
|
||||||
|
Float speechProb = speechProbList.get(i);
|
||||||
|
if (speechProb >= threshold && (tempEnd != 0)) {
|
||||||
|
tempEnd = 0;
|
||||||
|
if (nextStart < prevEnd) {
|
||||||
|
nextStart = windowSizeSample * i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speechProb >= threshold && !triggered) {
|
||||||
|
triggered = true;
|
||||||
|
segment.setStartOffset(windowSizeSample * i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (triggered && (windowSizeSample * i) - segment.getStartOffset() > maxSpeechSamples) {
|
||||||
|
if (prevEnd != 0) {
|
||||||
|
segment.setEndOffset(prevEnd);
|
||||||
|
result.add(segment);
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
if (nextStart < prevEnd) {
|
||||||
|
triggered = false;
|
||||||
|
}else {
|
||||||
|
segment.setStartOffset(nextStart);
|
||||||
|
}
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
}else {
|
||||||
|
segment.setEndOffset(windowSizeSample * i);
|
||||||
|
result.add(segment);
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
triggered = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speechProb < negThreshold && triggered) {
|
||||||
|
if (tempEnd == 0) {
|
||||||
|
tempEnd = windowSizeSample * i;
|
||||||
|
}
|
||||||
|
if (((windowSizeSample * i) - tempEnd) > minSilenceSamplesAtMaxSpeech) {
|
||||||
|
prevEnd = tempEnd;
|
||||||
|
}
|
||||||
|
if ((windowSizeSample * i) - tempEnd < minSilenceSamples) {
|
||||||
|
continue;
|
||||||
|
}else {
|
||||||
|
segment.setEndOffset(tempEnd);
|
||||||
|
if ((segment.getEndOffset() - segment.getStartOffset()) > minSpeechSamples) {
|
||||||
|
result.add(segment);
|
||||||
|
}
|
||||||
|
segment = new SileroSpeechSegment();
|
||||||
|
prevEnd = 0;
|
||||||
|
nextStart = 0;
|
||||||
|
tempEnd = 0;
|
||||||
|
triggered = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (segment.getStartOffset() != null && (audioLengthSamples - segment.getStartOffset()) > minSpeechSamples) {
|
||||||
|
segment.setEndOffset(audioLengthSamples);
|
||||||
|
result.add(segment);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < result.size(); i++) {
|
||||||
|
SileroSpeechSegment item = result.get(i);
|
||||||
|
if (i == 0) {
|
||||||
|
item.setStartOffset((int)(Math.max(0,item.getStartOffset() - speechPadSamples)));
|
||||||
|
}
|
||||||
|
if (i != result.size() - 1) {
|
||||||
|
SileroSpeechSegment nextItem = result.get(i + 1);
|
||||||
|
Integer silenceDuration = nextItem.getStartOffset() - item.getEndOffset();
|
||||||
|
if(silenceDuration < 2 * speechPadSamples){
|
||||||
|
item.setEndOffset(item.getEndOffset() + (silenceDuration / 2 ));
|
||||||
|
nextItem.setStartOffset(Math.max(0, nextItem.getStartOffset() - (silenceDuration / 2)));
|
||||||
|
} else {
|
||||||
|
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
|
||||||
|
nextItem.setStartOffset((int)(Math.max(0,nextItem.getStartOffset() - speechPadSamples)));
|
||||||
|
}
|
||||||
|
}else {
|
||||||
|
item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mergeListAndCalculateSecond(result, samplingRate);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SileroSpeechSegment> mergeListAndCalculateSecond(List<SileroSpeechSegment> original, Integer samplingRate) {
|
||||||
|
List<SileroSpeechSegment> result = new ArrayList<>();
|
||||||
|
if (original == null || original.size() == 0) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
Integer left = original.get(0).getStartOffset();
|
||||||
|
Integer right = original.get(0).getEndOffset();
|
||||||
|
if (original.size() > 1) {
|
||||||
|
original.sort(Comparator.comparingLong(SileroSpeechSegment::getStartOffset));
|
||||||
|
for (int i = 1; i < original.size(); i++) {
|
||||||
|
SileroSpeechSegment segment = original.get(i);
|
||||||
|
|
||||||
|
if (segment.getStartOffset() > right) {
|
||||||
|
result.add(new SileroSpeechSegment(left, right,
|
||||||
|
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||||
|
left = segment.getStartOffset();
|
||||||
|
right = segment.getEndOffset();
|
||||||
|
} else {
|
||||||
|
right = Math.max(right, segment.getEndOffset());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.add(new SileroSpeechSegment(left, right,
|
||||||
|
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||||
|
}else {
|
||||||
|
result.add(new SileroSpeechSegment(left, right,
|
||||||
|
calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Float calculateSecondByOffset(Integer offset, Integer samplingRate) {
|
||||||
|
float secondValue = offset * 1.0f / samplingRate;
|
||||||
|
return (float) Math.floor(secondValue * 1000.0f) / 1000.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,234 @@
|
|||||||
|
package org.example;
|
||||||
|
|
||||||
|
import ai.onnxruntime.OnnxTensor;
|
||||||
|
import ai.onnxruntime.OrtEnvironment;
|
||||||
|
import ai.onnxruntime.OrtException;
|
||||||
|
import ai.onnxruntime.OrtSession;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class SileroVadOnnxModel {
|
||||||
|
// Define private variable OrtSession
|
||||||
|
private final OrtSession session;
|
||||||
|
private float[][][] state;
|
||||||
|
private float[][] context;
|
||||||
|
// Define the last sample rate
|
||||||
|
private int lastSr = 0;
|
||||||
|
// Define the last batch size
|
||||||
|
private int lastBatchSize = 0;
|
||||||
|
// Define a list of supported sample rates
|
||||||
|
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
public SileroVadOnnxModel(String modelPath) throws OrtException {
|
||||||
|
// Get the ONNX runtime environment
|
||||||
|
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||||
|
// Create an ONNX session options object
|
||||||
|
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
||||||
|
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
|
||||||
|
opts.setInterOpNumThreads(1);
|
||||||
|
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
|
||||||
|
opts.setIntraOpNumThreads(1);
|
||||||
|
// Add a CPU device, setting to false disables CPU execution optimization
|
||||||
|
opts.addCPU(true);
|
||||||
|
// Create an ONNX session using the environment, model path, and options
|
||||||
|
session = env.createSession(modelPath, opts);
|
||||||
|
// Reset states
|
||||||
|
resetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset states
|
||||||
|
*/
|
||||||
|
void resetStates() {
|
||||||
|
state = new float[2][1][128];
|
||||||
|
context = new float[0][];
|
||||||
|
lastSr = 0;
|
||||||
|
lastBatchSize = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void close() throws OrtException {
|
||||||
|
session.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define inner class ValidationResult
|
||||||
|
*/
|
||||||
|
public static class ValidationResult {
|
||||||
|
public final float[][] x;
|
||||||
|
public final int sr;
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
public ValidationResult(float[][] x, int sr) {
|
||||||
|
this.x = x;
|
||||||
|
this.sr = sr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Function to validate input data
|
||||||
|
*/
|
||||||
|
private ValidationResult validateInput(float[][] x, int sr) {
|
||||||
|
// Process the input data with dimension 1
|
||||||
|
if (x.length == 1) {
|
||||||
|
x = new float[][]{x[0]};
|
||||||
|
}
|
||||||
|
// Throw an exception when the input data dimension is greater than 2
|
||||||
|
if (x.length > 2) {
|
||||||
|
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
|
||||||
|
if (sr != 16000 && (sr % 16000 == 0)) {
|
||||||
|
int step = sr / 16000;
|
||||||
|
float[][] reducedX = new float[x.length][];
|
||||||
|
|
||||||
|
for (int i = 0; i < x.length; i++) {
|
||||||
|
float[] current = x[i];
|
||||||
|
float[] newArr = new float[(current.length + step - 1) / step];
|
||||||
|
|
||||||
|
for (int j = 0, index = 0; j < current.length; j += step, index++) {
|
||||||
|
newArr[index] = current[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
reducedX[i] = newArr;
|
||||||
|
}
|
||||||
|
|
||||||
|
x = reducedX;
|
||||||
|
sr = 16000;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the sample rate is not in the list of supported sample rates, throw an exception
|
||||||
|
if (!SAMPLE_RATES.contains(sr)) {
|
||||||
|
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the input audio block is too short, throw an exception
|
||||||
|
if (((float) sr) / x[0].length > 31.25) {
|
||||||
|
throw new IllegalArgumentException("Input audio is too short");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the validated result
|
||||||
|
return new ValidationResult(x, sr);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[][] concatenate(float[][] a, float[][] b) {
|
||||||
|
if (a.length != b.length) {
|
||||||
|
throw new IllegalArgumentException("The number of rows in both arrays must be the same.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int rows = a.length;
|
||||||
|
int colsA = a[0].length;
|
||||||
|
int colsB = b[0].length;
|
||||||
|
float[][] result = new float[rows][colsA + colsB];
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++) {
|
||||||
|
System.arraycopy(a[i], 0, result[i], 0, colsA);
|
||||||
|
System.arraycopy(b[i], 0, result[i], colsA, colsB);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float[][] getLastColumns(float[][] array, int contextSize) {
|
||||||
|
int rows = array.length;
|
||||||
|
int cols = array[0].length;
|
||||||
|
|
||||||
|
if (contextSize > cols) {
|
||||||
|
throw new IllegalArgumentException("contextSize cannot be greater than the number of columns in the array.");
|
||||||
|
}
|
||||||
|
|
||||||
|
float[][] result = new float[rows][contextSize];
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++) {
|
||||||
|
System.arraycopy(array[i], cols - contextSize, result[i], 0, contextSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Method to call the ONNX model
|
||||||
|
*/
|
||||||
|
public float[] call(float[][] x, int sr) throws OrtException {
|
||||||
|
ValidationResult result = validateInput(x, sr);
|
||||||
|
x = result.x;
|
||||||
|
sr = result.sr;
|
||||||
|
int numberSamples = 256;
|
||||||
|
if (sr == 16000) {
|
||||||
|
numberSamples = 512;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x[0].length != numberSamples) {
|
||||||
|
throw new IllegalArgumentException("Provided number of samples is " + x[0].length + " (Supported values: 256 for 8000 sample rate, 512 for 16000)");
|
||||||
|
}
|
||||||
|
|
||||||
|
int batchSize = x.length;
|
||||||
|
|
||||||
|
int contextSize = 32;
|
||||||
|
if (sr == 16000) {
|
||||||
|
contextSize = 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lastBatchSize == 0) {
|
||||||
|
resetStates();
|
||||||
|
}
|
||||||
|
if (lastSr != 0 && lastSr != sr) {
|
||||||
|
resetStates();
|
||||||
|
}
|
||||||
|
if (lastBatchSize != 0 && lastBatchSize != batchSize) {
|
||||||
|
resetStates();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (context.length == 0) {
|
||||||
|
context = new float[batchSize][contextSize];
|
||||||
|
}
|
||||||
|
|
||||||
|
x = concatenate(context, x);
|
||||||
|
|
||||||
|
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||||
|
|
||||||
|
OnnxTensor inputTensor = null;
|
||||||
|
OnnxTensor stateTensor = null;
|
||||||
|
OnnxTensor srTensor = null;
|
||||||
|
OrtSession.Result ortOutputs = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Create input tensors
|
||||||
|
inputTensor = OnnxTensor.createTensor(env, x);
|
||||||
|
stateTensor = OnnxTensor.createTensor(env, state);
|
||||||
|
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
|
||||||
|
|
||||||
|
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||||
|
inputs.put("input", inputTensor);
|
||||||
|
inputs.put("sr", srTensor);
|
||||||
|
inputs.put("state", stateTensor);
|
||||||
|
|
||||||
|
// Call the ONNX model for calculation
|
||||||
|
ortOutputs = session.run(inputs);
|
||||||
|
// Get the output results
|
||||||
|
float[][] output = (float[][]) ortOutputs.get(0).getValue();
|
||||||
|
state = (float[][][]) ortOutputs.get(1).getValue();
|
||||||
|
|
||||||
|
context = getLastColumns(x, contextSize);
|
||||||
|
lastSr = sr;
|
||||||
|
lastBatchSize = batchSize;
|
||||||
|
return output[0];
|
||||||
|
} finally {
|
||||||
|
if (inputTensor != null) {
|
||||||
|
inputTensor.close();
|
||||||
|
}
|
||||||
|
if (stateTensor != null) {
|
||||||
|
stateTensor.close();
|
||||||
|
}
|
||||||
|
if (srTensor != null) {
|
||||||
|
srTensor.close();
|
||||||
|
}
|
||||||
|
if (ortOutputs != null) {
|
||||||
|
ortOutputs.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -31,11 +31,11 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"#!pip install numpy==1.20.2\n",
|
"#!pip install numpy==2.0.2\n",
|
||||||
"#!pip install torch==1.9.0\n",
|
"#!pip install torch==2.4.1\n",
|
||||||
"#!pip install matplotlib==3.4.2\n",
|
"#!pip install matplotlib==3.9.2\n",
|
||||||
"#!pip install torchaudio==0.9.0\n",
|
"#!pip install torchaudio==2.4.1\n",
|
||||||
"#!pip install soundfile==0.10.3.post1\n",
|
"#!pip install soundfile==0.12.1\n",
|
||||||
"#!pip install pyaudio==0.2.11"
|
"#!pip install pyaudio==0.2.11"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -61,7 +61,6 @@
|
|||||||
"import torchaudio\n",
|
"import torchaudio\n",
|
||||||
"import matplotlib\n",
|
"import matplotlib\n",
|
||||||
"import matplotlib.pylab as plt\n",
|
"import matplotlib.pylab as plt\n",
|
||||||
"torchaudio.set_audio_backend(\"soundfile\")\n",
|
|
||||||
"import pyaudio"
|
"import pyaudio"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -162,7 +161,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"num_samples = 1536"
|
"num_samples = 512"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -180,6 +179,8 @@
|
|||||||
"data = []\n",
|
"data = []\n",
|
||||||
"voiced_confidences = []\n",
|
"voiced_confidences = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"frames_to_record = 50\n",
|
||||||
|
"\n",
|
||||||
"print(\"Started Recording\")\n",
|
"print(\"Started Recording\")\n",
|
||||||
"for i in range(0, frames_to_record):\n",
|
"for i in range(0, frames_to_record):\n",
|
||||||
" \n",
|
" \n",
|
||||||
@@ -296,7 +297,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@@ -310,7 +311,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.7.10"
|
"version": "3.9.10"
|
||||||
},
|
},
|
||||||
"toc": {
|
"toc": {
|
||||||
"base_numbering": 1,
|
"base_numbering": 1,
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
use crate::utils;
|
use crate::utils;
|
||||||
use ndarray::{Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Silero {
|
pub struct Silero {
|
||||||
session: ort::Session,
|
session: ort::Session,
|
||||||
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
|
||||||
h: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
||||||
c: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Silero {
|
impl Silero {
|
||||||
@@ -16,20 +15,17 @@ impl Silero {
|
|||||||
model_path: impl AsRef<Path>,
|
model_path: impl AsRef<Path>,
|
||||||
) -> Result<Self, ort::Error> {
|
) -> Result<Self, ort::Error> {
|
||||||
let session = ort::Session::builder()?.commit_from_file(model_path)?;
|
let session = ort::Session::builder()?.commit_from_file(model_path)?;
|
||||||
let h = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||||
let c = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
|
||||||
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
|
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
session,
|
session,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
h,
|
state,
|
||||||
c,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset(&mut self) {
|
pub fn reset(&mut self) {
|
||||||
self.h = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
|
||||||
self.c = ArrayD::<f32>::zeros([2, 1, 64].as_slice());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
|
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
|
||||||
@@ -37,18 +33,17 @@ impl Silero {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|x| (*x as f32) / (i16::MAX as f32))
|
.map(|x| (*x as f32) / (i16::MAX as f32))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
|
let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
|
||||||
|
frame = frame.slice(s![.., ..480]).to_owned();
|
||||||
let inps = ort::inputs![
|
let inps = ort::inputs![
|
||||||
frame,
|
frame,
|
||||||
|
std::mem::take(&mut self.state),
|
||||||
self.sample_rate.clone(),
|
self.sample_rate.clone(),
|
||||||
std::mem::take(&mut self.h),
|
|
||||||
std::mem::take(&mut self.c)
|
|
||||||
]?;
|
]?;
|
||||||
let res = self
|
let res = self
|
||||||
.session
|
.session
|
||||||
.run(ort::SessionInputs::ValueSlice::<4>(&inps))?;
|
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
|
||||||
self.h = res["hn"].try_extract_tensor().unwrap().to_owned();
|
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
|
||||||
self.c = res["cn"].try_extract_tensor().unwrap().to_owned();
|
|
||||||
Ok(*res["output"]
|
Ok(*res["output"]
|
||||||
.try_extract_raw_tensor::<f32>()
|
.try_extract_raw_tensor::<f32>()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ impl VadIter {
|
|||||||
pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> {
|
pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> {
|
||||||
self.reset_states();
|
self.reset_states();
|
||||||
for audio_frame in samples.chunks_exact(self.params.frame_size_samples) {
|
for audio_frame in samples.chunks_exact(self.params.frame_size_samples) {
|
||||||
let speech_prob = self.silero.calc_level(audio_frame)?;
|
let speech_prob: f32 = self.silero.calc_level(audio_frame)?;
|
||||||
self.state.update(&self.params, speech_prob);
|
self.state.update(&self.params, speech_prob);
|
||||||
}
|
}
|
||||||
self.state.check_for_last_speech(samples.len());
|
self.state.check_for_last_speech(samples.len());
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
{"59": "mg, Malagasy", "76": "tk, Turkmen", "20": "lb, Luxembourgish, Letzeburgesch", "62": "or, Oriya", "30": "en, English", "26": "oc, Occitan", "69": "no, Norwegian", "77": "sr, Serbian", "90": "bs, Bosnian", "71": "el, Greek, Modern (1453\u2013)", "15": "az, Azerbaijani", "12": "lo, Lao", "85": "zh-HK, Chinese", "79": "cs, Czech", "43": "sv, Swedish", "37": "mn, Mongolian", "32": "fi, Finnish", "51": "tg, Tajik", "46": "am, Amharic", "17": "nn, Norwegian Nynorsk", "40": "ja, Japanese", "8": "it, Italian", "21": "ha, Hausa", "11": "as, Assamese", "29": "fa, Persian", "82": "bn, Bengali", "54": "mk, Macedonian", "31": "sw, Swahili", "45": "vi, Vietnamese", "41": "ur, Urdu", "74": "bo, Tibetan", "4": "hi, Hindi", "86": "mr, Marathi", "3": "fy-NL, Western Frisian", "65": "sk, Slovak", "2": "ln, Lingala", "92": "gl, Galician", "53": "sn, Shona", "87": "su, Sundanese", "35": "tt, Tatar", "93": "kn, Kannada", "6": "yo, Yoruba", "27": "ps, Pashto, Pushto", "34": "hy, Armenian", "25": "pa-IN, Punjabi, Panjabi", "23": "nl, Dutch, Flemish", "48": "th, Thai", "73": "mt, Maltese", "55": "ar, Arabic", "89": "ba, Bashkir", "78": "bg, Bulgarian", "42": "yi, Yiddish", "5": "ru, Russian", "84": "sv-SE, Swedish", "80": "tr, Turkish", "33": "sq, Albanian", "38": "kk, Kazakh", "50": "pl, Polish", "9": "hr, Croatian", "66": "ky, Kirghiz, Kyrgyz", "49": "hu, Hungarian", "10": "si, Sinhala, Sinhalese", "56": "la, Latin", "75": "de, German", "14": "ko, Korean", "22": "id, Indonesian", "47": "sl, Slovenian", "57": "be, Belarusian", "36": "ta, Tamil", "7": "da, Danish", "91": "sd, Sindhi", "28": "et, Estonian", "63": "pt, Portuguese", "60": "ne, Nepali", "94": "zh-TW, Chinese", "18": "zh-CN, Chinese", "88": "rw, Kinyarwanda", "19": "es, Spanish, Castilian", "39": "ht, Haitian, Haitian Creole", "64": "tl, Tagalog", "83": "ms, Malay", "70": "ro, Romanian, Moldavian, Moldovan", "68": "pa, Punjabi, Panjabi", "52": "uz, Uzbek", "58": "km, Central Khmer", "67": "my, Burmese", "0": "fr, French", "24": "af, Afrikaans", "16": "gu, Gujarati", "81": "so, Somali", "13": "uk, Ukrainian", "44": "ca, Catalan, Valencian", "72": "ml, Malayalam", "61": "te, Telugu", "1": "zh, Chinese"}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
{"0": ["Afrikaans", "Dutch, Flemish", "Western Frisian"], "1": ["Turkish", "Azerbaijani"], "2": ["Russian", "Slovak", "Ukrainian", "Czech", "Polish", "Belarusian"], "3": ["Bulgarian", "Macedonian", "Serbian", "Croatian", "Bosnian", "Slovenian"], "4": ["Norwegian Nynorsk", "Swedish", "Danish", "Norwegian"], "5": ["English"], "6": ["Finnish", "Estonian"], "7": ["Yiddish", "Luxembourgish, Letzeburgesch", "German"], "8": ["Spanish", "Occitan", "Portuguese", "Catalan, Valencian", "Galician", "Spanish, Castilian", "Italian"], "9": ["Maltese", "Arabic"], "10": ["Marathi"], "11": ["Hindi", "Urdu"], "12": ["Lao", "Thai"], "13": ["Malay", "Indonesian"], "14": ["Romanian, Moldavian, Moldovan"], "15": ["Tagalog"], "16": ["Tajik", "Persian"], "17": ["Kazakh", "Uzbek", "Kirghiz, Kyrgyz"], "18": ["Kinyarwanda"], "19": ["Tatar", "Bashkir"], "20": ["French"], "21": ["Chinese"], "22": ["Lingala"], "23": ["Yoruba"], "24": ["Sinhala, Sinhalese"], "25": ["Assamese"], "26": ["Korean"], "27": ["Gujarati"], "28": ["Hausa"], "29": ["Punjabi, Panjabi"], "30": ["Pashto, Pushto"], "31": ["Swahili"], "32": ["Albanian"], "33": ["Armenian"], "34": ["Mongolian"], "35": ["Tamil"], "36": ["Haitian, Haitian Creole"], "37": ["Japanese"], "38": ["Vietnamese"], "39": ["Amharic"], "40": ["Hungarian"], "41": ["Shona"], "42": ["Latin"], "43": ["Central Khmer"], "44": ["Malagasy"], "45": ["Nepali"], "46": ["Telugu"], "47": ["Oriya"], "48": ["Burmese"], "49": ["Greek, Modern (1453\u2013)"], "50": ["Malayalam"], "51": ["Tibetan"], "52": ["Turkmen"], "53": ["Somali"], "54": ["Bengali"], "55": ["Sundanese"], "56": ["Sindhi"], "57": ["Kannada"]}
|
|
||||||
21
hubconf.py
21
hubconf.py
@@ -1,16 +1,15 @@
|
|||||||
dependencies = ['torch', 'torchaudio']
|
dependencies = ['torch', 'torchaudio']
|
||||||
import torch
|
import torch
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from utils_vad import (init_jit_model,
|
import sys
|
||||||
get_speech_timestamps,
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||||
save_audio,
|
from silero_vad.utils_vad import (init_jit_model,
|
||||||
read_audio,
|
get_speech_timestamps,
|
||||||
VADIterator,
|
save_audio,
|
||||||
collect_chunks,
|
read_audio,
|
||||||
drop_chunks,
|
VADIterator,
|
||||||
Validator,
|
collect_chunks,
|
||||||
OnnxWrapper)
|
OnnxWrapper)
|
||||||
|
|
||||||
|
|
||||||
def versiontuple(v):
|
def versiontuple(v):
|
||||||
@@ -36,7 +35,7 @@ def silero_vad(onnx=False, force_onnx_cpu=False):
|
|||||||
if versiontuple(installed_version) < versiontuple(supported_version):
|
if versiontuple(installed_version) < versiontuple(supported_version):
|
||||||
raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)')
|
raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)')
|
||||||
|
|
||||||
model_dir = os.path.join(os.path.dirname(__file__), 'files')
|
model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data')
|
||||||
if onnx:
|
if onnx:
|
||||||
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
|
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
|
||||||
else:
|
else:
|
||||||
|
|||||||
35
pyproject.toml
Normal file
35
pyproject.toml
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
[project]
|
||||||
|
name = "silero-vad"
|
||||||
|
version = "5.1"
|
||||||
|
authors = [
|
||||||
|
{name="Silero Team", email="hello@silero.ai"},
|
||||||
|
]
|
||||||
|
description = "Voice Activity Detector (VAD) by Silero"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 5 - Production/Stable",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Scientific/Engineering",
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"torch>=1.12.0",
|
||||||
|
"torchaudio>=0.12.0",
|
||||||
|
"onnxruntime>=1.16.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/snakers4/silero-vad"
|
||||||
|
Issues = "https://github.com/snakers4/silero-vad/issues"
|
||||||
@@ -43,20 +43,30 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"USE_PIP = True # download model using pip package or torch.hub\n",
|
||||||
"USE_ONNX = False # change this to True if you want to test onnx model\n",
|
"USE_ONNX = False # change this to True if you want to test onnx model\n",
|
||||||
"if USE_ONNX:\n",
|
"if USE_ONNX:\n",
|
||||||
" !pip install -q onnxruntime\n",
|
" !pip install -q onnxruntime\n",
|
||||||
|
"if USE_PIP:\n",
|
||||||
|
" !pip install -q silero-vad\n",
|
||||||
|
" from silero_vad import (load_silero_vad,\n",
|
||||||
|
" read_audio,\n",
|
||||||
|
" get_speech_timestamps,\n",
|
||||||
|
" save_audio,\n",
|
||||||
|
" VADIterator,\n",
|
||||||
|
" collect_chunks)\n",
|
||||||
|
" model = load_silero_vad(onnx=USE_ONNX)\n",
|
||||||
|
"else:\n",
|
||||||
|
" model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||||
|
" model='silero_vad',\n",
|
||||||
|
" force_reload=True,\n",
|
||||||
|
" onnx=USE_ONNX)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
" (get_speech_timestamps,\n",
|
||||||
" model='silero_vad',\n",
|
" save_audio,\n",
|
||||||
" force_reload=True,\n",
|
" read_audio,\n",
|
||||||
" onnx=USE_ONNX)\n",
|
" VADIterator,\n",
|
||||||
"\n",
|
" collect_chunks) = utils"
|
||||||
"(get_speech_timestamps,\n",
|
|
||||||
" save_audio,\n",
|
|
||||||
" read_audio,\n",
|
|
||||||
" VADIterator,\n",
|
|
||||||
" collect_chunks) = utils"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
12
src/silero_vad/__init__.py
Normal file
12
src/silero_vad/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from importlib.metadata import version
|
||||||
|
try:
|
||||||
|
__version__ = version(__name__)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from silero_vad.model import load_silero_vad
|
||||||
|
from silero_vad.utils_vad import (get_speech_timestamps,
|
||||||
|
save_audio,
|
||||||
|
read_audio,
|
||||||
|
VADIterator,
|
||||||
|
collect_chunks)
|
||||||
0
src/silero_vad/data/__init__.py
Normal file
0
src/silero_vad/data/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
BIN
src/silero_vad/data/silero_vad_half.onnx
Normal file
BIN
src/silero_vad/data/silero_vad_half.onnx
Normal file
Binary file not shown.
25
src/silero_vad/model.py
Normal file
25
src/silero_vad/model.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from .utils_vad import init_jit_model, OnnxWrapper
|
||||||
|
import torch
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
def load_silero_vad(onnx=False):
|
||||||
|
model_name = 'silero_vad.onnx' if onnx else 'silero_vad.jit'
|
||||||
|
package_path = "silero_vad.data"
|
||||||
|
|
||||||
|
try:
|
||||||
|
import importlib_resources as impresources
|
||||||
|
model_file_path = str(impresources.files(package_path).joinpath(model_name))
|
||||||
|
except:
|
||||||
|
from importlib import resources as impresources
|
||||||
|
try:
|
||||||
|
with impresources.path(package_path, model_name) as f:
|
||||||
|
model_file_path = f
|
||||||
|
except:
|
||||||
|
model_file_path = str(impresources.files(package_path).joinpath(model_name))
|
||||||
|
|
||||||
|
if onnx:
|
||||||
|
model = OnnxWrapper(model_file_path, force_onnx_cpu=True)
|
||||||
|
else:
|
||||||
|
model = init_jit_model(model_file_path)
|
||||||
|
|
||||||
|
return model
|
||||||
@@ -53,10 +53,10 @@ class OnnxWrapper():
|
|||||||
|
|
||||||
x, sr = self._validate_input(x, sr)
|
x, sr = self._validate_input(x, sr)
|
||||||
num_samples = 512 if sr == 16000 else 256
|
num_samples = 512 if sr == 16000 else 256
|
||||||
|
|
||||||
if x.shape[-1] != num_samples:
|
if x.shape[-1] != num_samples:
|
||||||
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
context_size = 64 if sr == 16000 else 32
|
context_size = 64 if sr == 16000 else 32
|
||||||
|
|
||||||
@@ -132,18 +132,19 @@ class Validator():
|
|||||||
|
|
||||||
def read_audio(path: str,
|
def read_audio(path: str,
|
||||||
sampling_rate: int = 16000):
|
sampling_rate: int = 16000):
|
||||||
|
list_backends = torchaudio.list_audio_backends()
|
||||||
|
|
||||||
sox_backends = set(['sox', 'sox_io'])
|
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
|
||||||
audio_backends = torchaudio.list_audio_backends()
|
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
|
||||||
|
|
||||||
if len(sox_backends.intersection(audio_backends)) > 0:
|
try:
|
||||||
effects = [
|
effects = [
|
||||||
['channels', '1'],
|
['channels', '1'],
|
||||||
['rate', str(sampling_rate)]
|
['rate', str(sampling_rate)]
|
||||||
]
|
]
|
||||||
|
|
||||||
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
|
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
|
||||||
else:
|
except:
|
||||||
wav, sr = torchaudio.load(path)
|
wav, sr = torchaudio.load(path)
|
||||||
|
|
||||||
if wav.size(0) > 1:
|
if wav.size(0) > 1:
|
||||||
@@ -194,6 +195,7 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
return_seconds: bool = False,
|
return_seconds: bool = False,
|
||||||
visualize_probs: bool = False,
|
visualize_probs: bool = False,
|
||||||
progress_tracking_callback: Callable[[float], None] = None,
|
progress_tracking_callback: Callable[[float], None] = None,
|
||||||
|
neg_threshold: float = None,
|
||||||
window_size_samples: int = 512,):
|
window_size_samples: int = 512,):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -236,6 +238,9 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
progress_tracking_callback: Callable[[float], None] (default - None)
|
progress_tracking_callback: Callable[[float], None] (default - None)
|
||||||
callback function taking progress in percents as an argument
|
callback function taking progress in percents as an argument
|
||||||
|
|
||||||
|
neg_threshold: float (default = threshold - 0.15)
|
||||||
|
Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH.
|
||||||
|
|
||||||
window_size_samples: int (default - 512 samples)
|
window_size_samples: int (default - 512 samples)
|
||||||
!!! DEPRECATED, DOES NOTHING !!!
|
!!! DEPRECATED, DOES NOTHING !!!
|
||||||
|
|
||||||
@@ -297,15 +302,17 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
triggered = False
|
triggered = False
|
||||||
speeches = []
|
speeches = []
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
neg_threshold = threshold - 0.15
|
|
||||||
temp_end = 0 # to save potential segment end (and tolerate some silence)
|
if neg_threshold is None:
|
||||||
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
|
neg_threshold = threshold - 0.15
|
||||||
|
temp_end = 0 # to save potential segment end (and tolerate some silence)
|
||||||
|
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
|
||||||
|
|
||||||
for i, speech_prob in enumerate(speech_probs):
|
for i, speech_prob in enumerate(speech_probs):
|
||||||
if (speech_prob >= threshold) and temp_end:
|
if (speech_prob >= threshold) and temp_end:
|
||||||
temp_end = 0
|
temp_end = 0
|
||||||
if next_start < prev_end:
|
if next_start < prev_end:
|
||||||
next_start = window_size_samples * i
|
next_start = window_size_samples * i
|
||||||
|
|
||||||
if (speech_prob >= threshold) and not triggered:
|
if (speech_prob >= threshold) and not triggered:
|
||||||
triggered = True
|
triggered = True
|
||||||
@@ -317,7 +324,7 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
current_speech['end'] = prev_end
|
current_speech['end'] = prev_end
|
||||||
speeches.append(current_speech)
|
speeches.append(current_speech)
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
|
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
|
||||||
triggered = False
|
triggered = False
|
||||||
else:
|
else:
|
||||||
current_speech['start'] = next_start
|
current_speech['start'] = next_start
|
||||||
@@ -333,7 +340,7 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
if (speech_prob < neg_threshold) and triggered:
|
if (speech_prob < neg_threshold) and triggered:
|
||||||
if not temp_end:
|
if not temp_end:
|
||||||
temp_end = window_size_samples * i
|
temp_end = window_size_samples * i
|
||||||
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence
|
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
|
||||||
prev_end = temp_end
|
prev_end = temp_end
|
||||||
if (window_size_samples * i) - temp_end < min_silence_samples:
|
if (window_size_samples * i) - temp_end < min_silence_samples:
|
||||||
continue
|
continue
|
||||||
@@ -453,7 +460,7 @@ class VADIterator:
|
|||||||
|
|
||||||
if (speech_prob >= self.threshold) and not self.triggered:
|
if (speech_prob >= self.threshold) and not self.triggered:
|
||||||
self.triggered = True
|
self.triggered = True
|
||||||
speech_start = self.current_sample - self.speech_pad_samples - window_size_samples
|
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
||||||
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
|
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
|
||||||
|
|
||||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||||
74
tuning/README.md
Normal file
74
tuning/README.md
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# Тюнинг Silero-VAD модели
|
||||||
|
|
||||||
|
> Код тюнинга создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный
|
||||||
|
интеллект» национальной программы «Цифровая экономика Российской Федерации».
|
||||||
|
|
||||||
|
Тюнинг используется для улучшения качества детекции речи Silero-VAD модели на кастомных данных.
|
||||||
|
|
||||||
|
## Зависимости
|
||||||
|
Следующие зависимости используются при тюнинге VAD модели:
|
||||||
|
- `torchaudio>=0.12.0`
|
||||||
|
- `omegaconf>=2.3.0`
|
||||||
|
- `sklearn>=1.2.0`
|
||||||
|
- `torch>=1.12.0`
|
||||||
|
- `pandas>=2.2.2`
|
||||||
|
- `tqdm`
|
||||||
|
|
||||||
|
## Подготовка данных
|
||||||
|
|
||||||
|
Датафреймы для тюнинга должны быть подготовлены и сохранены в формате `.feather`. Следующие колонки в `.feather` файлах тренировки и валидации являются обязательными:
|
||||||
|
- **audio_path** - абсолютный путь до аудиофайла в дисковой системе. Аудиофайлы должны представлять собой `PCM` данные, предпочтительно в форматах `.wav` или `.opus` (иные популярные форматы аудио тоже поддерживаются). Для ускорения темпа дообучения рекомендуется предварительно выполнить ресемплинг аудиофайлов (изменить частоту дискретизации) до 16000 Гц;
|
||||||
|
- **speech_ts** - разметка для соответствующего аудиофайла. Список, состоящий из словарей формата `{'start': START_SEC, 'end': 'END_SEC'}`, где `START_SEC` и `END_SEC` - время начало и конца речевого отрезка в секундах соответственно. Для качественного дообучения рекомендуется использовать разметку с точностью до 30 миллисекунд.
|
||||||
|
|
||||||
|
Чем больше данных используется на этапе дообучения, тем эффективнее показывает себя адаптированная модель на целевом домене. Длина аудио не ограничена, т.к. каждое аудио будет обрезано до `max_train_length_sec` секунд перед подачей в нейросеть. Длинные аудио лучше предварительно порезать на кусочки длины `max_train_length_sec`.
|
||||||
|
|
||||||
|
Пример `.feather` датафрейма можно посмотреть в файле `example_dataframe.feather`
|
||||||
|
|
||||||
|
## Файл конфигурации `config.yml`
|
||||||
|
|
||||||
|
Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения:
|
||||||
|
- `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
|
||||||
|
- `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
|
||||||
|
- `jit_model_path` - абсолютный путь до Silero-VAD модели в формате `.jit`. Если оставить это поле пустым, то модель будет загружена из репозитория в зависимости от значения поля `use_torchhub`
|
||||||
|
- `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`);
|
||||||
|
- `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц;
|
||||||
|
- `model_save_path` - путь сохранения добученной модели;
|
||||||
|
- `noise_loss` - коэффициент лосса, применяемый для неречевых окон аудио;
|
||||||
|
- `max_train_length_sec` - максимальная длина аудио в секундах на этапе дообучения. Более длительные аудио будут обрезаны до этого показателя;
|
||||||
|
- `aug_prob` - вероятность применения аугментаций к аудиофайлу на этапе дообучения;
|
||||||
|
- `learning_rate` - темп дообучения;
|
||||||
|
- `batch_size` - размер батча при дообучении и валидации;
|
||||||
|
- `num_workers` - количество потоков, используемых для загрузки данных;
|
||||||
|
- `num_epochs` - количество эпох дообучения. За одну эпоху прогоняются все тренировочные данные;
|
||||||
|
- `device` - `cpu` или `cuda`.
|
||||||
|
|
||||||
|
## Дообучение
|
||||||
|
|
||||||
|
Дообучение запускается командой
|
||||||
|
|
||||||
|
`python tune.py`
|
||||||
|
|
||||||
|
Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit.
|
||||||
|
|
||||||
|
## Поиск пороговых значений
|
||||||
|
|
||||||
|
Порог на вход и порог на выход можно подобрать, используя команду
|
||||||
|
|
||||||
|
`python search_thresholds`
|
||||||
|
|
||||||
|
Данный скрипт использует файл конфигурации, описанный выше. Указанная в конфигурации модель будет использована для поиска оптимальных порогов на валидационном датасете.
|
||||||
|
|
||||||
|
## Цитирование
|
||||||
|
|
||||||
|
```
|
||||||
|
@misc{Silero VAD,
|
||||||
|
author = {Silero Team},
|
||||||
|
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
|
||||||
|
year = {2024},
|
||||||
|
publisher = {GitHub},
|
||||||
|
journal = {GitHub repository},
|
||||||
|
howpublished = {\url{https://github.com/snakers4/silero-vad}},
|
||||||
|
commit = {insert_some_commit_here},
|
||||||
|
email = {hello@silero.ai}
|
||||||
|
}
|
||||||
|
```
|
||||||
0
tuning/__init__.py
Normal file
0
tuning/__init__.py
Normal file
17
tuning/config.yml
Normal file
17
tuning/config.yml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
jit_model_path: '' # путь до Silero-VAD модели в формате jit, эта модель будет использована для дообучения. Если оставить поле пустым, то модель будет загружена автоматически
|
||||||
|
use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False
|
||||||
|
|
||||||
|
tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True
|
||||||
|
train_dataset_path: 'train_dataset_path.feather' # путь до датасета в формате feather для дообучения, подробности в README
|
||||||
|
val_dataset_path: 'val_dataset_path.feather' # путь до датасета в формате feather для валидации, подробности в README
|
||||||
|
model_save_path: 'model_save_path.jit' # путь сохранения дообученной модели
|
||||||
|
|
||||||
|
noise_loss: 0.5 # коэффициент, применяемый к лоссу на неречевых окнах
|
||||||
|
max_train_length_sec: 8 # во время тюнинга аудио длиннее будут обрезаны до данного значения
|
||||||
|
aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения
|
||||||
|
|
||||||
|
learning_rate: 5e-4 # темп дообучения модели
|
||||||
|
batch_size: 128 # размер батча при дообучении и валидации
|
||||||
|
num_workers: 4 # количество потоков, используемых для даталоадеров
|
||||||
|
num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных
|
||||||
|
device: 'cuda' # cpu или cuda, на чем будет производится дообучение
|
||||||
BIN
tuning/example_dataframe.feather
Normal file
BIN
tuning/example_dataframe.feather
Normal file
Binary file not shown.
36
tuning/search_thresholds.py
Normal file
36
tuning/search_thresholds.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
import torch
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = OmegaConf.load('config.yml')
|
||||||
|
|
||||||
|
loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'),
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
collate_fn=SileroVadPadder,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
if config.jit_model_path:
|
||||||
|
print(f'Loading model from the local folder: {config.jit_model_path}')
|
||||||
|
model = init_jit_model(config.jit_model_path, device=config.device)
|
||||||
|
else:
|
||||||
|
if config.use_torchhub:
|
||||||
|
print('Loading model using torch.hub')
|
||||||
|
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||||
|
model='silero_vad',
|
||||||
|
onnx=False,
|
||||||
|
force_reload=True)
|
||||||
|
else:
|
||||||
|
print('Loading model using silero-vad library')
|
||||||
|
from silero_vad import load_silero_vad
|
||||||
|
model = load_silero_vad(onnx=False)
|
||||||
|
|
||||||
|
print('Model loaded')
|
||||||
|
model.to(config.device)
|
||||||
|
|
||||||
|
print('Making predicts...')
|
||||||
|
all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000)
|
||||||
|
print('Calculating thresholds...')
|
||||||
|
best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts)
|
||||||
|
print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')
|
||||||
65
tuning/tune.py
Normal file
65
tuning/tune.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = OmegaConf.load('config.yml')
|
||||||
|
|
||||||
|
train_dataset = SileroVadDataset(config, mode='train')
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
collate_fn=SileroVadPadder,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
val_dataset = SileroVadDataset(config, mode='val')
|
||||||
|
val_loader = torch.utils.data.DataLoader(val_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
collate_fn=SileroVadPadder,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
if config.jit_model_path:
|
||||||
|
print(f'Loading model from the local folder: {config.jit_model_path}')
|
||||||
|
model = init_jit_model(config.jit_model_path, device=config.device)
|
||||||
|
else:
|
||||||
|
if config.use_torchhub:
|
||||||
|
print('Loading model using torch.hub')
|
||||||
|
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||||
|
model='silero_vad',
|
||||||
|
onnx=False,
|
||||||
|
force_reload=True)
|
||||||
|
else:
|
||||||
|
print('Loading model using silero-vad library')
|
||||||
|
from silero_vad import load_silero_vad
|
||||||
|
model = load_silero_vad(onnx=False)
|
||||||
|
|
||||||
|
print('Model loaded')
|
||||||
|
model.to(config.device)
|
||||||
|
decoder = VADDecoderRNNJIT().to(config.device)
|
||||||
|
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())
|
||||||
|
decoder.train()
|
||||||
|
params = decoder.parameters()
|
||||||
|
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params),
|
||||||
|
lr=config.learning_rate)
|
||||||
|
criterion = nn.BCELoss(reduction='none')
|
||||||
|
|
||||||
|
best_val_roc = 0
|
||||||
|
for i in range(config.num_epochs):
|
||||||
|
print(f'Starting epoch {i + 1}')
|
||||||
|
train_loss = train(config, train_loader, model, decoder, criterion, optimizer, config.device)
|
||||||
|
val_loss, val_roc = validate(config, val_loader, model, decoder, criterion, config.device)
|
||||||
|
print(f'Metrics after epoch {i + 1}:\n'
|
||||||
|
f'\tTrain loss: {round(train_loss, 3)}\n',
|
||||||
|
f'\tValidation loss: {round(val_loss, 3)}\n'
|
||||||
|
f'\tValidation ROC-AUC: {round(val_roc, 3)}')
|
||||||
|
|
||||||
|
if val_roc > best_val_roc:
|
||||||
|
print('New best ROC-AUC, saving model')
|
||||||
|
best_val_roc = val_roc
|
||||||
|
if config.tune_8k:
|
||||||
|
model._model_8k.decoder.load_state_dict(decoder.state_dict())
|
||||||
|
else:
|
||||||
|
model._model.decoder.load_state_dict(decoder.state_dict())
|
||||||
|
torch.jit.save(model, config.model_save_path)
|
||||||
|
print('Done')
|
||||||
357
tuning/utils.py
Normal file
357
tuning/utils.py
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
from sklearn.metrics import roc_auc_score, accuracy_score
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import torch.nn as nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import torchaudio
|
||||||
|
import warnings
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import gc
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def read_audio(path: str,
|
||||||
|
sampling_rate: int = 16000,
|
||||||
|
normalize=False):
|
||||||
|
|
||||||
|
wav, sr = torchaudio.load(path)
|
||||||
|
|
||||||
|
if wav.size(0) > 1:
|
||||||
|
wav = wav.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
if sampling_rate:
|
||||||
|
if sr != sampling_rate:
|
||||||
|
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
||||||
|
new_freq=sampling_rate)
|
||||||
|
wav = transform(wav)
|
||||||
|
sr = sampling_rate
|
||||||
|
|
||||||
|
if normalize and wav.abs().max() != 0:
|
||||||
|
wav = wav / wav.abs().max()
|
||||||
|
|
||||||
|
return wav.squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def build_audiomentations_augs(p):
|
||||||
|
from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \
|
||||||
|
LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \
|
||||||
|
Aliasing, AddGaussianNoise
|
||||||
|
transforms = [Aliasing(p=1),
|
||||||
|
AddGaussianNoise(p=1),
|
||||||
|
AirAbsorption(p=1),
|
||||||
|
BandPassFilter(p=1),
|
||||||
|
BandStopFilter(p=1),
|
||||||
|
ClippingDistortion(p=1),
|
||||||
|
HighPassFilter(p=1),
|
||||||
|
HighShelfFilter(p=1),
|
||||||
|
LowPassFilter(p=1),
|
||||||
|
LowShelfFilter(p=1),
|
||||||
|
Mp3Compression(p=1),
|
||||||
|
PeakingFilter(p=1),
|
||||||
|
PitchShift(p=1),
|
||||||
|
RoomSimulator(p=1, leave_length_unchanged=True),
|
||||||
|
SevenBandParametricEQ(p=1)]
|
||||||
|
tr = SomeOf((1, 3), transforms=transforms, p=p)
|
||||||
|
return tr
|
||||||
|
|
||||||
|
|
||||||
|
class SileroVadDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
config,
|
||||||
|
mode='train'):
|
||||||
|
|
||||||
|
self.num_samples = 512 # constant, do not change
|
||||||
|
self.sr = 16000 # constant, do not change
|
||||||
|
|
||||||
|
self.resample_to_8k = config.tune_8k
|
||||||
|
self.noise_loss = config.noise_loss
|
||||||
|
self.max_train_length_sec = config.max_train_length_sec
|
||||||
|
self.max_train_length_samples = config.max_train_length_sec * self.sr
|
||||||
|
|
||||||
|
assert self.max_train_length_samples % self.num_samples == 0
|
||||||
|
assert mode in ['train', 'val']
|
||||||
|
|
||||||
|
dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path
|
||||||
|
self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True)
|
||||||
|
self.index_dict = self.dataframe.to_dict('index')
|
||||||
|
self.mode = mode
|
||||||
|
print(f'DATASET SIZE : {len(self.dataframe)}')
|
||||||
|
|
||||||
|
if mode == 'train':
|
||||||
|
self.augs = build_audiomentations_augs(p=config.aug_prob)
|
||||||
|
else:
|
||||||
|
self.augs = None
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
idx = None if self.mode == 'train' else idx
|
||||||
|
wav, gt, mask = self.load_speech_sample(idx)
|
||||||
|
|
||||||
|
if self.mode == 'train':
|
||||||
|
wav = self.add_augs(wav)
|
||||||
|
if len(wav) > self.max_train_length_samples:
|
||||||
|
wav = wav[:self.max_train_length_samples]
|
||||||
|
gt = gt[:int(self.max_train_length_samples / self.num_samples)]
|
||||||
|
mask = mask[:int(self.max_train_length_samples / self.num_samples)]
|
||||||
|
|
||||||
|
wav = torch.FloatTensor(wav)
|
||||||
|
if self.resample_to_8k:
|
||||||
|
transform = torchaudio.transforms.Resample(orig_freq=self.sr,
|
||||||
|
new_freq=8000)
|
||||||
|
wav = transform(wav)
|
||||||
|
return wav, torch.FloatTensor(gt), torch.from_numpy(mask)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.index_dict)
|
||||||
|
|
||||||
|
def load_speech_sample(self, idx=None):
|
||||||
|
if idx is None:
|
||||||
|
idx = random.randint(0, len(self.index_dict) - 1)
|
||||||
|
wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy()
|
||||||
|
|
||||||
|
if len(wav) % self.num_samples != 0:
|
||||||
|
pad_num = self.num_samples - (len(wav) % (self.num_samples))
|
||||||
|
wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0)
|
||||||
|
|
||||||
|
gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav))
|
||||||
|
|
||||||
|
assert len(gt) == len(wav) / self.num_samples
|
||||||
|
|
||||||
|
mask[gt == 0]
|
||||||
|
|
||||||
|
return wav, gt, mask
|
||||||
|
|
||||||
|
def get_ground_truth_annotated(self, annotation, audio_length_samples):
|
||||||
|
gt = np.zeros(audio_length_samples)
|
||||||
|
|
||||||
|
for i in annotation:
|
||||||
|
gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1
|
||||||
|
|
||||||
|
squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1)
|
||||||
|
squeezed_predicts = (squeezed_predicts > 0.5).astype(int)
|
||||||
|
mask = np.ones(len(squeezed_predicts))
|
||||||
|
mask[squeezed_predicts == 0] = self.noise_loss
|
||||||
|
return squeezed_predicts, mask
|
||||||
|
|
||||||
|
def add_augs(self, wav):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
wav_aug = self.augs(wav, self.sr)
|
||||||
|
if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()):
|
||||||
|
return wav
|
||||||
|
return wav_aug
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def SileroVadPadder(batch):
|
||||||
|
wavs = [batch[i][0] for i in range(len(batch))]
|
||||||
|
labels = [batch[i][1] for i in range(len(batch))]
|
||||||
|
masks = [batch[i][2] for i in range(len(batch))]
|
||||||
|
|
||||||
|
wavs = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
wavs, batch_first=True, padding_value=0)
|
||||||
|
|
||||||
|
labels = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
labels, batch_first=True, padding_value=0)
|
||||||
|
|
||||||
|
masks = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
masks, batch_first=True, padding_value=0)
|
||||||
|
|
||||||
|
return wavs, labels, masks
|
||||||
|
|
||||||
|
|
||||||
|
class VADDecoderRNNJIT(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(VADDecoderRNNJIT, self).__init__()
|
||||||
|
|
||||||
|
self.rnn = nn.LSTMCell(128, 128)
|
||||||
|
self.decoder = nn.Sequential(nn.Dropout(0.1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1d(128, 1, kernel_size=1),
|
||||||
|
nn.Sigmoid())
|
||||||
|
|
||||||
|
def forward(self, x, state=torch.zeros(0)):
|
||||||
|
x = x.squeeze(-1)
|
||||||
|
if len(state):
|
||||||
|
h, c = self.rnn(x, (state[0], state[1]))
|
||||||
|
else:
|
||||||
|
h, c = self.rnn(x)
|
||||||
|
|
||||||
|
x = h.unsqueeze(-1).float()
|
||||||
|
state = torch.stack([h, c])
|
||||||
|
x = self.decoder(x)
|
||||||
|
return x, state
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
def train(config,
|
||||||
|
loader,
|
||||||
|
jit_model,
|
||||||
|
decoder,
|
||||||
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
device):
|
||||||
|
|
||||||
|
losses = AverageMeter()
|
||||||
|
decoder.train()
|
||||||
|
|
||||||
|
context_size = 32 if config.tune_8k else 64
|
||||||
|
num_samples = 256 if config.tune_8k else 512
|
||||||
|
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
|
||||||
|
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||||
|
targets = targets.to(device)
|
||||||
|
x = x.to(device)
|
||||||
|
masks = masks.to(device)
|
||||||
|
x = torch.nn.functional.pad(x, (context_size, 0))
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
state = torch.zeros(0)
|
||||||
|
for i in range(context_size, x.shape[1], num_samples):
|
||||||
|
input_ = x[:, i-context_size:i+num_samples]
|
||||||
|
out = stft_layer(input_)
|
||||||
|
out = encoder_layer(out)
|
||||||
|
out, state = decoder(out, state)
|
||||||
|
outs.append(out)
|
||||||
|
stacked = torch.cat(outs, dim=2).squeeze(1)
|
||||||
|
|
||||||
|
loss = criterion(stacked, targets)
|
||||||
|
loss = (loss * masks).mean()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
losses.update(loss.item(), masks.numel())
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return losses.avg
|
||||||
|
|
||||||
|
|
||||||
|
def validate(config,
|
||||||
|
loader,
|
||||||
|
jit_model,
|
||||||
|
decoder,
|
||||||
|
criterion,
|
||||||
|
device):
|
||||||
|
|
||||||
|
losses = AverageMeter()
|
||||||
|
decoder.eval()
|
||||||
|
|
||||||
|
predicts = []
|
||||||
|
gts = []
|
||||||
|
|
||||||
|
context_size = 32 if config.tune_8k else 64
|
||||||
|
num_samples = 256 if config.tune_8k else 512
|
||||||
|
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
|
||||||
|
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||||
|
targets = targets.to(device)
|
||||||
|
x = x.to(device)
|
||||||
|
masks = masks.to(device)
|
||||||
|
x = torch.nn.functional.pad(x, (context_size, 0))
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
state = torch.zeros(0)
|
||||||
|
for i in range(context_size, x.shape[1], num_samples):
|
||||||
|
input_ = x[:, i-context_size:i+num_samples]
|
||||||
|
out = stft_layer(input_)
|
||||||
|
out = encoder_layer(out)
|
||||||
|
out, state = decoder(out, state)
|
||||||
|
outs.append(out)
|
||||||
|
stacked = torch.cat(outs, dim=2).squeeze(1)
|
||||||
|
|
||||||
|
predicts.extend(stacked[masks != 0].tolist())
|
||||||
|
gts.extend(targets[masks != 0].tolist())
|
||||||
|
|
||||||
|
loss = criterion(stacked, targets)
|
||||||
|
loss = (loss * masks).mean()
|
||||||
|
losses.update(loss.item(), masks.numel())
|
||||||
|
score = roc_auc_score(gts, predicts)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return losses.avg, round(score, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def init_jit_model(model_path: str,
|
||||||
|
device=torch.device('cpu')):
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
model = torch.jit.load(model_path, map_location=device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def predict(model, loader, device, sr):
|
||||||
|
with torch.no_grad():
|
||||||
|
all_predicts = []
|
||||||
|
all_gts = []
|
||||||
|
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||||
|
x = x.to(device)
|
||||||
|
out = model.audio_forward(x, sr=sr)
|
||||||
|
|
||||||
|
for i, out_chunk in enumerate(out):
|
||||||
|
predict = out_chunk[masks[i] != 0].cpu().tolist()
|
||||||
|
gt = targets[i, masks[i] != 0].cpu().tolist()
|
||||||
|
|
||||||
|
all_predicts.append(predict)
|
||||||
|
all_gts.append(gt)
|
||||||
|
return all_predicts, all_gts
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_best_thresholds(all_predicts, all_gts):
|
||||||
|
best_acc = 0
|
||||||
|
for ths_enter in tqdm(np.linspace(0, 1, 20)):
|
||||||
|
for ths_exit in np.linspace(0, 1, 20):
|
||||||
|
if ths_exit >= ths_enter:
|
||||||
|
continue
|
||||||
|
|
||||||
|
accs = []
|
||||||
|
for j, predict in enumerate(all_predicts):
|
||||||
|
predict_bool = []
|
||||||
|
is_speech = False
|
||||||
|
for i in predict:
|
||||||
|
if i >= ths_enter:
|
||||||
|
is_speech = True
|
||||||
|
predict_bool.append(1)
|
||||||
|
elif i <= ths_exit:
|
||||||
|
is_speech = False
|
||||||
|
predict_bool.append(0)
|
||||||
|
else:
|
||||||
|
val = 1 if is_speech else 0
|
||||||
|
predict_bool.append(val)
|
||||||
|
|
||||||
|
score = round(accuracy_score(all_gts[j], predict_bool), 4)
|
||||||
|
accs.append(score)
|
||||||
|
|
||||||
|
mean_acc = round(np.mean(accs), 3)
|
||||||
|
if mean_acc > best_acc:
|
||||||
|
best_acc = mean_acc
|
||||||
|
best_ths_enter = round(ths_enter, 2)
|
||||||
|
best_ths_exit = round(ths_exit, 2)
|
||||||
|
return best_ths_enter, best_ths_exit, best_acc
|
||||||
Reference in New Issue
Block a user