mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
25
README.md
25
README.md
@@ -13,7 +13,7 @@
|
|||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://user-images.githubusercontent.com/12515440/228639780-876f7801-8ec5-4daf-89f3-b45b22dd1a73.png" />
|
<img src="https://github.com/snakers4/silero-vad/assets/36505480/300bd062-4da5-4f19-9736-9c144a45d7a7" />
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|
||||||
@@ -38,20 +38,16 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
|
|||||||
|
|
||||||
- **Lightweight**
|
- **Lightweight**
|
||||||
|
|
||||||
JIT model is around one megabyte in size.
|
JIT model is around two megabytes in size.
|
||||||
|
|
||||||
- **General**
|
- **General**
|
||||||
|
|
||||||
Silero VAD was trained on huge corpora that include over **100** languages and it performs well on audios from different domains with various background noise and quality levels.
|
Silero VAD was trained on huge corpora that include over **6000** languages and it performs well on audios from different domains with various background noise and quality levels.
|
||||||
|
|
||||||
- **Flexible sampling rate**
|
- **Flexible sampling rate**
|
||||||
|
|
||||||
Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate).
|
Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate).
|
||||||
|
|
||||||
- **Flexible chunk size**
|
|
||||||
|
|
||||||
Model was trained on **30 ms**. Longer chunks are supported directly, others may work as well.
|
|
||||||
|
|
||||||
- **Highly Portable**
|
- **Highly Portable**
|
||||||
|
|
||||||
Silero VAD reaps benefits from the rich ecosystems built around **PyTorch** and **ONNX** running everywhere where these runtimes are available.
|
Silero VAD reaps benefits from the rich ecosystems built around **PyTorch** and **ONNX** running everywhere where these runtimes are available.
|
||||||
@@ -60,6 +56,21 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
|
|||||||
|
|
||||||
Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
|
Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
|
||||||
|
|
||||||
|
<br/>
|
||||||
|
<h2 align="center">Fast start</h2>
|
||||||
|
<br/>
|
||||||
|
|
||||||
|
```python3
|
||||||
|
import torch
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
||||||
|
(get_speech_timestamps, _, read_audio, _, _) = utils
|
||||||
|
|
||||||
|
wav = read_audio('path_to_audio_file')
|
||||||
|
speech_timestamps = get_speech_timestamps(wav, model)
|
||||||
|
```
|
||||||
|
|
||||||
<br/>
|
<br/>
|
||||||
<h2 align="center">Typical Use Cases</h2>
|
<h2 align="center">Typical Use Cases</h2>
|
||||||
<br/>
|
<br/>
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -46,7 +46,7 @@
|
|||||||
"USE_ONNX = False # change this to True if you want to test onnx model\n",
|
"USE_ONNX = False # change this to True if you want to test onnx model\n",
|
||||||
"if USE_ONNX:\n",
|
"if USE_ONNX:\n",
|
||||||
" !pip install -q onnxruntime\n",
|
" !pip install -q onnxruntime\n",
|
||||||
" \n",
|
"\n",
|
||||||
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||||
" model='silero_vad',\n",
|
" model='silero_vad',\n",
|
||||||
" force_reload=True,\n",
|
" force_reload=True,\n",
|
||||||
@@ -65,16 +65,7 @@
|
|||||||
"id": "fXbbaUO3jsrw"
|
"id": "fXbbaUO3jsrw"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"## Full Audio"
|
"## Speech timestapms from full audio"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "RAfJPb_a-Auj"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"**Speech timestapms from full audio**"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -101,10 +92,33 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# merge all speech chunks to one audio\n",
|
"# merge all speech chunks to one audio\n",
|
||||||
"save_audio('only_speech.wav',\n",
|
"save_audio('only_speech.wav',\n",
|
||||||
" collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE) \n",
|
" collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE)\n",
|
||||||
"Audio('only_speech.wav')"
|
"Audio('only_speech.wav')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "zeO1xCqxUC6w"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Entire audio inference"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "LjZBcsaTT7Mk"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||||
|
"# audio is being splitted into 31.25 ms long pieces\n",
|
||||||
|
"# so output length equals ceil(input_length * 31.25 / SAMPLING_RATE)\n",
|
||||||
|
"predicts = model.audio_forward(wav, sr=SAMPLING_RATE)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@@ -124,10 +138,10 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## using VADIterator class\n",
|
"## using VADIterator class\n",
|
||||||
"\n",
|
"\n",
|
||||||
"vad_iterator = VADIterator(model)\n",
|
"vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE)\n",
|
||||||
"wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
"wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"window_size_samples = 1536 # number of samples in a single audio chunk\n",
|
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
|
||||||
"for i in range(0, len(wav), window_size_samples):\n",
|
"for i in range(0, len(wav), window_size_samples):\n",
|
||||||
" chunk = wav[i: i+ window_size_samples]\n",
|
" chunk = wav[i: i+ window_size_samples]\n",
|
||||||
" if len(chunk) < window_size_samples:\n",
|
" if len(chunk) < window_size_samples:\n",
|
||||||
@@ -150,7 +164,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
|
||||||
"speech_probs = []\n",
|
"speech_probs = []\n",
|
||||||
"window_size_samples = 1536\n",
|
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
|
||||||
"for i in range(0, len(wav), window_size_samples):\n",
|
"for i in range(0, len(wav), window_size_samples):\n",
|
||||||
" chunk = wav[i: i+ window_size_samples]\n",
|
" chunk = wav[i: i+ window_size_samples]\n",
|
||||||
" if len(chunk) < window_size_samples:\n",
|
" if len(chunk) < window_size_samples:\n",
|
||||||
|
|||||||
55
utils_vad.py
55
utils_vad.py
@@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
import torch.nn.functional as F
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
languages = ['ru', 'en', 'de', 'es']
|
languages = ['ru', 'en', 'de', 'es']
|
||||||
@@ -39,22 +38,27 @@ class OnnxWrapper():
|
|||||||
|
|
||||||
if sr not in self.sample_rates:
|
if sr not in self.sample_rates:
|
||||||
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
||||||
|
|
||||||
if sr / x.shape[1] > 31.25:
|
if sr / x.shape[1] > 31.25:
|
||||||
raise ValueError("Input audio chunk is too short")
|
raise ValueError("Input audio chunk is too short")
|
||||||
|
|
||||||
return x, sr
|
return x, sr
|
||||||
|
|
||||||
def reset_states(self, batch_size=1):
|
def reset_states(self, batch_size=1):
|
||||||
self._h = np.zeros((2, batch_size, 64)).astype('float32')
|
self._state = torch.zeros((2, batch_size, 128)).float()
|
||||||
self._c = np.zeros((2, batch_size, 64)).astype('float32')
|
self._context = torch.zeros(0)
|
||||||
self._last_sr = 0
|
self._last_sr = 0
|
||||||
self._last_batch_size = 0
|
self._last_batch_size = 0
|
||||||
|
|
||||||
def __call__(self, x, sr: int):
|
def __call__(self, x, sr: int):
|
||||||
|
|
||||||
x, sr = self._validate_input(x, sr)
|
x, sr = self._validate_input(x, sr)
|
||||||
|
num_samples = 512 if sr == 16000 else 256
|
||||||
|
|
||||||
|
if x.shape[-1] != num_samples:
|
||||||
|
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
|
context_size = 64 if sr == 16000 else 32
|
||||||
|
|
||||||
if not self._last_batch_size:
|
if not self._last_batch_size:
|
||||||
self.reset_states(batch_size)
|
self.reset_states(batch_size)
|
||||||
@@ -63,28 +67,35 @@ class OnnxWrapper():
|
|||||||
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
||||||
self.reset_states(batch_size)
|
self.reset_states(batch_size)
|
||||||
|
|
||||||
|
if not len(self._context):
|
||||||
|
self._context = torch.zeros(batch_size, context_size)
|
||||||
|
|
||||||
|
x = torch.cat([self._context, x], dim=1)
|
||||||
if sr in [8000, 16000]:
|
if sr in [8000, 16000]:
|
||||||
ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')}
|
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr)}
|
||||||
ort_outs = self.session.run(None, ort_inputs)
|
ort_outs = self.session.run(None, ort_inputs)
|
||||||
out, self._h, self._c = ort_outs
|
out, state = ort_outs
|
||||||
|
self._state = torch.from_numpy(state)
|
||||||
else:
|
else:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
|
self._context = x[..., -context_size:]
|
||||||
self._last_sr = sr
|
self._last_sr = sr
|
||||||
self._last_batch_size = batch_size
|
self._last_batch_size = batch_size
|
||||||
|
|
||||||
out = torch.tensor(out)
|
out = torch.from_numpy(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def audio_forward(self, x, sr: int, num_samples: int = 512):
|
def audio_forward(self, x, sr: int):
|
||||||
outs = []
|
outs = []
|
||||||
x, sr = self._validate_input(x, sr)
|
x, sr = self._validate_input(x, sr)
|
||||||
|
self.reset_states()
|
||||||
|
num_samples = 512 if sr == 16000 else 256
|
||||||
|
|
||||||
if x.shape[1] % num_samples:
|
if x.shape[1] % num_samples:
|
||||||
pad_num = num_samples - (x.shape[1] % num_samples)
|
pad_num = num_samples - (x.shape[1] % num_samples)
|
||||||
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
|
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
|
||||||
|
|
||||||
self.reset_states(x.shape[0])
|
|
||||||
for i in range(0, x.shape[1], num_samples):
|
for i in range(0, x.shape[1], num_samples):
|
||||||
wavs_batch = x[:, i:i+num_samples]
|
wavs_batch = x[:, i:i+num_samples]
|
||||||
out_chunk = self.__call__(wavs_batch, sr)
|
out_chunk = self.__call__(wavs_batch, sr)
|
||||||
@@ -179,11 +190,11 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
min_speech_duration_ms: int = 250,
|
min_speech_duration_ms: int = 250,
|
||||||
max_speech_duration_s: float = float('inf'),
|
max_speech_duration_s: float = float('inf'),
|
||||||
min_silence_duration_ms: int = 100,
|
min_silence_duration_ms: int = 100,
|
||||||
window_size_samples: int = 512,
|
|
||||||
speech_pad_ms: int = 30,
|
speech_pad_ms: int = 30,
|
||||||
return_seconds: bool = False,
|
return_seconds: bool = False,
|
||||||
visualize_probs: bool = False,
|
visualize_probs: bool = False,
|
||||||
progress_tracking_callback: Callable[[float], None] = None):
|
progress_tracking_callback: Callable[[float], None] = None,
|
||||||
|
window_size_samples: int = 512,):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This method is used for splitting long audios into speech chunks using silero VAD
|
This method is used for splitting long audios into speech chunks using silero VAD
|
||||||
@@ -193,14 +204,14 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
audio: torch.Tensor, one dimensional
|
audio: torch.Tensor, one dimensional
|
||||||
One dimensional float torch.Tensor, other types are casted to torch if possible
|
One dimensional float torch.Tensor, other types are casted to torch if possible
|
||||||
|
|
||||||
model: preloaded .jit silero VAD model
|
model: preloaded .jit/.onnx silero VAD model
|
||||||
|
|
||||||
threshold: float (default - 0.5)
|
threshold: float (default - 0.5)
|
||||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||||
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||||
|
|
||||||
sampling_rate: int (default - 16000)
|
sampling_rate: int (default - 16000)
|
||||||
Currently silero VAD models support 8000 and 16000 sample rates
|
Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates
|
||||||
|
|
||||||
min_speech_duration_ms: int (default - 250 milliseconds)
|
min_speech_duration_ms: int (default - 250 milliseconds)
|
||||||
Final speech chunks shorter min_speech_duration_ms are thrown out
|
Final speech chunks shorter min_speech_duration_ms are thrown out
|
||||||
@@ -213,11 +224,6 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
min_silence_duration_ms: int (default - 100 milliseconds)
|
min_silence_duration_ms: int (default - 100 milliseconds)
|
||||||
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
||||||
|
|
||||||
window_size_samples: int (default - 1536 samples)
|
|
||||||
Audio chunks of window_size_samples size are fed to the silero VAD model.
|
|
||||||
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples for 8000 sample rate.
|
|
||||||
Values other than these may affect model perfomance!!
|
|
||||||
|
|
||||||
speech_pad_ms: int (default - 30 milliseconds)
|
speech_pad_ms: int (default - 30 milliseconds)
|
||||||
Final speech chunks are padded by speech_pad_ms each side
|
Final speech chunks are padded by speech_pad_ms each side
|
||||||
|
|
||||||
@@ -230,6 +236,9 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
progress_tracking_callback: Callable[[float], None] (default - None)
|
progress_tracking_callback: Callable[[float], None] (default - None)
|
||||||
callback function taking progress in percents as an argument
|
callback function taking progress in percents as an argument
|
||||||
|
|
||||||
|
window_size_samples: int (default - 512 samples)
|
||||||
|
!!! DEPRECATED, DOES NOTHING !!!
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
speeches: list of dicts
|
speeches: list of dicts
|
||||||
@@ -256,10 +265,10 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
else:
|
else:
|
||||||
step = 1
|
step = 1
|
||||||
|
|
||||||
if sampling_rate == 8000 and window_size_samples > 768:
|
if sampling_rate not in [8000, 16000]:
|
||||||
warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!')
|
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
|
||||||
if window_size_samples not in [256, 512, 768, 1024, 1536]:
|
|
||||||
warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate')
|
window_size_samples = 512 if sampling_rate == 16000 else 256
|
||||||
|
|
||||||
model.reset_states()
|
model.reset_states()
|
||||||
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
||||||
@@ -450,7 +459,7 @@ class VADIterator:
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
model: preloaded .jit silero VAD model
|
model: preloaded .jit/.onnx silero VAD model
|
||||||
|
|
||||||
threshold: float (default - 0.5)
|
threshold: float (default - 0.5)
|
||||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||||
|
|||||||
Reference in New Issue
Block a user