from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder from omegaconf import OmegaConf import torch torch.set_num_threads(1) if __name__ == '__main__': config = OmegaConf.load('config.yml') loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'), batch_size=config.batch_size, collate_fn=SileroVadPadder, num_workers=config.num_workers) if config.jit_model_path: print(f'Loading model from the local folder: {config.jit_model_path}') model = init_jit_model(config.jit_model_path, device=config.device) else: if config.use_torchhub: print('Loading model using torch.hub') model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad', onnx=False, force_reload=True) else: print('Loading model using silero-vad library') from silero_vad import load_silero_vad model = load_silero_vad(onnx=False) print('Model loaded') model.to(config.device) print('Making predicts...') all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000) print('Calculating thresholds...') best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts) print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')