diff --git a/utils_vad.py b/utils_vad.py index ea0293b..88c11b9 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -379,72 +379,6 @@ def get_speech_timestamps(audio: torch.Tensor, return speeches -def get_number_ts(wav: torch.Tensor, - model, - model_stride=8, - hop_length=160, - sample_rate=16000): - wav = torch.unsqueeze(wav, dim=0) - perframe_logits = model(wav)[0] - perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided) - extended_preds = [] - for i in perframe_preds: - extended_preds.extend([i.item()] * model_stride) - # len(extended_preds) is *num_frames_real*; for each frame of audio we know if it has a number in it. - triggered = False - timings = [] - cur_timing = {} - for i, pred in enumerate(extended_preds): - if pred == 1: - if not triggered: - cur_timing['start'] = int((i * hop_length) / (sample_rate / 1000)) - triggered = True - elif pred == 0: - if triggered: - cur_timing['end'] = int((i * hop_length) / (sample_rate / 1000)) - timings.append(cur_timing) - cur_timing = {} - triggered = False - if cur_timing: - cur_timing['end'] = int(len(wav) / (sample_rate / 1000)) - timings.append(cur_timing) - return timings - - -def get_language(wav: torch.Tensor, - model): - wav = torch.unsqueeze(wav, dim=0) - lang_logits = model(wav)[2] - lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1 - assert lang_pred < len(languages) - return languages[lang_pred] - - -def get_language_and_group(wav: torch.Tensor, - model, - lang_dict: dict, - lang_group_dict: dict, - top_n=1): - wav = torch.unsqueeze(wav, dim=0) - lang_logits, lang_group_logits = model(wav) - - softm = torch.softmax(lang_logits, dim=1).squeeze() - softm_group = torch.softmax(lang_group_logits, dim=1).squeeze() - - srtd = torch.argsort(softm, descending=True) - srtd_group = torch.argsort(softm_group, descending=True) - - outs = [] - outs_group = [] - for i in range(top_n): - prob = round(softm[srtd[i]].item(), 2) - prob_group = round(softm_group[srtd_group[i]].item(), 2) - outs.append((lang_dict[str(srtd[i].item())], prob)) - outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group)) - - return outs, outs_group - - class VADIterator: def __init__(self, model,