From 5412e5a769c7a90580f97a5aab701c5f73df207e Mon Sep 17 00:00:00 2001 From: snakers41 Date: Fri, 11 Dec 2020 13:57:22 +0000 Subject: [PATCH] Fx --- silero-vad.ipynb | 2 +- utils.py | 41 +++++++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/silero-vad.ipynb b/silero-vad.ipynb index 403df15..81fa68f 100644 --- a/silero-vad.ipynb +++ b/silero-vad.ipynb @@ -301,7 +301,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.3" + "version": "3.7.7" }, "toc": { "base_numbering": 1, diff --git a/utils.py b/utils.py index 8355a73..81c2805 100644 --- a/utils.py +++ b/utils.py @@ -73,11 +73,15 @@ def init_jit_model(model_path, return model -def get_speech_ts(wav, model, extractor, trig_sum=0.25, neg_trig_sum=0.01, num_steps=8, batch_size=200): +def get_speech_ts(wav, model, extractor, + trig_sum=0.25, neg_trig_sum=0.01, + num_steps=8, batch_size=200): + assert 4000 % num_steps == 0 - step = int(4000 / num_steps) + step = int(4000 / num_steps) # stride / hop outs = [] to_concat = [] + for i in range(0, len(wav), step): chunk = wav[i: i+4000] if len(chunk) < 4000: @@ -89,20 +93,21 @@ def get_speech_ts(wav, model, extractor, trig_sum=0.25, neg_trig_sum=0.01, num_s out = model(extractor(chunks))[-2] outs.append(out) to_concat = [] - + if to_concat: chunks = torch.Tensor(torch.vstack(to_concat)) with torch.no_grad(): out = model(extractor(chunks))[-2] outs.append(out) - + outs = torch.cat(outs, dim=0) - - buffer = deque(maxlen=num_steps) + + buffer = deque(maxlen=num_steps) # when max queue len is reach, first element is dropped triggered = False speeches = [] current_speech = {} - for i, predict in enumerate(outs[:, 1]): + + for i, predict in enumerate(outs[:, 1]): # add name buffer.append(predict) if (np.mean(buffer) >= trig_sum) and not triggered: triggered = True @@ -150,7 +155,9 @@ class STFTExtractor(nn.Module): class VADiterator: - def __init__(self, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8): + def __init__(self, + trig_sum=0.26, neg_trig_sum=0.01, + num_steps=8): self.num_steps = num_steps assert 4000 % num_steps == 0 self.step = int(4000 / num_steps) @@ -162,14 +169,14 @@ class VADiterator: self.trig_sum = trig_sum self.neg_trig_sum = neg_trig_sum self.current_name = '' - + def refresh(self): self.prev = torch.zeros(4000) self.last = False self.triggered = False self.buffer = deque(maxlen=8) self.num_frames = 0 - + def prepare_batch(self, wav_chunk, name=None): if (name is not None) and (name != self.current_name): self.refresh() @@ -177,15 +184,15 @@ class VADiterator: assert len(wav_chunk) <= 4000 self.num_frames += len(wav_chunk) if len(wav_chunk) < 4000: - wav_chunk = F.pad(wav_chunk, (0, 4000 - len(wav_chunk))) # assume that short chunk means end of the audio + wav_chunk = F.pad(wav_chunk, (0, 4000 - len(wav_chunk))) # assume that short chunk means end of the audio self.last = True - + stacked = torch.hstack([self.prev, wav_chunk]) self.prev = wav_chunk - - overlap_chunks = [stacked[i:i+4000] for i in range(500, 4001, self.step)] # 500 step is good enough + + overlap_chunks = [stacked[i:i+4000] for i in range(500, 4001, self.step)] # 500 step is good enough return torch.vstack(overlap_chunks) - + def state(self, model_out): current_speech = {} for i, predict in enumerate(model_out[:, 1]): @@ -203,7 +210,6 @@ class VADiterator: return current_speech, self.current_name - def state_generator(model, audios, extractor, onnx=False, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8, audios_in_stream=5): VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)] for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)): @@ -218,7 +224,7 @@ def state_generator(model, audios, extractor, onnx=False, trig_sum=0.26, neg_tri else: outs = model(extractor(batch)) vad_outs = np.split(outs[-2].numpy(), audios_in_stream) - + states = [] for x, y in zip(VADiters, vad_outs): cur_st = x.state(y) @@ -259,4 +265,3 @@ def stream_imitator(stereo, audios_in_stream): values.append((out, wav_name)) yield values -