From 11da69d88bfcd762f3432742260e8de2a8253413 Mon Sep 17 00:00:00 2001 From: Saenyakorn Siangsanoh Date: Wed, 28 Dec 2022 14:13:48 +0700 Subject: [PATCH] add progress_tracking callback to get_speech_timestamps --- utils_vad.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/utils_vad.py b/utils_vad.py index 9e4b1bb..09aff37 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -1,6 +1,6 @@ import torch import torchaudio -from typing import List +from typing import Callable, List import torch.nn.functional as F import warnings @@ -168,7 +168,8 @@ def get_speech_timestamps(audio: torch.Tensor, window_size_samples: int = 512, speech_pad_ms: int = 30, return_seconds: bool = False, - visualize_probs: bool = False): + visualize_probs: bool = False, + progress_tracking_callback: Callable[[float], None] = None): """ This method is used for splitting long audios into speech chunks using silero VAD @@ -212,6 +213,9 @@ def get_speech_timestamps(audio: torch.Tensor, visualize_probs: bool (default - False) whether draw prob hist or not + progress_tracking_callback: Callable[[float], None] (default - None) + callback function taking progress in percents as an argument + Returns ---------- speeches: list of dicts @@ -259,6 +263,14 @@ def get_speech_timestamps(audio: torch.Tensor, chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk)))) speech_prob = model(chunk, sampling_rate).item() speech_probs.append(speech_prob) + # caculate progress and seng it to callback function + progress = (current_start_sample / audio_length_samples) * 100 + if progress_tracking_callback: + progress_tracking_callback(progress) + + # sending 100% progress to callback function after processing with actual model + if progress_tracking_callback: + progress_tracking_callback(100) triggered = False speeches = []