diff --git a/files/silero_vad.jit b/files/silero_vad.jit index 38237dc..42a5ed4 100644 Binary files a/files/silero_vad.jit and b/files/silero_vad.jit differ diff --git a/files/silero_vad.onnx b/files/silero_vad.onnx index e6db48d..d0ccd9d 100644 Binary files a/files/silero_vad.onnx and b/files/silero_vad.onnx differ diff --git a/silero-vad.ipynb b/silero-vad.ipynb index f3b521d..22f528a 100644 --- a/silero-vad.ipynb +++ b/silero-vad.ipynb @@ -46,7 +46,7 @@ "USE_ONNX = False # change this to True if you want to test onnx model\n", "if USE_ONNX:\n", " !pip install -q onnxruntime\n", - " \n", + "\n", "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", " model='silero_vad',\n", " force_reload=True,\n", @@ -65,16 +65,7 @@ "id": "fXbbaUO3jsrw" }, "source": [ - "## Full Audio" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RAfJPb_a-Auj" - }, - "source": [ - "**Speech timestapms from full audio**" + "## Speech timestapms from full audio" ] }, { @@ -101,10 +92,33 @@ "source": [ "# merge all speech chunks to one audio\n", "save_audio('only_speech.wav',\n", - " collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE) \n", + " collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE)\n", "Audio('only_speech.wav')" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "zeO1xCqxUC6w" + }, + "source": [ + "## Entire audio inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LjZBcsaTT7Mk" + }, + "outputs": [], + "source": [ + "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n", + "# audio is being splitted into 31.25 ms long pieces\n", + "# so output length equals ceil(input_length * 31.25 / SAMPLING_RATE)\n", + "predicts = model.audio_forward(wav, sr=SAMPLING_RATE)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -124,10 +138,10 @@ "source": [ "## using VADIterator class\n", "\n", - "vad_iterator = VADIterator(model)\n", + "vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE)\n", "wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n", "\n", - "window_size_samples = 1536 # number of samples in a single audio chunk\n", + "window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n", "for i in range(0, len(wav), window_size_samples):\n", " chunk = wav[i: i+ window_size_samples]\n", " if len(chunk) < window_size_samples:\n", @@ -150,7 +164,7 @@ "\n", "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n", "speech_probs = []\n", - "window_size_samples = 1536\n", + "window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n", "for i in range(0, len(wav), window_size_samples):\n", " chunk = wav[i: i+ window_size_samples]\n", " if len(chunk) < window_size_samples:\n", diff --git a/utils_vad.py b/utils_vad.py index 1ebf2d7..ea0293b 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -1,7 +1,6 @@ import torch import torchaudio from typing import Callable, List -import torch.nn.functional as F import warnings languages = ['ru', 'en', 'de', 'es'] @@ -39,22 +38,27 @@ class OnnxWrapper(): if sr not in self.sample_rates: raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)") - if sr / x.shape[1] > 31.25: raise ValueError("Input audio chunk is too short") return x, sr def reset_states(self, batch_size=1): - self._h = np.zeros((2, batch_size, 64)).astype('float32') - self._c = np.zeros((2, batch_size, 64)).astype('float32') + self._state = torch.zeros((2, batch_size, 128)).float() + self._context = torch.zeros(0) self._last_sr = 0 self._last_batch_size = 0 def __call__(self, x, sr: int): x, sr = self._validate_input(x, sr) + num_samples = 512 if sr == 16000 else 256 + + if x.shape[-1] != num_samples: + raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)") + batch_size = x.shape[0] + context_size = 64 if sr == 16000 else 32 if not self._last_batch_size: self.reset_states(batch_size) @@ -63,28 +67,35 @@ class OnnxWrapper(): if (self._last_batch_size) and (self._last_batch_size != batch_size): self.reset_states(batch_size) + if not len(self._context): + self._context = torch.zeros(batch_size, context_size) + + x = torch.cat([self._context, x], dim=1) if sr in [8000, 16000]: - ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')} + ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr)} ort_outs = self.session.run(None, ort_inputs) - out, self._h, self._c = ort_outs + out, state = ort_outs + self._state = torch.from_numpy(state) else: raise ValueError() + self._context = x[..., -context_size:] self._last_sr = sr self._last_batch_size = batch_size - out = torch.tensor(out) + out = torch.from_numpy(out) return out - def audio_forward(self, x, sr: int, num_samples: int = 512): + def audio_forward(self, x, sr: int): outs = [] x, sr = self._validate_input(x, sr) + self.reset_states() + num_samples = 512 if sr == 16000 else 256 if x.shape[1] % num_samples: pad_num = num_samples - (x.shape[1] % num_samples) x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0) - self.reset_states(x.shape[0]) for i in range(0, x.shape[1], num_samples): wavs_batch = x[:, i:i+num_samples] out_chunk = self.__call__(wavs_batch, sr) @@ -179,11 +190,11 @@ def get_speech_timestamps(audio: torch.Tensor, min_speech_duration_ms: int = 250, max_speech_duration_s: float = float('inf'), min_silence_duration_ms: int = 100, - window_size_samples: int = 512, speech_pad_ms: int = 30, return_seconds: bool = False, visualize_probs: bool = False, - progress_tracking_callback: Callable[[float], None] = None): + progress_tracking_callback: Callable[[float], None] = None, + window_size_samples: int = 512,): """ This method is used for splitting long audios into speech chunks using silero VAD @@ -193,14 +204,14 @@ def get_speech_timestamps(audio: torch.Tensor, audio: torch.Tensor, one dimensional One dimensional float torch.Tensor, other types are casted to torch if possible - model: preloaded .jit silero VAD model + model: preloaded .jit/.onnx silero VAD model threshold: float (default - 0.5) Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. sampling_rate: int (default - 16000) - Currently silero VAD models support 8000 and 16000 sample rates + Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates min_speech_duration_ms: int (default - 250 milliseconds) Final speech chunks shorter min_speech_duration_ms are thrown out @@ -213,11 +224,6 @@ def get_speech_timestamps(audio: torch.Tensor, min_silence_duration_ms: int (default - 100 milliseconds) In the end of each speech chunk wait for min_silence_duration_ms before separating it - window_size_samples: int (default - 1536 samples) - Audio chunks of window_size_samples size are fed to the silero VAD model. - WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples for 8000 sample rate. - Values other than these may affect model perfomance!! - speech_pad_ms: int (default - 30 milliseconds) Final speech chunks are padded by speech_pad_ms each side @@ -230,6 +236,9 @@ def get_speech_timestamps(audio: torch.Tensor, progress_tracking_callback: Callable[[float], None] (default - None) callback function taking progress in percents as an argument + window_size_samples: int (default - 512 samples) + !!! DEPRECATED, DOES NOTHING !!! + Returns ---------- speeches: list of dicts @@ -256,10 +265,10 @@ def get_speech_timestamps(audio: torch.Tensor, else: step = 1 - if sampling_rate == 8000 and window_size_samples > 768: - warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!') - if window_size_samples not in [256, 512, 768, 1024, 1536]: - warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate') + if sampling_rate not in [8000, 16000]: + raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates") + + window_size_samples = 512 if sampling_rate == 16000 else 256 model.reset_states() min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 @@ -450,7 +459,7 @@ class VADIterator: Parameters ---------- - model: preloaded .jit silero VAD model + model: preloaded .jit/.onnx silero VAD model threshold: float (default - 0.5) Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.