mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 09:29:22 +08:00
Remove old unused utils
This commit is contained in:
66
utils_vad.py
66
utils_vad.py
@@ -379,72 +379,6 @@ def get_speech_timestamps(audio: torch.Tensor,
|
||||
return speeches
|
||||
|
||||
|
||||
def get_number_ts(wav: torch.Tensor,
|
||||
model,
|
||||
model_stride=8,
|
||||
hop_length=160,
|
||||
sample_rate=16000):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
perframe_logits = model(wav)[0]
|
||||
perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
|
||||
extended_preds = []
|
||||
for i in perframe_preds:
|
||||
extended_preds.extend([i.item()] * model_stride)
|
||||
# len(extended_preds) is *num_frames_real*; for each frame of audio we know if it has a number in it.
|
||||
triggered = False
|
||||
timings = []
|
||||
cur_timing = {}
|
||||
for i, pred in enumerate(extended_preds):
|
||||
if pred == 1:
|
||||
if not triggered:
|
||||
cur_timing['start'] = int((i * hop_length) / (sample_rate / 1000))
|
||||
triggered = True
|
||||
elif pred == 0:
|
||||
if triggered:
|
||||
cur_timing['end'] = int((i * hop_length) / (sample_rate / 1000))
|
||||
timings.append(cur_timing)
|
||||
cur_timing = {}
|
||||
triggered = False
|
||||
if cur_timing:
|
||||
cur_timing['end'] = int(len(wav) / (sample_rate / 1000))
|
||||
timings.append(cur_timing)
|
||||
return timings
|
||||
|
||||
|
||||
def get_language(wav: torch.Tensor,
|
||||
model):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
lang_logits = model(wav)[2]
|
||||
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
|
||||
assert lang_pred < len(languages)
|
||||
return languages[lang_pred]
|
||||
|
||||
|
||||
def get_language_and_group(wav: torch.Tensor,
|
||||
model,
|
||||
lang_dict: dict,
|
||||
lang_group_dict: dict,
|
||||
top_n=1):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
lang_logits, lang_group_logits = model(wav)
|
||||
|
||||
softm = torch.softmax(lang_logits, dim=1).squeeze()
|
||||
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
|
||||
|
||||
srtd = torch.argsort(softm, descending=True)
|
||||
srtd_group = torch.argsort(softm_group, descending=True)
|
||||
|
||||
outs = []
|
||||
outs_group = []
|
||||
for i in range(top_n):
|
||||
prob = round(softm[srtd[i]].item(), 2)
|
||||
prob_group = round(softm_group[srtd_group[i]].item(), 2)
|
||||
outs.append((lang_dict[str(srtd[i].item())], prob))
|
||||
outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group))
|
||||
|
||||
return outs, outs_group
|
||||
|
||||
|
||||
class VADIterator:
|
||||
def __init__(self,
|
||||
model,
|
||||
|
||||
Reference in New Issue
Block a user