Minor fixes

This commit is contained in:
snakers41
2020-12-14 14:47:50 +00:00
parent bb02e92ff9
commit 2fb68b2fad

View File

@@ -1,16 +1,16 @@
import torch
import tempfile
import torchaudio
from typing import List
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
import numpy as np
from itertools import repeat
import onnxruntime
import numpy as np
from typing import List
from itertools import repeat
from collections import deque
import torch.nn.functional as F
torchaudio.set_audio_backend("soundfile") # switch backend
def read_audio(path: str,
target_sr: int = 16000):
@@ -29,32 +29,16 @@ def read_audio(path: str,
assert sr == target_sr
return wav.squeeze(0)
def save_audio(path: str,
tensor: torch.Tensor,
sr: int):
torchaudio.save(path, tensor, sr)
#def init_jit_model(model_url: str,
# device: torch.device = torch.device('cpu')):
# torch.set_grad_enabled(False)
# with tempfile.NamedTemporaryFile('wb', suffix='.model') as f:
# torch.hub.download_url_to_file(model_url,
# f.name,
# progress=True)
# model = torch.jit.load(f.name, map_location=device)
# model.eval()
# return model
def init_jit_model(model_path,
device):
def init_jit_model(model_path: str,
device=torch.device('cpu')):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
def init_onnx_model(model_path):
def init_onnx_model(model_path: str):
return onnxruntime.InferenceSession(model_path)
@@ -86,12 +70,12 @@ def get_speech_ts(wav, model,
outs = torch.cat(outs, dim=0)
buffer = deque(maxlen=num_steps) # when max queue len is reach, first element is dropped
buffer = deque(maxlen=num_steps) # when max queue len is reached, first element is dropped
triggered = False
speeches = []
current_speech = {}
speech_probs = outs[:, 1]
speech_probs = outs[:, 1] # this is very misleading
for i, predict in enumerate(speech_probs): # add name
buffer.append(predict)
if (np.mean(buffer) >= trig_sum) and not triggered:
@@ -108,6 +92,7 @@ def get_speech_ts(wav, model,
speeches.append(current_speech)
return speeches
class VADiterator:
def __init__(self,
trig_sum=0.26, neg_trig_sum=0.01,
@@ -139,7 +124,7 @@ class VADiterator:
assert len(wav_chunk) <= self.num_samples
self.num_frames += len(wav_chunk)
if len(wav_chunk) < self.num_samples:
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of the audio
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of audio
self.last = True
stacked = torch.hstack([self.prev, wav_chunk])
@@ -150,7 +135,7 @@ class VADiterator:
def state(self, model_out):
current_speech = {}
speech_probs = model_out[:, 1]
speech_probs = model_out[:, 1] # this is very misleading
for i, predict in enumerate(speech_probs): # add name
self.buffer.append(predict)
if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered:
@@ -219,7 +204,8 @@ def stream_imitator(audios, audios_in_stream):
values.append((out, wav_name))
yield values
def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
neg_trig_sum=0.01, num_steps=8):
num_samples = 4000
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
@@ -227,9 +213,9 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)])
for chunk in wav_chunks:
batch = VADiter.prepare_batch(chunk)
outs = validate(model, batch)
vad_outs = outs[-2]
vad_outs = outs[-2] # this is very misleading
states = []
state = VADiter.state(vad_outs)
@@ -237,6 +223,7 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
states.append(state[0])
yield states
def validate(model, inputs):
onnx = False
if type(model) == onnxruntime.capi.session.InferenceSession: