Minor fixes

This commit is contained in:
snakers41
2020-12-14 14:47:50 +00:00
parent bb02e92ff9
commit 2fb68b2fad

View File

@@ -1,16 +1,16 @@
import torch import torch
import tempfile
import torchaudio import torchaudio
from typing import List
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
import numpy as np
from itertools import repeat
import onnxruntime import onnxruntime
import numpy as np
from typing import List
from itertools import repeat
from collections import deque
import torch.nn.functional as F
torchaudio.set_audio_backend("soundfile") # switch backend torchaudio.set_audio_backend("soundfile") # switch backend
def read_audio(path: str, def read_audio(path: str,
target_sr: int = 16000): target_sr: int = 16000):
@@ -29,32 +29,16 @@ def read_audio(path: str,
assert sr == target_sr assert sr == target_sr
return wav.squeeze(0) return wav.squeeze(0)
def save_audio(path: str,
tensor: torch.Tensor,
sr: int):
torchaudio.save(path, tensor, sr)
def init_jit_model(model_path: str,
#def init_jit_model(model_url: str, device=torch.device('cpu')):
# device: torch.device = torch.device('cpu')):
# torch.set_grad_enabled(False)
# with tempfile.NamedTemporaryFile('wb', suffix='.model') as f:
# torch.hub.download_url_to_file(model_url,
# f.name,
# progress=True)
# model = torch.jit.load(f.name, map_location=device)
# model.eval()
# return model
def init_jit_model(model_path,
device):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location=device) model = torch.jit.load(model_path, map_location=device)
model.eval() model.eval()
return model return model
def init_onnx_model(model_path):
def init_onnx_model(model_path: str):
return onnxruntime.InferenceSession(model_path) return onnxruntime.InferenceSession(model_path)
@@ -86,12 +70,12 @@ def get_speech_ts(wav, model,
outs = torch.cat(outs, dim=0) outs = torch.cat(outs, dim=0)
buffer = deque(maxlen=num_steps) # when max queue len is reach, first element is dropped buffer = deque(maxlen=num_steps) # when max queue len is reached, first element is dropped
triggered = False triggered = False
speeches = [] speeches = []
current_speech = {} current_speech = {}
speech_probs = outs[:, 1] speech_probs = outs[:, 1] # this is very misleading
for i, predict in enumerate(speech_probs): # add name for i, predict in enumerate(speech_probs): # add name
buffer.append(predict) buffer.append(predict)
if (np.mean(buffer) >= trig_sum) and not triggered: if (np.mean(buffer) >= trig_sum) and not triggered:
@@ -108,6 +92,7 @@ def get_speech_ts(wav, model,
speeches.append(current_speech) speeches.append(current_speech)
return speeches return speeches
class VADiterator: class VADiterator:
def __init__(self, def __init__(self,
trig_sum=0.26, neg_trig_sum=0.01, trig_sum=0.26, neg_trig_sum=0.01,
@@ -139,7 +124,7 @@ class VADiterator:
assert len(wav_chunk) <= self.num_samples assert len(wav_chunk) <= self.num_samples
self.num_frames += len(wav_chunk) self.num_frames += len(wav_chunk)
if len(wav_chunk) < self.num_samples: if len(wav_chunk) < self.num_samples:
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of the audio wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of audio
self.last = True self.last = True
stacked = torch.hstack([self.prev, wav_chunk]) stacked = torch.hstack([self.prev, wav_chunk])
@@ -150,7 +135,7 @@ class VADiterator:
def state(self, model_out): def state(self, model_out):
current_speech = {} current_speech = {}
speech_probs = model_out[:, 1] speech_probs = model_out[:, 1] # this is very misleading
for i, predict in enumerate(speech_probs): # add name for i, predict in enumerate(speech_probs): # add name
self.buffer.append(predict) self.buffer.append(predict)
if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered: if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered:
@@ -219,6 +204,7 @@ def stream_imitator(audios, audios_in_stream):
values.append((out, wav_name)) values.append((out, wav_name))
yield values yield values
def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
neg_trig_sum=0.01, num_steps=8): neg_trig_sum=0.01, num_steps=8):
num_samples = 4000 num_samples = 4000
@@ -229,7 +215,7 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
batch = VADiter.prepare_batch(chunk) batch = VADiter.prepare_batch(chunk)
outs = validate(model, batch) outs = validate(model, batch)
vad_outs = outs[-2] vad_outs = outs[-2] # this is very misleading
states = [] states = []
state = VADiter.state(vad_outs) state = VADiter.state(vad_outs)
@@ -237,6 +223,7 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
states.append(state[0]) states.append(state[0])
yield states yield states
def validate(model, inputs): def validate(model, inputs):
onnx = False onnx = False
if type(model) == onnxruntime.capi.session.InferenceSession: if type(model) == onnxruntime.capi.session.InferenceSession: