diff --git a/files/model.jit b/files/model.jit index 0bec210..4a8b0ed 100644 Binary files a/files/model.jit and b/files/model.jit differ diff --git a/files/model.onnx b/files/model.onnx index e4596f1..2ec5c38 100644 Binary files a/files/model.onnx and b/files/model.onnx differ diff --git a/silero-vad.ipynb b/silero-vad.ipynb index 34dabb6..818b8e6 100644 --- a/silero-vad.ipynb +++ b/silero-vad.ipynb @@ -233,7 +233,7 @@ " ort_inputs = {'input': inputs.cpu().numpy()}\n", " outs = model.run(None, ort_inputs)\n", " outs = [torch.Tensor(x) for x in outs]\n", - " return outs" + " return outs[0]" ] }, { @@ -405,5 +405,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 } diff --git a/utils.py b/utils.py index 00d9cc6..98a8f29 100644 --- a/utils.py +++ b/utils.py @@ -52,7 +52,7 @@ def init_jit_model(model_path: str, def get_speech_ts(wav: torch.Tensor, model, trig_sum: float = 0.25, - neg_trig_sum: float = 0.02, + neg_trig_sum: float = 0.1, num_steps: int = 8, batch_size: int = 200, run_function=validate): @@ -70,13 +70,13 @@ def get_speech_ts(wav: torch.Tensor, to_concat.append(chunk.unsqueeze(0)) if len(to_concat) >= batch_size: chunks = torch.Tensor(torch.cat(to_concat, dim=0)) - out = run_function(model, chunks)[-2] + out = run_function(model, chunks) outs.append(out) to_concat = [] if to_concat: chunks = torch.Tensor(torch.cat(to_concat, dim=0)) - out = run_function(model, chunks)[-2] + out = run_function(model, chunks) outs.append(out) outs = torch.cat(outs, dim=0) @@ -107,7 +107,7 @@ def get_speech_ts(wav: torch.Tensor, class VADiterator: def __init__(self, trig_sum: float = 0.26, - neg_trig_sum: float = 0.02, + neg_trig_sum: float = 0.1, num_steps: int = 8): self.num_samples = 4000 self.num_steps = num_steps @@ -168,7 +168,7 @@ def state_generator(model, audios: List[str], onnx: bool = False, trig_sum: float = 0.26, - neg_trig_sum: float = 0.02, + neg_trig_sum: float = 0.1, num_steps: int = 8, audios_in_stream: int = 2, run_function=validate): @@ -178,7 +178,7 @@ def state_generator(model, batch = torch.cat(for_batch) outs = run_function(model, batch) - vad_outs = torch.split(outs[-2], num_steps) + vad_outs = torch.split(outs, num_steps) states = [] for x, y in zip(VADiters, vad_outs): @@ -227,7 +227,7 @@ def single_audio_stream(model, audio: str, onnx: bool = False, trig_sum: float = 0.26, - neg_trig_sum: float = 0.02, + neg_trig_sum: float = 0.1, num_steps: int = 8, run_function=validate): num_samples = 4000 @@ -238,7 +238,7 @@ def single_audio_stream(model, batch = VADiter.prepare_batch(chunk) outs = run_function(model, batch) - vad_outs = outs[-2] # this is very misleading + vad_outs = outs # this is very misleading states = [] state = VADiter.state(vad_outs)