diff --git a/src/silero_vad/utils_vad.py b/src/silero_vad/utils_vad.py index 9a64717..3fd5635 100644 --- a/src/silero_vad/utils_vad.py +++ b/src/silero_vad/utils_vad.py @@ -197,6 +197,7 @@ def get_speech_timestamps(audio: torch.Tensor, min_silence_duration_ms: int = 100, speech_pad_ms: int = 30, return_seconds: bool = False, + time_resolution: int = 1, visualize_probs: bool = False, progress_tracking_callback: Callable[[float], None] = None, neg_threshold: float = None, @@ -236,6 +237,9 @@ def get_speech_timestamps(audio: torch.Tensor, return_seconds: bool (default - False) whether return timestamps in seconds (default - samples) + time_resolution: bool (default - 1) + time resolution of speech coordinates when requested as seconds + visualize_probs: bool (default - False) whether draw prob hist or not @@ -378,8 +382,8 @@ def get_speech_timestamps(audio: torch.Tensor, if return_seconds: audio_length_seconds = audio_length_samples / sampling_rate for speech_dict in speeches: - speech_dict['start'] = max(round(speech_dict['start'] / sampling_rate, 1), 0) - speech_dict['end'] = min(round(speech_dict['end'] / sampling_rate, 1), audio_length_seconds) + speech_dict['start'] = max(round(speech_dict['start'] / sampling_rate, time_resolution), 0) + speech_dict['end'] = min(round(speech_dict['end'] / sampling_rate, time_resolution), audio_length_seconds) elif step > 1: for speech_dict in speeches: speech_dict['start'] *= step @@ -440,13 +444,16 @@ class VADIterator: self.current_sample = 0 @torch.no_grad() - def __call__(self, x, return_seconds=False): + def __call__(self, x, return_seconds=False, time_resolution: int = 1): """ x: torch.Tensor audio chunk (see examples in repo) return_seconds: bool (default - False) whether return timestamps in seconds (default - samples) + + time_resolution: int (default - 1) + time resolution of speech coordinates when requested as seconds """ if not torch.is_tensor(x): @@ -466,7 +473,7 @@ class VADIterator: if (speech_prob >= self.threshold) and not self.triggered: self.triggered = True speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples) - return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)} + return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)} if (speech_prob < self.threshold - 0.15) and self.triggered: if not self.temp_end: @@ -477,7 +484,7 @@ class VADIterator: speech_end = self.temp_end + self.speech_pad_samples - window_size_samples self.temp_end = 0 self.triggered = False - return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)} + return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)} return None