mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
Add type annotations, clean-up code
This commit is contained in:
54
utils.py
54
utils.py
@@ -1,17 +1,22 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from typing import List
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
torchaudio.set_audio_backend("soundfile") # switch backend
|
torchaudio.set_audio_backend("soundfile") # switch backend
|
||||||
|
|
||||||
def validate(model, inputs):
|
|
||||||
|
def validate(model,
|
||||||
|
inputs: torch.Tensor):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outs = model(inputs)
|
outs = model(inputs)
|
||||||
return outs
|
return outs
|
||||||
|
|
||||||
|
|
||||||
def read_audio(path: str,
|
def read_audio(path: str,
|
||||||
target_sr: int = 16000):
|
target_sr: int = 16000):
|
||||||
|
|
||||||
@@ -44,9 +49,14 @@ def init_jit_model(model_path: str,
|
|||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_speech_ts(wav, model,
|
|
||||||
trig_sum=0.25, neg_trig_sum=0.02,
|
def get_speech_ts(wav: torch.Tensor,
|
||||||
num_steps=8, batch_size=200, run_function=validate):
|
model,
|
||||||
|
trig_sum: float = 0.25,
|
||||||
|
neg_trig_sum: float = 0.02,
|
||||||
|
num_steps: int = 8,
|
||||||
|
batch_size: int = 200,
|
||||||
|
run_function=validate):
|
||||||
|
|
||||||
num_samples = 4000
|
num_samples = 4000
|
||||||
assert num_samples % num_steps == 0
|
assert num_samples % num_steps == 0
|
||||||
@@ -97,8 +107,9 @@ def get_speech_ts(wav, model,
|
|||||||
|
|
||||||
class VADiterator:
|
class VADiterator:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
trig_sum=0.26, neg_trig_sum=0.02,
|
trig_sum: float = 0.26,
|
||||||
num_steps=8):
|
neg_trig_sum: float = 0.02,
|
||||||
|
num_steps: int = 8):
|
||||||
self.num_samples = 4000
|
self.num_samples = 4000
|
||||||
self.num_steps = num_steps
|
self.num_steps = num_steps
|
||||||
assert self.num_samples % num_steps == 0
|
assert self.num_samples % num_steps == 0
|
||||||
@@ -126,19 +137,20 @@ class VADiterator:
|
|||||||
assert len(wav_chunk) <= self.num_samples
|
assert len(wav_chunk) <= self.num_samples
|
||||||
self.num_frames += len(wav_chunk)
|
self.num_frames += len(wav_chunk)
|
||||||
if len(wav_chunk) < self.num_samples:
|
if len(wav_chunk) < self.num_samples:
|
||||||
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of audio
|
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # short chunk => eof audio
|
||||||
self.last = True
|
self.last = True
|
||||||
|
|
||||||
stacked = torch.cat([self.prev, wav_chunk])
|
stacked = torch.cat([self.prev, wav_chunk])
|
||||||
self.prev = wav_chunk
|
self.prev = wav_chunk
|
||||||
|
|
||||||
overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0) for i in range(self.step, self.num_samples+1, self.step)] # 500 step is good enough
|
overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0)
|
||||||
|
for i in range(self.step, self.num_samples+1, self.step)] # 500 sample step is good enough
|
||||||
return torch.cat(overlap_chunks, dim=0)
|
return torch.cat(overlap_chunks, dim=0)
|
||||||
|
|
||||||
def state(self, model_out):
|
def state(self, model_out):
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
speech_probs = model_out[:, 1] # this is very misleading
|
speech_probs = model_out[:, 1] # this is very misleading
|
||||||
for i, predict in enumerate(speech_probs): # add name
|
for i, predict in enumerate(speech_probs):
|
||||||
self.buffer.append(predict)
|
self.buffer.append(predict)
|
||||||
if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered:
|
if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered:
|
||||||
self.triggered = True
|
self.triggered = True
|
||||||
@@ -153,10 +165,14 @@ class VADiterator:
|
|||||||
return current_speech, self.current_name
|
return current_speech, self.current_name
|
||||||
|
|
||||||
|
|
||||||
def state_generator(model, audios,
|
def state_generator(model,
|
||||||
onnx=False,
|
audios: List[str],
|
||||||
trig_sum=0.26, neg_trig_sum=0.02,
|
onnx: bool = False,
|
||||||
num_steps=8, audios_in_stream=5, run_function=validate):
|
trig_sum: float = 0.26,
|
||||||
|
neg_trig_sum: float = 0.02,
|
||||||
|
num_steps: int = 8,
|
||||||
|
audios_in_stream: int = 5,
|
||||||
|
run_function=validate):
|
||||||
VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)]
|
VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)]
|
||||||
for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)):
|
for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)):
|
||||||
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
|
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
|
||||||
@@ -173,7 +189,8 @@ def state_generator(model, audios,
|
|||||||
yield states
|
yield states
|
||||||
|
|
||||||
|
|
||||||
def stream_imitator(audios, audios_in_stream):
|
def stream_imitator(audios: List[str],
|
||||||
|
audios_in_stream: int):
|
||||||
audio_iter = iter(audios)
|
audio_iter = iter(audios)
|
||||||
iterators = []
|
iterators = []
|
||||||
num_samples = 4000
|
num_samples = 4000
|
||||||
@@ -207,8 +224,13 @@ def stream_imitator(audios, audios_in_stream):
|
|||||||
yield values
|
yield values
|
||||||
|
|
||||||
|
|
||||||
def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
|
def single_audio_stream(model,
|
||||||
neg_trig_sum=0.02, num_steps=8, run_function=validate):
|
audio: str,
|
||||||
|
onnx: bool = False,
|
||||||
|
trig_sum: float = 0.26,
|
||||||
|
neg_trig_sum: float = 0.02,
|
||||||
|
num_steps: int = 8,
|
||||||
|
run_function=validate):
|
||||||
num_samples = 4000
|
num_samples = 4000
|
||||||
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
|
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
|
||||||
wav = read_audio(audio)
|
wav = read_audio(audio)
|
||||||
|
|||||||
Reference in New Issue
Block a user