Added params for hop_size, and min_silence_at_max speech to cut at a possible silence when max_dur reached to avoid abrupt cuts

This commit is contained in:
shashank14k
2025-07-25 20:51:40 +05:30
parent 94811cbe12
commit bbf22a0064

View File

@@ -201,7 +201,10 @@ def get_speech_timestamps(audio: torch.Tensor,
visualize_probs: bool = False, visualize_probs: bool = False,
progress_tracking_callback: Callable[[float], None] = None, progress_tracking_callback: Callable[[float], None] = None,
neg_threshold: float = None, neg_threshold: float = None,
window_size_samples: int = 512,): window_size_samples: int = 512,
hop_size_ratio: float = 1,
min_silence_at_max_speech: float = 98,
use_max_poss_sil_at_max_speech: bool = True):
""" """
This method is used for splitting long audios into speech chunks using silero VAD This method is used for splitting long audios into speech chunks using silero VAD
@@ -251,13 +254,16 @@ def get_speech_timestamps(audio: torch.Tensor,
window_size_samples: int (default - 512 samples) window_size_samples: int (default - 512 samples)
!!! DEPRECATED, DOES NOTHING !!! !!! DEPRECATED, DOES NOTHING !!!
hop_size_ratio: float (default - 1), number of samples by which the window is shifted, 1 means hop_size_samples = window_size_samples
min_silence_at_max_speech: float (default - 25ms), minimum silence duration in ms which is used to avoid abrupt cuts when max_speech_duration_s is reached
use_max_poss_sil_at_max_speech: bool (default - True), whether to use the maximum possible silence at max_speech_duration_s or not. If not, the last silence is used.
Returns Returns
---------- ----------
speeches: list of dicts speeches: list of dicts
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds) list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
""" """
if not torch.is_tensor(audio): if not torch.is_tensor(audio):
try: try:
audio = torch.Tensor(audio) audio = torch.Tensor(audio)
@@ -282,25 +288,29 @@ def get_speech_timestamps(audio: torch.Tensor,
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates") raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
window_size_samples = 512 if sampling_rate == 16000 else 256 window_size_samples = 512 if sampling_rate == 16000 else 256
hop_size_samples = int(window_size_samples * hop_size_ratio)
model.reset_states() model.reset_states()
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000 speech_pad_samples = sampling_rate * speech_pad_ms / 1000
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 min_silence_samples_at_max_speech = sampling_rate * min_silence_at_max_speech / 1000
audio_length_samples = len(audio) audio_length_samples = len(audio)
speech_probs = [] speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples): for current_start_sample in range(0, audio_length_samples, hop_size_samples):
chunk = audio[current_start_sample: current_start_sample + window_size_samples] chunk = audio[current_start_sample: current_start_sample + window_size_samples]
if len(chunk) < window_size_samples: if len(chunk) < window_size_samples:
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk)))) chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob = model(chunk, sampling_rate).item() try:
speech_prob = model(chunk, sampling_rate).item()
except Exception as e:
import ipdb; ipdb.set_trace()
speech_probs.append(speech_prob) speech_probs.append(speech_prob)
# caculate progress and seng it to callback function # caculate progress and seng it to callback function
progress = current_start_sample + window_size_samples progress = current_start_sample + hop_size_samples
if progress > audio_length_samples: if progress > audio_length_samples:
progress = audio_length_samples progress = audio_length_samples
progress_percent = (progress / audio_length_samples) * 100 progress_percent = (progress / audio_length_samples) * 100
@@ -315,42 +325,56 @@ def get_speech_timestamps(audio: torch.Tensor,
neg_threshold = max(threshold - 0.15, 0.01) neg_threshold = max(threshold - 0.15, 0.01)
temp_end = 0 # to save potential segment end (and tolerate some silence) temp_end = 0 # to save potential segment end (and tolerate some silence)
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
possible_ends = []
for i, speech_prob in enumerate(speech_probs): for i, speech_prob in enumerate(speech_probs):
if (speech_prob >= threshold) and temp_end: if (speech_prob >= threshold) and temp_end:
temp_end = 0 if temp_end != 0:
sil_dur = (hop_size_samples * i) - temp_end
if sil_dur > min_silence_samples_at_max_speech:
possible_ends.append((temp_end, sil_dur))
temp_end = 0
if next_start < prev_end: if next_start < prev_end:
next_start = window_size_samples * i next_start = hop_size_samples * i
if (speech_prob >= threshold) and not triggered: if (speech_prob >= threshold) and not triggered:
triggered = True triggered = True
current_speech['start'] = window_size_samples * i current_speech['start'] = hop_size_samples * i
continue continue
if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples: if triggered and (hop_size_samples * i) - current_speech['start'] > max_speech_samples:
if prev_end: if possible_ends:
if use_max_poss_sil_at_max_speech:
prev_end, dur = max(possible_ends, key=lambda x: x[1]) # use the longest possible silence segment in the current speech chunk
else:
prev_end, dur = possible_ends[-1] # use the last possible silence segement
current_speech['end'] = prev_end current_speech['end'] = prev_end
speeches.append(current_speech) speeches.append(current_speech)
current_speech = {} current_speech = {}
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres) next_start = prev_end + dur
triggered = False if next_start < prev_end + hop_size_samples * i: # previously reached silence (< neg_thres) and is still not speech (< thres)
else: #triggered = False
current_speech['start'] = next_start current_speech['start'] = next_start
else:
triggered = False
#current_speech['start'] = next_start
prev_end = next_start = temp_end = 0 prev_end = next_start = temp_end = 0
possible_ends = []
else: else:
current_speech['end'] = window_size_samples * i current_speech['end'] = hop_size_samples * i
speeches.append(current_speech) speeches.append(current_speech)
current_speech = {} current_speech = {}
prev_end = next_start = temp_end = 0 prev_end = next_start = temp_end = 0
triggered = False triggered = False
possible_ends = []
continue continue
if (speech_prob < neg_threshold) and triggered: if (speech_prob < neg_threshold) and triggered:
if not temp_end: if not temp_end:
temp_end = window_size_samples * i temp_end = hop_size_samples * i
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence # if ((hop_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
prev_end = temp_end # prev_end = temp_end
if (window_size_samples * i) - temp_end < min_silence_samples: if (hop_size_samples * i) - temp_end < min_silence_samples:
continue continue
else: else:
current_speech['end'] = temp_end current_speech['end'] = temp_end
@@ -359,6 +383,7 @@ def get_speech_timestamps(audio: torch.Tensor,
current_speech = {} current_speech = {}
prev_end = next_start = temp_end = 0 prev_end = next_start = temp_end = 0
triggered = False triggered = False
possible_ends = []
continue continue
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples: if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
@@ -390,7 +415,7 @@ def get_speech_timestamps(audio: torch.Tensor,
speech_dict['end'] *= step speech_dict['end'] *= step
if visualize_probs: if visualize_probs:
make_visualization(speech_probs, window_size_samples / sampling_rate) make_visualization(speech_probs, hop_size_samples / sampling_rate)
return speeches return speeches