Merge pull request #137 from snakers4/adamnsandle

Adamnsandle
This commit is contained in:
Dimitrii Voronin
2021-12-17 17:50:13 +03:00
committed by GitHub
10 changed files with 185 additions and 377 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
files/silero_vad.onnx Normal file

Binary file not shown.

View File

@@ -1,7 +1,6 @@
dependencies = ['torch', 'torchaudio']
import torch
import json
from utils_vad import (init_jit_model,
get_speech_timestamps,
get_number_ts,
@@ -12,16 +11,20 @@ from utils_vad import (init_jit_model,
VADIterator,
collect_chunks,
drop_chunks,
donwload_onnx_model)
Validator,
OnnxWrapper)
def silero_vad(**kwargs):
def silero_vad(onnx=False):
"""Silero Voice Activity Detector
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
hub_dir = torch.hub.get_dir()
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/silero_vad.jit')
if onnx:
model = OnnxWrapper(f'{hub_dir}/snakers4_silero-vad_master/files/silero_vad.onnx')
else:
model = init_jit_model(model_path=f'{hub_dir}/snakers4_silero-vad_master/files/silero_vad.jit')
utils = (get_speech_timestamps,
save_audio,
read_audio,
@@ -31,46 +34,53 @@ def silero_vad(**kwargs):
return model, utils
def silero_number_detector(**kwargs):
def silero_number_detector(onnx=False):
"""Silero Number Detector
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
torch.hub.download_url_to_file('https://models.silero.ai/vad_models/number_detector.jit', 'number_detector.jit')
model = init_jit_model(model_path='number_detector.jit')
if onnx:
url = 'https://models.silero.ai/vad_models/number_detector.onnx'
else:
url = 'https://models.silero.ai/vad_models/number_detector.jit'
model = Validator(url)
utils = (get_number_ts,
save_audio,
read_audio,
collect_chunks,
drop_chunks,
donwload_onnx_model)
drop_chunks)
return model, utils
def silero_lang_detector(**kwargs):
def silero_lang_detector(onnx=False):
"""Silero Language Classifier
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
torch.hub.download_url_to_file('https://models.silero.ai/vad_models/number_detector.jit', 'number_detector.jit')
model = init_jit_model(model_path='number_detector.jit')
if onnx:
url = 'https://models.silero.ai/vad_models/number_detector.onnx'
else:
url = 'https://models.silero.ai/vad_models/number_detector.jit'
model = Validator(url)
utils = (get_language,
read_audio,
donwload_onnx_model)
read_audio)
return model, utils
def silero_lang_detector_95(**kwargs):
def silero_lang_detector_95(onnx=False):
"""Silero Language Classifier (95 languages)
Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples
"""
hub_dir = torch.hub.get_dir()
torch.hub.download_url_to_file('https://models.silero.ai/vad_models/lang_classifier_95.jit', 'lang_classifier_95.jit')
model = init_jit_model(model_path='lang_classifier_95.jit')
if onnx:
url = 'https://models.silero.ai/vad_models/lang_classifier_95.onnx'
else:
url = 'https://models.silero.ai/vad_models/lang_classifier_95.jit'
model = Validator(url)
with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_dict_95.json', 'r') as f:
lang_dict = json.load(f)
@@ -78,6 +88,6 @@ def silero_lang_detector_95(**kwargs):
with open(f'{hub_dir}/snakers4_silero-vad_master/files/lang_group_dict_95.json', 'r') as f:
lang_group_dict = json.load(f)
utils = (get_language_and_group, read_audio, donwload_onnx_model)
utils = (get_language_and_group, read_audio)
return model, lang_dict, lang_group_dict, utils

View File

@@ -1,21 +1,12 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "sVNOuHQQjsrp"
},
"source": [
"# PyTorch Examples"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FpMplOCA2Fwp"
},
"source": [
"## VAD"
"#VAD"
]
},
{
@@ -25,7 +16,7 @@
"id": "62A6F_072Fwq"
},
"source": [
"### Install Dependencies"
"## Install Dependencies"
]
},
{
@@ -42,26 +33,39 @@
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio\n",
"\n",
"SAMPLE_RATE = 16000\n",
"SAMPLING_RATE = 16000\n",
"\n",
"import glob\n",
"import torch\n",
"torch.set_num_threads(1)\n",
"\n",
"from IPython.display import Audio\n",
"from pprint import pprint\n",
"\n",
"# download example\n",
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pSifus5IilRp"
},
"outputs": [],
"source": [
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
" \n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=True)\n",
" force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n",
"(get_speech_timestamps,\n",
" save_audio,\n",
" read_audio,\n",
" VADIterator,\n",
" collect_chunks) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
" collect_chunks) = utils"
]
},
{
@@ -70,29 +74,7 @@
"id": "fXbbaUO3jsrw"
},
"source": [
"### Full Audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RJRBksv39xf5"
},
"outputs": [],
"source": [
"wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tEKb0YF_9y-i"
},
"outputs": [],
"source": [
"wav"
"## Full Audio"
]
},
{
@@ -112,9 +94,9 @@
},
"outputs": [],
"source": [
"wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n",
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"# get speech timestamps from full audio file\n",
"speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLE_RATE)\n",
"speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLING_RATE)\n",
"pprint(speech_timestamps)"
]
},
@@ -128,7 +110,7 @@
"source": [
"# merge all speech chunks to one audio\n",
"save_audio('only_speech.wav',\n",
" collect_chunks(speech_timestamps, wav), sampling_rate=16000) \n",
" collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE) \n",
"Audio('only_speech.wav')"
]
},
@@ -138,7 +120,7 @@
"id": "iDKQbVr8jsry"
},
"source": [
"### Stream imitation example"
"## Stream imitation example"
]
},
{
@@ -152,7 +134,7 @@
"## using VADIterator class\n",
"\n",
"vad_iterator = VADIterator(model)\n",
"wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n",
"wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"\n",
"window_size_samples = 1536 # number of samples in a single audio chunk\n",
"for i in range(0, len(wav), window_size_samples):\n",
@@ -172,14 +154,15 @@
"source": [
"## just probabilities\n",
"\n",
"wav = read_audio(f'{files_dir}/en.wav', sampling_rate=SAMPLE_RATE)\n",
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"speech_probs = []\n",
"window_size_samples = 1536\n",
"for i in range(0, len(wav), window_size_samples):\n",
" speech_prob = model(wav[i: i+ window_size_samples], SAMPLE_RATE).item()\n",
" speech_prob = model(wav[i: i+ window_size_samples], SAMPLING_RATE).item()\n",
" speech_probs.append(speech_prob)\n",
"vad_iterator.reset_states() # reset model states after each audio\n",
"\n",
"pprint(speech_probs[:100])"
"print(speech_probs[:10]) # first 10 chunks predicts"
]
},
{
@@ -189,7 +172,7 @@
"id": "36jY0niD2Fww"
},
"source": [
"## Number detector"
"# Number detector"
]
},
{
@@ -200,7 +183,7 @@
"id": "scd1DlS42Fwx"
},
"source": [
"### Install Dependencies"
"## Install Dependencies"
]
},
{
@@ -215,27 +198,41 @@
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile\n",
"!pip install -q torchaudio\n",
"\n",
"SAMPLING_RATE = 16000\n",
"\n",
"import glob\n",
"import torch\n",
"torch.set_num_threads(1)\n",
"\n",
"from IPython.display import Audio\n",
"from pprint import pprint\n",
"\n",
"# download example\n",
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en_num.wav', 'en_number_example.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dPwCFHmFycUF"
},
"outputs": [],
"source": [
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
" \n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_number_detector',\n",
" force_reload=True)\n",
" force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n",
"(get_number_ts,\n",
" save_audio,\n",
" read_audio,\n",
" collect_chunks,\n",
" drop_chunks,\n",
" _) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
" drop_chunks) = utils\n"
]
},
{
@@ -246,7 +243,7 @@
"id": "qhPa30ij2Fwy"
},
"source": [
"### Full audio"
"## Full audio"
]
},
{
@@ -258,7 +255,7 @@
},
"outputs": [],
"source": [
"wav = read_audio(f'{files_dir}/en_num.wav')\n",
"wav = read_audio('en_number_example.wav', sampling_rate=SAMPLING_RATE)\n",
"# get number timestamps from full audio file\n",
"number_timestamps = get_number_ts(wav, model)\n",
"pprint(number_timestamps)"
@@ -273,11 +270,10 @@
},
"outputs": [],
"source": [
"sample_rate = 16000\n",
"# convert ms in timestamps to samples\n",
"for timestamp in number_timestamps:\n",
" timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n",
" timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)"
" timestamp['start'] = int(timestamp['start'] * SAMPLING_RATE / 1000)\n",
" timestamp['end'] = int(timestamp['end'] * SAMPLING_RATE / 1000)"
]
},
{
@@ -291,7 +287,7 @@
"source": [
"# merge all number chunks to one audio\n",
"save_audio('only_numbers.wav',\n",
" collect_chunks(number_timestamps, wav), sample_rate) \n",
" collect_chunks(number_timestamps, wav), SAMPLING_RATE) \n",
"Audio('only_numbers.wav')"
]
},
@@ -306,7 +302,7 @@
"source": [
"# drop all number chunks from audio\n",
"save_audio('no_numbers.wav',\n",
" drop_chunks(number_timestamps, wav), sample_rate) \n",
" drop_chunks(number_timestamps, wav), SAMPLING_RATE) \n",
"Audio('no_numbers.wav')"
]
},
@@ -317,7 +313,7 @@
"id": "PnKtJKbq2Fwz"
},
"source": [
"## Language detector"
"# Language detector"
]
},
{
@@ -328,7 +324,7 @@
"id": "F5cAmMbP2Fwz"
},
"source": [
"### Install Dependencies"
"## Install Dependencies"
]
},
{
@@ -343,24 +339,37 @@
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile\n",
"!pip install -q torchaudio\n",
"\n",
"SAMPLING_RATE = 16000\n",
"\n",
"import glob\n",
"import torch\n",
"torch.set_num_threads(1)\n",
"\n",
"from IPython.display import Audio\n",
"from pprint import pprint\n",
"\n",
"# download example\n",
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JfRKDZiRztFe"
},
"outputs": [],
"source": [
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
" \n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_lang_detector',\n",
" force_reload=True)\n",
" force_reload=True,\n",
" onnx=USE_ONNX)\n",
"\n",
"(get_language,\n",
" read_audio,\n",
" _) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'"
"get_language, read_audio = utils"
]
},
{
@@ -371,7 +380,7 @@
"id": "iC696eMX2Fwz"
},
"source": [
"### Full audio"
"## Full audio"
]
},
{
@@ -383,268 +392,10 @@
},
"outputs": [],
"source": [
"wav = read_audio(f'{files_dir}/en.wav')\n",
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"lang = get_language(wav, model)\n",
"print(lang)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "57avIBd6jsrz"
},
"source": [
"# ONNX Example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hEhnfORV2Fw0"
},
"source": [
"## VAD"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Cy7y-NAyALSe"
},
"source": [
"**TO BE DONE**"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "7QMvUvpg2Fw4"
},
"source": [
"## Number detector"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "tBPDkpHr2Fw4"
},
"source": [
"### Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "PdjGd56R2Fw5"
},
"outputs": [],
"source": [
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile onnxruntime\n",
"\n",
"import glob\n",
"import torch\n",
"import onnxruntime\n",
"from pprint import pprint\n",
"\n",
"from IPython.display import Audio\n",
"\n",
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_number_detector',\n",
" force_reload=True)\n",
"\n",
"(get_number_ts,\n",
" save_audio,\n",
" read_audio,\n",
" collect_chunks,\n",
" drop_chunks,\n",
" donwload_onnx_model) = utils\n",
"\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
"donwload_onnx_model('number_detector')\n",
"\n",
"def init_onnx_model(model_path: str):\n",
" return onnxruntime.InferenceSession(model_path)\n",
"\n",
"def validate_onnx(model, inputs):\n",
" with torch.no_grad():\n",
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
" outs = model.run(None, ort_inputs)\n",
" outs = [torch.Tensor(x) for x in outs]\n",
" return outs"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "I9QWSFZh2Fw5"
},
"source": [
"### Full Audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "_r6QZiwu2Fw5"
},
"outputs": [],
"source": [
"model = init_onnx_model('number_detector.onnx')\n",
"wav = read_audio(f'{files_dir}/en_num.wav')\n",
"\n",
"# get number timestamps from full audio file\n",
"number_timestamps = get_number_ts(wav, model, run_function=validate_onnx)\n",
"pprint(number_timestamps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "FN4aDwLV2Fw5"
},
"outputs": [],
"source": [
"sample_rate = 16000\n",
"# convert ms in timestamps to samples\n",
"for timestamp in number_timestamps:\n",
" timestamp['start'] = int(timestamp['start'] * sample_rate / 1000)\n",
" timestamp['end'] = int(timestamp['end'] * sample_rate / 1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "JnvS6WTK2Fw5"
},
"outputs": [],
"source": [
"# merge all number chunks to one audio\n",
"save_audio('only_numbers.wav',\n",
" collect_chunks(number_timestamps, wav), 16000) \n",
"Audio('only_numbers.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "yUxOcOFG2Fw6"
},
"outputs": [],
"source": [
"# drop all number chunks from audio\n",
"save_audio('no_numbers.wav',\n",
" drop_chunks(number_timestamps, wav), 16000) \n",
"Audio('no_numbers.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"id": "SR8Bgcd52Fw6"
},
"source": [
"## Language detector"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true,
"hidden": true,
"id": "PBnXPtKo2Fw6"
},
"source": [
"### Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "iNkDWJ3H2Fw6"
},
"outputs": [],
"source": [
"#@title Install and Import Dependencies\n",
"\n",
"# this assumes that you have a relevant version of PyTorch installed\n",
"!pip install -q torchaudio soundfile onnxruntime\n",
"\n",
"import glob\n",
"import torch\n",
"import onnxruntime\n",
"from pprint import pprint\n",
"\n",
"from IPython.display import Audio\n",
"\n",
"_, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_lang_detector',\n",
" force_reload=True)\n",
"\n",
"(get_language,\n",
" read_audio,\n",
" donwload_onnx_model) = utils\n",
"\n",
"donwload_onnx_model('number_detector')\n",
"files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'\n",
"\n",
"def init_onnx_model(model_path: str):\n",
" return onnxruntime.InferenceSession(model_path)\n",
"\n",
"def validate_onnx(model, inputs):\n",
" with torch.no_grad():\n",
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
" outs = model.run(None, ort_inputs)\n",
" outs = [torch.Tensor(x) for x in outs]\n",
" return outs"
]
},
{
"cell_type": "markdown",
"metadata": {
"hidden": true,
"id": "G8N8oP4q2Fw6"
},
"source": [
"### Full Audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true,
"id": "WHXnh9IV2Fw6"
},
"outputs": [],
"source": [
"model = init_onnx_model('number_detector.onnx')\n",
"wav = read_audio(f'{files_dir}/en.wav')\n",
"\n",
"lang = get_language(wav, model, run_function=validate_onnx)\n",
"print(lang)"
]
}
],
"metadata": {

View File

@@ -5,25 +5,68 @@ import torch.nn.functional as F
import warnings
languages = ['ru', 'en', 'de', 'es']
onnx_url_dict = {
'lang_classifier_95': 'https://models.silero.ai/vad_models/lang_classifier_95.onnx',
'number_detector':'https://models.silero.ai/vad_models/number_detector.onnx'
}
def donwload_onnx_model(model_name):
class OnnxWrapper():
if model_name not in ['lang_classifier_95', 'number_detector']:
raise ValueError
def __init__(self, path):
import numpy as np
global np
import onnxruntime
self.session = onnxruntime.InferenceSession(path)
self.session.intra_op_num_threads = 1
self.session.inter_op_num_threads = 1
torch.hub.download_url_to_file(onnx_url_dict[model_name], f'{model_name}.onnx')
self.reset_states()
def reset_states(self):
self._h = np.zeros((2, 1, 64)).astype('float32')
self._c = np.zeros((2, 1, 64)).astype('float32')
def __call__(self, x, sr: int):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() > 2:
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
if x.shape[0] > 1:
raise ValueError("Onnx model does not support batching")
if sr not in [16000]:
raise ValueError(f"Supported sample rates: {[16000]}")
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
ort_inputs = {'input': x.numpy(), 'h0': self._h, 'c0': self._c}
ort_outs = self.session.run(None, ort_inputs)
out, self._h, self._c = ort_outs
out = torch.tensor(out).squeeze(2)[:, 1] # make output type match JIT analog
return out
def validate(model,
inputs: torch.Tensor):
with torch.no_grad():
outs = model(inputs)
return outs
class Validator():
def __init__(self, url):
self.onnx = True if url.endswith('.onnx') else False
torch.hub.download_url_to_file(url, 'inf.model')
if self.onnx:
import onnxruntime
self.model = onnxruntime.InferenceSession('inf.model')
else:
self.model = init_jit_model(model_path='inf.model')
def __call__(self, inputs: torch.Tensor):
with torch.no_grad():
if self.onnx:
ort_inputs = {'input': inputs.cpu().numpy()}
outs = self.model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
else:
outs = self.model(inputs)
return outs
def read_audio(path: str,
@@ -215,10 +258,9 @@ def get_number_ts(wav: torch.Tensor,
model,
model_stride=8,
hop_length=160,
sample_rate=16000,
run_function=validate):
sample_rate=16000):
wav = torch.unsqueeze(wav, dim=0)
perframe_logits = run_function(model, wav)[0]
perframe_logits = model(wav)[0]
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
extended_preds = []
for i in perframe_preds:
@@ -245,10 +287,9 @@ def get_number_ts(wav: torch.Tensor,
def get_language(wav: torch.Tensor,
model,
run_function=validate):
model):
wav = torch.unsqueeze(wav, dim=0)
lang_logits = run_function(model, wav)[2]
lang_logits = model(wav)[2]
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
assert lang_pred < len(languages)
return languages[lang_pred]
@@ -258,10 +299,9 @@ def get_language_and_group(wav: torch.Tensor,
model,
lang_dict: dict,
lang_group_dict: dict,
top_n=1,
run_function=validate):
top_n=1):
wav = torch.unsqueeze(wav, dim=0)
lang_logits, lang_group_logits = run_function(model, wav)
lang_logits, lang_group_logits = model(wav)
softm = torch.softmax(lang_logits, dim=1).squeeze()
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
@@ -332,6 +372,13 @@ class VADIterator:
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
"""
if not torch.is_tensor(x):
try:
x = torch.Tensor(x)
except:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
self.current_sample += window_size_samples