добавлен поиск порогов

This commit is contained in:
adamnsandle
2024-08-19 16:53:28 +00:00
parent e706ec6fee
commit 827e86e685
6 changed files with 138 additions and 25 deletions

View File

@@ -7,12 +7,12 @@
## Зависимости
Следующие зависимости используются при тюнинге VAD модели:
- `torch>=1.12.0`
- `torchaudio>=0.12.0`
- `sklearn>=1.2.0`
- `tqdm`
- `pandas>=2.2.2`
- `omegaconf>=2.3.0`
- `sklearn>=1.2.0`
- `torch>=1.12.0`
- `pandas>=2.2.2`
- `tqdm`
## Подготовка данных
@@ -29,6 +29,7 @@
Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения:
- `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
- `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
- `jit_model_path` - абсолютный путь до Silero-VAD модели в формате `.jit`. Если оставить это поле пустым, то модель будет загружена из репозитория в зависимости от значения поля `use_torchhub`
- `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`);
- `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц;
- `model_save_path` - путь сохранения добученной модели;
@@ -43,17 +44,27 @@
## Дообучение
Дообучение запускается командой `python tune.py`
Дообучение запускается командой
`python tune.py`
Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit.
## Поиск пороговых значений
Порог на вход и порог на выход можно подобрать, используя команду
`python search_thresholds`
Данный скрипт использует файл конфигурации, описанный выше. Указанная в конфигурации модель будет использована для поиска оптимальных порогов на валидационном датасете.
## Цитирование
```
@misc{Silero VAD,
author = {Silero Team},
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
year = {2021},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/snakers4/silero-vad}},

View File

@@ -1,3 +1,4 @@
jit_model_path: '' # путь до Silero-VAD модели в формате jit, эта модель будет использована для дообучения. Если оставить поле пустым, то модель будет загружена автоматически
use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False
tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True
@@ -10,7 +11,7 @@ max_train_length_sec: 8 # во время тюнинга аудио длинн
aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения
learning_rate: 5e-4 # темп дообучения модели
batch_size: 384 # размер батча при дообучении и валидации
batch_size: 128 # размер батча при дообучении и валидации
num_workers: 4 # количество потоков, используемых для даталоадеров
num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных
device: 'cpu' # cpu или cuda, на чем будет производится дообучение
device: 'cuda' # cpu или cuda, на чем будет производится дообучение

View File

@@ -0,0 +1,36 @@
from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder
from omegaconf import OmegaConf
import torch
torch.set_num_threads(1)
if __name__ == '__main__':
config = OmegaConf.load('config.yml')
loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'),
batch_size=config.batch_size,
collate_fn=SileroVadPadder,
num_workers=config.num_workers)
if config.jit_model_path:
print(f'Loading model from the local folder: {config.jit_model_path}')
model = init_jit_model(config.jit_model_path, device=config.device)
else:
if config.use_torchhub:
print('Loading model using torch.hub')
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=False,
force_reload=True)
else:
print('Loading model using silero-vad library')
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=False)
print('Model loaded')
model.to(config.device)
print('Making predicts...')
all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000)
print('Calculating thresholds...')
best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts)
print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')

View File

@@ -1,7 +1,7 @@
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch
if __name__ == '__main__':
@@ -19,15 +19,22 @@ if __name__ == '__main__':
collate_fn=SileroVadPadder,
num_workers=config.num_workers)
if config.use_torchhub:
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=False,
force_reload=True)
if config.jit_model_path:
print(f'Loading model from the local folder: {config.jit_model_path}')
model = init_jit_model(config.jit_model_path, device=config.device)
else:
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=False)
if config.use_torchhub:
print('Loading model using torch.hub')
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=False,
force_reload=True)
else:
print('Loading model using silero-vad library')
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=False)
print('Model loaded')
model.to(config.device)
decoder = VADDecoderRNNJIT().to(config.device)
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())

View File

@@ -1,14 +1,14 @@
import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score, accuracy_score
from torch.utils.data import Dataset
import torchaudio
import numpy as np
import random
import gc
from sklearn.metrics import roc_auc_score
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
import numpy as np
import torchaudio
import warnings
import random
import torch
import gc
warnings.filterwarnings('ignore')
@@ -297,3 +297,61 @@ def validate(config,
gc.collect()
return losses.avg, round(score, 3)
def init_jit_model(model_path: str,
device=torch.device('cpu')):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
def predict(model, loader, device, sr):
with torch.no_grad():
all_predicts = []
all_gts = []
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
x = x.to(device)
out = model.audio_forward(x, sr=sr)
for i, out_chunk in enumerate(out):
predict = out_chunk[masks[i] != 0].cpu().tolist()
gt = targets[i, masks[i] != 0].cpu().tolist()
all_predicts.append(predict)
all_gts.append(gt)
return all_predicts, all_gts
def calculate_best_thresholds(all_predicts, all_gts):
best_acc = 0
for ths_enter in tqdm(np.linspace(0, 1, 20)):
for ths_exit in np.linspace(0, 1, 20):
if ths_exit >= ths_enter:
continue
accs = []
for j, predict in enumerate(all_predicts):
predict_bool = []
is_speech = False
for i in predict:
if i >= ths_enter:
is_speech = True
predict_bool.append(1)
elif i <= ths_exit:
is_speech = False
predict_bool.append(0)
else:
val = 1 if is_speech else 0
predict_bool.append(val)
score = round(accuracy_score(all_gts[j], predict_bool), 4)
accs.append(score)
mean_acc = round(np.mean(accs), 3)
if mean_acc > best_acc:
best_acc = mean_acc
best_ths_enter = round(ths_enter, 2)
best_ths_exit = round(ths_exit, 2)
return best_ths_enter, best_ths_exit, best_acc