mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
18
README.md
18
README.md
@@ -237,10 +237,13 @@ get_language_and_group, read_audio = utils
|
||||
files_dir = torch.hub.get_dir() + '/snakers4_silero-vad_master/files'
|
||||
|
||||
wav = read_audio(f'{files_dir}/de.wav')
|
||||
language, language_group = get_language_and_group(wav, model, lang_dict, lang_group_dict)
|
||||
languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=2)
|
||||
|
||||
pprint(f'Language: {language}')
|
||||
pprint(f'Language group: {language_group}')
|
||||
for i in languages:
|
||||
pprint(f'Language: {i[0]} with prob {i[-1]}')
|
||||
|
||||
for i in language_groups:
|
||||
pprint(f'Language group: {i[0]} with prob {i[-1]}')
|
||||
```
|
||||
|
||||
### ONNX
|
||||
@@ -388,10 +391,13 @@ def validate_onnx(model, inputs):
|
||||
model = init_onnx_model(f'{files_dir}/lang_classifier_116.onnx')
|
||||
wav = read_audio(f'{files_dir}/de.wav')
|
||||
|
||||
language, language_group = get_language_and_group(wav, model, lang_dict, lang_group_dict, run_function=validate_onnx)
|
||||
languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=2, run_function=validate_onnx)
|
||||
|
||||
pprint(f'Language: {language}')
|
||||
pprint(f'Language group: {language_group}')
|
||||
for i in languages:
|
||||
pprint(f'Language: {i[0]} with prob {i[-1]}')
|
||||
|
||||
for i in language_groups:
|
||||
pprint(f'Language group: {i[0]} with prob {i[-1]}')
|
||||
|
||||
```
|
||||
[](https://pytorch.org/hub/snakers4_silero-vad_language/)
|
||||
|
||||
20
utils_vad.py
20
utils_vad.py
@@ -333,14 +333,26 @@ def get_language_and_group(wav: torch.Tensor,
|
||||
model,
|
||||
lang_dict: dict,
|
||||
lang_group_dict: dict,
|
||||
top_n=1,
|
||||
run_function=validate):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
lang_logits, lang_group_logits = run_function(model, wav)
|
||||
|
||||
softm = torch.softmax(lang_logits, dim=1).squeeze()
|
||||
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
|
||||
|
||||
srtd = torch.argsort(softm, descending=True)
|
||||
srtd_group = torch.argsort(softm_group, descending=True)
|
||||
|
||||
outs = []
|
||||
outs_group = []
|
||||
for i in range(top_n):
|
||||
prob = round(softm[srtd[i]].item(), 2)
|
||||
prob_group = round(softm_group[srtd_group[i]].item(), 2)
|
||||
outs.append((lang_dict[str(srtd[i].item())], prob))
|
||||
outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group))
|
||||
|
||||
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
|
||||
lang_group_pred = torch.argmax(torch.softmax(lang_group_logits, dim=1), dim=1).item()
|
||||
|
||||
return lang_dict[str(lang_pred)], lang_group_dict[str(lang_group_pred)]
|
||||
return outs, outs_group
|
||||
|
||||
|
||||
class VADiterator:
|
||||
|
||||
Reference in New Issue
Block a user