mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 09:29:22 +08:00
37 lines
1.6 KiB
Python
37 lines
1.6 KiB
Python
from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder
|
|
from omegaconf import OmegaConf
|
|
import torch
|
|
torch.set_num_threads(1)
|
|
|
|
if __name__ == '__main__':
|
|
config = OmegaConf.load('config.yml')
|
|
|
|
loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'),
|
|
batch_size=config.batch_size,
|
|
collate_fn=SileroVadPadder,
|
|
num_workers=config.num_workers)
|
|
|
|
if config.jit_model_path:
|
|
print(f'Loading model from the local folder: {config.jit_model_path}')
|
|
model = init_jit_model(config.jit_model_path, device=config.device)
|
|
else:
|
|
if config.use_torchhub:
|
|
print('Loading model using torch.hub')
|
|
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
|
model='silero_vad',
|
|
onnx=False,
|
|
force_reload=True)
|
|
else:
|
|
print('Loading model using silero-vad library')
|
|
from silero_vad import load_silero_vad
|
|
model = load_silero_vad(onnx=False)
|
|
|
|
print('Model loaded')
|
|
model.to(config.device)
|
|
|
|
print('Making predicts...')
|
|
all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000)
|
|
print('Calculating thresholds...')
|
|
best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts)
|
|
print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')
|