mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
Merge pull request #282 from saenyakorn/master
Add `progress_tracking_callback` argument to `get_speech_timestamps` function
This commit is contained in:
15
utils_vad.py
15
utils_vad.py
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from typing import List
|
||||
from typing import Callable, List
|
||||
import torch.nn.functional as F
|
||||
import warnings
|
||||
|
||||
@@ -168,7 +168,8 @@ def get_speech_timestamps(audio: torch.Tensor,
|
||||
window_size_samples: int = 512,
|
||||
speech_pad_ms: int = 30,
|
||||
return_seconds: bool = False,
|
||||
visualize_probs: bool = False):
|
||||
visualize_probs: bool = False,
|
||||
progress_tracking_callback: Callable[[float], None] = None):
|
||||
|
||||
"""
|
||||
This method is used for splitting long audios into speech chunks using silero VAD
|
||||
@@ -212,6 +213,9 @@ def get_speech_timestamps(audio: torch.Tensor,
|
||||
visualize_probs: bool (default - False)
|
||||
whether draw prob hist or not
|
||||
|
||||
progress_tracking_callback: Callable[[float], None] (default - None)
|
||||
callback function taking progress in percents as an argument
|
||||
|
||||
Returns
|
||||
----------
|
||||
speeches: list of dicts
|
||||
@@ -259,6 +263,13 @@ def get_speech_timestamps(audio: torch.Tensor,
|
||||
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||
speech_prob = model(chunk, sampling_rate).item()
|
||||
speech_probs.append(speech_prob)
|
||||
# caculate progress and seng it to callback function
|
||||
progress = current_start_sample + window_size_samples
|
||||
if progress > audio_length_samples:
|
||||
progress = audio_length_samples
|
||||
progress_percent = (progress / audio_length_samples) * 100
|
||||
if progress_tracking_callback:
|
||||
progress_tracking_callback(progress_percent)
|
||||
|
||||
triggered = False
|
||||
speeches = []
|
||||
|
||||
Reference in New Issue
Block a user