add 116 lang classifier

This commit is contained in:
adamnsandle
2021-07-09 11:32:22 +00:00
parent 588e459a00
commit d097d1630d
7 changed files with 112 additions and 2 deletions

View File

@@ -329,6 +329,20 @@ def get_language(wav: torch.Tensor,
return languages[lang_pred]
def get_language_and_group(wav: torch.Tensor,
model,
lang_dict: dict,
lang_group_dict: dict,
run_function=validate):
wav = torch.unsqueeze(wav, dim=0)
lang_logits, lang_group_logits = run_function(model, wav)
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
lang_group_pred = torch.argmax(torch.softmax(lang_group_logits, dim=1), dim=1).item()
return lang_dict[str(lang_pred)], lang_group_dict[str(lang_group_pred)]
class VADiterator:
def __init__(self,
trig_sum: float = 0.26,