Add Number Detector + utils

This commit is contained in:
sontref
2020-12-30 18:05:11 +00:00
parent 6d3f7f282b
commit b42d13f869
7 changed files with 400 additions and 29 deletions

View File

@@ -105,6 +105,39 @@ def get_speech_ts(wav: torch.Tensor,
return speeches
def get_number_ts(wav: torch.Tensor,
model,
model_stride=8,
hop_length=160,
sample_rate=16000,
run_function=validate):
wav = torch.unsqueeze(wav, dim=0)
perframe_logits = run_function(model, wav)[0]
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
extended_preds = []
for i in perframe_preds:
extended_preds.extend([i.item()] * model_stride)
# len(extended_preds) is *num_frames_real*; for each frame of audio we know if it has a number in it.
triggered = False
timings = []
cur_timing = {}
for i, pred in enumerate(extended_preds):
if pred == 1:
if not triggered:
cur_timing['start'] = (i * hop_length) / sample_rate
triggered = True
elif pred == 0:
if triggered:
cur_timing['end'] = (i * hop_length) / sample_rate
timings.append(cur_timing)
cur_timing = {}
triggered = False
if cur_timing:
cur_timing['end'] = len(wav) / sample_rate
timings.append(cur_timing)
return timings
class VADiterator:
def __init__(self,
trig_sum: float = 0.26,
@@ -252,9 +285,19 @@ def single_audio_stream(model,
yield states
def collect_speeches(tss: List[dict],
wav: torch.Tensor):
speech_chunks = []
def collect_chunks(tss: List[dict],
wav: torch.Tensor):
chunks = []
for i in tss:
speech_chunks.append(wav[i['start']: i['end']])
return torch.cat(speech_chunks)
chunks.append(wav[i['start']: i['end']])
return torch.cat(chunks)
def drop_chunks(tss: List[dict],
wav: torch.Tensor):
chunks = []
cur_start = 0
for i in tss:
chunks.append((wav[cur_start: i['start']]))
cur_start = i['end']
return torch.cat(chunks)