add top n to lang classifier

This commit is contained in:
adamnsandle
2021-07-12 10:15:26 +00:00
parent a3cf5f722f
commit 2939077a90
2 changed files with 28 additions and 10 deletions

View File

@@ -333,14 +333,26 @@ def get_language_and_group(wav: torch.Tensor,
model,
lang_dict: dict,
lang_group_dict: dict,
top_n=1,
run_function=validate):
wav = torch.unsqueeze(wav, dim=0)
lang_logits, lang_group_logits = run_function(model, wav)
softm = torch.softmax(lang_logits, dim=1).squeeze()
softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
srtd = torch.argsort(softm, descending=True)
srtd_group = torch.argsort(softm_group, descending=True)
outs = []
outs_group = []
for i in range(top_n):
prob = round(softm[srtd[i]].item(), 2)
prob_group = round(softm_group[srtd_group[i]].item(), 2)
outs.append((lang_dict[str(srtd[i].item())], prob))
outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group))
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
lang_group_pred = torch.argmax(torch.softmax(lang_group_logits, dim=1), dim=1).item()
return lang_dict[str(lang_pred)], lang_group_dict[str(lang_group_pred)]
return outs, outs_group
class VADiterator: