mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 09:59:20 +08:00
Add Number Detector + utils
This commit is contained in:
53
utils.py
53
utils.py
@@ -105,6 +105,39 @@ def get_speech_ts(wav: torch.Tensor,
|
||||
return speeches
|
||||
|
||||
|
||||
def get_number_ts(wav: torch.Tensor,
|
||||
model,
|
||||
model_stride=8,
|
||||
hop_length=160,
|
||||
sample_rate=16000,
|
||||
run_function=validate):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
perframe_logits = run_function(model, wav)[0]
|
||||
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
|
||||
extended_preds = []
|
||||
for i in perframe_preds:
|
||||
extended_preds.extend([i.item()] * model_stride)
|
||||
# len(extended_preds) is *num_frames_real*; for each frame of audio we know if it has a number in it.
|
||||
triggered = False
|
||||
timings = []
|
||||
cur_timing = {}
|
||||
for i, pred in enumerate(extended_preds):
|
||||
if pred == 1:
|
||||
if not triggered:
|
||||
cur_timing['start'] = (i * hop_length) / sample_rate
|
||||
triggered = True
|
||||
elif pred == 0:
|
||||
if triggered:
|
||||
cur_timing['end'] = (i * hop_length) / sample_rate
|
||||
timings.append(cur_timing)
|
||||
cur_timing = {}
|
||||
triggered = False
|
||||
if cur_timing:
|
||||
cur_timing['end'] = len(wav) / sample_rate
|
||||
timings.append(cur_timing)
|
||||
return timings
|
||||
|
||||
|
||||
class VADiterator:
|
||||
def __init__(self,
|
||||
trig_sum: float = 0.26,
|
||||
@@ -252,9 +285,19 @@ def single_audio_stream(model,
|
||||
yield states
|
||||
|
||||
|
||||
def collect_speeches(tss: List[dict],
|
||||
wav: torch.Tensor):
|
||||
speech_chunks = []
|
||||
def collect_chunks(tss: List[dict],
|
||||
wav: torch.Tensor):
|
||||
chunks = []
|
||||
for i in tss:
|
||||
speech_chunks.append(wav[i['start']: i['end']])
|
||||
return torch.cat(speech_chunks)
|
||||
chunks.append(wav[i['start']: i['end']])
|
||||
return torch.cat(chunks)
|
||||
|
||||
|
||||
def drop_chunks(tss: List[dict],
|
||||
wav: torch.Tensor):
|
||||
chunks = []
|
||||
cur_start = 0
|
||||
for i in tss:
|
||||
chunks.append((wav[cur_start: i['start']]))
|
||||
cur_start = i['end']
|
||||
return torch.cat(chunks)
|
||||
|
||||
Reference in New Issue
Block a user