From 4109b107c1376006d03e125dc84698cb0dcaaa6d Mon Sep 17 00:00:00 2001 From: adamnsandle Date: Tue, 20 Aug 2024 08:53:15 +0000 Subject: [PATCH] add neg_threshold parameter explicitly --- src/silero_vad/utils_vad.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/silero_vad/utils_vad.py b/src/silero_vad/utils_vad.py index c910db6..d95487d 100644 --- a/src/silero_vad/utils_vad.py +++ b/src/silero_vad/utils_vad.py @@ -53,10 +53,10 @@ class OnnxWrapper(): x, sr = self._validate_input(x, sr) num_samples = 512 if sr == 16000 else 256 - + if x.shape[-1] != num_samples: raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)") - + batch_size = x.shape[0] context_size = 64 if sr == 16000 else 32 @@ -133,7 +133,7 @@ class Validator(): def read_audio(path: str, sampling_rate: int = 16000): list_backends = torchaudio.list_audio_backends() - + assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \ \n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)' @@ -195,6 +195,7 @@ def get_speech_timestamps(audio: torch.Tensor, return_seconds: bool = False, visualize_probs: bool = False, progress_tracking_callback: Callable[[float], None] = None, + neg_threshold: float = None, window_size_samples: int = 512,): """ @@ -237,6 +238,9 @@ def get_speech_timestamps(audio: torch.Tensor, progress_tracking_callback: Callable[[float], None] (default - None) callback function taking progress in percents as an argument + neg_threshold: float (default = threshold - 0.15) + Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH. + window_size_samples: int (default - 512 samples) !!! DEPRECATED, DOES NOTHING !!! @@ -298,15 +302,17 @@ def get_speech_timestamps(audio: torch.Tensor, triggered = False speeches = [] current_speech = {} - neg_threshold = threshold - 0.15 - temp_end = 0 # to save potential segment end (and tolerate some silence) - prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached + + if neg_threshold is None: + neg_threshold = threshold - 0.15 + temp_end = 0 # to save potential segment end (and tolerate some silence) + prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached for i, speech_prob in enumerate(speech_probs): if (speech_prob >= threshold) and temp_end: temp_end = 0 if next_start < prev_end: - next_start = window_size_samples * i + next_start = window_size_samples * i if (speech_prob >= threshold) and not triggered: triggered = True @@ -318,7 +324,7 @@ def get_speech_timestamps(audio: torch.Tensor, current_speech['end'] = prev_end speeches.append(current_speech) current_speech = {} - if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres) + if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres) triggered = False else: current_speech['start'] = next_start @@ -334,7 +340,7 @@ def get_speech_timestamps(audio: torch.Tensor, if (speech_prob < neg_threshold) and triggered: if not temp_end: temp_end = window_size_samples * i - if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence + if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence prev_end = temp_end if (window_size_samples * i) - temp_end < min_silence_samples: continue