mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
Add Language Detector
This commit is contained in:
16
hubconf.py
16
hubconf.py
@@ -3,6 +3,7 @@ import torch
|
|||||||
from utils import (init_jit_model,
|
from utils import (init_jit_model,
|
||||||
get_speech_ts,
|
get_speech_ts,
|
||||||
get_number_ts,
|
get_number_ts,
|
||||||
|
get_language,
|
||||||
save_audio,
|
save_audio,
|
||||||
read_audio,
|
read_audio,
|
||||||
state_generator,
|
state_generator,
|
||||||
@@ -29,7 +30,7 @@ def silero_vad(**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def silero_number_detector(**kwargs):
|
def silero_number_detector(**kwargs):
|
||||||
"""Silero Number Detector and Language Classifier
|
"""Silero Number Detector
|
||||||
Returns a model with a set of utils
|
Returns a model with a set of utils
|
||||||
Please see https://github.com/snakers4/silero-vad for usage examples
|
Please see https://github.com/snakers4/silero-vad for usage examples
|
||||||
"""
|
"""
|
||||||
@@ -42,3 +43,16 @@ def silero_number_detector(**kwargs):
|
|||||||
drop_chunks)
|
drop_chunks)
|
||||||
|
|
||||||
return model, utils
|
return model, utils
|
||||||
|
|
||||||
|
|
||||||
|
def silero_lang_detector(**kwargs):
|
||||||
|
"""Silero Language Classifier
|
||||||
|
Returns a model with a set of utils
|
||||||
|
Please see https://github.com/snakers4/silero-vad for usage examples
|
||||||
|
"""
|
||||||
|
hub_dir = torch.hub.get_dir()
|
||||||
|
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/number_detector.jit')
|
||||||
|
utils = (get_language,
|
||||||
|
read_audio)
|
||||||
|
|
||||||
|
return model, utils
|
||||||
|
|||||||
164
silero-vad.ipynb
164
silero-vad.ipynb
@@ -312,6 +312,78 @@
|
|||||||
"Audio('no_numbers.wav')"
|
"Audio('no_numbers.wav')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"heading_collapsed": true
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Language detector"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"heading_collapsed": true,
|
||||||
|
"hidden": true
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Install Dependencies"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"hidden": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#@title Install and Import Dependencies\n",
|
||||||
|
"\n",
|
||||||
|
"# this assumes that you have a relevant version of PyTorch installed\n",
|
||||||
|
"!pip install -q torchaudio soundfile\n",
|
||||||
|
"\n",
|
||||||
|
"import glob\n",
|
||||||
|
"import torch\n",
|
||||||
|
"torch.set_num_threads(1)\n",
|
||||||
|
"\n",
|
||||||
|
"from IPython.display import Audio\n",
|
||||||
|
"from pprint import pprint\n",
|
||||||
|
"\n",
|
||||||
|
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||||
|
" model='silero_lang_detector',\n",
|
||||||
|
" force_reload=True)\n",
|
||||||
|
"\n",
|
||||||
|
"(get_language,\n",
|
||||||
|
" read_audio) = utils\n",
|
||||||
|
"\n",
|
||||||
|
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"heading_collapsed": true,
|
||||||
|
"hidden": true
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Full audio"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"hidden": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"wav = read_audio(f'{files_dir}/en.wav')\n",
|
||||||
|
"lang = get_language(wav, model)\n",
|
||||||
|
"print(lang)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@@ -677,6 +749,98 @@
|
|||||||
" drop_chunks(number_timestamps, wav), 16000) \n",
|
" drop_chunks(number_timestamps, wav), 16000) \n",
|
||||||
"Audio('no_numbers.wav')"
|
"Audio('no_numbers.wav')"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"heading_collapsed": true
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Language detector"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"heading_collapsed": true,
|
||||||
|
"hidden": true,
|
||||||
|
"id": "bL4kn4KJrlyL"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Install Dependencies"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2020-12-30T17:25:19.107534Z",
|
||||||
|
"start_time": "2020-12-30T17:24:51.853293Z"
|
||||||
|
},
|
||||||
|
"cellView": "form",
|
||||||
|
"hidden": true,
|
||||||
|
"id": "Q4QIfSpprnkI"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#@title Install and Import Dependencies\n",
|
||||||
|
"\n",
|
||||||
|
"# this assumes that you have a relevant version of PyTorch installed\n",
|
||||||
|
"!pip install -q torchaudio soundfile onnxruntime\n",
|
||||||
|
"\n",
|
||||||
|
"import glob\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import onnxruntime\n",
|
||||||
|
"from pprint import pprint\n",
|
||||||
|
"\n",
|
||||||
|
"from IPython.display import Audio\n",
|
||||||
|
"\n",
|
||||||
|
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
|
||||||
|
" model='silero_lang_detector',\n",
|
||||||
|
" force_reload=True)\n",
|
||||||
|
"\n",
|
||||||
|
"(get_language,\n",
|
||||||
|
" read_audio) = utils\n",
|
||||||
|
"\n",
|
||||||
|
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
|
||||||
|
"\n",
|
||||||
|
"def init_onnx_model(model_path: str):\n",
|
||||||
|
" return onnxruntime.InferenceSession(model_path)\n",
|
||||||
|
"\n",
|
||||||
|
"def validate_onnx(model, inputs):\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
|
||||||
|
" outs = model.run(None, ort_inputs)\n",
|
||||||
|
" outs = [torch.Tensor(x) for x in outs]\n",
|
||||||
|
" return outs"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"heading_collapsed": true,
|
||||||
|
"hidden": true,
|
||||||
|
"id": "5JHErdB7jsr0"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Full Audio"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"hidden": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = init_onnx_model(f'{files_dir}/number_detector.onnx')\n",
|
||||||
|
"wav = read_audio(f'{files_dir}/en.wav')\n",
|
||||||
|
"\n",
|
||||||
|
"lang = get_language(wav, model, run_function=validate_onnx)\n",
|
||||||
|
"print(lang)"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
13
utils.py
13
utils.py
@@ -9,6 +9,9 @@ import torch.nn.functional as F
|
|||||||
torchaudio.set_audio_backend("soundfile") # switch backend
|
torchaudio.set_audio_backend("soundfile") # switch backend
|
||||||
|
|
||||||
|
|
||||||
|
languages = ['ru', 'en', 'de', 'es']
|
||||||
|
|
||||||
|
|
||||||
def validate(model,
|
def validate(model,
|
||||||
inputs: torch.Tensor):
|
inputs: torch.Tensor):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -138,6 +141,16 @@ def get_number_ts(wav: torch.Tensor,
|
|||||||
return timings
|
return timings
|
||||||
|
|
||||||
|
|
||||||
|
def get_language(wav: torch.Tensor,
|
||||||
|
model,
|
||||||
|
run_function=validate):
|
||||||
|
wav = torch.unsqueeze(wav, dim=0)
|
||||||
|
lang_logits = run_function(model, wav)[2]
|
||||||
|
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
|
||||||
|
assert lang_pred < len(languages)
|
||||||
|
return languages[lang_pred]
|
||||||
|
|
||||||
|
|
||||||
class VADiterator:
|
class VADiterator:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
trig_sum: float = 0.26,
|
trig_sum: float = 0.26,
|
||||||
|
|||||||
Reference in New Issue
Block a user