mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 01:49:22 +08:00
Add Language Detector
This commit is contained in:
13
utils.py
13
utils.py
@@ -9,6 +9,9 @@ import torch.nn.functional as F
|
||||
torchaudio.set_audio_backend("soundfile") # switch backend
|
||||
|
||||
|
||||
languages = ['ru', 'en', 'de', 'es']
|
||||
|
||||
|
||||
def validate(model,
|
||||
inputs: torch.Tensor):
|
||||
with torch.no_grad():
|
||||
@@ -138,6 +141,16 @@ def get_number_ts(wav: torch.Tensor,
|
||||
return timings
|
||||
|
||||
|
||||
def get_language(wav: torch.Tensor,
|
||||
model,
|
||||
run_function=validate):
|
||||
wav = torch.unsqueeze(wav, dim=0)
|
||||
lang_logits = run_function(model, wav)[2]
|
||||
lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
|
||||
assert lang_pred < len(languages)
|
||||
return languages[lang_pred]
|
||||
|
||||
|
||||
class VADiterator:
|
||||
def __init__(self,
|
||||
trig_sum: float = 0.26,
|
||||
|
||||
Reference in New Issue
Block a user