diff --git a/files/silero_vad.jit b/files/silero_vad.jit index 9fcd286..e29f1e1 100644 Binary files a/files/silero_vad.jit and b/files/silero_vad.jit differ diff --git a/utils_vad.py b/utils_vad.py index 29b69c3..fca2d82 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -29,6 +29,11 @@ class OnnxWrapper(): if x.dim() > 2: raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}") + if sr != 16000 and (sr % 16000 == 0): + step = sr // 16000 + x = x[::step] + sr = 16000 + if x.shape[0] > 1: raise ValueError("Onnx model does not support batching") @@ -177,6 +182,14 @@ def get_speech_timestamps(audio: torch.Tensor, if len(audio.shape) > 1: raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?") + if sampling_rate > 16000 and (sampling_rate % 16000 == 0): + step = sampling_rate // 16000 + sampling_rate = 16000 + audio = audio[::step] + warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!') + else: + step = 1 + if sampling_rate == 8000 and window_size_samples > 768: warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 1536 for 8000 sample rate!') if window_size_samples not in [256, 512, 768, 1024, 1536]: @@ -247,6 +260,10 @@ def get_speech_timestamps(audio: torch.Tensor, for speech_dict in speeches: speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1) speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1) + elif step > 1: + for speech_dict in speeches: + speech_dict['start'] *= step + speech_dict['end'] *= step if visualize_probs: make_visualization(speech_probs, window_size_samples / sampling_rate) @@ -353,6 +370,10 @@ class VADIterator: self.model = model self.threshold = threshold self.sampling_rate = sampling_rate + + if sampling_rate not in [8000, 16000]: + raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]') + self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 self.reset_states()