Add Language Detector

This commit is contained in:
sontref
2021-01-07 08:19:54 +00:00
parent b9704fbe35
commit fb5af4966e
3 changed files with 192 additions and 1 deletions

View File

@@ -9,6 +9,9 @@ import torch.nn.functional as F
torchaudio.set_audio_backend("soundfile") # switch backend
languages = ['ru', 'en', 'de', 'es']
def validate(model,
inputs: torch.Tensor):
with torch.no_grad():
@@ -138,6 +141,16 @@ def get_number_ts(wav: torch.Tensor,
return timings
def get_language(wav: torch.Tensor,
model,
run_function=validate):
wav = torch.unsqueeze(wav, dim=0)
lang_logits = run_function(model, wav)[2]
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
assert lang_pred < len(languages)
return languages[lang_pred]
class VADiterator:
def __init__(self,
trig_sum: float = 0.26,