mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
Merge pull request #282 from saenyakorn/master
Add `progress_tracking_callback` argument to `get_speech_timestamps` function
This commit is contained in:
15
utils_vad.py
15
utils_vad.py
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from typing import List
|
from typing import Callable, List
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@@ -168,7 +168,8 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
window_size_samples: int = 512,
|
window_size_samples: int = 512,
|
||||||
speech_pad_ms: int = 30,
|
speech_pad_ms: int = 30,
|
||||||
return_seconds: bool = False,
|
return_seconds: bool = False,
|
||||||
visualize_probs: bool = False):
|
visualize_probs: bool = False,
|
||||||
|
progress_tracking_callback: Callable[[float], None] = None):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This method is used for splitting long audios into speech chunks using silero VAD
|
This method is used for splitting long audios into speech chunks using silero VAD
|
||||||
@@ -212,6 +213,9 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
visualize_probs: bool (default - False)
|
visualize_probs: bool (default - False)
|
||||||
whether draw prob hist or not
|
whether draw prob hist or not
|
||||||
|
|
||||||
|
progress_tracking_callback: Callable[[float], None] (default - None)
|
||||||
|
callback function taking progress in percents as an argument
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
speeches: list of dicts
|
speeches: list of dicts
|
||||||
@@ -259,6 +263,13 @@ def get_speech_timestamps(audio: torch.Tensor,
|
|||||||
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||||
speech_prob = model(chunk, sampling_rate).item()
|
speech_prob = model(chunk, sampling_rate).item()
|
||||||
speech_probs.append(speech_prob)
|
speech_probs.append(speech_prob)
|
||||||
|
# caculate progress and seng it to callback function
|
||||||
|
progress = current_start_sample + window_size_samples
|
||||||
|
if progress > audio_length_samples:
|
||||||
|
progress = audio_length_samples
|
||||||
|
progress_percent = (progress / audio_length_samples) * 100
|
||||||
|
if progress_tracking_callback:
|
||||||
|
progress_tracking_callback(progress_percent)
|
||||||
|
|
||||||
triggered = False
|
triggered = False
|
||||||
speeches = []
|
speeches = []
|
||||||
|
|||||||
Reference in New Issue
Block a user