add progress_tracking callback to get_speech_timestamps

This commit is contained in:
Saenyakorn Siangsanoh
2022-12-28 14:13:48 +07:00
parent d5a944b9f1
commit 11da69d88b

View File

@@ -1,6 +1,6 @@
import torch
import torchaudio
from typing import List
from typing import Callable, List
import torch.nn.functional as F
import warnings
@@ -168,7 +168,8 @@ def get_speech_timestamps(audio: torch.Tensor,
window_size_samples: int = 512,
speech_pad_ms: int = 30,
return_seconds: bool = False,
visualize_probs: bool = False):
visualize_probs: bool = False,
progress_tracking_callback: Callable[[float], None] = None):
"""
This method is used for splitting long audios into speech chunks using silero VAD
@@ -212,6 +213,9 @@ def get_speech_timestamps(audio: torch.Tensor,
visualize_probs: bool (default - False)
whether draw prob hist or not
progress_tracking_callback: Callable[[float], None] (default - None)
callback function taking progress in percents as an argument
Returns
----------
speeches: list of dicts
@@ -259,6 +263,14 @@ def get_speech_timestamps(audio: torch.Tensor,
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob = model(chunk, sampling_rate).item()
speech_probs.append(speech_prob)
# caculate progress and seng it to callback function
progress = (current_start_sample / audio_length_samples) * 100
if progress_tracking_callback:
progress_tracking_callback(progress)
# sending 100% progress to callback function after processing with actual model
if progress_tracking_callback:
progress_tracking_callback(100)
triggered = False
speeches = []