mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
fx bugs, del extractor
This commit is contained in:
BIN
files/joint_VAD_just_RU.onnx
Normal file
BIN
files/joint_VAD_just_RU.onnx
Normal file
Binary file not shown.
Binary file not shown.
350
silero-vad.ipynb
350
silero-vad.ipynb
File diff suppressed because one or more lines are too long
109
utils.py
109
utils.py
@@ -7,23 +7,10 @@ import torch.nn.functional as F
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
|
import onnxruntime
|
||||||
|
|
||||||
torchaudio.set_audio_backend("soundfile") # switch backend
|
torchaudio.set_audio_backend("soundfile") # switch backend
|
||||||
|
|
||||||
|
|
||||||
def read_batch(audio_paths: List[str]):
|
|
||||||
return [read_audio(audio_path)
|
|
||||||
for audio_path
|
|
||||||
in audio_paths]
|
|
||||||
|
|
||||||
|
|
||||||
def split_into_batches(lst: List[str],
|
|
||||||
batch_size: int = 10):
|
|
||||||
return [lst[i:i + batch_size]
|
|
||||||
for i in
|
|
||||||
range(0, len(lst), batch_size)]
|
|
||||||
|
|
||||||
|
|
||||||
def read_audio(path: str,
|
def read_audio(path: str,
|
||||||
target_sr: int = 16000):
|
target_sr: int = 16000):
|
||||||
|
|
||||||
@@ -42,15 +29,10 @@ def read_audio(path: str,
|
|||||||
assert sr == target_sr
|
assert sr == target_sr
|
||||||
return wav.squeeze(0)
|
return wav.squeeze(0)
|
||||||
|
|
||||||
|
def save_audio(path: str,
|
||||||
def prepare_model_input(batch: List[torch.Tensor],
|
tensor: torch.Tensor,
|
||||||
device=torch.device('cpu')):
|
sr: int):
|
||||||
max_seqlength = max(max([len(_) for _ in batch]), 12800)
|
torchaudio.save(path, tensor, sr)
|
||||||
inputs = torch.zeros(len(batch), max_seqlength)
|
|
||||||
for i, wav in enumerate(batch):
|
|
||||||
inputs[i, :len(wav)].copy_(wav)
|
|
||||||
inputs = inputs.to(device)
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
|
|
||||||
#def init_jit_model(model_url: str,
|
#def init_jit_model(model_url: str,
|
||||||
@@ -72,8 +54,11 @@ def init_jit_model(model_path,
|
|||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def init_onnx_model(model_path):
|
||||||
|
return onnxruntime.InferenceSession(model_path)
|
||||||
|
|
||||||
def get_speech_ts(wav, model, extractor,
|
|
||||||
|
def get_speech_ts(wav, model,
|
||||||
trig_sum=0.25, neg_trig_sum=0.01,
|
trig_sum=0.25, neg_trig_sum=0.01,
|
||||||
num_steps=8, batch_size=200):
|
num_steps=8, batch_size=200):
|
||||||
|
|
||||||
@@ -90,15 +75,13 @@ def get_speech_ts(wav, model, extractor,
|
|||||||
to_concat.append(chunk)
|
to_concat.append(chunk)
|
||||||
if len(to_concat) >= batch_size:
|
if len(to_concat) >= batch_size:
|
||||||
chunks = torch.Tensor(torch.vstack(to_concat))
|
chunks = torch.Tensor(torch.vstack(to_concat))
|
||||||
with torch.no_grad():
|
out = validate(model, chunks)[-2]
|
||||||
out = model(extractor(chunks))[-2]
|
|
||||||
outs.append(out)
|
outs.append(out)
|
||||||
to_concat = []
|
to_concat = []
|
||||||
|
|
||||||
if to_concat:
|
if to_concat:
|
||||||
chunks = torch.Tensor(torch.vstack(to_concat))
|
chunks = torch.Tensor(torch.vstack(to_concat))
|
||||||
with torch.no_grad():
|
out = validate(model, chunks)[-2]
|
||||||
out = model(extractor(chunks))[-2]
|
|
||||||
outs.append(out)
|
outs.append(out)
|
||||||
|
|
||||||
outs = torch.cat(outs, dim=0)
|
outs = torch.cat(outs, dim=0)
|
||||||
@@ -125,37 +108,6 @@ def get_speech_ts(wav, model, extractor,
|
|||||||
speeches.append(current_speech)
|
speeches.append(current_speech)
|
||||||
return speeches
|
return speeches
|
||||||
|
|
||||||
|
|
||||||
class STFTExtractor(nn.Module):
|
|
||||||
def __init__(self, sr=16000, win_size=0.02, mode='mag'):
|
|
||||||
super(STFTExtractor, self).__init__()
|
|
||||||
self.sr = sr
|
|
||||||
self.n_fft = int(sr * (win_size + 1e-8))
|
|
||||||
self.win_length = self.n_fft
|
|
||||||
self.hop_length = self.win_length // 2
|
|
||||||
self.mode = 'mag' if mode == '' else mode
|
|
||||||
|
|
||||||
def forward(self, wav):
|
|
||||||
# center==True because other frame-level features are centered by default in torch/librosa and we can't change this.
|
|
||||||
stft_sample = torch.stft(wav,
|
|
||||||
n_fft=self.n_fft,
|
|
||||||
win_length=self.win_length,
|
|
||||||
hop_length=self.hop_length,
|
|
||||||
center=True)
|
|
||||||
mag, phase = torchaudio.functional.magphase(stft_sample)
|
|
||||||
|
|
||||||
# It seems it is not a "mag", it is "power" (exp == 1).
|
|
||||||
# Also there is "energy" (exp == 2).
|
|
||||||
if self.mode == 'mag':
|
|
||||||
return mag
|
|
||||||
if self.mode == 'phase':
|
|
||||||
return phase
|
|
||||||
elif self.mode == 'magphase':
|
|
||||||
return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class VADiterator:
|
class VADiterator:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
trig_sum=0.26, neg_trig_sum=0.01,
|
trig_sum=0.26, neg_trig_sum=0.01,
|
||||||
@@ -214,7 +166,7 @@ class VADiterator:
|
|||||||
return current_speech, self.current_name
|
return current_speech, self.current_name
|
||||||
|
|
||||||
|
|
||||||
def state_generator(model, audios, extractor,
|
def state_generator(model, audios,
|
||||||
onnx=False,
|
onnx=False,
|
||||||
trig_sum=0.26, neg_trig_sum=0.01,
|
trig_sum=0.26, neg_trig_sum=0.01,
|
||||||
num_steps=8, audios_in_stream=5):
|
num_steps=8, audios_in_stream=5):
|
||||||
@@ -223,14 +175,8 @@ def state_generator(model, audios, extractor,
|
|||||||
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
|
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
|
||||||
batch = torch.cat(for_batch)
|
batch = torch.cat(for_batch)
|
||||||
|
|
||||||
with torch.no_grad():
|
outs = validate(model, batch)
|
||||||
if onnx:
|
vad_outs = np.split(outs[-2].numpy(), audios_in_stream)
|
||||||
ort_inputs = {'input': to_numpy(extractor(batch))}
|
|
||||||
ort_outs = model.run(None, ort_inputs)
|
|
||||||
vad_outs = np.split(ort_outs[-2], audios_in_stream)
|
|
||||||
else:
|
|
||||||
outs = model(extractor(batch))
|
|
||||||
vad_outs = np.split(outs[-2].numpy(), audios_in_stream)
|
|
||||||
|
|
||||||
states = []
|
states = []
|
||||||
for x, y in zip(VADiters, vad_outs):
|
for x, y in zip(VADiters, vad_outs):
|
||||||
@@ -273,7 +219,7 @@ def stream_imitator(audios, audios_in_stream):
|
|||||||
values.append((out, wav_name))
|
values.append((out, wav_name))
|
||||||
yield values
|
yield values
|
||||||
|
|
||||||
def single_audio_stream(model, audio, extractor, onnx=False, trig_sum=0.26,
|
def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
|
||||||
neg_trig_sum=0.01, num_steps=8):
|
neg_trig_sum=0.01, num_steps=8):
|
||||||
num_samples = 4000
|
num_samples = 4000
|
||||||
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
|
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
|
||||||
@@ -281,18 +227,25 @@ def single_audio_stream(model, audio, extractor, onnx=False, trig_sum=0.26,
|
|||||||
wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)])
|
wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)])
|
||||||
for chunk in wav_chunks:
|
for chunk in wav_chunks:
|
||||||
batch = VADiter.prepare_batch(chunk)
|
batch = VADiter.prepare_batch(chunk)
|
||||||
|
|
||||||
with torch.no_grad():
|
outs = validate(model, batch)
|
||||||
if onnx:
|
vad_outs = outs[-2]
|
||||||
ort_inputs = {'input': to_numpy(extractor(batch))}
|
|
||||||
ort_outs = model.run(None, ort_inputs)
|
|
||||||
vad_outs = ort_outs[-2]
|
|
||||||
else:
|
|
||||||
outs = model(extractor(batch))
|
|
||||||
vad_outs = outs[-2]
|
|
||||||
|
|
||||||
states = []
|
states = []
|
||||||
state = VADiter.state(vad_outs)
|
state = VADiter.state(vad_outs)
|
||||||
if state[0]:
|
if state[0]:
|
||||||
states.append(state[0])
|
states.append(state[0])
|
||||||
yield states
|
yield states
|
||||||
|
|
||||||
|
def validate(model, inputs):
|
||||||
|
onnx = False
|
||||||
|
if type(model) == onnxruntime.capi.session.InferenceSession:
|
||||||
|
onnx = True
|
||||||
|
with torch.no_grad():
|
||||||
|
if onnx:
|
||||||
|
ort_inputs = {'input': inputs.cpu().numpy()}
|
||||||
|
outs = model.run(None, ort_inputs)
|
||||||
|
outs = [torch.Tensor(x) for x in outs]
|
||||||
|
else:
|
||||||
|
outs = model(inputs)
|
||||||
|
return outs
|
||||||
|
|||||||
Reference in New Issue
Block a user