добавлен поиск порогов

This commit is contained in:
adamnsandle
2024-08-19 16:53:28 +00:00
parent e706ec6fee
commit 827e86e685
6 changed files with 138 additions and 25 deletions

View File

@@ -1,7 +1,7 @@
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch
if __name__ == '__main__':
@@ -19,15 +19,22 @@ if __name__ == '__main__':
collate_fn=SileroVadPadder,
num_workers=config.num_workers)
if config.use_torchhub:
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=False,
force_reload=True)
if config.jit_model_path:
print(f'Loading model from the local folder: {config.jit_model_path}')
model = init_jit_model(config.jit_model_path, device=config.device)
else:
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=False)
if config.use_torchhub:
print('Loading model using torch.hub')
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=False,
force_reload=True)
else:
print('Loading model using silero-vad library')
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=False)
print('Model loaded')
model.to(config.device)
decoder = VADDecoderRNNJIT().to(config.device)
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())