diff --git a/utils.py b/utils.py index 5fafca7..a8f8c60 100644 --- a/utils.py +++ b/utils.py @@ -60,7 +60,8 @@ def get_speech_ts(wav: torch.Tensor, batch_size: int = 200, num_samples_per_window: int = 4000, min_speech_samples: int = 10000, #samples - run_function=validate): + run_function=validate, + visualize_probs=False): num_samples = num_samples_per_window assert num_samples % num_steps == 0 @@ -89,14 +90,20 @@ def get_speech_ts(wav: torch.Tensor, triggered = False speeches = [] current_speech = {} + if visualize_probs: + import pandas as pd + smoothed_probs = [] speech_probs = outs[:, 1] # this is very misleading for i, predict in enumerate(speech_probs): # add name buffer.append(predict) - if ((sum(buffer) / len(buffer))>= trig_sum) and not triggered: + smoothed_prob = (sum(buffer) / len(buffer)) + if visualize_probs: + smoothed_probs.append(float(smoothed_prob)) + if (smoothed_prob >= trig_sum) and not triggered: triggered = True current_speech['start'] = step * max(0, i-num_steps) - if ((sum(buffer) / len(buffer)) < neg_trig_sum) and triggered: + if (smoothed_prob < neg_trig_sum) and triggered: current_speech['end'] = step * i if (current_speech['end'] - current_speech['start']) > min_speech_samples: speeches.append(current_speech) @@ -105,6 +112,9 @@ def get_speech_ts(wav: torch.Tensor, if current_speech: current_speech['end'] = len(wav) speeches.append(current_speech) + + if visualize_probs: + pd.DataFrame({'probs':smoothed_probs}).plot(figsize=(16,8)) return speeches