This commit is contained in:
adamnsandle
2021-12-07 10:59:30 +00:00
parent 8f16c14066
commit 8794d6f835

View File

@@ -62,7 +62,7 @@ def make_visualization(probs, step):
def get_speech_timestamps(audio: torch.Tensor, def get_speech_timestamps(audio: torch.Tensor,
model, model,
threshold: float = 0.5, threshold: float = 0.5,
sample_rate: int = 16000, sampling_rate: int = 16000,
min_speech_duration_ms: int = 250, min_speech_duration_ms: int = 250,
min_silence_duration_ms: int = 100, min_silence_duration_ms: int = 100,
window_size_samples: int = 1536, window_size_samples: int = 1536,
@@ -84,7 +84,7 @@ def get_speech_timestamps(audio: torch.Tensor,
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
sample_rate: int (default - 16000) sampling_rate: int (default - 16000)
Currently silero VAD models support 8000 and 16000 sample rates Currently silero VAD models support 8000 and 16000 sample rates
min_speech_duration_ms: int (default - 250 milliseconds) min_speech_duration_ms: int (default - 250 milliseconds)
@@ -125,15 +125,15 @@ def get_speech_timestamps(audio: torch.Tensor,
if len(audio.shape) > 1: if len(audio.shape) > 1:
raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?") raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?")
if sample_rate == 8000 and window_size_samples > 768: if sampling_rate == 8000 and window_size_samples > 768:
warnings.warn('window_size_samples is too big for 8000 sample_rate! Better set window_size_samples to 256, 512 or 1536 for 8000 sample rate!') warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 1536 for 8000 sample rate!')
if window_size_samples not in [256, 512, 768, 1024, 1536]: if window_size_samples not in [256, 512, 768, 1024, 1536]:
warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sample_rate\n - [256, 512, 768] for 8000 sample_rate') warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate')
model.reset_states() model.reset_states()
min_speech_samples = sample_rate * min_speech_duration_ms / 1000 min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
min_silence_samples = sample_rate * min_silence_duration_ms / 1000 min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
speech_pad_samples = sample_rate * speech_pad_ms / 1000 speech_pad_samples = sampling_rate * speech_pad_ms / 1000
audio_length_samples = len(audio) audio_length_samples = len(audio)
@@ -142,7 +142,7 @@ def get_speech_timestamps(audio: torch.Tensor,
chunk = audio[current_start_sample: current_start_sample + window_size_samples] chunk = audio[current_start_sample: current_start_sample + window_size_samples]
if len(chunk) < window_size_samples: if len(chunk) < window_size_samples:
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk)))) chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob = model(chunk, sample_rate).item() speech_prob = model(chunk, sampling_rate).item()
speech_probs.append(speech_prob) speech_probs.append(speech_prob)
triggered = False triggered = False
@@ -193,11 +193,11 @@ def get_speech_timestamps(audio: torch.Tensor,
if return_seconds: if return_seconds:
for speech_dict in speeches: for speech_dict in speeches:
speech_dict['start'] = round(speech_dict['start'] / sample_rate, 1) speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1)
speech_dict['end'] = round(speech_dict['end'] / sample_rate, 1) speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1)
if visualize_probs: if visualize_probs:
make_visualization(speech_probs, window_size_samples / sample_rate) make_visualization(speech_probs, window_size_samples / sampling_rate)
return speeches return speeches
@@ -275,7 +275,7 @@ class VADIterator:
def __init__(self, def __init__(self,
model, model,
threshold: float = 0.5, threshold: float = 0.5,
sample_rate: int = 16000, sampling_rate: int = 16000,
min_silence_duration_ms: int = 100, min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30 speech_pad_ms: int = 30
): ):
@@ -291,7 +291,7 @@ class VADIterator:
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
sample_rate: int (default - 16000) sampling_rate: int (default - 16000)
Currently silero VAD models support 8000 and 16000 sample rates Currently silero VAD models support 8000 and 16000 sample rates
min_silence_duration_ms: int (default - 100 milliseconds) min_silence_duration_ms: int (default - 100 milliseconds)
@@ -303,9 +303,9 @@ class VADIterator:
self.model = model self.model = model
self.threshold = threshold self.threshold = threshold
self.sample_rate = sample_rate self.sampling_rate = sampling_rate
self.min_silence_samples = sample_rate * min_silence_duration_ms / 1000 self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
self.speech_pad_samples = sample_rate * speech_pad_ms / 1000 self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
self.reset_states() self.reset_states()
def reset_states(self): def reset_states(self):
@@ -326,7 +326,7 @@ class VADIterator:
window_size_samples = len(x[0]) if x.dim() == 2 else len(x) window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
self.current_sample += window_size_samples self.current_sample += window_size_samples
speech_prob = self.model(x, self.sample_rate).item() speech_prob = self.model(x, self.sampling_rate).item()
if (speech_prob >= self.threshold) and self.temp_end: if (speech_prob >= self.threshold) and self.temp_end:
self.temp_end = 0 self.temp_end = 0
@@ -334,7 +334,7 @@ class VADIterator:
if (speech_prob >= self.threshold) and not self.triggered: if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True self.triggered = True
speech_start = self.current_sample - self.speech_pad_samples speech_start = self.current_sample - self.speech_pad_samples
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sample_rate, 1)} return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
if (speech_prob < self.threshold - 0.15) and self.triggered: if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end: if not self.temp_end:
@@ -345,7 +345,7 @@ class VADIterator:
speech_end = self.temp_end + self.speech_pad_samples speech_end = self.temp_end + self.speech_pad_samples
self.temp_end = 0 self.temp_end = 0
self.triggered = False self.triggered = False
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sample_rate, 1)} return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
return None return None