41 Commits
v3.1 ... v4.0

Author SHA1 Message Date
adamnsandle
915dd3d639 v4.0stable force_onnx_cpu fx 2024-07-08 10:16:52 +00:00
adamnsandle
ac128b3c55 v4.0stable fx 2024-07-01 09:53:25 +00:00
Dimitrii Voronin
82d199ff22 Merge pull request #256 from snakers4/adamnsandle
Adamnsandle
2022-10-28 13:57:10 +03:00
adamnsandle
5ba388d894 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-10-28 10:55:59 +00:00
adamnsandle
790844ba0f revert to exception 2022-10-28 10:55:46 +00:00
Dimitrii Voronin
51b5245410 Merge pull request #255 from snakers4/adamnsandle
Adamnsandle
2022-10-28 13:33:18 +03:00
adamnsandle
888970e77d Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-10-28 10:32:08 +00:00
adamnsandle
cb6d308335 fx 2022-10-28 10:31:55 +00:00
adamnsandle
1b212c6e95 change exception to warning 2022-10-28 10:26:07 +00:00
adamnsandle
452060ad65 fx 2022-10-28 10:13:00 +00:00
Dimitrii Voronin
c7eab751b5 Merge pull request #253 from snakers4/adamnsandle
add torch version check
2022-10-28 13:09:18 +03:00
adamnsandle
d1714a9ff7 add torch version check 2022-10-28 10:08:07 +00:00
Dimitrii Voronin
94c79d899d Merge pull request #251 from snakers4/adamnsandle
v4 hotfix
2022-10-27 20:26:31 +03:00
adamnsandle
1baf307b35 v4 hotfix 2022-10-27 17:25:31 +00:00
Dimitrii Voronin
e324285cdc Merge pull request #247 from snakers4/adamnsandle
Adamnsandle
2022-10-26 19:17:44 +03:00
adamnsandle
13dce2d067 Merge branch 'MASTER' into adamnsandle 2022-10-26 16:13:37 +00:00
adamnsandle
081e6b9886 VAD v4 2022-10-26 16:10:20 +00:00
Alexander Veysov
572134fdf1 Update README.md 2022-10-25 05:52:53 +03:00
Dimitrii Voronin
a799dea837 Merge pull request #244 from owlsometech-kenyang/feature/support-force-onnx-cpu
Suggesting a new kwarg: force_onnx_cpu
2022-10-14 11:53:46 +03:00
ChiehKai Yang
17209e6c4f add new parameter: force_onnx_cpu 2022-10-12 01:56:43 +08:00
adamnsandle
6661cc9691 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-06-02 10:41:54 +00:00
Dimitrii Voronin
7c671a75c2 Merge pull request #199 from snakers4/adamnsandle
fx end of chunk may exceed audio length
2022-06-02 13:40:42 +03:00
adamnsandle
622016e672 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2022-06-02 10:40:11 +00:00
adamnsandle
8eba346bc9 fx end of chunk may exceed audio length 2022-06-02 10:39:16 +00:00
Dimitrii Voronin
900c71a109 Merge pull request #198 from snakers4/adamnsandle
fx get_speech ts start of an audio chunk pad
2022-06-02 13:33:36 +03:00
adamnsandle
bf0127e016 fx get_speech ts start of an audio chunk pad 2022-06-02 10:32:32 +00:00
Dimitrii Voronin
ea7af70fe9 Merge pull request #182 from snakers4/adamnsandle
Adamnsandle
2022-04-05 14:36:00 +03:00
adamnsandle
8cdc8d36c9 fx 2022-04-05 11:35:23 +00:00
adamnsandle
6e9fd77500 fx stram imitation example bug 2022-04-05 11:33:34 +00:00
Alexander Veysov
6cc08b1077 Merge pull request #170 from gabrielziegler3/169-fix-min-speech-duration-bug
Fix #169
2022-02-10 12:18:23 +03:00
Gabriel Ziegler
0e8e080894 Remove unnecessary if statement 2022-02-09 19:22:04 -03:00
Gabriel Ziegler
af6931d1de Fix bug where min_speech_duration_ms is not checked in the last speech segment
Signed-off-by: Gabriel Ziegler <gabrielziegler3@gmail.com>
2022-02-09 19:18:48 -03:00
Alexander Veysov
76687cbe25 Update README.md 2021-12-21 14:43:36 +03:00
Dimitrii Voronin
b2329fa5f2 Merge pull request #144 from snakers4/adamnsandle
Update README.md
2021-12-21 14:25:56 +03:00
Dimitrii Voronin
005886e7eb Update README.md 2021-12-21 13:25:14 +02:00
Dimitrii Voronin
f6b1294cb2 Merge pull request #143 from snakers4/adamnsandle
Adamnsandle
2021-12-21 14:02:25 +03:00
adamnsandle
2392ea33f4 Merge branch 'master' of github.com:snakers4/silero-vad into adamnsandle 2021-12-21 11:01:25 +00:00
adamnsandle
45d72863b6 add multiple of 16k sr support 2021-12-21 11:01:07 +00:00
Alexander Veysov
f40cc128a4 Update utils_vad.py 2021-12-21 08:24:48 +03:00
Alexander Veysov
0d61e4cee1 Update README.md 2021-12-17 22:03:49 +03:00
Alexander Veysov
011268e492 Polish the copy a bit 2021-12-17 22:00:36 +03:00
6 changed files with 135 additions and 39 deletions

View File

@@ -15,7 +15,7 @@ This repository also includes Number Detector and Language classifier [models](h
<br/>
<p align="center">
<img src="https://user-images.githubusercontent.com/36505480/145563071-681b57e3-06b5-4cd0-bdee-e2ade3d50a60.png" />
<img src="https://user-images.githubusercontent.com/36505480/198026365-8da383e0-5398-4a12-b7f8-22c2c0059512.png" />
</p>
<details>
@@ -29,17 +29,17 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
<h2 align="center">Key Features</h2>
<br/>
- **High accuracy**
- **Stellar accuracy**
Silero VAD has [excellent results](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#vs-other-available-solutions) on speech detection tasks.
- **Fast**
One audio chunk (30+ ms) [takes](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics#silero-vad-performance-metrics) around **1ms** to be processed on a single CPU thread. Using batching or GPU can also improve performance considerably.
One audio chunk (30+ ms) [takes](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics#silero-vad-performance-metrics) less than **1ms** to be processed on a single CPU thread. Using batching or GPU can also improve performance considerably. Under certain conditions ONNX may even run up to 4-5x faster.
- **Lightweight**
JIT model is less than one megabyte in size.
JIT model is around one megabyte in size.
- **General**
@@ -47,11 +47,19 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
- **Flexible sampling rate**
Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** (JIT) and **16000 Hz** (ONNX) [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate).
Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate).
- **Flexible chunk size**
Model was trained on audio chunks of different lengths. **30 ms**, **60 ms** and **100 ms** long chunks are supported directly, others may work as well.
Model was trained on **30 ms**. Longer chunks are supported directly, others may work as well.
- **Highly Portable**
Silero VAD reaps benefits from the rich ecosystems built around **PyTorch** and **ONNX** running everywhere where these runtimes are available.
- **No Strings Attached**
Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
<br/>
<h2 align="center">Typical Use Cases</h2>
@@ -70,9 +78,10 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
- [Examples and Dependencies](https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies)
- [Quality Metrics](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics)
- [Performance Metrics](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics)
- Number Detector and Language classifier [models](https://github.com/snakers4/silero-vad/wiki/Other-Models)
- [Number Detector and Language classifier models](https://github.com/snakers4/silero-vad/wiki/Other-Models)
- [Versions and Available Models](https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)
- [Further reading](https://github.com/snakers4/silero-models#further-reading)
- [FAQ](https://github.com/snakers4/silero-vad/wiki/FAQ)
<br/>
<h2 align="center">Get In Touch</h2>
@@ -96,3 +105,9 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) and [tiers
email = {hello@silero.ai}
}
```
<br/>
<h2 align="center">VAD-based Community Apps</h2>
<br/>
- Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web

Binary file not shown.

Binary file not shown.

View File

@@ -16,14 +16,25 @@ from utils_vad import (init_jit_model,
OnnxWrapper)
def silero_vad(onnx=False):
def versiontuple(v):
return tuple(map(int, (v.split('+')[0].split("."))))
def silero_vad(onnx=False, force_onnx_cpu=False):
"""Silero Voice Activity Detector
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
if not onnx:
installed_version = torch.__version__
supported_version = '1.12.0'
if versiontuple(installed_version) < versiontuple(supported_version):
raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)')
model_dir = os.path.join(os.path.dirname(__file__), 'files')
if onnx:
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'))
model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
else:
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
utils = (get_speech_timestamps,
@@ -35,7 +46,7 @@ def silero_vad(onnx=False):
return model, utils
def silero_number_detector(onnx=False):
def silero_number_detector(onnx=False, force_onnx_cpu=False):
"""Silero Number Detector
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
@@ -44,7 +55,7 @@ def silero_number_detector(onnx=False):
url = 'https://models.silero.ai/vad_models/number_detector.onnx'
else:
url = 'https://models.silero.ai/vad_models/number_detector.jit'
model = Validator(url)
model = Validator(url, force_onnx_cpu)
utils = (get_number_ts,
save_audio,
read_audio,
@@ -54,7 +65,7 @@ def silero_number_detector(onnx=False):
return model, utils
def silero_lang_detector(onnx=False):
def silero_lang_detector(onnx=False, force_onnx_cpu=False):
"""Silero Language Classifier
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
@@ -63,14 +74,14 @@ def silero_lang_detector(onnx=False):
url = 'https://models.silero.ai/vad_models/number_detector.onnx'
else:
url = 'https://models.silero.ai/vad_models/number_detector.jit'
model = Validator(url)
model = Validator(url, force_onnx_cpu)
utils = (get_language,
read_audio)
return model, utils
def silero_lang_detector_95(onnx=False):
def silero_lang_detector_95(onnx=False, force_onnx_cpu=False):
"""Silero Language Classifier (95 languages)
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
@@ -80,8 +91,8 @@ def silero_lang_detector_95(onnx=False):
url = 'https://models.silero.ai/vad_models/lang_classifier_95.onnx'
else:
url = 'https://models.silero.ai/vad_models/lang_classifier_95.jit'
model = Validator(url)
model = Validator(url, force_onnx_cpu)
model_dir = os.path.join(os.path.dirname(__file__), 'files')
with open(os.path.join(model_dir, 'lang_dict_95.json'), 'r') as f:
lang_dict = json.load(f)

View File

@@ -138,7 +138,10 @@
"\n",
"window_size_samples = 1536 # number of samples in a single audio chunk\n",
"for i in range(0, len(wav), window_size_samples):\n",
" speech_dict = vad_iterator(wav[i: i+ window_size_samples], return_seconds=True)\n",
" chunk = wav[i: i+ window_size_samples]\n",
" if len(chunk) < window_size_samples:\n",
" break\n",
" speech_dict = vad_iterator(chunk, return_seconds=True)\n",
" if speech_dict:\n",
" print(speech_dict, end=' ')\n",
"vad_iterator.reset_states() # reset model states after each audio"
@@ -158,7 +161,10 @@
"speech_probs = []\n",
"window_size_samples = 1536\n",
"for i in range(0, len(wav), window_size_samples):\n",
" speech_prob = model(wav[i: i+ window_size_samples], SAMPLING_RATE).item()\n",
" chunk = wav[i: i+ window_size_samples]\n",
" if len(chunk) < window_size_samples:\n",
" break\n",
" speech_prob = model(chunk, SAMPLING_RATE).item()\n",
" speech_probs.append(speech_prob)\n",
"vad_iterator.reset_states() # reset model states after each audio\n",
"\n",

View File

@@ -9,51 +9,98 @@ languages = ['ru', 'en', 'de', 'es']
class OnnxWrapper():
def __init__(self, path):
def __init__(self, path, force_onnx_cpu=False):
import numpy as np
global np
import onnxruntime
self.session = onnxruntime.InferenceSession(path)
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])
else:
self.session = onnxruntime.InferenceSession(path)
self.session.intra_op_num_threads = 1
self.session.inter_op_num_threads = 1
self.reset_states()
self.sample_rates = [8000, 16000]
def reset_states(self):
self._h = np.zeros((2, 1, 64)).astype('float32')
self._c = np.zeros((2, 1, 64)).astype('float32')
def __call__(self, x, sr: int):
def _validate_input(self, x, sr: int):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() > 2:
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
if x.shape[0] > 1:
raise ValueError("Onnx model does not support batching")
if sr != 16000 and (sr % 16000 == 0):
step = sr // 16000
x = x[::step]
sr = 16000
if sr not in [16000]:
raise ValueError(f"Supported sample rates: {[16000]}")
if sr not in self.sample_rates:
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
ort_inputs = {'input': x.numpy(), 'h0': self._h, 'c0': self._c}
ort_outs = self.session.run(None, ort_inputs)
out, self._h, self._c = ort_outs
return x, sr
out = torch.tensor(out).squeeze(2)[:, 1] # make output type match JIT analog
def reset_states(self, batch_size=1):
self._h = np.zeros((2, batch_size, 64)).astype('float32')
self._c = np.zeros((2, batch_size, 64)).astype('float32')
self._last_sr = 0
self._last_batch_size = 0
def __call__(self, x, sr: int):
x, sr = self._validate_input(x, sr)
batch_size = x.shape[0]
if not self._last_batch_size:
self.reset_states(batch_size)
if (self._last_sr) and (self._last_sr != sr):
self.reset_states(batch_size)
if (self._last_batch_size) and (self._last_batch_size != batch_size):
self.reset_states(batch_size)
if sr in [8000, 16000]:
ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr)}
ort_outs = self.session.run(None, ort_inputs)
out, self._h, self._c = ort_outs
else:
raise ValueError()
self._last_sr = sr
self._last_batch_size = batch_size
out = torch.tensor(out)
return out
def audio_forward(self, x, sr: int, num_samples: int = 512):
outs = []
x, sr = self._validate_input(x, sr)
if x.shape[1] % num_samples:
pad_num = num_samples - (x.shape[1] % num_samples)
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
self.reset_states(x.shape[0])
for i in range(0, x.shape[1], num_samples):
wavs_batch = x[:, i:i+num_samples]
out_chunk = self.__call__(wavs_batch, sr)
outs.append(out_chunk)
stacked = torch.cat(outs, dim=1)
return stacked.cpu()
class Validator():
def __init__(self, url):
def __init__(self, url, force_onnx_cpu):
self.onnx = True if url.endswith('.onnx') else False
torch.hub.download_url_to_file(url, 'inf.model')
if self.onnx:
import onnxruntime
self.model = onnxruntime.InferenceSession('inf.model')
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider'])
else:
self.model = onnxruntime.InferenceSession('inf.model')
else:
self.model = init_jit_model(model_path='inf.model')
@@ -117,7 +164,7 @@ def get_speech_timestamps(audio: torch.Tensor,
sampling_rate: int = 16000,
min_speech_duration_ms: int = 250,
min_silence_duration_ms: int = 100,
window_size_samples: int = 1536,
window_size_samples: int = 512,
speech_pad_ms: int = 30,
return_seconds: bool = False,
visualize_probs: bool = False):
@@ -177,8 +224,16 @@ def get_speech_timestamps(audio: torch.Tensor,
if len(audio.shape) > 1:
raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?")
if sampling_rate > 16000 and (sampling_rate % 16000 == 0):
step = sampling_rate // 16000
sampling_rate = 16000
audio = audio[::step]
warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!')
else:
step = 1
if sampling_rate == 8000 and window_size_samples > 768:
warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 1536 for 8000 sample rate!')
warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!')
if window_size_samples not in [256, 512, 768, 1024, 1536]:
warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate')
@@ -226,7 +281,7 @@ def get_speech_timestamps(audio: torch.Tensor,
triggered = False
continue
if current_speech:
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
current_speech['end'] = audio_length_samples
speeches.append(current_speech)
@@ -239,7 +294,8 @@ def get_speech_timestamps(audio: torch.Tensor,
speech['end'] += int(silence_duration // 2)
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2))
else:
speech['end'] += int(speech_pad_samples)
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples))
else:
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
@@ -247,6 +303,10 @@ def get_speech_timestamps(audio: torch.Tensor,
for speech_dict in speeches:
speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1)
speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1)
elif step > 1:
for speech_dict in speeches:
speech_dict['start'] *= step
speech_dict['end'] *= step
if visualize_probs:
make_visualization(speech_probs, window_size_samples / sampling_rate)
@@ -353,6 +413,10 @@ class VADIterator:
self.model = model
self.threshold = threshold
self.sampling_rate = sampling_rate
if sampling_rate not in [8000, 16000]:
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
self.reset_states()