mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
add single stream example
This commit is contained in:
353
silero-vad.ipynb
353
silero-vad.ipynb
File diff suppressed because one or more lines are too long
77
utils.py
77
utils.py
@@ -77,15 +77,16 @@ def get_speech_ts(wav, model, extractor,
|
|||||||
trig_sum=0.25, neg_trig_sum=0.01,
|
trig_sum=0.25, neg_trig_sum=0.01,
|
||||||
num_steps=8, batch_size=200):
|
num_steps=8, batch_size=200):
|
||||||
|
|
||||||
assert 4000 % num_steps == 0
|
num_samples = 4000
|
||||||
step = int(4000 / num_steps) # stride / hop
|
assert num_samples % num_steps == 0
|
||||||
|
step = int(num_samples / num_steps) # stride / hop
|
||||||
outs = []
|
outs = []
|
||||||
to_concat = []
|
to_concat = []
|
||||||
|
|
||||||
for i in range(0, len(wav), step):
|
for i in range(0, len(wav), step):
|
||||||
chunk = wav[i: i+4000]
|
chunk = wav[i: i+num_samples]
|
||||||
if len(chunk) < 4000:
|
if len(chunk) < num_samples:
|
||||||
chunk = F.pad(chunk, (0, 4000 - len(chunk)))
|
chunk = F.pad(chunk, (0, num_samples - len(chunk)))
|
||||||
to_concat.append(chunk)
|
to_concat.append(chunk)
|
||||||
if len(to_concat) >= batch_size:
|
if len(to_concat) >= batch_size:
|
||||||
chunks = torch.Tensor(torch.vstack(to_concat))
|
chunks = torch.Tensor(torch.vstack(to_concat))
|
||||||
@@ -107,7 +108,8 @@ def get_speech_ts(wav, model, extractor,
|
|||||||
speeches = []
|
speeches = []
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
|
|
||||||
for i, predict in enumerate(outs[:, 1]): # add name
|
speech_probs = outs[:, 1]
|
||||||
|
for i, predict in enumerate(speech_probs): # add name
|
||||||
buffer.append(predict)
|
buffer.append(predict)
|
||||||
if (np.mean(buffer) >= trig_sum) and not triggered:
|
if (np.mean(buffer) >= trig_sum) and not triggered:
|
||||||
triggered = True
|
triggered = True
|
||||||
@@ -158,44 +160,46 @@ class VADiterator:
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
trig_sum=0.26, neg_trig_sum=0.01,
|
trig_sum=0.26, neg_trig_sum=0.01,
|
||||||
num_steps=8):
|
num_steps=8):
|
||||||
|
self.num_samples = 4000
|
||||||
self.num_steps = num_steps
|
self.num_steps = num_steps
|
||||||
assert 4000 % num_steps == 0
|
assert self.num_samples % num_steps == 0
|
||||||
self.step = int(4000 / num_steps)
|
self.step = int(self.num_samples / num_steps)
|
||||||
self.prev = torch.zeros(4000)
|
self.prev = torch.zeros(self.num_samples)
|
||||||
self.last = False
|
self.last = False
|
||||||
self.triggered = False
|
self.triggered = False
|
||||||
self.buffer = deque(maxlen=8)
|
self.buffer = deque(maxlen=num_steps)
|
||||||
self.num_frames = 0
|
self.num_frames = 0
|
||||||
self.trig_sum = trig_sum
|
self.trig_sum = trig_sum
|
||||||
self.neg_trig_sum = neg_trig_sum
|
self.neg_trig_sum = neg_trig_sum
|
||||||
self.current_name = ''
|
self.current_name = ''
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
self.prev = torch.zeros(4000)
|
self.prev = torch.zeros(self.num_samples)
|
||||||
self.last = False
|
self.last = False
|
||||||
self.triggered = False
|
self.triggered = False
|
||||||
self.buffer = deque(maxlen=8)
|
self.buffer = deque(maxlen=self.num_steps)
|
||||||
self.num_frames = 0
|
self.num_frames = 0
|
||||||
|
|
||||||
def prepare_batch(self, wav_chunk, name=None):
|
def prepare_batch(self, wav_chunk, name=None):
|
||||||
if (name is not None) and (name != self.current_name):
|
if (name is not None) and (name != self.current_name):
|
||||||
self.refresh()
|
self.refresh()
|
||||||
self.current_name = name
|
self.current_name = name
|
||||||
assert len(wav_chunk) <= 4000
|
assert len(wav_chunk) <= self.num_samples
|
||||||
self.num_frames += len(wav_chunk)
|
self.num_frames += len(wav_chunk)
|
||||||
if len(wav_chunk) < 4000:
|
if len(wav_chunk) < self.num_samples:
|
||||||
wav_chunk = F.pad(wav_chunk, (0, 4000 - len(wav_chunk))) # assume that short chunk means end of the audio
|
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # assume that short chunk means end of the audio
|
||||||
self.last = True
|
self.last = True
|
||||||
|
|
||||||
stacked = torch.hstack([self.prev, wav_chunk])
|
stacked = torch.hstack([self.prev, wav_chunk])
|
||||||
self.prev = wav_chunk
|
self.prev = wav_chunk
|
||||||
|
|
||||||
overlap_chunks = [stacked[i:i+4000] for i in range(self.step, 4001, self.step)] # 500 step is good enough
|
overlap_chunks = [stacked[i:i+self.num_samples] for i in range(self.step, self.num_samples+1, self.step)] # 500 step is good enough
|
||||||
return torch.vstack(overlap_chunks)
|
return torch.vstack(overlap_chunks)
|
||||||
|
|
||||||
def state(self, model_out):
|
def state(self, model_out):
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
for i, predict in enumerate(model_out[:, 1]): # add name
|
speech_probs = model_out[:, 1]
|
||||||
|
for i, predict in enumerate(speech_probs): # add name
|
||||||
self.buffer.append(predict)
|
self.buffer.append(predict)
|
||||||
if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered:
|
if (np.mean(self.buffer) >= self.trig_sum) and not self.triggered:
|
||||||
self.triggered = True
|
self.triggered = True
|
||||||
@@ -236,14 +240,15 @@ def state_generator(model, audios, extractor,
|
|||||||
yield states
|
yield states
|
||||||
|
|
||||||
|
|
||||||
def stream_imitator(stereo, audios_in_stream):
|
def stream_imitator(audios, audios_in_stream):
|
||||||
stereo_iter = iter(stereo)
|
audio_iter = iter(audios)
|
||||||
iterators = []
|
iterators = []
|
||||||
|
num_samples = 4000
|
||||||
# initial wavs
|
# initial wavs
|
||||||
for i in range(audios_in_stream):
|
for i in range(audios_in_stream):
|
||||||
next_wav = next(stereo_iter)
|
next_wav = next(audio_iter)
|
||||||
wav = read_audio(next_wav)
|
wav = read_audio(next_wav)
|
||||||
wav_chunks = iter([(wav[i:i+4000], next_wav) for i in range(0, len(wav), 4000)])
|
wav_chunks = iter([(wav[i:i+num_samples], next_wav) for i in range(0, len(wav), num_samples)])
|
||||||
iterators.append(wav_chunks)
|
iterators.append(wav_chunks)
|
||||||
print('Done initial Loading')
|
print('Done initial Loading')
|
||||||
good_iters = audios_in_stream
|
good_iters = audios_in_stream
|
||||||
@@ -254,16 +259,40 @@ def stream_imitator(stereo, audios_in_stream):
|
|||||||
out, wav_name = next(it)
|
out, wav_name = next(it)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
try:
|
try:
|
||||||
next_wav = next(stereo_iter)
|
next_wav = next(audio_iter)
|
||||||
print('Loading next wav: ', next_wav)
|
print('Loading next wav: ', next_wav)
|
||||||
wav = read_audio(next_wav)
|
wav = read_audio(next_wav)
|
||||||
iterators[i] = iter([(wav[i:i+4000], next_wav) for i in range(0, len(wav), 4000)])
|
iterators[i] = iter([(wav[i:i+num_samples], next_wav) for i in range(0, len(wav), num_samples)])
|
||||||
out, wav_name = next(iterators[i])
|
out, wav_name = next(iterators[i])
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
good_iters -= 1
|
good_iters -= 1
|
||||||
iterators[i] = repeat((torch.zeros(4000), 'junk'))
|
iterators[i] = repeat((torch.zeros(num_samples), 'junk'))
|
||||||
out, wav_name = next(iterators[i])
|
out, wav_name = next(iterators[i])
|
||||||
if good_iters == 0:
|
if good_iters == 0:
|
||||||
return
|
return
|
||||||
values.append((out, wav_name))
|
values.append((out, wav_name))
|
||||||
yield values
|
yield values
|
||||||
|
|
||||||
|
def single_audio_stream(model, audio, extractor, onnx=False, trig_sum=0.26,
|
||||||
|
neg_trig_sum=0.01, num_steps=8):
|
||||||
|
num_samples = 4000
|
||||||
|
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
|
||||||
|
wav = read_audio(audio)
|
||||||
|
wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)])
|
||||||
|
for chunk in wav_chunks:
|
||||||
|
batch = VADiter.prepare_batch(chunk)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if onnx:
|
||||||
|
ort_inputs = {'input': to_numpy(extractor(batch))}
|
||||||
|
ort_outs = model.run(None, ort_inputs)
|
||||||
|
vad_outs = ort_outs[-2]
|
||||||
|
else:
|
||||||
|
outs = model(extractor(batch))
|
||||||
|
vad_outs = outs[-2]
|
||||||
|
|
||||||
|
states = []
|
||||||
|
state = VADiter.state(vad_outs)
|
||||||
|
if state[0]:
|
||||||
|
states.append(state[0])
|
||||||
|
yield states
|
||||||
|
|||||||
Reference in New Issue
Block a user