Add type annotations, clean-up code

This commit is contained in:
snakers41
2020-12-15 12:30:47 +00:00
parent 557a32ed1b
commit 95111b9535

View File

@@ -1,17 +1,22 @@
import torch import torch
import torchaudio import torchaudio
import numpy as np import numpy as np
from typing import List
from itertools import repeat from itertools import repeat
from collections import deque from collections import deque
import torch.nn.functional as F import torch.nn.functional as F
torchaudio.set_audio_backend("soundfile") # switch backend torchaudio.set_audio_backend("soundfile") # switch backend
def validate(model, inputs):
def validate(model,
inputs: torch.Tensor):
with torch.no_grad(): with torch.no_grad():
outs = model(inputs) outs = model(inputs)
return outs return outs
def read_audio(path: str, def read_audio(path: str,
target_sr: int = 16000): target_sr: int = 16000):
@@ -44,9 +49,14 @@ def init_jit_model(model_path: str,
model.eval() model.eval()
return model return model
def get_speech_ts(wav, model,
trig_sum=0.25, neg_trig_sum=0.02, def get_speech_ts(wav: torch.Tensor,
num_steps=8, batch_size=200, run_function=validate): model,
trig_sum: float = 0.25,
neg_trig_sum: float = 0.02,
num_steps: int = 8,
batch_size: int = 200,
run_function=validate):
num_samples = 4000 num_samples = 4000
assert num_samples % num_steps == 0 assert num_samples % num_steps == 0
@@ -97,8 +107,9 @@ def get_speech_ts(wav, model,
class VADiterator: class VADiterator:
def __init__(self, def __init__(self,
trig_sum=0.26, neg_trig_sum=0.02, trig_sum: float = 0.26,
num_steps=8): neg_trig_sum: float = 0.02,
num_steps: int = 8):
self.num_samples = 4000 self.num_samples = 4000
self.num_steps = num_steps self.num_steps = num_steps
assert self.num_samples % num_steps == 0 assert self.num_samples % num_steps == 0
@@ -126,19 +137,20 @@ class VADiterator:
assert len(wav_chunk) <= self.num_samples assert len(wav_chunk) <= self.num_samples
self.num_frames += len(wav_chunk) self.num_frames += len(wav_chunk)
if len(wav_chunk) < self.num_samples: if len(wav_chunk) < self.num_samples:
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of audio wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # short chunk => eof audio
self.last = True self.last = True
stacked = torch.cat([self.prev, wav_chunk]) stacked = torch.cat([self.prev, wav_chunk])
self.prev = wav_chunk self.prev = wav_chunk
overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0) for i in range(self.step, self.num_samples+1, self.step)] # 500 step is good enough overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0)
for i in range(self.step, self.num_samples+1, self.step)] # 500 sample step is good enough
return torch.cat(overlap_chunks, dim=0) return torch.cat(overlap_chunks, dim=0)
def state(self, model_out): def state(self, model_out):
current_speech = {} current_speech = {}
speech_probs = model_out[:, 1] # this is very misleading speech_probs = model_out[:, 1] # this is very misleading
for i, predict in enumerate(speech_probs): # add name for i, predict in enumerate(speech_probs):
self.buffer.append(predict) self.buffer.append(predict)
if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered: if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered:
self.triggered = True self.triggered = True
@@ -153,10 +165,14 @@ class VADiterator:
return current_speech, self.current_name return current_speech, self.current_name
def state_generator(model, audios, def state_generator(model,
onnx=False, audios: List[str],
trig_sum=0.26, neg_trig_sum=0.02, onnx: bool = False,
num_steps=8, audios_in_stream=5, run_function=validate): trig_sum: float = 0.26,
neg_trig_sum: float = 0.02,
num_steps: int = 8,
audios_in_stream: int = 5,
run_function=validate):
VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)] VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)]
for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)): for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)):
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)] for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
@@ -173,7 +189,8 @@ def state_generator(model, audios,
yield states yield states
def stream_imitator(audios, audios_in_stream): def stream_imitator(audios: List[str],
audios_in_stream: int):
audio_iter = iter(audios) audio_iter = iter(audios)
iterators = [] iterators = []
num_samples = 4000 num_samples = 4000
@@ -207,8 +224,13 @@ def stream_imitator(audios, audios_in_stream):
yield values yield values
def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, def single_audio_stream(model,
neg_trig_sum=0.02, num_steps=8, run_function=validate): audio: str,
onnx: bool = False,
trig_sum: float = 0.26,
neg_trig_sum: float = 0.02,
num_steps: int = 8,
run_function=validate):
num_samples = 4000 num_samples = 4000
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps) VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
wav = read_audio(audio) wav = read_audio(audio)