diff --git a/utils_vad.py b/utils_vad.py index 9e4b1bb..603f0ca 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -1,6 +1,6 @@ import torch import torchaudio -from typing import List +from typing import Callable, List import torch.nn.functional as F import warnings @@ -168,7 +168,8 @@ def get_speech_timestamps(audio: torch.Tensor, window_size_samples: int = 512, speech_pad_ms: int = 30, return_seconds: bool = False, - visualize_probs: bool = False): + visualize_probs: bool = False, + progress_tracking_callback: Callable[[float], None] = None): """ This method is used for splitting long audios into speech chunks using silero VAD @@ -212,6 +213,9 @@ def get_speech_timestamps(audio: torch.Tensor, visualize_probs: bool (default - False) whether draw prob hist or not + progress_tracking_callback: Callable[[float], None] (default - None) + callback function taking progress in percents as an argument + Returns ---------- speeches: list of dicts @@ -259,6 +263,13 @@ def get_speech_timestamps(audio: torch.Tensor, chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk)))) speech_prob = model(chunk, sampling_rate).item() speech_probs.append(speech_prob) + # caculate progress and seng it to callback function + progress = current_start_sample + window_size_samples + if progress > audio_length_samples: + progress = audio_length_samples + progress_percent = (progress / audio_length_samples) * 100 + if progress_tracking_callback: + progress_tracking_callback(progress_percent) triggered = False speeches = []