Merge pull request #83 from snakers4/adamnsandle

Adamnsandle
This commit is contained in:
Alexander Veysov
2021-07-12 13:18:45 +03:00
committed by GitHub
2 changed files with 28 additions and 10 deletions

View File

@@ -237,10 +237,13 @@ get_language_and_group, read_audio = utils
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
wav = read_audio(f'{files_dir}/de.wav')
language, language_group = get_language_and_group(wav, model, lang_dict, lang_group_dict)
languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=2)
pprint(f'Language: {language}')
pprint(f'Language group: {language_group}')
for i in languages:
pprint(f'Language: {i[0]} with prob {i[-1]}')
for i in language_groups:
pprint(f'Language group: {i[0]} with prob {i[-1]}')
```
### ONNX
@@ -388,10 +391,13 @@ def validate_onnx(model, inputs):
model = init_onnx_model(f'{files_dir}/lang_classifier_116.onnx')
wav = read_audio(f'{files_dir}/de.wav')
language, language_group = get_language_and_group(wav, model, lang_dict, lang_group_dict, run_function=validate_onnx)
languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=2, run_function=validate_onnx)
pprint(f'Language: {language}')
pprint(f'Language group: {language_group}')
for i in languages:
pprint(f'Language: {i[0]} with prob {i[-1]}')
for i in language_groups:
pprint(f'Language group: {i[0]} with prob {i[-1]}')
```
[![Open on Torch Hub](https://img.shields.io/badge/Torch-Hub-red?logo=pytorch&style=for-the-badge)](https://pytorch.org/hub/snakers4_silero-vad_language/)

View File

@@ -333,14 +333,26 @@ def get_language_and_group(wav: torch.Tensor,
model,
lang_dict: dict,
lang_group_dict: dict,
top_n=1,
run_function=validate):
wav = torch.unsqueeze(wav, dim=0)
lang_logits, lang_group_logits = run_function(model, wav)
softm = torch.softmax(lang_logits, dim=1).squeeze()
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
srtd = torch.argsort(softm, descending=True)
srtd_group = torch.argsort(softm_group, descending=True)
outs = []
outs_group = []
for i in range(top_n):
prob = round(softm[srtd[i]].item(), 2)
prob_group = round(softm_group[srtd_group[i]].item(), 2)
outs.append((lang_dict[str(srtd[i].item())], prob))
outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group))
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
lang_group_pred = torch.argmax(torch.softmax(lang_group_logits, dim=1), dim=1).item()
return lang_dict[str(lang_pred)], lang_group_dict[str(lang_group_pred)]
return outs, outs_group
class VADiterator: