diff --git a/README.md b/README.md index 7e0ec3a..25e8c6f 100644 --- a/README.md +++ b/README.md @@ -237,10 +237,13 @@ get_language_and_group, read_audio = utils files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files' wav = read_audio(f'{files_dir}/de.wav') -language, language_group = get_language_and_group(wav, model, lang_dict, lang_group_dict) +languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=2) -pprint(f'Language: {language}') -pprint(f'Language group: {language_group}') +for i in languages: + pprint(f'Language: {i[0]} with prob {i[-1]}') + +for i in language_groups: + pprint(f'Language group: {i[0]} with prob {i[-1]}') ``` ### ONNX @@ -388,10 +391,13 @@ def validate_onnx(model, inputs): model = init_onnx_model(f'{files_dir}/lang_classifier_116.onnx') wav = read_audio(f'{files_dir}/de.wav') -language, language_group = get_language_and_group(wav, model, lang_dict, lang_group_dict, run_function=validate_onnx) +languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=2, run_function=validate_onnx) -pprint(f'Language: {language}') -pprint(f'Language group: {language_group}') +for i in languages: + pprint(f'Language: {i[0]} with prob {i[-1]}') + +for i in language_groups: + pprint(f'Language group: {i[0]} with prob {i[-1]}') ``` [![Open on Torch Hub](https://img.shields.io/badge/Torch-Hub-red?logo=pytorch&style=for-the-badge)](https://pytorch.org/hub/snakers4_silero-vad_language/) diff --git a/utils_vad.py b/utils_vad.py index ad54e25..f65a089 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -333,14 +333,26 @@ def get_language_and_group(wav: torch.Tensor, model, lang_dict: dict, lang_group_dict: dict, + top_n=1, run_function=validate): wav = torch.unsqueeze(wav, dim=0) lang_logits, lang_group_logits = run_function(model, wav) + + softm = torch.softmax(lang_logits, dim=1).squeeze() + softm_group = torch.softmax(lang_group_logits, dim=1).squeeze() + + srtd = torch.argsort(softm, descending=True) + srtd_group = torch.argsort(softm_group, descending=True) + + outs = [] + outs_group = [] + for i in range(top_n): + prob = round(softm[srtd[i]].item(), 2) + prob_group = round(softm_group[srtd_group[i]].item(), 2) + outs.append((lang_dict[str(srtd[i].item())], prob)) + outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group)) - lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1 - lang_group_pred = torch.argmax(torch.softmax(lang_group_logits, dim=1), dim=1).item() - - return lang_dict[str(lang_pred)], lang_group_dict[str(lang_group_pred)] + return outs, outs_group class VADiterator: