mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
initial 3.0 commit
This commit is contained in:
603
utils_vad.py
603
utils_vad.py
@@ -4,6 +4,7 @@ from typing import List
|
||||
from itertools import repeat
|
||||
from collections import deque
|
||||
import torch.nn.functional as F
|
||||
import warnings
|
||||
|
||||
|
||||
torchaudio.set_audio_backend("soundfile") # switch backend
|
||||
@@ -12,39 +13,18 @@ torchaudio.set_audio_backend("soundfile") # switch backend
|
||||
languages = ['ru', 'en', 'de', 'es']
|
||||
|
||||
|
||||
class IterativeMedianMeter():
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.median = 0
|
||||
self.counts = {}
|
||||
for i in range(0, 101, 1):
|
||||
self.counts[i / 100] = 0
|
||||
self.total_values = 0
|
||||
|
||||
def __call__(self, val):
|
||||
self.total_values += 1
|
||||
rounded = round(abs(val), 2)
|
||||
self.counts[rounded] += 1
|
||||
bin_sum = 0
|
||||
for j in self.counts:
|
||||
bin_sum += self.counts[j]
|
||||
if bin_sum >= self.total_values / 2:
|
||||
self.median = j
|
||||
break
|
||||
return self.median
|
||||
|
||||
|
||||
def validate(model,
|
||||
inputs: torch.Tensor):
|
||||
inputs: torch.Tensor,
|
||||
**kwargs):
|
||||
with torch.no_grad():
|
||||
outs = model(inputs)
|
||||
return outs
|
||||
outs = model(inputs, **kwargs)
|
||||
if len(outs.shape) == 1:
|
||||
return outs[1:]
|
||||
return outs[:, 1] # 0 for noise, 1 for speech
|
||||
|
||||
|
||||
def read_audio(path: str,
|
||||
target_sr: int = 16000):
|
||||
sampling_rate: int = 16000):
|
||||
|
||||
assert torchaudio.get_audio_backend() == 'soundfile'
|
||||
wav, sr = torchaudio.load(path)
|
||||
@@ -64,7 +44,7 @@ def read_audio(path: str,
|
||||
|
||||
def save_audio(path: str,
|
||||
tensor: torch.Tensor,
|
||||
sr: int = 16000):
|
||||
sampling_rate: int = 16000):
|
||||
torchaudio.save(path, tensor.unsqueeze(0), sr)
|
||||
|
||||
|
||||
@@ -76,192 +56,121 @@ def init_jit_model(model_path: str,
|
||||
return model
|
||||
|
||||
|
||||
def get_speech_ts(wav: torch.Tensor,
|
||||
model,
|
||||
trig_sum: float = 0.25,
|
||||
neg_trig_sum: float = 0.07,
|
||||
num_steps: int = 8,
|
||||
batch_size: int = 200,
|
||||
num_samples_per_window: int = 4000,
|
||||
min_speech_samples: int = 10000, #samples
|
||||
min_silence_samples: int = 500,
|
||||
run_function=validate,
|
||||
visualize_probs=False,
|
||||
smoothed_prob_func='mean',
|
||||
device='cpu'):
|
||||
|
||||
assert smoothed_prob_func in ['mean', 'max'], 'smoothed_prob_func not in ["max", "mean"]'
|
||||
num_samples = num_samples_per_window
|
||||
assert num_samples % num_steps == 0
|
||||
step = int(num_samples / num_steps) # stride / hop
|
||||
outs = []
|
||||
to_concat = []
|
||||
for i in range(0, len(wav), step):
|
||||
chunk = wav[i: i+num_samples]
|
||||
if len(chunk) < num_samples:
|
||||
chunk = F.pad(chunk, (0, num_samples - len(chunk)))
|
||||
to_concat.append(chunk.unsqueeze(0))
|
||||
if len(to_concat) >= batch_size:
|
||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
|
||||
out = run_function(model, chunks)
|
||||
outs.append(out)
|
||||
to_concat = []
|
||||
|
||||
if to_concat:
|
||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
|
||||
out = run_function(model, chunks)
|
||||
outs.append(out)
|
||||
|
||||
outs = torch.cat(outs, dim=0)
|
||||
|
||||
buffer = deque(maxlen=num_steps) # maxlen reached => first element dropped
|
||||
triggered = False
|
||||
speeches = []
|
||||
current_speech = {}
|
||||
if visualize_probs:
|
||||
import pandas as pd
|
||||
smoothed_probs = []
|
||||
|
||||
speech_probs = outs[:, 1] # this is very misleading
|
||||
temp_end = 0
|
||||
for i, predict in enumerate(speech_probs): # add name
|
||||
buffer.append(predict)
|
||||
if smoothed_prob_func == 'mean':
|
||||
smoothed_prob = (sum(buffer) / len(buffer))
|
||||
elif smoothed_prob_func == 'max':
|
||||
smoothed_prob = max(buffer)
|
||||
|
||||
if visualize_probs:
|
||||
smoothed_probs.append(float(smoothed_prob))
|
||||
if (smoothed_prob >= trig_sum) and temp_end:
|
||||
temp_end=0
|
||||
if (smoothed_prob >= trig_sum) and not triggered:
|
||||
triggered = True
|
||||
current_speech['start'] = step * max(0, i-num_steps)
|
||||
continue
|
||||
if (smoothed_prob < neg_trig_sum) and triggered:
|
||||
if not temp_end:
|
||||
temp_end = step * i
|
||||
if step * i - temp_end < min_silence_samples:
|
||||
continue
|
||||
else:
|
||||
current_speech['end'] = temp_end
|
||||
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
|
||||
speeches.append(current_speech)
|
||||
temp_end = 0
|
||||
current_speech = {}
|
||||
triggered = False
|
||||
continue
|
||||
if current_speech:
|
||||
current_speech['end'] = len(wav)
|
||||
speeches.append(current_speech)
|
||||
|
||||
if visualize_probs:
|
||||
pd.DataFrame({'probs':smoothed_probs}).plot(figsize=(16,8))
|
||||
return speeches
|
||||
def make_visualization(probs, step):
|
||||
import pandas as pd
|
||||
pd.DataFrame({'probs': probs},
|
||||
index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8),
|
||||
kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step],
|
||||
xlabel='seconds',
|
||||
ylabel='speech probability',
|
||||
colormap='tab20')
|
||||
|
||||
|
||||
def get_speech_ts_adaptive(wav: torch.Tensor,
|
||||
model,
|
||||
batch_size: int = 200,
|
||||
step: int = 500,
|
||||
num_samples_per_window: int = 4000, # Number of samples per audio chunk to feed to NN (4000 for 16k SR, 2000 for 8k SR is optimal)
|
||||
min_speech_samples: int = 10000, # samples
|
||||
min_silence_samples: int = 4000,
|
||||
speech_pad_samples: int = 2000,
|
||||
run_function=validate,
|
||||
visualize_probs=False,
|
||||
device='cpu'):
|
||||
def get_speech_timestamps(audio: torch.Tensor,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sample_rate: int = 16000,
|
||||
min_speech_duration_ms: int = 250,
|
||||
min_silence_duration_ms: int = 100,
|
||||
window_size_samples: int = 1536,
|
||||
speech_pad_ms: int = 30,
|
||||
return_seconds: bool = False,
|
||||
visualize_probs: bool = False):
|
||||
|
||||
"""
|
||||
This function is used for splitting long audios into speech chunks using silero VAD
|
||||
Attention! All default sample rate values are optimal for 16000 sample rate model, if you are using 8000 sample rate model optimal values are half as much!
|
||||
This method is used for splitting long audios into speech chunks using silero VAD
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
batch size to feed to silero VAD (default - 200)
|
||||
audio: torch.Tensor, one dimensional
|
||||
One dimensional float torch.Tensor, other types are casted to torch if possible
|
||||
|
||||
step: int
|
||||
step size in samples, (default - 500)
|
||||
model: preloaded .jit silero VAD model
|
||||
|
||||
num_samples_per_window: int
|
||||
window size in samples (chunk length in samples to feed to NN, default - 4000)
|
||||
threshold: float (default - 0.5)
|
||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||
|
||||
min_speech_samples: int
|
||||
if speech duration is shorter than this value, do not consider it speech (default - 10000)
|
||||
sample_rate: int (default - 16000)
|
||||
Currently silero VAD models support 8000 and 16000 sample rates
|
||||
|
||||
min_silence_samples: int
|
||||
number of samples to wait before considering as the end of speech (default - 4000)
|
||||
min_speech_duration_ms: int (default - 250 milliseconds)
|
||||
Final speech chunks shorter min_speech_duration_ms are thrown out
|
||||
|
||||
speech_pad_samples: int
|
||||
widen speech by this amount of samples each side (default - 2000)
|
||||
min_silence_duration_ms: int (default - 100 milliseconds)
|
||||
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
||||
|
||||
run_function: function
|
||||
function to use for the model call
|
||||
window_size_samples: int (default - 1536 samples)
|
||||
Audio chunks of window_size_samples size are fed to the silero VAD model.
|
||||
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples for 8000 sample rate.
|
||||
Values other than these may affect model perfomance!!
|
||||
|
||||
visualize_probs: bool
|
||||
whether draw prob hist or not (default: False)
|
||||
speech_pad_ms: int (default - 30 milliseconds)
|
||||
Final speech chunks are padded by speech_pad_ms each side
|
||||
|
||||
device: string
|
||||
torch device to use for the model call (default - "cpu")
|
||||
return_seconds: bool (default - False)
|
||||
whether return timestamps in seconds (default - samples)
|
||||
|
||||
visualize_probs: bool (default - False)
|
||||
whether draw prob hist or not
|
||||
|
||||
Returns
|
||||
----------
|
||||
speeches: list
|
||||
list containing ends and beginnings of speech chunks (in samples)
|
||||
speeches: list of dicts
|
||||
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
|
||||
"""
|
||||
if visualize_probs:
|
||||
import pandas as pd
|
||||
|
||||
num_samples = num_samples_per_window
|
||||
num_steps = int(num_samples / step)
|
||||
assert min_silence_samples >= step
|
||||
outs = []
|
||||
to_concat = []
|
||||
for i in range(0, len(wav), step):
|
||||
chunk = wav[i: i+num_samples]
|
||||
if len(chunk) < num_samples:
|
||||
chunk = F.pad(chunk, (0, num_samples - len(chunk)))
|
||||
to_concat.append(chunk.unsqueeze(0))
|
||||
if len(to_concat) >= batch_size:
|
||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
|
||||
out = run_function(model, chunks)
|
||||
outs.append(out)
|
||||
to_concat = []
|
||||
if not torch.is_tensor(audio):
|
||||
try:
|
||||
audio = torch.Tensor(audio)
|
||||
except:
|
||||
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
|
||||
|
||||
if to_concat:
|
||||
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
|
||||
out = run_function(model, chunks)
|
||||
outs.append(out)
|
||||
if len(audio.shape) > 1:
|
||||
for i in range(len(audio.shape)): # trying to squeeze empty dimensions
|
||||
audio = audio.squeeze(0)
|
||||
if len(audio.shape) > 1:
|
||||
raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?")
|
||||
|
||||
outs = torch.cat(outs, dim=0).cpu()
|
||||
if sample_rate == 8000 and window_size_samples > 768:
|
||||
warnings.warn('window_size_samples is too big for 8000 sample_rate! Better set window_size_samples to 256, 512 or 1536 for 8000 sample rate!')
|
||||
if window_size_samples not in [256, 512, 768, 1024, 1536]:
|
||||
warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sample_rate\n - [256, 512, 768] for 8000 sample_rate')
|
||||
|
||||
model.reset_states()
|
||||
min_speech_samples = sample_rate * min_speech_duration_ms / 1000
|
||||
min_silence_samples = sample_rate * min_silence_duration_ms / 1000
|
||||
speech_pad_samples = sample_rate * speech_pad_ms / 1000
|
||||
|
||||
audio_length_samples = len(audio)
|
||||
|
||||
speech_probs = []
|
||||
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
||||
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
|
||||
if len(chunk) < window_size_samples:
|
||||
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||
speech_prob = model(chunk, sample_rate).item()
|
||||
speech_probs.append(speech_prob)
|
||||
|
||||
buffer = deque(maxlen=num_steps)
|
||||
triggered = False
|
||||
speeches = []
|
||||
smoothed_probs = []
|
||||
current_speech = {}
|
||||
speech_probs = outs[:, 1] # 0 index for silence probs, 1 index for speech probs
|
||||
median_probs = speech_probs.median()
|
||||
|
||||
trig_sum = 0.89 * median_probs + 0.08 # 0.08 when median is zero, 0.97 when median is 1
|
||||
|
||||
neg_threshold = threshold - 0.15
|
||||
temp_end = 0
|
||||
for i, predict in enumerate(speech_probs):
|
||||
buffer.append(predict)
|
||||
smoothed_prob = max(buffer)
|
||||
if visualize_probs:
|
||||
smoothed_probs.append(float(smoothed_prob))
|
||||
if (smoothed_prob >= trig_sum) and temp_end:
|
||||
|
||||
for i, speech_prob in enumerate(speech_probs):
|
||||
if (speech_prob >= threshold) and temp_end:
|
||||
temp_end = 0
|
||||
if (smoothed_prob >= trig_sum) and not triggered:
|
||||
|
||||
if (speech_prob >= threshold) and not triggered:
|
||||
triggered = True
|
||||
current_speech['start'] = step * max(0, i-num_steps)
|
||||
current_speech['start'] = window_size_samples * i
|
||||
continue
|
||||
if (smoothed_prob < trig_sum) and triggered:
|
||||
|
||||
if (speech_prob < neg_threshold) and triggered:
|
||||
if not temp_end:
|
||||
temp_end = step * i
|
||||
if step * i - temp_end < min_silence_samples:
|
||||
temp_end = window_size_samples * i
|
||||
if (window_size_samples * i) - temp_end < min_silence_samples:
|
||||
continue
|
||||
else:
|
||||
current_speech['end'] = temp_end
|
||||
@@ -271,24 +180,31 @@ def get_speech_ts_adaptive(wav: torch.Tensor,
|
||||
current_speech = {}
|
||||
triggered = False
|
||||
continue
|
||||
if current_speech:
|
||||
current_speech['end'] = len(wav)
|
||||
speeches.append(current_speech)
|
||||
if visualize_probs:
|
||||
pd.DataFrame({'probs': smoothed_probs}).plot(figsize=(16, 8))
|
||||
|
||||
for i, ts in enumerate(speeches):
|
||||
if current_speech:
|
||||
current_speech['end'] = audio_length_samples
|
||||
speeches.append(current_speech)
|
||||
|
||||
for i, speech in enumerate(speeches):
|
||||
if i == 0:
|
||||
ts['start'] = max(0, ts['start'] - speech_pad_samples)
|
||||
speech['start'] = int(max(0, speech['start'] - speech_pad_samples))
|
||||
if i != len(speeches) - 1:
|
||||
silence_duration = speeches[i+1]['start'] - ts['end']
|
||||
silence_duration = speeches[i+1]['start'] - speech['end']
|
||||
if silence_duration < 2 * speech_pad_samples:
|
||||
ts['end'] += silence_duration // 2
|
||||
speeches[i+1]['start'] = max(0, speeches[i+1]['start'] - silence_duration // 2)
|
||||
speech['end'] += int(silence_duration // 2)
|
||||
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2))
|
||||
else:
|
||||
ts['end'] += speech_pad_samples
|
||||
speech['end'] += int(speech_pad_samples)
|
||||
else:
|
||||
ts['end'] = min(len(wav), ts['end'] + speech_pad_samples)
|
||||
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
|
||||
|
||||
if return_seconds:
|
||||
for speech_dict in speeches:
|
||||
speech_dict['start'] = round(speech_dict['start'] / sample_rate, 1)
|
||||
speech_dict['end'] = round(speech_dict['end'] / sample_rate, 1)
|
||||
|
||||
if visualize_probs:
|
||||
make_visualization(speech_probs, window_size_samples / sample_rate)
|
||||
|
||||
return speeches
|
||||
|
||||
@@ -344,13 +260,13 @@ def get_language_and_group(wav: torch.Tensor,
|
||||
run_function=validate):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
lang_logits, lang_group_logits = run_function(model, wav)
|
||||
|
||||
|
||||
softm = torch.softmax(lang_logits, dim=1).squeeze()
|
||||
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
|
||||
|
||||
|
||||
srtd = torch.argsort(softm, descending=True)
|
||||
srtd_group = torch.argsort(softm_group, descending=True)
|
||||
|
||||
|
||||
outs = []
|
||||
outs_group = []
|
||||
for i in range(top_n):
|
||||
@@ -362,256 +278,83 @@ def get_language_and_group(wav: torch.Tensor,
|
||||
return outs, outs_group
|
||||
|
||||
|
||||
class VADiterator:
|
||||
class VADIterator:
|
||||
def __init__(self,
|
||||
trig_sum: float = 0.26,
|
||||
neg_trig_sum: float = 0.07,
|
||||
num_steps: int = 8,
|
||||
num_samples_per_window: int = 4000):
|
||||
self.num_samples = num_samples_per_window
|
||||
self.num_steps = num_steps
|
||||
assert self.num_samples % num_steps == 0
|
||||
self.step = int(self.num_samples / num_steps) # 500 samples is good enough
|
||||
self.prev = torch.zeros(self.num_samples)
|
||||
self.last = False
|
||||
self.triggered = False
|
||||
self.buffer = deque(maxlen=num_steps)
|
||||
self.num_frames = 0
|
||||
self.trig_sum = trig_sum
|
||||
self.neg_trig_sum = neg_trig_sum
|
||||
self.current_name = ''
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sample_rate: int = 16000,
|
||||
min_silence_duration_ms: int = 100,
|
||||
speech_pad_ms: int = 30
|
||||
):
|
||||
|
||||
def refresh(self):
|
||||
self.prev = torch.zeros(self.num_samples)
|
||||
self.last = False
|
||||
self.triggered = False
|
||||
self.buffer = deque(maxlen=self.num_steps)
|
||||
self.num_frames = 0
|
||||
|
||||
def prepare_batch(self, wav_chunk, name=None):
|
||||
if (name is not None) and (name != self.current_name):
|
||||
self.refresh()
|
||||
self.current_name = name
|
||||
assert len(wav_chunk) <= self.num_samples
|
||||
self.num_frames += len(wav_chunk)
|
||||
if len(wav_chunk) < self.num_samples:
|
||||
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # short chunk => eof audio
|
||||
self.last = True
|
||||
|
||||
stacked = torch.cat([self.prev, wav_chunk])
|
||||
self.prev = wav_chunk
|
||||
|
||||
overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0)
|
||||
for i in range(self.step, self.num_samples+1, self.step)]
|
||||
return torch.cat(overlap_chunks, dim=0)
|
||||
|
||||
def state(self, model_out):
|
||||
current_speech = {}
|
||||
speech_probs = model_out[:, 1] # this is very misleading
|
||||
for i, predict in enumerate(speech_probs):
|
||||
self.buffer.append(predict)
|
||||
if ((sum(self.buffer) / len(self.buffer)) >= self.trig_sum) and not self.triggered:
|
||||
self.triggered = True
|
||||
current_speech[self.num_frames - (self.num_steps-i) * self.step] = 'start'
|
||||
if ((sum(self.buffer) / len(self.buffer)) < self.neg_trig_sum) and self.triggered:
|
||||
current_speech[self.num_frames - (self.num_steps-i) * self.step] = 'end'
|
||||
self.triggered = False
|
||||
if self.triggered and self.last:
|
||||
current_speech[self.num_frames] = 'end'
|
||||
if self.last:
|
||||
self.refresh()
|
||||
return current_speech, self.current_name
|
||||
|
||||
|
||||
class VADiteratorAdaptive:
|
||||
def __init__(self,
|
||||
trig_sum: float = 0.26,
|
||||
neg_trig_sum: float = 0.06,
|
||||
step: int = 500,
|
||||
num_samples_per_window: int = 4000,
|
||||
speech_pad_samples: int = 1000,
|
||||
accum_period: int = 50):
|
||||
"""
|
||||
This class is used for streaming silero VAD usage
|
||||
Class for stream imitation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trig_sum: float
|
||||
trigger value for speech probability, probs above this value are considered speech, switch to TRIGGERED state (default - 0.26)
|
||||
model: preloaded .jit silero VAD model
|
||||
|
||||
neg_trig_sum: float
|
||||
in triggered state probabilites below this value are considered nonspeech, switch to NONTRIGGERED state (default - 0.06)
|
||||
threshold: float (default - 0.5)
|
||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||
|
||||
step: int
|
||||
step size in samples, (default - 500)
|
||||
sample_rate: int (default - 16000)
|
||||
Currently silero VAD models support 8000 and 16000 sample rates
|
||||
|
||||
num_samples_per_window: int
|
||||
window size in samples (chunk length in samples to feed to NN, default - 4000)
|
||||
|
||||
speech_pad_samples: int
|
||||
widen speech by this amount of samples each side (default - 1000)
|
||||
|
||||
accum_period: int
|
||||
number of chunks / iterations to wait before switching from constant (initial) trig and neg_trig coeffs to adaptive median coeffs (default - 50)
|
||||
min_silence_duration_ms: int (default - 100 milliseconds)
|
||||
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
||||
|
||||
speech_pad_ms: int (default - 30 milliseconds)
|
||||
Final speech chunks are padded by speech_pad_ms each side
|
||||
"""
|
||||
self.num_samples = num_samples_per_window
|
||||
self.num_steps = int(num_samples_per_window / step)
|
||||
self.step = step
|
||||
self.prev = torch.zeros(self.num_samples)
|
||||
self.last = False
|
||||
|
||||
self.model = model
|
||||
self.threshold = threshold
|
||||
self.sample_rate = sample_rate
|
||||
self.min_silence_samples = sample_rate * min_silence_duration_ms / 1000
|
||||
self.speech_pad_samples = sample_rate * speech_pad_ms / 1000
|
||||
self.reset_states()
|
||||
|
||||
def reset_states(self):
|
||||
|
||||
self.model.reset_states()
|
||||
self.triggered = False
|
||||
self.buffer = deque(maxlen=self.num_steps)
|
||||
self.num_frames = 0
|
||||
self.trig_sum = trig_sum
|
||||
self.neg_trig_sum = neg_trig_sum
|
||||
self.current_name = ''
|
||||
self.median_meter = IterativeMedianMeter()
|
||||
self.median = 0
|
||||
self.total_steps = 0
|
||||
self.accum_period = accum_period
|
||||
self.speech_pad_samples = speech_pad_samples
|
||||
self.temp_end = 0
|
||||
self.current_sample = 0
|
||||
|
||||
def refresh(self):
|
||||
self.prev = torch.zeros(self.num_samples)
|
||||
self.last = False
|
||||
self.triggered = False
|
||||
self.buffer = deque(maxlen=self.num_steps)
|
||||
self.num_frames = 0
|
||||
self.median_meter.reset()
|
||||
self.median = 0
|
||||
self.total_steps = 0
|
||||
def __call__(self, x, return_seconds=False):
|
||||
"""
|
||||
x: torch.Tensor
|
||||
audio chunk (see examples in repo)
|
||||
|
||||
def prepare_batch(self, wav_chunk, name=None):
|
||||
if (name is not None) and (name != self.current_name):
|
||||
self.refresh()
|
||||
self.current_name = name
|
||||
assert len(wav_chunk) <= self.num_samples
|
||||
self.num_frames += len(wav_chunk)
|
||||
if len(wav_chunk) < self.num_samples:
|
||||
wav_chunk = F.pad(wav_chunk, (0, self.num_samples - len(wav_chunk))) # short chunk => eof audio
|
||||
self.last = True
|
||||
return_seconds: bool (default - False)
|
||||
whether return timestamps in seconds (default - samples)
|
||||
"""
|
||||
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
||||
self.current_sample += window_size_samples
|
||||
|
||||
stacked = torch.cat([self.prev, wav_chunk])
|
||||
self.prev = wav_chunk
|
||||
speech_prob = self.model(x, self.sample_rate).item()
|
||||
|
||||
overlap_chunks = [stacked[i:i+self.num_samples].unsqueeze(0)
|
||||
for i in range(self.step, self.num_samples+1, self.step)]
|
||||
return torch.cat(overlap_chunks, dim=0)
|
||||
if (speech_prob >= self.threshold) and self.temp_end:
|
||||
self.temp_end = 0
|
||||
|
||||
def state(self, model_out):
|
||||
current_speech = {}
|
||||
speech_probs = model_out[:, 1] # 0 index for silence probs, 1 index for speech probs
|
||||
for i, predict in enumerate(speech_probs):
|
||||
self.median = self.median_meter(predict.item())
|
||||
if self.total_steps < self.accum_period:
|
||||
trig_sum = self.trig_sum
|
||||
neg_trig_sum = self.neg_trig_sum
|
||||
if (speech_prob >= self.threshold) and not self.triggered:
|
||||
self.triggered = True
|
||||
speech_start = self.current_sample - self.speech_pad_samples
|
||||
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sample_rate, 1)}
|
||||
|
||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||
if not self.temp_end:
|
||||
self.temp_end = self.current_sample
|
||||
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||
return None
|
||||
else:
|
||||
trig_sum = 0.89 * self.median + 0.08 # 0.08 when median is zero, 0.97 when median is 1
|
||||
neg_trig_sum = 0.6 * self.median
|
||||
self.total_steps += 1
|
||||
self.buffer.append(predict)
|
||||
smoothed_prob = max(self.buffer)
|
||||
if (smoothed_prob >= trig_sum) and not self.triggered:
|
||||
self.triggered = True
|
||||
current_speech[max(0, self.num_frames - (self.num_steps-i) * self.step - self.speech_pad_samples)] = 'start'
|
||||
if (smoothed_prob < neg_trig_sum) and self.triggered:
|
||||
current_speech[self.num_frames - (self.num_steps-i) * self.step + self.speech_pad_samples] = 'end'
|
||||
speech_end = self.temp_end + self.speech_pad_samples
|
||||
self.temp_end = 0
|
||||
self.triggered = False
|
||||
if self.triggered and self.last:
|
||||
current_speech[self.num_frames] = 'end'
|
||||
if self.last:
|
||||
self.refresh()
|
||||
return current_speech, self.current_name
|
||||
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sample_rate, 1)}
|
||||
|
||||
|
||||
def state_generator(model,
|
||||
audios: List[str],
|
||||
onnx: bool = False,
|
||||
trig_sum: float = 0.26,
|
||||
neg_trig_sum: float = 0.07,
|
||||
num_steps: int = 8,
|
||||
num_samples_per_window: int = 4000,
|
||||
audios_in_stream: int = 2,
|
||||
run_function=validate):
|
||||
VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps, num_samples_per_window) for i in range(audios_in_stream)]
|
||||
for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream, num_samples_per_window)):
|
||||
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
|
||||
batch = torch.cat(for_batch)
|
||||
|
||||
outs = run_function(model, batch)
|
||||
vad_outs = torch.split(outs, num_steps)
|
||||
|
||||
states = []
|
||||
for x, y in zip(VADiters, vad_outs):
|
||||
cur_st = x.state(y)
|
||||
if cur_st[0]:
|
||||
states.append(cur_st)
|
||||
yield states
|
||||
|
||||
|
||||
def stream_imitator(audios: List[str],
|
||||
audios_in_stream: int,
|
||||
num_samples_per_window: int = 4000):
|
||||
audio_iter = iter(audios)
|
||||
iterators = []
|
||||
num_samples = num_samples_per_window
|
||||
# initial wavs
|
||||
for i in range(audios_in_stream):
|
||||
next_wav = next(audio_iter)
|
||||
wav = read_audio(next_wav)
|
||||
wav_chunks = iter([(wav[i:i+num_samples], next_wav) for i in range(0, len(wav), num_samples)])
|
||||
iterators.append(wav_chunks)
|
||||
print('Done initial Loading')
|
||||
good_iters = audios_in_stream
|
||||
while True:
|
||||
values = []
|
||||
for i, it in enumerate(iterators):
|
||||
try:
|
||||
out, wav_name = next(it)
|
||||
except StopIteration:
|
||||
try:
|
||||
next_wav = next(audio_iter)
|
||||
print('Loading next wav: ', next_wav)
|
||||
wav = read_audio(next_wav)
|
||||
iterators[i] = iter([(wav[i:i+num_samples], next_wav) for i in range(0, len(wav), num_samples)])
|
||||
out, wav_name = next(iterators[i])
|
||||
except StopIteration:
|
||||
good_iters -= 1
|
||||
iterators[i] = repeat((torch.zeros(num_samples), 'junk'))
|
||||
out, wav_name = next(iterators[i])
|
||||
if good_iters == 0:
|
||||
return
|
||||
values.append((out, wav_name))
|
||||
yield values
|
||||
|
||||
|
||||
def single_audio_stream(model,
|
||||
audio: torch.Tensor,
|
||||
num_samples_per_window:int = 4000,
|
||||
run_function=validate,
|
||||
iterator_type='basic',
|
||||
**kwargs):
|
||||
|
||||
num_samples = num_samples_per_window
|
||||
if iterator_type == 'basic':
|
||||
VADiter = VADiterator(num_samples_per_window=num_samples_per_window, **kwargs)
|
||||
elif iterator_type == 'adaptive':
|
||||
VADiter = VADiteratorAdaptive(num_samples_per_window=num_samples_per_window, **kwargs)
|
||||
|
||||
wav = read_audio(audio)
|
||||
wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)])
|
||||
for chunk in wav_chunks:
|
||||
batch = VADiter.prepare_batch(chunk)
|
||||
|
||||
outs = run_function(model, batch)
|
||||
|
||||
states = []
|
||||
state = VADiter.state(outs)
|
||||
if state[0]:
|
||||
states.append(state[0])
|
||||
yield states
|
||||
return None
|
||||
|
||||
|
||||
def collect_chunks(tss: List[dict],
|
||||
|
||||
Reference in New Issue
Block a user