diff --git a/utils_vad.py b/utils_vad.py index c4fd682..4c408f7 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -203,7 +203,7 @@ def get_speech_ts_adaptive(wav: torch.Tensor, speeches: list list containing ends and beginnings of speech chunks (in samples) """ - + num_samples = num_samples_per_window num_steps = int(num_samples / step) assert min_silence_samples >= step @@ -268,9 +268,18 @@ def get_speech_ts_adaptive(wav: torch.Tensor, if visualize_probs: pd.DataFrame({'probs': smoothed_probs}).plot(figsize=(16, 8)) - for ts in speeches: - ts['start'] = max(0, ts['start'] - speech_pad_samples) - ts['end'] += speech_pad_samples + for i, ts in enumerate(speeches): + if i == 0: + ts['start'] = max(0, ts['start'] - speech_pad_samples) + if i != len(speeches) - 1: + silence_duration = speeches[i+1]['start'] - ts['end'] + if silence_duration < 2 * speech_pad_samples: + ts['end'] += silence_duration // 2 + speeches[i+1]['start'] = max(0, speeches[i+1]['start'] - silence_duration // 2) + else: + ts['end'] += speech_pad_samples + else: + ts['end'] = min(len(wav), ts['end'] + speech_pad_samples) return speeches