diff --git a/files/de.wav b/files/de.wav deleted file mode 100644 index 562e38f..0000000 Binary files a/files/de.wav and /dev/null differ diff --git a/files/en.wav b/files/en.wav deleted file mode 100644 index 0b1b64e..0000000 Binary files a/files/en.wav and /dev/null differ diff --git a/files/en_num.wav b/files/en_num.wav deleted file mode 100644 index 906ff26..0000000 Binary files a/files/en_num.wav and /dev/null differ diff --git a/files/es.wav b/files/es.wav deleted file mode 100644 index f76e59d..0000000 Binary files a/files/es.wav and /dev/null differ diff --git a/files/ru.wav b/files/ru.wav deleted file mode 100644 index f18bc28..0000000 Binary files a/files/ru.wav and /dev/null differ diff --git a/files/ru_num.wav b/files/ru_num.wav deleted file mode 100644 index 2f90971..0000000 Binary files a/files/ru_num.wav and /dev/null differ diff --git a/files/silero_vad.onnx b/files/silero_vad.onnx new file mode 100644 index 0000000..127ffc9 Binary files /dev/null and b/files/silero_vad.onnx differ diff --git a/hubconf.py b/hubconf.py index f51c726..a2b3754 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,7 +1,6 @@ dependencies = ['torch', 'torchaudio'] import torch import json - from utils_vad import (init_jit_model, get_speech_timestamps, get_number_ts, @@ -12,16 +11,20 @@ from utils_vad import (init_jit_model, VADIterator, collect_chunks, drop_chunks, - donwload_onnx_model) + Validator, + OnnxWrapper) -def silero_vad(**kwargs): +def silero_vad(onnx=False): """Silero Voice Activity Detector Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples """ hub_dir = torch.hub.get_dir() - model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/silero_vad.jit') + if onnx: + model = OnnxWrapper(f'{hub_dir}/snakers4_silero-vad_master/files/silero_vad.onnx') + else: + model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/silero_vad.jit') utils = (get_speech_timestamps, save_audio, read_audio, @@ -31,46 +34,53 @@ def silero_vad(**kwargs): return model, utils -def silero_number_detector(**kwargs): +def silero_number_detector(onnx=False): """Silero Number Detector Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples """ - torch.hub.download_url_to_file('https://models.silero.ai/vad_models/number_detector.jit', 'number_detector.jit') - model = init_jit_model(model_path='number_detector.jit') + if onnx: + url = 'https://models.silero.ai/vad_models/number_detector.onnx' + else: + url = 'https://models.silero.ai/vad_models/number_detector.jit' + model = Validator(url) utils = (get_number_ts, save_audio, read_audio, collect_chunks, - drop_chunks, - donwload_onnx_model) + drop_chunks) return model, utils -def silero_lang_detector(**kwargs): +def silero_lang_detector(onnx=False): """Silero Language Classifier Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples """ - torch.hub.download_url_to_file('https://models.silero.ai/vad_models/number_detector.jit', 'number_detector.jit') - model = init_jit_model(model_path='number_detector.jit') + if onnx: + url = 'https://models.silero.ai/vad_models/number_detector.onnx' + else: + url = 'https://models.silero.ai/vad_models/number_detector.jit' + model = Validator(url) utils = (get_language, - read_audio, - donwload_onnx_model) + read_audio) return model, utils -def silero_lang_detector_95(**kwargs): +def silero_lang_detector_95(onnx=False): """Silero Language Classifier (95 languages) Returns a model with a set of utils Please see https://github.com/snakers4/silero-vad for usage examples """ hub_dir = torch.hub.get_dir() - torch.hub.download_url_to_file('https://models.silero.ai/vad_models/lang_classifier_95.jit', 'lang_classifier_95.jit') - model = init_jit_model(model_path='lang_classifier_95.jit') + if onnx: + url = 'https://models.silero.ai/vad_models/lang_classifier_95.onnx' + else: + url = 'https://models.silero.ai/vad_models/lang_classifier_95.jit' + model = Validator(url) with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_dict_95.json', 'r') as f: lang_dict = json.load(f) @@ -78,6 +88,6 @@ def silero_lang_detector_95(**kwargs): with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_group_dict_95.json', 'r') as f: lang_group_dict = json.load(f) - utils = (get_language_and_group, read_audio, donwload_onnx_model) + utils = (get_language_and_group, read_audio) return model, lang_dict, lang_group_dict, utils diff --git a/silero-vad.ipynb b/silero-vad.ipynb index 59346e8..f2d3fa2 100644 --- a/silero-vad.ipynb +++ b/silero-vad.ipynb @@ -1,21 +1,12 @@ { "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "sVNOuHQQjsrp" - }, - "source": [ - "# PyTorch Examples" - ] - }, { "cell_type": "markdown", "metadata": { "id": "FpMplOCA2Fwp" }, "source": [ - "## VAD" + "#VAD" ] }, { @@ -25,7 +16,7 @@ "id": "62A6F_072Fwq" }, "source": [ - "### Install Dependencies" + "## Install Dependencies" ] }, { @@ -42,26 +33,39 @@ "# this assumes that you have a relevant version of PyTorch installed\n", "!pip install -q torchaudio\n", "\n", - "SAMPLE_RATE = 16000\n", + "SAMPLING_RATE = 16000\n", "\n", - "import glob\n", "import torch\n", "torch.set_num_threads(1)\n", "\n", "from IPython.display import Audio\n", "from pprint import pprint\n", - "\n", + "# download example\n", + "torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pSifus5IilRp" + }, + "outputs": [], + "source": [ + "USE_ONNX = False # change this to True if you want to test onnx model\n", + "if USE_ONNX:\n", + " !pip install -q onnxruntime\n", + " \n", "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", " model='silero_vad',\n", - " force_reload=True)\n", + " force_reload=True,\n", + " onnx=USE_ONNX)\n", "\n", "(get_speech_timestamps,\n", " save_audio,\n", " read_audio,\n", " VADIterator,\n", - " collect_chunks) = utils\n", - "\n", - "files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'" + " collect_chunks) = utils" ] }, { @@ -70,29 +74,7 @@ "id": "fXbbaUO3jsrw" }, "source": [ - "### Full Audio" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RJRBksv39xf5" - }, - "outputs": [], - "source": [ - "wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "tEKb0YF_9y-i" - }, - "outputs": [], - "source": [ - "wav" + "## Full Audio" ] }, { @@ -112,9 +94,9 @@ }, "outputs": [], "source": [ - "wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n", + "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n", "# get speech timestamps from full audio file\n", - "speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLE_RATE)\n", + "speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLING_RATE)\n", "pprint(speech_timestamps)" ] }, @@ -128,7 +110,7 @@ "source": [ "# merge all speech chunks to one audio\n", "save_audio('only_speech.wav',\n", - " collect_chunks(speech_timestamps, wav), sampling_rate=16000) \n", + " collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE) \n", "Audio('only_speech.wav')" ] }, @@ -138,7 +120,7 @@ "id": "iDKQbVr8jsry" }, "source": [ - "### Stream imitation example" + "## Stream imitation example" ] }, { @@ -152,7 +134,7 @@ "## using VADIterator class\n", "\n", "vad_iterator = VADIterator(model)\n", - "wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n", + "wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n", "\n", "window_size_samples = 1536 # number of samples in a single audio chunk\n", "for i in range(0, len(wav), window_size_samples):\n", @@ -172,14 +154,15 @@ "source": [ "## just probabilities\n", "\n", - "wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n", + "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n", "speech_probs = []\n", "window_size_samples = 1536\n", "for i in range(0, len(wav), window_size_samples):\n", - " speech_prob = model(wav[i: i+ window_size_samples], SAMPLE_RATE).item()\n", + " speech_prob = model(wav[i: i+ window_size_samples], SAMPLING_RATE).item()\n", " speech_probs.append(speech_prob)\n", + "vad_iterator.reset_states() # reset model states after each audio\n", "\n", - "pprint(speech_probs[:100])" + "print(speech_probs[:10]) # first 10 chunks predicts" ] }, { @@ -189,7 +172,7 @@ "id": "36jY0niD2Fww" }, "source": [ - "## Number detector" + "# Number detector" ] }, { @@ -200,7 +183,7 @@ "id": "scd1DlS42Fwx" }, "source": [ - "### Install Dependencies" + "## Install Dependencies" ] }, { @@ -215,27 +198,41 @@ "#@title Install and Import Dependencies\n", "\n", "# this assumes that you have a relevant version of PyTorch installed\n", - "!pip install -q torchaudio soundfile\n", + "!pip install -q torchaudio\n", + "\n", + "SAMPLING_RATE = 16000\n", "\n", - "import glob\n", "import torch\n", "torch.set_num_threads(1)\n", "\n", "from IPython.display import Audio\n", "from pprint import pprint\n", - "\n", + "# download example\n", + "torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en_num.wav', 'en_number_example.wav')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dPwCFHmFycUF" + }, + "outputs": [], + "source": [ + "USE_ONNX = False # change this to True if you want to test onnx model\n", + "if USE_ONNX:\n", + " !pip install -q onnxruntime\n", + " \n", "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", " model='silero_number_detector',\n", - " force_reload=True)\n", + " force_reload=True,\n", + " onnx=USE_ONNX)\n", "\n", "(get_number_ts,\n", " save_audio,\n", " read_audio,\n", " collect_chunks,\n", - " drop_chunks,\n", - " _) = utils\n", - "\n", - "files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'" + " drop_chunks) = utils\n" ] }, { @@ -246,7 +243,7 @@ "id": "qhPa30ij2Fwy" }, "source": [ - "### Full audio" + "## Full audio" ] }, { @@ -258,7 +255,7 @@ }, "outputs": [], "source": [ - "wav = read_audio(f'{files_dir}/en_num.wav')\n", + "wav = read_audio('en_number_example.wav', sampling_rate=SAMPLING_RATE)\n", "# get number timestamps from full audio file\n", "number_timestamps = get_number_ts(wav, model)\n", "pprint(number_timestamps)" @@ -273,11 +270,10 @@ }, "outputs": [], "source": [ - "sample_rate = 16000\n", "# convert ms in timestamps to samples\n", "for timestamp in number_timestamps:\n", - " timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n", - " timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)" + " timestamp['start'] = int(timestamp['start'] * SAMPLING_RATE / 1000)\n", + " timestamp['end'] = int(timestamp['end'] * SAMPLING_RATE / 1000)" ] }, { @@ -291,7 +287,7 @@ "source": [ "# merge all number chunks to one audio\n", "save_audio('only_numbers.wav',\n", - " collect_chunks(number_timestamps, wav), sample_rate) \n", + " collect_chunks(number_timestamps, wav), SAMPLING_RATE) \n", "Audio('only_numbers.wav')" ] }, @@ -306,7 +302,7 @@ "source": [ "# drop all number chunks from audio\n", "save_audio('no_numbers.wav',\n", - " drop_chunks(number_timestamps, wav), sample_rate) \n", + " drop_chunks(number_timestamps, wav), SAMPLING_RATE) \n", "Audio('no_numbers.wav')" ] }, @@ -317,7 +313,7 @@ "id": "PnKtJKbq2Fwz" }, "source": [ - "## Language detector" + "# Language detector" ] }, { @@ -328,7 +324,7 @@ "id": "F5cAmMbP2Fwz" }, "source": [ - "### Install Dependencies" + "## Install Dependencies" ] }, { @@ -343,24 +339,37 @@ "#@title Install and Import Dependencies\n", "\n", "# this assumes that you have a relevant version of PyTorch installed\n", - "!pip install -q torchaudio soundfile\n", + "!pip install -q torchaudio\n", + "\n", + "SAMPLING_RATE = 16000\n", "\n", - "import glob\n", "import torch\n", "torch.set_num_threads(1)\n", "\n", "from IPython.display import Audio\n", "from pprint import pprint\n", - "\n", + "# download example\n", + "torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JfRKDZiRztFe" + }, + "outputs": [], + "source": [ + "USE_ONNX = False # change this to True if you want to test onnx model\n", + "if USE_ONNX:\n", + " !pip install -q onnxruntime\n", + " \n", "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", " model='silero_lang_detector',\n", - " force_reload=True)\n", + " force_reload=True,\n", + " onnx=USE_ONNX)\n", "\n", - "(get_language,\n", - " read_audio,\n", - " _) = utils\n", - "\n", - "files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'" + "get_language, read_audio = utils" ] }, { @@ -371,7 +380,7 @@ "id": "iC696eMX2Fwz" }, "source": [ - "### Full audio" + "## Full audio" ] }, { @@ -383,268 +392,10 @@ }, "outputs": [], "source": [ - "wav = read_audio(f'{files_dir}/en.wav')\n", + "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n", "lang = get_language(wav, model)\n", "print(lang)" ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "57avIBd6jsrz" - }, - "source": [ - "# ONNX Example" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hEhnfORV2Fw0" - }, - "source": [ - "## VAD" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Cy7y-NAyALSe" - }, - "source": [ - "**TO BE DONE**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "heading_collapsed": true, - "id": "7QMvUvpg2Fw4" - }, - "source": [ - "## Number detector" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "heading_collapsed": true, - "hidden": true, - "id": "tBPDkpHr2Fw4" - }, - "source": [ - "### Install Dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "hidden": true, - "id": "PdjGd56R2Fw5" - }, - "outputs": [], - "source": [ - "#@title Install and Import Dependencies\n", - "\n", - "# this assumes that you have a relevant version of PyTorch installed\n", - "!pip install -q torchaudio soundfile onnxruntime\n", - "\n", - "import glob\n", - "import torch\n", - "import onnxruntime\n", - "from pprint import pprint\n", - "\n", - "from IPython.display import Audio\n", - "\n", - "_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", - " model='silero_number_detector',\n", - " force_reload=True)\n", - "\n", - "(get_number_ts,\n", - " save_audio,\n", - " read_audio,\n", - " collect_chunks,\n", - " drop_chunks,\n", - " donwload_onnx_model) = utils\n", - "\n", - "files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n", - "donwload_onnx_model('number_detector')\n", - "\n", - "def init_onnx_model(model_path: str):\n", - " return onnxruntime.InferenceSession(model_path)\n", - "\n", - "def validate_onnx(model, inputs):\n", - " with torch.no_grad():\n", - " ort_inputs = {'input': inputs.cpu().numpy()}\n", - " outs = model.run(None, ort_inputs)\n", - " outs = [torch.Tensor(x) for x in outs]\n", - " return outs" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "heading_collapsed": true, - "hidden": true, - "id": "I9QWSFZh2Fw5" - }, - "source": [ - "### Full Audio" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "hidden": true, - "id": "_r6QZiwu2Fw5" - }, - "outputs": [], - "source": [ - "model = init_onnx_model('number_detector.onnx')\n", - "wav = read_audio(f'{files_dir}/en_num.wav')\n", - "\n", - "# get number timestamps from full audio file\n", - "number_timestamps = get_number_ts(wav, model, run_function=validate_onnx)\n", - "pprint(number_timestamps)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "hidden": true, - "id": "FN4aDwLV2Fw5" - }, - "outputs": [], - "source": [ - "sample_rate = 16000\n", - "# convert ms in timestamps to samples\n", - "for timestamp in number_timestamps:\n", - " timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n", - " timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "hidden": true, - "id": "JnvS6WTK2Fw5" - }, - "outputs": [], - "source": [ - "# merge all number chunks to one audio\n", - "save_audio('only_numbers.wav',\n", - " collect_chunks(number_timestamps, wav), 16000) \n", - "Audio('only_numbers.wav')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "hidden": true, - "id": "yUxOcOFG2Fw6" - }, - "outputs": [], - "source": [ - "# drop all number chunks from audio\n", - "save_audio('no_numbers.wav',\n", - " drop_chunks(number_timestamps, wav), 16000) \n", - "Audio('no_numbers.wav')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "heading_collapsed": true, - "id": "SR8Bgcd52Fw6" - }, - "source": [ - "## Language detector" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "heading_collapsed": true, - "hidden": true, - "id": "PBnXPtKo2Fw6" - }, - "source": [ - "### Install Dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "hidden": true, - "id": "iNkDWJ3H2Fw6" - }, - "outputs": [], - "source": [ - "#@title Install and Import Dependencies\n", - "\n", - "# this assumes that you have a relevant version of PyTorch installed\n", - "!pip install -q torchaudio soundfile onnxruntime\n", - "\n", - "import glob\n", - "import torch\n", - "import onnxruntime\n", - "from pprint import pprint\n", - "\n", - "from IPython.display import Audio\n", - "\n", - "_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", - " model='silero_lang_detector',\n", - " force_reload=True)\n", - "\n", - "(get_language,\n", - " read_audio,\n", - " donwload_onnx_model) = utils\n", - "\n", - "donwload_onnx_model('number_detector')\n", - "files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n", - "\n", - "def init_onnx_model(model_path: str):\n", - " return onnxruntime.InferenceSession(model_path)\n", - "\n", - "def validate_onnx(model, inputs):\n", - " with torch.no_grad():\n", - " ort_inputs = {'input': inputs.cpu().numpy()}\n", - " outs = model.run(None, ort_inputs)\n", - " outs = [torch.Tensor(x) for x in outs]\n", - " return outs" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "hidden": true, - "id": "G8N8oP4q2Fw6" - }, - "source": [ - "### Full Audio" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "hidden": true, - "id": "WHXnh9IV2Fw6" - }, - "outputs": [], - "source": [ - "model = init_onnx_model('number_detector.onnx')\n", - "wav = read_audio(f'{files_dir}/en.wav')\n", - "\n", - "lang = get_language(wav, model, run_function=validate_onnx)\n", - "print(lang)" - ] } ], "metadata": { diff --git a/utils_vad.py b/utils_vad.py index ae51a7f..29b69c3 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -5,25 +5,68 @@ import torch.nn.functional as F import warnings languages = ['ru', 'en', 'de', 'es'] -onnx_url_dict = { - 'lang_classifier_95': 'https://models.silero.ai/vad_models/lang_classifier_95.onnx', - 'number_detector':'https://models.silero.ai/vad_models/number_detector.onnx' - } -def donwload_onnx_model(model_name): +class OnnxWrapper(): - if model_name not in ['lang_classifier_95', 'number_detector']: - raise ValueError + def __init__(self, path): + import numpy as np + global np + import onnxruntime + self.session = onnxruntime.InferenceSession(path) + self.session.intra_op_num_threads = 1 + self.session.inter_op_num_threads = 1 - torch.hub.download_url_to_file(onnx_url_dict[model_name], f'{model_name}.onnx') + self.reset_states() + + def reset_states(self): + self._h = np.zeros((2, 1, 64)).astype('float32') + self._c = np.zeros((2, 1, 64)).astype('float32') + + def __call__(self, x, sr: int): + if x.dim() == 1: + x = x.unsqueeze(0) + if x.dim() > 2: + raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}") + + if x.shape[0] > 1: + raise ValueError("Onnx model does not support batching") + + if sr not in [16000]: + raise ValueError(f"Supported sample rates: {[16000]}") + + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + ort_inputs = {'input': x.numpy(), 'h0': self._h, 'c0': self._c} + ort_outs = self.session.run(None, ort_inputs) + out, self._h, self._c = ort_outs + + out = torch.tensor(out).squeeze(2)[:, 1] # make output type match JIT analog + + return out -def validate(model, - inputs: torch.Tensor): - with torch.no_grad(): - outs = model(inputs) - return outs +class Validator(): + def __init__(self, url): + self.onnx = True if url.endswith('.onnx') else False + torch.hub.download_url_to_file(url, 'inf.model') + if self.onnx: + import onnxruntime + self.model = onnxruntime.InferenceSession('inf.model') + else: + self.model = init_jit_model(model_path='inf.model') + + def __call__(self, inputs: torch.Tensor): + with torch.no_grad(): + if self.onnx: + ort_inputs = {'input': inputs.cpu().numpy()} + outs = self.model.run(None, ort_inputs) + outs = [torch.Tensor(x) for x in outs] + else: + outs = self.model(inputs) + + return outs def read_audio(path: str, @@ -215,10 +258,9 @@ def get_number_ts(wav: torch.Tensor, model, model_stride=8, hop_length=160, - sample_rate=16000, - run_function=validate): + sample_rate=16000): wav = torch.unsqueeze(wav, dim=0) - perframe_logits = run_function(model, wav)[0] + perframe_logits = model(wav)[0] perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided) extended_preds = [] for i in perframe_preds: @@ -245,10 +287,9 @@ def get_number_ts(wav: torch.Tensor, def get_language(wav: torch.Tensor, - model, - run_function=validate): + model): wav = torch.unsqueeze(wav, dim=0) - lang_logits = run_function(model, wav)[2] + lang_logits = model(wav)[2] lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1 assert lang_pred < len(languages) return languages[lang_pred] @@ -258,10 +299,9 @@ def get_language_and_group(wav: torch.Tensor, model, lang_dict: dict, lang_group_dict: dict, - top_n=1, - run_function=validate): + top_n=1): wav = torch.unsqueeze(wav, dim=0) - lang_logits, lang_group_logits = run_function(model, wav) + lang_logits, lang_group_logits = model(wav) softm = torch.softmax(lang_logits, dim=1).squeeze() softm_group = torch.softmax(lang_group_logits, dim=1).squeeze() @@ -332,6 +372,13 @@ class VADIterator: return_seconds: bool (default - False) whether return timestamps in seconds (default - samples) """ + + if not torch.is_tensor(x): + try: + x = torch.Tensor(x) + except: + raise TypeError("Audio cannot be casted to tensor. Cast it manually") + window_size_samples = len(x[0]) if x.dim() == 2 else len(x) self.current_sample += window_size_samples