mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
добавлен поиск порогов
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate
|
||||
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -19,15 +19,22 @@ if __name__ == '__main__':
|
||||
collate_fn=SileroVadPadder,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
if config.use_torchhub:
|
||||
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
onnx=False,
|
||||
force_reload=True)
|
||||
if config.jit_model_path:
|
||||
print(f'Loading model from the local folder: {config.jit_model_path}')
|
||||
model = init_jit_model(config.jit_model_path, device=config.device)
|
||||
else:
|
||||
from silero_vad import load_silero_vad
|
||||
model = load_silero_vad(onnx=False)
|
||||
if config.use_torchhub:
|
||||
print('Loading model using torch.hub')
|
||||
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
onnx=False,
|
||||
force_reload=True)
|
||||
else:
|
||||
print('Loading model using silero-vad library')
|
||||
from silero_vad import load_silero_vad
|
||||
model = load_silero_vad(onnx=False)
|
||||
|
||||
print('Model loaded')
|
||||
model.to(config.device)
|
||||
decoder = VADDecoderRNNJIT().to(config.device)
|
||||
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())
|
||||
|
||||
Reference in New Issue
Block a user