diff --git a/files/lang_classifier_95.jit b/files/lang_classifier_95.jit deleted file mode 100644 index 0bf001d..0000000 Binary files a/files/lang_classifier_95.jit and /dev/null differ diff --git a/files/lang_classifier_95.onnx b/files/lang_classifier_95.onnx deleted file mode 100644 index 748d64f..0000000 Binary files a/files/lang_classifier_95.onnx and /dev/null differ diff --git a/files/number_detector.jit b/files/number_detector.jit deleted file mode 100644 index 3e9c9c2..0000000 Binary files a/files/number_detector.jit and /dev/null differ diff --git a/files/number_detector.onnx b/files/number_detector.onnx deleted file mode 100644 index 2a7e5e4..0000000 Binary files a/files/number_detector.onnx and /dev/null differ diff --git a/hubconf.py b/hubconf.py index 95ecd0b..f51c726 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,6 +1,7 @@ dependencies = ['torch', 'torchaudio'] import torch import json + from utils_vad import (init_jit_model, get_speech_timestamps, get_number_ts, @@ -10,7 +11,8 @@ from utils_vad import (init_jit_model, read_audio, VADIterator, collect_chunks, - drop_chunks) + drop_chunks, + donwload_onnx_model) def silero_vad(**kwargs): @@ -34,13 +36,14 @@ def silero_number_detector(**kwargs): Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples """ - hub_dir = torch.hub.get_dir() - model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/number_detector.jit') + torch.hub.download_url_to_file('https://models.silero.ai/vad_models/number_detector.jit', 'number_detector.jit') + model = init_jit_model(model_path='number_detector.jit') utils = (get_number_ts, save_audio, read_audio, collect_chunks, - drop_chunks) + drop_chunks, + donwload_onnx_model) return model, utils @@ -50,10 +53,11 @@ def silero_lang_detector(**kwargs): Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples """ - hub_dir = torch.hub.get_dir() - model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/number_detector.jit') + torch.hub.download_url_to_file('https://models.silero.ai/vad_models/number_detector.jit', 'number_detector.jit') + model = init_jit_model(model_path='number_detector.jit') utils = (get_language, - read_audio) + read_audio, + donwload_onnx_model) return model, utils @@ -65,7 +69,8 @@ def silero_lang_detector_95(**kwargs): """ hub_dir = torch.hub.get_dir() - model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/lang_classifier_95.jit') + torch.hub.download_url_to_file('https://models.silero.ai/vad_models/lang_classifier_95.jit', 'lang_classifier_95.jit') + model = init_jit_model(model_path='lang_classifier_95.jit') with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_dict_95.json', 'r') as f: lang_dict = json.load(f) @@ -73,6 +78,6 @@ def silero_lang_detector_95(**kwargs): with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_group_dict_95.json', 'r') as f: lang_group_dict = json.load(f) - utils = (get_language_and_group, read_audio) + utils = (get_language_and_group, read_audio, donwload_onnx_model) return model, lang_dict, lang_group_dict, utils diff --git a/silero-vad.ipynb b/silero-vad.ipynb index 9fb5d8a..59346e8 100644 --- a/silero-vad.ipynb +++ b/silero-vad.ipynb @@ -73,6 +73,28 @@ "### Full Audio" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RJRBksv39xf5" + }, + "outputs": [], + "source": [ + "wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tEKb0YF_9y-i" + }, + "outputs": [], + "source": [ + "wav" + ] + }, { "cell_type": "markdown", "metadata": { @@ -210,7 +232,8 @@ " save_audio,\n", " read_audio,\n", " collect_chunks,\n", - " drop_chunks) = utils\n", + " drop_chunks,\n", + " _) = utils\n", "\n", "files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'" ] @@ -334,7 +357,8 @@ " force_reload=True)\n", "\n", "(get_language,\n", - " read_audio) = utils\n", + " read_audio,\n", + " _) = utils\n", "\n", "files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'" ] @@ -416,7 +440,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "hidden": true, "id": "PdjGd56R2Fw5" }, @@ -442,9 +465,11 @@ " save_audio,\n", " read_audio,\n", " collect_chunks,\n", - " drop_chunks) = utils\n", + " drop_chunks,\n", + " donwload_onnx_model) = utils\n", "\n", "files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n", + "donwload_onnx_model('number_detector')\n", "\n", "def init_onnx_model(model_path: str):\n", " return onnxruntime.InferenceSession(model_path)\n", @@ -477,7 +502,7 @@ }, "outputs": [], "source": [ - "model = init_onnx_model(f'{files_dir}/number_detector.onnx')\n", + "model = init_onnx_model('number_detector.onnx')\n", "wav = read_audio(f'{files_dir}/en_num.wav')\n", "\n", "# get number timestamps from full audio file\n", @@ -556,7 +581,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "hidden": true, "id": "iNkDWJ3H2Fw6" }, @@ -579,8 +603,10 @@ " force_reload=True)\n", "\n", "(get_language,\n", - " read_audio) = utils\n", + " read_audio,\n", + " donwload_onnx_model) = utils\n", "\n", + "donwload_onnx_model('number_detector')\n", "files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n", "\n", "def init_onnx_model(model_path: str):\n", @@ -613,7 +639,7 @@ }, "outputs": [], "source": [ - "model = init_onnx_model(f'{files_dir}/number_detector.onnx')\n", + "model = init_onnx_model('number_detector.onnx')\n", "wav = read_audio(f'{files_dir}/en.wav')\n", "\n", "lang = get_language(wav, model, run_function=validate_onnx)\n", diff --git a/utils_vad.py b/utils_vad.py index f8a7b0c..ae51a7f 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -5,6 +5,18 @@ import torch.nn.functional as F import warnings languages = ['ru', 'en', 'de', 'es'] +onnx_url_dict = { + 'lang_classifier_95': 'https://models.silero.ai/vad_models/lang_classifier_95.onnx', + 'number_detector':'https://models.silero.ai/vad_models/number_detector.onnx' + } + + +def donwload_onnx_model(model_name): + + if model_name not in ['lang_classifier_95', 'number_detector']: + raise ValueError + + torch.hub.download_url_to_file(onnx_url_dict[model_name], f'{model_name}.onnx') def validate(model,