mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
Minor fixes
This commit is contained in:
55
utils.py
55
utils.py
@@ -1,16 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
import tempfile
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from typing import List
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from collections import deque
|
|
||||||
import numpy as np
|
|
||||||
from itertools import repeat
|
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
|
import numpy as np
|
||||||
|
from typing import List
|
||||||
|
from itertools import repeat
|
||||||
|
from collections import deque
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
torchaudio.set_audio_backend("soundfile") # switch backend
|
torchaudio.set_audio_backend("soundfile") # switch backend
|
||||||
|
|
||||||
|
|
||||||
def read_audio(path: str,
|
def read_audio(path: str,
|
||||||
target_sr: int = 16000):
|
target_sr: int = 16000):
|
||||||
|
|
||||||
@@ -29,32 +29,16 @@ def read_audio(path: str,
|
|||||||
assert sr == target_sr
|
assert sr == target_sr
|
||||||
return wav.squeeze(0)
|
return wav.squeeze(0)
|
||||||
|
|
||||||
def save_audio(path: str,
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
sr: int):
|
|
||||||
torchaudio.save(path, tensor, sr)
|
|
||||||
|
|
||||||
|
def init_jit_model(model_path: str,
|
||||||
#def init_jit_model(model_url: str,
|
device=torch.device('cpu')):
|
||||||
# device: torch.device = torch.device('cpu')):
|
|
||||||
# torch.set_grad_enabled(False)
|
|
||||||
# with tempfile.NamedTemporaryFile('wb', suffix='.model') as f:
|
|
||||||
# torch.hub.download_url_to_file(model_url,
|
|
||||||
# f.name,
|
|
||||||
# progress=True)
|
|
||||||
# model = torch.jit.load(f.name, map_location=device)
|
|
||||||
# model.eval()
|
|
||||||
# return model
|
|
||||||
|
|
||||||
|
|
||||||
def init_jit_model(model_path,
|
|
||||||
device):
|
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
model = torch.jit.load(model_path, map_location=device)
|
model = torch.jit.load(model_path, map_location=device)
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def init_onnx_model(model_path):
|
|
||||||
|
def init_onnx_model(model_path: str):
|
||||||
return onnxruntime.InferenceSession(model_path)
|
return onnxruntime.InferenceSession(model_path)
|
||||||
|
|
||||||
|
|
||||||
@@ -86,12 +70,12 @@ def get_speech_ts(wav, model,
|
|||||||
|
|
||||||
outs = torch.cat(outs, dim=0)
|
outs = torch.cat(outs, dim=0)
|
||||||
|
|
||||||
buffer = deque(maxlen=num_steps) # when max queue len is reach, first element is dropped
|
buffer = deque(maxlen=num_steps) # when max queue len is reached, first element is dropped
|
||||||
triggered = False
|
triggered = False
|
||||||
speeches = []
|
speeches = []
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
|
|
||||||
speech_probs = outs[:, 1]
|
speech_probs = outs[:, 1] # this is very misleading
|
||||||
for i, predict in enumerate(speech_probs): # add name
|
for i, predict in enumerate(speech_probs): # add name
|
||||||
buffer.append(predict)
|
buffer.append(predict)
|
||||||
if (np.mean(buffer) >= trig_sum) and not triggered:
|
if (np.mean(buffer) >= trig_sum) and not triggered:
|
||||||
@@ -108,6 +92,7 @@ def get_speech_ts(wav, model,
|
|||||||
speeches.append(current_speech)
|
speeches.append(current_speech)
|
||||||
return speeches
|
return speeches
|
||||||
|
|
||||||
|
|
||||||
class VADiterator:
|
class VADiterator:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
trig_sum=0.26, neg_trig_sum=0.01,
|
trig_sum=0.26, neg_trig_sum=0.01,
|
||||||
@@ -139,7 +124,7 @@ class VADiterator:
|
|||||||
assert len(wav_chunk) <= self.num_samples
|
assert len(wav_chunk) <= self.num_samples
|
||||||
self.num_frames += len(wav_chunk)
|
self.num_frames += len(wav_chunk)
|
||||||
if len(wav_chunk) < self.num_samples:
|
if len(wav_chunk) < self.num_samples:
|
||||||
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of the audio
|
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of audio
|
||||||
self.last = True
|
self.last = True
|
||||||
|
|
||||||
stacked = torch.hstack([self.prev, wav_chunk])
|
stacked = torch.hstack([self.prev, wav_chunk])
|
||||||
@@ -150,7 +135,7 @@ class VADiterator:
|
|||||||
|
|
||||||
def state(self, model_out):
|
def state(self, model_out):
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
speech_probs = model_out[:, 1]
|
speech_probs = model_out[:, 1] # this is very misleading
|
||||||
for i, predict in enumerate(speech_probs): # add name
|
for i, predict in enumerate(speech_probs): # add name
|
||||||
self.buffer.append(predict)
|
self.buffer.append(predict)
|
||||||
if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered:
|
if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered:
|
||||||
@@ -219,7 +204,8 @@ def stream_imitator(audios, audios_in_stream):
|
|||||||
values.append((out, wav_name))
|
values.append((out, wav_name))
|
||||||
yield values
|
yield values
|
||||||
|
|
||||||
def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
|
|
||||||
|
def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
|
||||||
neg_trig_sum=0.01, num_steps=8):
|
neg_trig_sum=0.01, num_steps=8):
|
||||||
num_samples = 4000
|
num_samples = 4000
|
||||||
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
|
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
|
||||||
@@ -227,9 +213,9 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
|
|||||||
wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)])
|
wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)])
|
||||||
for chunk in wav_chunks:
|
for chunk in wav_chunks:
|
||||||
batch = VADiter.prepare_batch(chunk)
|
batch = VADiter.prepare_batch(chunk)
|
||||||
|
|
||||||
outs = validate(model, batch)
|
outs = validate(model, batch)
|
||||||
vad_outs = outs[-2]
|
vad_outs = outs[-2] # this is very misleading
|
||||||
|
|
||||||
states = []
|
states = []
|
||||||
state = VADiter.state(vad_outs)
|
state = VADiter.state(vad_outs)
|
||||||
@@ -237,6 +223,7 @@ def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
|
|||||||
states.append(state[0])
|
states.append(state[0])
|
||||||
yield states
|
yield states
|
||||||
|
|
||||||
|
|
||||||
def validate(model, inputs):
|
def validate(model, inputs):
|
||||||
onnx = False
|
onnx = False
|
||||||
if type(model) == onnxruntime.capi.session.InferenceSession:
|
if type(model) == onnxruntime.capi.session.InferenceSession:
|
||||||
|
|||||||
Reference in New Issue
Block a user