From 2fb68b2fad29d722ab9b7383887789fead34efef Mon Sep 17 00:00:00 2001 From: snakers41 Date: Mon, 14 Dec 2020 14:47:50 +0000 Subject: [PATCH] Minor fixes --- utils.py | 55 +++++++++++++++++++++---------------------------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/utils.py b/utils.py index 0cda630..60da2dd 100644 --- a/utils.py +++ b/utils.py @@ -1,16 +1,16 @@ import torch -import tempfile import torchaudio -from typing import List -import torch.nn as nn -import torch.nn.functional as F -from collections import deque -import numpy as np -from itertools import repeat import onnxruntime +import numpy as np +from typing import List +from itertools import repeat +from collections import deque +import torch.nn.functional as F + torchaudio.set_audio_backend("soundfile") # switch backend + def read_audio(path: str, target_sr: int = 16000): @@ -29,32 +29,16 @@ def read_audio(path: str, assert sr == target_sr return wav.squeeze(0) -def save_audio(path: str, - tensor: torch.Tensor, - sr: int): - torchaudio.save(path, tensor, sr) - -#def init_jit_model(model_url: str, -# device: torch.device = torch.device('cpu')): -# torch.set_grad_enabled(False) -# with tempfile.NamedTemporaryFile('wb', suffix='.model') as f: -# torch.hub.download_url_to_file(model_url, -# f.name, -# progress=True) -# model = torch.jit.load(f.name, map_location=device) -# model.eval() -# return model - - -def init_jit_model(model_path, - device): +def init_jit_model(model_path: str, + device=torch.device('cpu')): torch.set_grad_enabled(False) model = torch.jit.load(model_path, map_location=device) model.eval() return model -def init_onnx_model(model_path): + +def init_onnx_model(model_path: str): return onnxruntime.InferenceSession(model_path) @@ -86,12 +70,12 @@ def get_speech_ts(wav, model, outs = torch.cat(outs, dim=0) - buffer = deque(maxlen=num_steps) # when max queue len is reach, first element is dropped + buffer = deque(maxlen=num_steps) # when max queue len is reached, first element is dropped triggered = False speeches = [] current_speech = {} - speech_probs = outs[:, 1] + speech_probs = outs[:, 1] # this is very misleading for i, predict in enumerate(speech_probs): # add name buffer.append(predict) if (np.mean(buffer) >= trig_sum) and not triggered: @@ -108,6 +92,7 @@ def get_speech_ts(wav, model, speeches.append(current_speech) return speeches + class VADiterator: def __init__(self, trig_sum=0.26, neg_trig_sum=0.01, @@ -139,7 +124,7 @@ class VADiterator: assert len(wav_chunk) <= self.num_samples self.num_frames += len(wav_chunk) if len(wav_chunk) < self.num_samples: - wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of the audio + wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of audio self.last = True stacked = torch.hstack([self.prev, wav_chunk]) @@ -150,7 +135,7 @@ class VADiterator: def state(self, model_out): current_speech = {} - speech_probs = model_out[:, 1] + speech_probs = model_out[:, 1] # this is very misleading for i, predict in enumerate(speech_probs): # add name self.buffer.append(predict) if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered: @@ -219,7 +204,8 @@ def stream_imitator(audios, audios_in_stream): values.append((out, wav_name)) yield values -def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, + +def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8): num_samples = 4000 VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps) @@ -227,9 +213,9 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)]) for chunk in wav_chunks: batch = VADiter.prepare_batch(chunk) - + outs = validate(model, batch) - vad_outs = outs[-2] + vad_outs = outs[-2] # this is very misleading states = [] state = VADiter.state(vad_outs) @@ -237,6 +223,7 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, states.append(state[0]) yield states + def validate(model, inputs): onnx = False if type(model) == onnxruntime.capi.session.InferenceSession: