mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
add neg_threshold parameter explicitly
This commit is contained in:
@@ -53,10 +53,10 @@ class OnnxWrapper():
|
|||||||
|
|
||||||
x, sr = self._validate_input(x, sr)
|
x, sr = self._validate_input(x, sr)
|
||||||
num_samples = 512 if sr == 16000 else 256
|
num_samples = 512 if sr == 16000 else 256
|
||||||
|
|
||||||
if x.shape[-1] != num_samples:
|
if x.shape[-1] != num_samples:
|
||||||
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
context_size = 64 if sr == 16000 else 32
|
context_size = 64 if sr == 16000 else 32
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ class Validator():
|
|||||||
def read_audio(path: str,
|
def read_audio(path: str,
|
||||||
sampling_rate: int = 16000):
|
sampling_rate: int = 16000):
|
||||||
list_backends = torchaudio.list_audio_backends()
|
list_backends = torchaudio.list_audio_backends()
|
||||||
|
|
||||||
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
|
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
|
||||||
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
|
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
|
||||||
|
|
||||||
@@ -195,6 +195,7 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
return_seconds: bool = False,
|
return_seconds: bool = False,
|
||||||
visualize_probs: bool = False,
|
visualize_probs: bool = False,
|
||||||
progress_tracking_callback: Callable[[float], None] = None,
|
progress_tracking_callback: Callable[[float], None] = None,
|
||||||
|
neg_threshold: float = None,
|
||||||
window_size_samples: int = 512,):
|
window_size_samples: int = 512,):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -237,6 +238,9 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
progress_tracking_callback: Callable[[float], None] (default - None)
|
progress_tracking_callback: Callable[[float], None] (default - None)
|
||||||
callback function taking progress in percents as an argument
|
callback function taking progress in percents as an argument
|
||||||
|
|
||||||
|
neg_threshold: float (default = threshold - 0.15)
|
||||||
|
Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH.
|
||||||
|
|
||||||
window_size_samples: int (default - 512 samples)
|
window_size_samples: int (default - 512 samples)
|
||||||
!!! DEPRECATED, DOES NOTHING !!!
|
!!! DEPRECATED, DOES NOTHING !!!
|
||||||
|
|
||||||
@@ -298,15 +302,17 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
triggered = False
|
triggered = False
|
||||||
speeches = []
|
speeches = []
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
neg_threshold = threshold - 0.15
|
|
||||||
temp_end = 0 # to save potential segment end (and tolerate some silence)
|
if neg_threshold is None:
|
||||||
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
|
neg_threshold = threshold - 0.15
|
||||||
|
temp_end = 0 # to save potential segment end (and tolerate some silence)
|
||||||
|
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
|
||||||
|
|
||||||
for i, speech_prob in enumerate(speech_probs):
|
for i, speech_prob in enumerate(speech_probs):
|
||||||
if (speech_prob >= threshold) and temp_end:
|
if (speech_prob >= threshold) and temp_end:
|
||||||
temp_end = 0
|
temp_end = 0
|
||||||
if next_start < prev_end:
|
if next_start < prev_end:
|
||||||
next_start = window_size_samples * i
|
next_start = window_size_samples * i
|
||||||
|
|
||||||
if (speech_prob >= threshold) and not triggered:
|
if (speech_prob >= threshold) and not triggered:
|
||||||
triggered = True
|
triggered = True
|
||||||
@@ -318,7 +324,7 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
current_speech['end'] = prev_end
|
current_speech['end'] = prev_end
|
||||||
speeches.append(current_speech)
|
speeches.append(current_speech)
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
|
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
|
||||||
triggered = False
|
triggered = False
|
||||||
else:
|
else:
|
||||||
current_speech['start'] = next_start
|
current_speech['start'] = next_start
|
||||||
@@ -334,7 +340,7 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
if (speech_prob < neg_threshold) and triggered:
|
if (speech_prob < neg_threshold) and triggered:
|
||||||
if not temp_end:
|
if not temp_end:
|
||||||
temp_end = window_size_samples * i
|
temp_end = window_size_samples * i
|
||||||
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence
|
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
|
||||||
prev_end = temp_end
|
prev_end = temp_end
|
||||||
if (window_size_samples * i) - temp_end < min_silence_samples:
|
if (window_size_samples * i) - temp_end < min_silence_samples:
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user