additional vad utils

This commit is contained in:
adamnsandle
2021-08-27 10:10:14 +00:00
parent 17071068e1
commit 1fc6b72ac9
2 changed files with 67 additions and 4 deletions

View File

@@ -86,8 +86,11 @@ def get_speech_ts(wav: torch.Tensor,
min_speech_samples: int = 10000, #samples
min_silence_samples: int = 500,
run_function=validate,
visualize_probs=False):
visualize_probs=False,
smoothed_prob_func='mean',
device='cpu'):
assert smoothed_prob_func in ['mean', 'max'], 'smoothed_prob_func not in ["max", "mean"]'
num_samples = num_samples_per_window
assert num_samples % num_steps == 0
step = int(num_samples / num_steps) # stride / hop
@@ -99,13 +102,13 @@ def get_speech_ts(wav: torch.Tensor,
chunk = F.pad(chunk, (0, num_samples - len(chunk)))
to_concat.append(chunk.unsqueeze(0))
if len(to_concat) >= batch_size:
chunks = torch.Tensor(torch.cat(to_concat, dim=0))
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
out = run_function(model, chunks)
outs.append(out)
to_concat = []
if to_concat:
chunks = torch.Tensor(torch.cat(to_concat, dim=0))
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
out = run_function(model, chunks)
outs.append(out)
@@ -123,7 +126,11 @@ def get_speech_ts(wav: torch.Tensor,
temp_end = 0
for i, predict in enumerate(speech_probs): # add name
buffer.append(predict)
smoothed_prob = (sum(buffer) / len(buffer))
if smoothed_prob_func == 'mean':
smoothed_prob = (sum(buffer) / len(buffer))
elif smoothed_prob_func == 'max':
smoothed_prob = max(buffer)
if visualize_probs:
smoothed_probs.append(float(smoothed_prob))
if (smoothed_prob >= trig_sum) and temp_end: