diff --git a/utils.py b/utils.py index a5c00bb..a9d1f38 100644 --- a/utils.py +++ b/utils.py @@ -1,17 +1,22 @@ import torch import torchaudio import numpy as np +from typing import List from itertools import repeat from collections import deque import torch.nn.functional as F + torchaudio.set_audio_backend("soundfile") # switch backend -def validate(model, inputs): + +def validate(model, + inputs: torch.Tensor): with torch.no_grad(): outs = model(inputs) return outs + def read_audio(path: str, target_sr: int = 16000): @@ -44,9 +49,14 @@ def init_jit_model(model_path: str, model.eval() return model -def get_speech_ts(wav, model, - trig_sum=0.25, neg_trig_sum=0.02, - num_steps=8, batch_size=200, run_function=validate): + +def get_speech_ts(wav: torch.Tensor, + model, + trig_sum: float = 0.25, + neg_trig_sum: float = 0.02, + num_steps: int = 8, + batch_size: int = 200, + run_function=validate): num_samples = 4000 assert num_samples % num_steps == 0 @@ -97,8 +107,9 @@ def get_speech_ts(wav, model, class VADiterator: def __init__(self, - trig_sum=0.26, neg_trig_sum=0.02, - num_steps=8): + trig_sum: float = 0.26, + neg_trig_sum: float = 0.02, + num_steps: int = 8): self.num_samples = 4000 self.num_steps = num_steps assert self.num_samples % num_steps == 0 @@ -126,19 +137,20 @@ class VADiterator: assert len(wav_chunk) <= self.num_samples self.num_frames += len(wav_chunk) if len(wav_chunk) < self.num_samples: - wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of audio + wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # short chunk => eof audio self.last = True stacked = torch.cat([self.prev, wav_chunk]) self.prev = wav_chunk - overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0) for i in range(self.step, self.num_samples+1, self.step)] # 500 step is good enough + overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0) + for i in range(self.step, self.num_samples+1, self.step)] # 500 sample step is good enough return torch.cat(overlap_chunks, dim=0) def state(self, model_out): current_speech = {} speech_probs = model_out[:, 1] # this is very misleading - for i, predict in enumerate(speech_probs): # add name + for i, predict in enumerate(speech_probs): self.buffer.append(predict) if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered: self.triggered = True @@ -153,10 +165,14 @@ class VADiterator: return current_speech, self.current_name -def state_generator(model, audios, - onnx=False, - trig_sum=0.26, neg_trig_sum=0.02, - num_steps=8, audios_in_stream=5, run_function=validate): +def state_generator(model, + audios: List[str], + onnx: bool = False, + trig_sum: float = 0.26, + neg_trig_sum: float = 0.02, + num_steps: int = 8, + audios_in_stream: int = 5, + run_function=validate): VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)] for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)): for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)] @@ -173,7 +189,8 @@ def state_generator(model, audios, yield states -def stream_imitator(audios, audios_in_stream): +def stream_imitator(audios: List[str], + audios_in_stream: int): audio_iter = iter(audios) iterators = [] num_samples = 4000 @@ -207,8 +224,13 @@ def stream_imitator(audios, audios_in_stream): yield values -def single_audio_stream(model, audio, onnx=False, trig_sum=0.26, - neg_trig_sum=0.02, num_steps=8, run_function=validate): +def single_audio_stream(model, + audio: str, + onnx: bool = False, + trig_sum: float = 0.26, + neg_trig_sum: float = 0.02, + num_steps: int = 8, + run_function=validate): num_samples = 4000 VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps) wav = read_audio(audio)