mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
Add a tiny visualization tool
This commit is contained in:
16
utils.py
16
utils.py
@@ -60,7 +60,8 @@ def get_speech_ts(wav: torch.Tensor,
|
|||||||
batch_size: int = 200,
|
batch_size: int = 200,
|
||||||
num_samples_per_window: int = 4000,
|
num_samples_per_window: int = 4000,
|
||||||
min_speech_samples: int = 10000, #samples
|
min_speech_samples: int = 10000, #samples
|
||||||
run_function=validate):
|
run_function=validate,
|
||||||
|
visualize_probs=False):
|
||||||
|
|
||||||
num_samples = num_samples_per_window
|
num_samples = num_samples_per_window
|
||||||
assert num_samples % num_steps == 0
|
assert num_samples % num_steps == 0
|
||||||
@@ -89,14 +90,20 @@ def get_speech_ts(wav: torch.Tensor,
|
|||||||
triggered = False
|
triggered = False
|
||||||
speeches = []
|
speeches = []
|
||||||
current_speech = {}
|
current_speech = {}
|
||||||
|
if visualize_probs:
|
||||||
|
import pandas as pd
|
||||||
|
smoothed_probs = []
|
||||||
|
|
||||||
speech_probs = outs[:, 1] # this is very misleading
|
speech_probs = outs[:, 1] # this is very misleading
|
||||||
for i, predict in enumerate(speech_probs): # add name
|
for i, predict in enumerate(speech_probs): # add name
|
||||||
buffer.append(predict)
|
buffer.append(predict)
|
||||||
if ((sum(buffer) / len(buffer))>= trig_sum) and not triggered:
|
smoothed_prob = (sum(buffer) / len(buffer))
|
||||||
|
if visualize_probs:
|
||||||
|
smoothed_probs.append(float(smoothed_prob))
|
||||||
|
if (smoothed_prob >= trig_sum) and not triggered:
|
||||||
triggered = True
|
triggered = True
|
||||||
current_speech['start'] = step * max(0, i-num_steps)
|
current_speech['start'] = step * max(0, i-num_steps)
|
||||||
if ((sum(buffer) / len(buffer)) < neg_trig_sum) and triggered:
|
if (smoothed_prob < neg_trig_sum) and triggered:
|
||||||
current_speech['end'] = step * i
|
current_speech['end'] = step * i
|
||||||
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
|
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
|
||||||
speeches.append(current_speech)
|
speeches.append(current_speech)
|
||||||
@@ -105,6 +112,9 @@ def get_speech_ts(wav: torch.Tensor,
|
|||||||
if current_speech:
|
if current_speech:
|
||||||
current_speech['end'] = len(wav)
|
current_speech['end'] = len(wav)
|
||||||
speeches.append(current_speech)
|
speeches.append(current_speech)
|
||||||
|
|
||||||
|
if visualize_probs:
|
||||||
|
pd.DataFrame({'probs':smoothed_probs}).plot(figsize=(16,8))
|
||||||
return speeches
|
return speeches
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user