mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
additional vad utils
This commit is contained in:
15
utils_vad.py
15
utils_vad.py
@@ -86,8 +86,11 @@ def get_speech_ts(wav: torch.Tensor,
|
||||
min_speech_samples: int = 10000, #samples
|
||||
min_silence_samples: int = 500,
|
||||
run_function=validate,
|
||||
visualize_probs=False):
|
||||
visualize_probs=False,
|
||||
smoothed_prob_func='mean',
|
||||
device='cpu'):
|
||||
|
||||
assert smoothed_prob_func in ['mean', 'max'], 'smoothed_prob_func not in ["max", "mean"]'
|
||||
num_samples = num_samples_per_window
|
||||
assert num_samples % num_steps == 0
|
||||
step = int(num_samples / num_steps) # stride / hop
|
||||
@@ -99,13 +102,13 @@ def get_speech_ts(wav: torch.Tensor,
|
||||
chunk = F.pad(chunk, (0, num_samples - len(chunk)))
|
||||
to_concat.append(chunk.unsqueeze(0))
|
||||
if len(to_concat) >= batch_size:
|
||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0))
|
||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
|
||||
out = run_function(model, chunks)
|
||||
outs.append(out)
|
||||
to_concat = []
|
||||
|
||||
if to_concat:
|
||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0))
|
||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
|
||||
out = run_function(model, chunks)
|
||||
outs.append(out)
|
||||
|
||||
@@ -123,7 +126,11 @@ def get_speech_ts(wav: torch.Tensor,
|
||||
temp_end = 0
|
||||
for i, predict in enumerate(speech_probs): # add name
|
||||
buffer.append(predict)
|
||||
smoothed_prob = (sum(buffer) / len(buffer))
|
||||
if smoothed_prob_func == 'mean':
|
||||
smoothed_prob = (sum(buffer) / len(buffer))
|
||||
elif smoothed_prob_func == 'max':
|
||||
smoothed_prob = max(buffer)
|
||||
|
||||
if visualize_probs:
|
||||
smoothed_probs.append(float(smoothed_prob))
|
||||
if (smoothed_prob >= trig_sum) and temp_end:
|
||||
|
||||
Reference in New Issue
Block a user