mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
добавлен поиск порогов
This commit is contained in:
@@ -7,12 +7,12 @@
|
||||
|
||||
## Зависимости
|
||||
Следующие зависимости используются при тюнинге VAD модели:
|
||||
- `torch>=1.12.0`
|
||||
- `torchaudio>=0.12.0`
|
||||
- `sklearn>=1.2.0`
|
||||
- `tqdm`
|
||||
- `pandas>=2.2.2`
|
||||
- `omegaconf>=2.3.0`
|
||||
- `sklearn>=1.2.0`
|
||||
- `torch>=1.12.0`
|
||||
- `pandas>=2.2.2`
|
||||
- `tqdm`
|
||||
|
||||
## Подготовка данных
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения:
|
||||
- `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
|
||||
- `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
|
||||
- `jit_model_path` - абсолютный путь до Silero-VAD модели в формате `.jit`. Если оставить это поле пустым, то модель будет загружена из репозитория в зависимости от значения поля `use_torchhub`
|
||||
- `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`);
|
||||
- `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц;
|
||||
- `model_save_path` - путь сохранения добученной модели;
|
||||
@@ -43,17 +44,27 @@
|
||||
|
||||
## Дообучение
|
||||
|
||||
Дообучение запускается командой `python tune.py`
|
||||
Дообучение запускается командой
|
||||
|
||||
`python tune.py`
|
||||
|
||||
Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit.
|
||||
|
||||
## Поиск пороговых значений
|
||||
|
||||
Порог на вход и порог на выход можно подобрать, используя команду
|
||||
|
||||
`python search_thresholds`
|
||||
|
||||
Данный скрипт использует файл конфигурации, описанный выше. Указанная в конфигурации модель будет использована для поиска оптимальных порогов на валидационном датасете.
|
||||
|
||||
## Цитирование
|
||||
|
||||
```
|
||||
@misc{Silero VAD,
|
||||
author = {Silero Team},
|
||||
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
|
||||
year = {2021},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/snakers4/silero-vad}},
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
jit_model_path: '' # путь до Silero-VAD модели в формате jit, эта модель будет использована для дообучения. Если оставить поле пустым, то модель будет загружена автоматически
|
||||
use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False
|
||||
|
||||
tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True
|
||||
@@ -10,7 +11,7 @@ max_train_length_sec: 8 # во время тюнинга аудио длинн
|
||||
aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения
|
||||
|
||||
learning_rate: 5e-4 # темп дообучения модели
|
||||
batch_size: 384 # размер батча при дообучении и валидации
|
||||
batch_size: 128 # размер батча при дообучении и валидации
|
||||
num_workers: 4 # количество потоков, используемых для даталоадеров
|
||||
num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных
|
||||
device: 'cpu' # cpu или cuda, на чем будет производится дообучение
|
||||
device: 'cuda' # cpu или cuda, на чем будет производится дообучение
|
||||
36
tuning/search_thresholds.py
Normal file
36
tuning/search_thresholds.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = OmegaConf.load('config.yml')
|
||||
|
||||
loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'),
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=SileroVadPadder,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
if config.jit_model_path:
|
||||
print(f'Loading model from the local folder: {config.jit_model_path}')
|
||||
model = init_jit_model(config.jit_model_path, device=config.device)
|
||||
else:
|
||||
if config.use_torchhub:
|
||||
print('Loading model using torch.hub')
|
||||
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
onnx=False,
|
||||
force_reload=True)
|
||||
else:
|
||||
print('Loading model using silero-vad library')
|
||||
from silero_vad import load_silero_vad
|
||||
model = load_silero_vad(onnx=False)
|
||||
|
||||
print('Model loaded')
|
||||
model.to(config.device)
|
||||
|
||||
print('Making predicts...')
|
||||
all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000)
|
||||
print('Calculating thresholds...')
|
||||
best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts)
|
||||
print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')
|
||||
@@ -1,7 +1,7 @@
|
||||
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate
|
||||
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -19,15 +19,22 @@ if __name__ == '__main__':
|
||||
collate_fn=SileroVadPadder,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
if config.use_torchhub:
|
||||
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
onnx=False,
|
||||
force_reload=True)
|
||||
if config.jit_model_path:
|
||||
print(f'Loading model from the local folder: {config.jit_model_path}')
|
||||
model = init_jit_model(config.jit_model_path, device=config.device)
|
||||
else:
|
||||
from silero_vad import load_silero_vad
|
||||
model = load_silero_vad(onnx=False)
|
||||
if config.use_torchhub:
|
||||
print('Loading model using torch.hub')
|
||||
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
onnx=False,
|
||||
force_reload=True)
|
||||
else:
|
||||
print('Loading model using silero-vad library')
|
||||
from silero_vad import load_silero_vad
|
||||
model = load_silero_vad(onnx=False)
|
||||
|
||||
print('Model loaded')
|
||||
model.to(config.device)
|
||||
decoder = VADDecoderRNNJIT().to(config.device)
|
||||
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sklearn.metrics import roc_auc_score, accuracy_score
|
||||
from torch.utils.data import Dataset
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
import random
|
||||
import gc
|
||||
from sklearn.metrics import roc_auc_score
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
import warnings
|
||||
import random
|
||||
import torch
|
||||
import gc
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
@@ -297,3 +297,61 @@ def validate(config,
|
||||
gc.collect()
|
||||
|
||||
return losses.avg, round(score, 3)
|
||||
|
||||
|
||||
def init_jit_model(model_path: str,
|
||||
device=torch.device('cpu')):
|
||||
torch.set_grad_enabled(False)
|
||||
model = torch.jit.load(model_path, map_location=device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def predict(model, loader, device, sr):
|
||||
with torch.no_grad():
|
||||
all_predicts = []
|
||||
all_gts = []
|
||||
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||
x = x.to(device)
|
||||
out = model.audio_forward(x, sr=sr)
|
||||
|
||||
for i, out_chunk in enumerate(out):
|
||||
predict = out_chunk[masks[i] != 0].cpu().tolist()
|
||||
gt = targets[i, masks[i] != 0].cpu().tolist()
|
||||
|
||||
all_predicts.append(predict)
|
||||
all_gts.append(gt)
|
||||
return all_predicts, all_gts
|
||||
|
||||
|
||||
def calculate_best_thresholds(all_predicts, all_gts):
|
||||
best_acc = 0
|
||||
for ths_enter in tqdm(np.linspace(0, 1, 20)):
|
||||
for ths_exit in np.linspace(0, 1, 20):
|
||||
if ths_exit >= ths_enter:
|
||||
continue
|
||||
|
||||
accs = []
|
||||
for j, predict in enumerate(all_predicts):
|
||||
predict_bool = []
|
||||
is_speech = False
|
||||
for i in predict:
|
||||
if i >= ths_enter:
|
||||
is_speech = True
|
||||
predict_bool.append(1)
|
||||
elif i <= ths_exit:
|
||||
is_speech = False
|
||||
predict_bool.append(0)
|
||||
else:
|
||||
val = 1 if is_speech else 0
|
||||
predict_bool.append(val)
|
||||
|
||||
score = round(accuracy_score(all_gts[j], predict_bool), 4)
|
||||
accs.append(score)
|
||||
|
||||
mean_acc = round(np.mean(accs), 3)
|
||||
if mean_acc > best_acc:
|
||||
best_acc = mean_acc
|
||||
best_ths_enter = round(ths_enter, 2)
|
||||
best_ths_exit = round(ths_exit, 2)
|
||||
return best_ths_enter, best_ths_exit, best_acc
|
||||
|
||||
Reference in New Issue
Block a user