delete onnx from utils

This commit is contained in:
adamnsandle
2020-12-15 12:00:37 +00:00
parent 2c41efaa27
commit 557a32ed1b
4 changed files with 255 additions and 86 deletions

View File

@@ -1,15 +1,16 @@
import torch
import torchaudio
import onnxruntime
import numpy as np
from typing import List
from itertools import repeat
from collections import deque
import torch.nn.functional as F
torchaudio.set_audio_backend("soundfile") # switch backend
def validate(model, inputs):
with torch.no_grad():
outs = model(inputs)
return outs
def read_audio(path: str,
target_sr: int = 16000):
@@ -43,14 +44,9 @@ def init_jit_model(model_path: str,
model.eval()
return model
def init_onnx_model(model_path: str):
return onnxruntime.InferenceSession(model_path)
def get_speech_ts(wav, model,
trig_sum=0.25, neg_trig_sum=0.01,
num_steps=8, batch_size=200):
trig_sum=0.25, neg_trig_sum=0.02,
num_steps=8, batch_size=200, run_function=validate):
num_samples = 4000
assert num_samples % num_steps == 0
@@ -62,16 +58,16 @@ def get_speech_ts(wav, model,
chunk = wav[i: i+num_samples]
if len(chunk) < num_samples:
chunk = F.pad(chunk, (0, num_samples - len(chunk)))
to_concat.append(chunk)
to_concat.append(chunk.unsqueeze(0))
if len(to_concat) >= batch_size:
chunks = torch.Tensor(torch.vstack(to_concat))
out = validate(model, chunks)[-2]
chunks = torch.Tensor(torch.cat(to_concat, dim=0))
out = run_function(model, chunks)[-2]
outs.append(out)
to_concat = []
if to_concat:
chunks = torch.Tensor(torch.vstack(to_concat))
out = validate(model, chunks)[-2]
chunks = torch.Tensor(torch.cat(to_concat, dim=0))
out = run_function(model, chunks)[-2]
outs.append(out)
outs = torch.cat(outs, dim=0)
@@ -101,7 +97,7 @@ def get_speech_ts(wav, model,
class VADiterator:
def __init__(self,
trig_sum=0.26, neg_trig_sum=0.01,
trig_sum=0.26, neg_trig_sum=0.02,
num_steps=8):
self.num_samples = 4000
self.num_steps = num_steps
@@ -133,11 +129,11 @@ class VADiterator:
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of audio
self.last = True
stacked = torch.hstack([self.prev, wav_chunk])
stacked = torch.cat([self.prev, wav_chunk])
self.prev = wav_chunk
overlap_chunks = [stacked[i:i+self.num_samples] for i in range(self.step, self.num_samples+1, self.step)] # 500 step is good enough
return torch.vstack(overlap_chunks)
overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0) for i in range(self.step, self.num_samples+1, self.step)] # 500 step is good enough
return torch.cat(overlap_chunks, dim=0)
def state(self, model_out):
current_speech = {}
@@ -159,14 +155,14 @@ class VADiterator:
def state_generator(model, audios,
onnx=False,
trig_sum=0.26, neg_trig_sum=0.01,
num_steps=8, audios_in_stream=5):
trig_sum=0.26, neg_trig_sum=0.02,
num_steps=8, audios_in_stream=5, run_function=validate):
VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)]
for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)):
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
batch = torch.cat(for_batch)
outs = validate(model, batch)
outs = run_function(model, batch)
vad_outs = np.split(outs[-2].numpy(), audios_in_stream)
states = []
@@ -212,7 +208,7 @@ def stream_imitator(audios, audios_in_stream):
def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
neg_trig_sum=0.01, num_steps=8):
neg_trig_sum=0.02, num_steps=8, run_function=validate):
num_samples = 4000
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
wav = read_audio(audio)
@@ -220,7 +216,7 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
for chunk in wav_chunks:
batch = VADiter.prepare_batch(chunk)
outs = validate(model, batch)
outs = run_function(model, batch)
vad_outs = outs[-2] # this is very misleading
states = []
@@ -228,17 +224,3 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
if state[0]:
states.append(state[0])
yield states
def validate(model, inputs):
onnx = False
if type(model) == onnxruntime.capi.session.InferenceSession:
onnx = True
with torch.no_grad():
if onnx:
ort_inputs = {'input': inputs.cpu().numpy()}
outs = model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
else:
outs = model(inputs)
return outs