delete onnx from utils

This commit is contained in:
adamnsandle
2020-12-15 12:00:37 +00:00
parent 2c41efaa27
commit 557a32ed1b
4 changed files with 255 additions and 86 deletions

BIN
files/model.jit Normal file

Binary file not shown.

View File

@@ -1,30 +1,39 @@
{ {
"cells": [ "cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Jit example"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:24.487521Z", "end_time": "2020-12-15T11:54:25.940761Z",
"start_time": "2020-12-14T13:43:23.780570Z" "start_time": "2020-12-15T11:54:25.933842Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# imports\n",
"import glob\n", "import glob\n",
"import torch\n", "import torch\n",
"from IPython.display import Audio\n", "from IPython.display import Audio\n",
"torch.set_num_threads(1)\n", "torch.set_num_threads(1)\n",
"from utils import (init_jit_model, get_speech_ts, \n", "\n",
"from utils import (init_jit_model, get_speech_ts,\n",
" save_audio, read_audio, \n", " save_audio, read_audio, \n",
" state_generator, single_audio_stream, init_onnx_model)" " state_generator, single_audio_stream)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Full audio example" "## Full audio"
] ]
}, },
{ {
@@ -32,8 +41,8 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:24.492506Z", "end_time": "2020-12-15T11:54:27.939388Z",
"start_time": "2020-12-14T13:43:24.489440Z" "start_time": "2020-12-15T11:54:27.936636Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -50,14 +59,13 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:24.760714Z", "end_time": "2020-12-15T11:54:28.415177Z",
"start_time": "2020-12-14T13:43:24.493992Z" "start_time": "2020-12-15T11:54:28.231677Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n", "model = init_jit_model('files/model.jit', 'cpu')"
"model = init_onnx_model('files/joint_VAD_just_RU.onnx')"
] ]
}, },
{ {
@@ -65,14 +73,13 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:24.793384Z", "end_time": "2020-12-15T11:54:28.560822Z",
"start_time": "2020-12-14T13:43:24.762311Z" "start_time": "2020-12-15T11:54:28.549811Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"Audio('files/test_audio_2.wav')\n", "wav = read_audio('files/en.wav')"
"wav = read_audio('files/test_audio_2.wav')"
] ]
}, },
{ {
@@ -80,13 +87,13 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:25.320324Z", "end_time": "2020-12-15T11:54:30.088721Z",
"start_time": "2020-12-14T13:43:24.808594Z" "start_time": "2020-12-15T11:54:29.019358Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"speech_timestamps = get_speech_ts(wav, model, num_steps=4) # kill extractor" "speech_timestamps = get_speech_ts(wav, model, num_steps=4) # get speech timestamps from full audio file"
] ]
}, },
{ {
@@ -94,8 +101,8 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:25.324901Z", "end_time": "2020-12-15T11:54:30.198484Z",
"start_time": "2020-12-14T13:43:25.321759Z" "start_time": "2020-12-15T11:54:30.188311Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -108,13 +115,13 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:25.344065Z", "end_time": "2020-12-15T11:54:30.816893Z",
"start_time": "2020-12-14T13:43:25.326162Z" "start_time": "2020-12-15T11:54:30.782667Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"save_audio('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000)\n", "save_audio('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000) # merge all speech chunks to one audio\n",
"Audio('only_speech.wav')" "Audio('only_speech.wav')"
] ]
}, },
@@ -122,7 +129,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Single stream example" "## Single audio stream"
] ]
}, },
{ {
@@ -130,15 +137,14 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:25.778585Z", "end_time": "2020-12-15T11:54:31.886189Z",
"start_time": "2020-12-14T13:43:25.496583Z" "start_time": "2020-12-15T11:54:31.572194Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n", "model = init_jit_model('files/model.jit', 'cpu')\n",
"#model = init_onnx_model('files/joint_VAD_just_RU.onnx')\n", "wav = 'files/en.wav'"
"audio = 'files/test_audio_6.wav'"
] ]
}, },
{ {
@@ -146,13 +152,13 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:29.402604Z", "end_time": "2020-12-15T11:54:35.624279Z",
"start_time": "2020-12-14T13:43:25.780037Z" "start_time": "2020-12-15T11:54:32.049532Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"for i in single_audio_stream(model, audio):\n", "for i in single_audio_stream(model, wav):\n",
" if i:\n", " if i:\n",
" print(i)" " print(i)"
] ]
@@ -161,7 +167,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Multiple stream example" "## Multiple audio stream"
] ]
}, },
{ {
@@ -169,14 +175,13 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:29.674262Z", "end_time": "2020-12-15T11:40:13.406225Z",
"start_time": "2020-12-14T13:43:29.403972Z" "start_time": "2020-12-15T11:40:13.206354Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"model = init_jit_model('files/joint_VAD_just_RU_jit_cut_q.pth.tar', 'cpu') # from yml file\n", "model = init_jit_model('files/model.jit', 'cpu')"
"model = init_onnx_model('files/joint_VAD_just_RU.onnx')"
] ]
}, },
{ {
@@ -184,14 +189,14 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:29.678449Z", "end_time": "2020-12-15T11:41:08.470917Z",
"start_time": "2020-12-14T13:43:29.675519Z" "start_time": "2020-12-15T11:41:08.467369Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"audios_for_stream = glob.glob('files/test*.wav')\n", "audios_for_stream = glob.glob('files/*.wav')\n",
"len(audios_for_stream)" "len(audios_for_stream) # total 4 audios"
] ]
}, },
{ {
@@ -199,29 +204,211 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:43:40.236387Z", "end_time": "2020-12-15T11:41:25.685356Z",
"start_time": "2020-12-14T13:43:29.679274Z" "start_time": "2020-12-15T11:41:16.222672Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"for i in state_generator(model, audios_for_stream, audios_in_stream=2):\n", "for i in state_generator(model, audios_for_stream, audios_in_stream=2): # 2 audio stream\n",
" if i:\n", " if i:\n",
" print(i)" " print(i)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Onnx example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-12-14T13:46:49.812052Z", "end_time": "2020-12-15T11:55:45.597504Z",
"start_time": "2020-12-14T13:46:49.586637Z" "start_time": "2020-12-15T11:55:45.582356Z"
} }
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"!cp ../silero-models-research/model_saves/joint_VAD_just_RU_jit_cut_q.pth.tar files/" "# imports\n",
"import glob\n",
"import torch\n",
"from IPython.display import Audio\n",
"torch.set_num_threads(1)\n",
"import onnxruntime\n",
"\n",
"from utils import (get_speech_ts, save_audio, read_audio, \n",
" state_generator, single_audio_stream)\n",
"\n",
"def init_onnx_model(model_path: str):\n",
" return onnxruntime.InferenceSession(model_path)\n",
"\n",
"def validate_onnx(model, inputs):\n",
" with torch.no_grad():\n",
" ort_inputs = {'input': inputs.cpu().numpy()}\n",
" outs = model.run(None, ort_inputs)\n",
" outs = [torch.Tensor(x) for x in outs]\n",
" return outs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Full audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-15T11:55:56.874376Z",
"start_time": "2020-12-15T11:55:56.782230Z"
}
},
"outputs": [],
"source": [
"model = init_onnx_model('files/model.onnx')\n",
"wav = read_audio('files/en.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-15T11:56:12.159463Z",
"start_time": "2020-12-15T11:56:11.446991Z"
}
},
"outputs": [],
"source": [
"speech_timestamps = get_speech_ts(wav, model, num_steps=4, run_function=validate_onnx) # get speech timestamps from full audio file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-15T11:56:20.488863Z",
"start_time": "2020-12-15T11:56:20.485485Z"
}
},
"outputs": [],
"source": [
"speech_timestamps"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-15T11:56:27.908128Z",
"start_time": "2020-12-15T11:56:27.870978Z"
}
},
"outputs": [],
"source": [
"save_audio('only_speech.wav', collect_speeches(speech_timestamps, wav), 16000) # merge all speech chunks to one audio\n",
"Audio('only_speech.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Single audio stream"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-15T11:58:09.012892Z",
"start_time": "2020-12-15T11:58:08.940907Z"
}
},
"outputs": [],
"source": [
"model = init_onnx_model('files/model.onnx')\n",
"wav = 'files/en.wav'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-15T11:58:11.562186Z",
"start_time": "2020-12-15T11:58:09.949825Z"
}
},
"outputs": [],
"source": [
"for i in single_audio_stream(model, wav, run_function=validate_onnx):\n",
" if i:\n",
" print(i)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multiple audio stream"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = init_onnx_model('files/model.onnx')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-15T11:59:09.381687Z",
"start_time": "2020-12-15T11:59:09.378552Z"
}
},
"outputs": [],
"source": [
"audios_for_stream = glob.glob('files/*.wav')\n",
"len(audios_for_stream) # total 4 audios"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2020-12-15T11:59:27.712905Z",
"start_time": "2020-12-15T11:59:21.608435Z"
}
},
"outputs": [],
"source": [
"for i in state_generator(model, audios_for_stream, audios_in_stream=2, run_function=validate_onnx): # 2 audio stream\n",
" if i:\n",
" print(i)"
] ]
}, },
{ {

View File

@@ -1,15 +1,16 @@
import torch import torch
import torchaudio import torchaudio
import onnxruntime
import numpy as np import numpy as np
from typing import List
from itertools import repeat from itertools import repeat
from collections import deque from collections import deque
import torch.nn.functional as F import torch.nn.functional as F
torchaudio.set_audio_backend("soundfile") # switch backend torchaudio.set_audio_backend("soundfile") # switch backend
def validate(model, inputs):
with torch.no_grad():
outs = model(inputs)
return outs
def read_audio(path: str, def read_audio(path: str,
target_sr: int = 16000): target_sr: int = 16000):
@@ -43,14 +44,9 @@ def init_jit_model(model_path: str,
model.eval() model.eval()
return model return model
def init_onnx_model(model_path: str):
return onnxruntime.InferenceSession(model_path)
def get_speech_ts(wav, model, def get_speech_ts(wav, model,
trig_sum=0.25, neg_trig_sum=0.01, trig_sum=0.25, neg_trig_sum=0.02,
num_steps=8, batch_size=200): num_steps=8, batch_size=200, run_function=validate):
num_samples = 4000 num_samples = 4000
assert num_samples % num_steps == 0 assert num_samples % num_steps == 0
@@ -62,16 +58,16 @@ def get_speech_ts(wav, model,
chunk = wav[i: i+num_samples] chunk = wav[i: i+num_samples]
if len(chunk) < num_samples: if len(chunk) < num_samples:
chunk = F.pad(chunk, (0, num_samples - len(chunk))) chunk = F.pad(chunk, (0, num_samples - len(chunk)))
to_concat.append(chunk) to_concat.append(chunk.unsqueeze(0))
if len(to_concat) >= batch_size: if len(to_concat) >= batch_size:
chunks = torch.Tensor(torch.vstack(to_concat)) chunks = torch.Tensor(torch.cat(to_concat, dim=0))
out = validate(model, chunks)[-2] out = run_function(model, chunks)[-2]
outs.append(out) outs.append(out)
to_concat = [] to_concat = []
if to_concat: if to_concat:
chunks = torch.Tensor(torch.vstack(to_concat)) chunks = torch.Tensor(torch.cat(to_concat, dim=0))
out = validate(model, chunks)[-2] out = run_function(model, chunks)[-2]
outs.append(out) outs.append(out)
outs = torch.cat(outs, dim=0) outs = torch.cat(outs, dim=0)
@@ -101,7 +97,7 @@ def get_speech_ts(wav, model,
class VADiterator: class VADiterator:
def __init__(self, def __init__(self,
trig_sum=0.26, neg_trig_sum=0.01, trig_sum=0.26, neg_trig_sum=0.02,
num_steps=8): num_steps=8):
self.num_samples = 4000 self.num_samples = 4000
self.num_steps = num_steps self.num_steps = num_steps
@@ -133,11 +129,11 @@ class VADiterator:
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of audio wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of audio
self.last = True self.last = True
stacked = torch.hstack([self.prev, wav_chunk]) stacked = torch.cat([self.prev, wav_chunk])
self.prev = wav_chunk self.prev = wav_chunk
overlap_chunks = [stacked[i:i+self.num_samples] for i in range(self.step, self.num_samples+1, self.step)] # 500 step is good enough overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0) for i in range(self.step, self.num_samples+1, self.step)] # 500 step is good enough
return torch.vstack(overlap_chunks) return torch.cat(overlap_chunks, dim=0)
def state(self, model_out): def state(self, model_out):
current_speech = {} current_speech = {}
@@ -159,14 +155,14 @@ class VADiterator:
def state_generator(model, audios, def state_generator(model, audios,
onnx=False, onnx=False,
trig_sum=0.26, neg_trig_sum=0.01, trig_sum=0.26, neg_trig_sum=0.02,
num_steps=8, audios_in_stream=5): num_steps=8, audios_in_stream=5, run_function=validate):
VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)] VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)]
for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)): for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)):
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)] for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
batch = torch.cat(for_batch) batch = torch.cat(for_batch)
outs = validate(model, batch) outs = run_function(model, batch)
vad_outs = np.split(outs[-2].numpy(), audios_in_stream) vad_outs = np.split(outs[-2].numpy(), audios_in_stream)
states = [] states = []
@@ -212,7 +208,7 @@ def stream_imitator(audios, audios_in_stream):
def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
neg_trig_sum=0.01, num_steps=8): neg_trig_sum=0.02, num_steps=8, run_function=validate):
num_samples = 4000 num_samples = 4000
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps) VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
wav = read_audio(audio) wav = read_audio(audio)
@@ -220,7 +216,7 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
for chunk in wav_chunks: for chunk in wav_chunks:
batch = VADiter.prepare_batch(chunk) batch = VADiter.prepare_batch(chunk)
outs = validate(model, batch) outs = run_function(model, batch)
vad_outs = outs[-2] # this is very misleading vad_outs = outs[-2] # this is very misleading
states = [] states = []
@@ -228,17 +224,3 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
if state[0]: if state[0]:
states.append(state[0]) states.append(state[0])
yield states yield states
def validate(model, inputs):
onnx = False
if type(model) == onnxruntime.capi.session.InferenceSession:
onnx = True
with torch.no_grad():
if onnx:
ort_inputs = {'input': inputs.cpu().numpy()}
outs = model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
else:
outs = model(inputs)
return outs