Add a tiny visualization tool

This commit is contained in:
Alexander Veysov
2021-01-29 09:59:26 +03:00
committed by GitHub
parent 2cc23551e8
commit 41dd68aa24

View File

@@ -60,7 +60,8 @@ def get_speech_ts(wav: torch.Tensor,
batch_size: int = 200,
num_samples_per_window: int = 4000,
min_speech_samples: int = 10000, #samples
run_function=validate):
run_function=validate,
visualize_probs=False):
num_samples = num_samples_per_window
assert num_samples % num_steps == 0
@@ -89,14 +90,20 @@ def get_speech_ts(wav: torch.Tensor,
triggered = False
speeches = []
current_speech = {}
if visualize_probs:
import pandas as pd
smoothed_probs = []
speech_probs = outs[:, 1] # this is very misleading
for i, predict in enumerate(speech_probs): # add name
buffer.append(predict)
if ((sum(buffer) / len(buffer))>= trig_sum) and not triggered:
smoothed_prob = (sum(buffer) / len(buffer))
if visualize_probs:
smoothed_probs.append(float(smoothed_prob))
if (smoothed_prob >= trig_sum) and not triggered:
triggered = True
current_speech['start'] = step * max(0, i-num_steps)
if ((sum(buffer) / len(buffer)) < neg_trig_sum) and triggered:
if (smoothed_prob < neg_trig_sum) and triggered:
current_speech['end'] = step * i
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
speeches.append(current_speech)
@@ -105,6 +112,9 @@ def get_speech_ts(wav: torch.Tensor,
if current_speech:
current_speech['end'] = len(wav)
speeches.append(current_speech)
if visualize_probs:
pd.DataFrame({'probs':smoothed_probs}).plot(figsize=(16,8))
return speeches