diff --git a/utils_vad.py b/utils_vad.py index a8f8c60..258001b 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -60,6 +60,7 @@ def get_speech_ts(wav: torch.Tensor, batch_size: int = 200, num_samples_per_window: int = 4000, min_speech_samples: int = 10000, #samples + min_silence_samples: int = 8000, run_function=validate, visualize_probs=False): @@ -95,20 +96,31 @@ def get_speech_ts(wav: torch.Tensor, smoothed_probs = [] speech_probs = outs[:, 1] # this is very misleading + temp_end = 0 for i, predict in enumerate(speech_probs): # add name buffer.append(predict) smoothed_prob = (sum(buffer) / len(buffer)) if visualize_probs: smoothed_probs.append(float(smoothed_prob)) + if (smoothed_prob >= trig_sum) and temp_end: + temp_end=0 if (smoothed_prob >= trig_sum) and not triggered: triggered = True current_speech['start'] = step * max(0, i-num_steps) + continue if (smoothed_prob < neg_trig_sum) and triggered: - current_speech['end'] = step * i - if (current_speech['end'] - current_speech['start']) > min_speech_samples: - speeches.append(current_speech) - current_speech = {} - triggered = False + if not temp_end: + temp_end = step * i + if step * i - temp_end < min_silence_samples: + continue + else: + current_speech['end'] = temp_end + if (current_speech['end'] - current_speech['start']) > min_speech_samples: + speeches.append(current_speech) + temp_end = 0 + current_speech = {} + triggered = False + continue if current_speech: current_speech['end'] = len(wav) speeches.append(current_speech)