Merge pull request #627 from b3by/feature/time_coordinates_resolution

Specify time resolution when returning speech coordinates in seconds
This commit is contained in:
Dimitrii Voronin
2025-03-24 18:59:25 +03:00
committed by GitHub

View File

@@ -197,6 +197,7 @@ def get_speech_timestamps(audio: torch.Tensor,
min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30,
return_seconds: bool = False,
time_resolution: int = 1,
visualize_probs: bool = False,
progress_tracking_callback: Callable[[float], None] = None,
neg_threshold: float = None,
@@ -236,6 +237,9 @@ def get_speech_timestamps(audio: torch.Tensor,
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
time_resolution: bool (default - 1)
time resolution of speech coordinates when requested as seconds
visualize_probs: bool (default - False)
whether draw prob hist or not
@@ -378,8 +382,8 @@ def get_speech_timestamps(audio: torch.Tensor,
if return_seconds:
audio_length_seconds = audio_length_samples / sampling_rate
for speech_dict in speeches:
speech_dict['start'] = max(round(speech_dict['start'] / sampling_rate, 1), 0)
speech_dict['end'] = min(round(speech_dict['end'] / sampling_rate, 1), audio_length_seconds)
speech_dict['start'] = max(round(speech_dict['start'] / sampling_rate, time_resolution), 0)
speech_dict['end'] = min(round(speech_dict['end'] / sampling_rate, time_resolution), audio_length_seconds)
elif step > 1:
for speech_dict in speeches:
speech_dict['start'] *= step
@@ -440,13 +444,16 @@ class VADIterator:
self.current_sample = 0
@torch.no_grad()
def __call__(self, x, return_seconds=False):
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
"""
x: torch.Tensor
audio chunk (see examples in repo)
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
time_resolution: int (default - 1)
time resolution of speech coordinates when requested as seconds
"""
if not torch.is_tensor(x):
@@ -466,7 +473,7 @@ class VADIterator:
if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end:
@@ -477,7 +484,7 @@ class VADIterator:
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
self.temp_end = 0
self.triggered = False
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
return None