mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
add progress_tracking callback to get_speech_timestamps
This commit is contained in:
16
utils_vad.py
16
utils_vad.py
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from typing import List
|
from typing import Callable, List
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@@ -168,7 +168,8 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
window_size_samples: int = 512,
|
window_size_samples: int = 512,
|
||||||
speech_pad_ms: int = 30,
|
speech_pad_ms: int = 30,
|
||||||
return_seconds: bool = False,
|
return_seconds: bool = False,
|
||||||
visualize_probs: bool = False):
|
visualize_probs: bool = False,
|
||||||
|
progress_tracking_callback: Callable[[float], None] = None):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This method is used for splitting long audios into speech chunks using silero VAD
|
This method is used for splitting long audios into speech chunks using silero VAD
|
||||||
@@ -212,6 +213,9 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
visualize_probs: bool (default - False)
|
visualize_probs: bool (default - False)
|
||||||
whether draw prob hist or not
|
whether draw prob hist or not
|
||||||
|
|
||||||
|
progress_tracking_callback: Callable[[float], None] (default - None)
|
||||||
|
callback function taking progress in percents as an argument
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
speeches: list of dicts
|
speeches: list of dicts
|
||||||
@@ -259,6 +263,14 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||||
speech_prob = model(chunk, sampling_rate).item()
|
speech_prob = model(chunk, sampling_rate).item()
|
||||||
speech_probs.append(speech_prob)
|
speech_probs.append(speech_prob)
|
||||||
|
# caculate progress and seng it to callback function
|
||||||
|
progress = (current_start_sample / audio_length_samples) * 100
|
||||||
|
if progress_tracking_callback:
|
||||||
|
progress_tracking_callback(progress)
|
||||||
|
|
||||||
|
# sending 100% progress to callback function after processing with actual model
|
||||||
|
if progress_tracking_callback:
|
||||||
|
progress_tracking_callback(100)
|
||||||
|
|
||||||
triggered = False
|
triggered = False
|
||||||
speeches = []
|
speeches = []
|
||||||
|
|||||||
Reference in New Issue
Block a user