добавлен поиск порогов

This commit is contained in:
adamnsandle
2024-08-19 16:53:28 +00:00
parent e706ec6fee
commit 827e86e685
6 changed files with 138 additions and 25 deletions

View File

@@ -120,7 +120,7 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for releva
@misc{Silero VAD, @misc{Silero VAD,
author = {Silero Team}, author = {Silero Team},
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier}, title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
year = {2021}, year = {2024},
publisher = {GitHub}, publisher = {GitHub},
journal = {GitHub repository}, journal = {GitHub repository},
howpublished = {\url{https://github.com/snakers4/silero-vad}}, howpublished = {\url{https://github.com/snakers4/silero-vad}},

View File

@@ -7,12 +7,12 @@
## Зависимости ## Зависимости
Следующие зависимости используются при тюнинге VAD модели: Следующие зависимости используются при тюнинге VAD модели:
- `torch>=1.12.0`
- `torchaudio>=0.12.0` - `torchaudio>=0.12.0`
- `sklearn>=1.2.0`
- `tqdm`
- `pandas>=2.2.2`
- `omegaconf>=2.3.0` - `omegaconf>=2.3.0`
- `sklearn>=1.2.0`
- `torch>=1.12.0`
- `pandas>=2.2.2`
- `tqdm`
## Подготовка данных ## Подготовка данных
@@ -29,6 +29,7 @@
Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения: Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения:
- `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`; - `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
- `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`; - `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
- `jit_model_path` - абсолютный путь до Silero-VAD модели в формате `.jit`. Если оставить это поле пустым, то модель будет загружена из репозитория в зависимости от значения поля `use_torchhub`
- `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`); - `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`);
- `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц; - `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц;
- `model_save_path` - путь сохранения добученной модели; - `model_save_path` - путь сохранения добученной модели;
@@ -43,17 +44,27 @@
## Дообучение ## Дообучение
Дообучение запускается командой `python tune.py` Дообучение запускается командой
`python tune.py`
Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit. Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit.
## Поиск пороговых значений
Порог на вход и порог на выход можно подобрать, используя команду
`python search_thresholds`
Данный скрипт использует файл конфигурации, описанный выше. Указанная в конфигурации модель будет использована для поиска оптимальных порогов на валидационном датасете.
## Цитирование ## Цитирование
``` ```
@misc{Silero VAD, @misc{Silero VAD,
author = {Silero Team}, author = {Silero Team},
title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier}, title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
year = {2021}, year = {2024},
publisher = {GitHub}, publisher = {GitHub},
journal = {GitHub repository}, journal = {GitHub repository},
howpublished = {\url{https://github.com/snakers4/silero-vad}}, howpublished = {\url{https://github.com/snakers4/silero-vad}},

View File

@@ -1,3 +1,4 @@
jit_model_path: '' # путь до Silero-VAD модели в формате jit, эта модель будет использована для дообучения. Если оставить поле пустым, то модель будет загружена автоматически
use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False
tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True
@@ -10,7 +11,7 @@ max_train_length_sec: 8 # во время тюнинга аудио длинн
aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения
learning_rate: 5e-4 # темп дообучения модели learning_rate: 5e-4 # темп дообучения модели
batch_size: 384 # размер батча при дообучении и валидации batch_size: 128 # размер батча при дообучении и валидации
num_workers: 4 # количество потоков, используемых для даталоадеров num_workers: 4 # количество потоков, используемых для даталоадеров
num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных
device: 'cpu' # cpu или cuda, на чем будет производится дообучение device: 'cuda' # cpu или cuda, на чем будет производится дообучение

View File

@@ -0,0 +1,36 @@
from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder
from omegaconf import OmegaConf
import torch
torch.set_num_threads(1)
if __name__ == '__main__':
config = OmegaConf.load('config.yml')
loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'),
batch_size=config.batch_size,
collate_fn=SileroVadPadder,
num_workers=config.num_workers)
if config.jit_model_path:
print(f'Loading model from the local folder: {config.jit_model_path}')
model = init_jit_model(config.jit_model_path, device=config.device)
else:
if config.use_torchhub:
print('Loading model using torch.hub')
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=False,
force_reload=True)
else:
print('Loading model using silero-vad library')
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=False)
print('Model loaded')
model.to(config.device)
print('Making predicts...')
all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000)
print('Calculating thresholds...')
best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts)
print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')

View File

@@ -1,7 +1,7 @@
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
from omegaconf import OmegaConf from omegaconf import OmegaConf
import torch
import torch.nn as nn import torch.nn as nn
import torch
if __name__ == '__main__': if __name__ == '__main__':
@@ -19,15 +19,22 @@ if __name__ == '__main__':
collate_fn=SileroVadPadder, collate_fn=SileroVadPadder,
num_workers=config.num_workers) num_workers=config.num_workers)
if config.use_torchhub: if config.jit_model_path:
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', print(f'Loading model from the local folder: {config.jit_model_path}')
model='silero_vad', model = init_jit_model(config.jit_model_path, device=config.device)
onnx=False,
force_reload=True)
else: else:
from silero_vad import load_silero_vad if config.use_torchhub:
model = load_silero_vad(onnx=False) print('Loading model using torch.hub')
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
onnx=False,
force_reload=True)
else:
print('Loading model using silero-vad library')
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=False)
print('Model loaded')
model.to(config.device) model.to(config.device)
decoder = VADDecoderRNNJIT().to(config.device) decoder = VADDecoderRNNJIT().to(config.device)
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict()) decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())

View File

@@ -1,14 +1,14 @@
import torch from sklearn.metrics import roc_auc_score, accuracy_score
import torch.nn as nn
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torchaudio import torch.nn as nn
import numpy as np
import random
import gc
from sklearn.metrics import roc_auc_score
from tqdm import tqdm from tqdm import tqdm
import pandas as pd import pandas as pd
import numpy as np
import torchaudio
import warnings import warnings
import random
import torch
import gc
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
@@ -297,3 +297,61 @@ def validate(config,
gc.collect() gc.collect()
return losses.avg, round(score, 3) return losses.avg, round(score, 3)
def init_jit_model(model_path: str,
device=torch.device('cpu')):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
def predict(model, loader, device, sr):
with torch.no_grad():
all_predicts = []
all_gts = []
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
x = x.to(device)
out = model.audio_forward(x, sr=sr)
for i, out_chunk in enumerate(out):
predict = out_chunk[masks[i] != 0].cpu().tolist()
gt = targets[i, masks[i] != 0].cpu().tolist()
all_predicts.append(predict)
all_gts.append(gt)
return all_predicts, all_gts
def calculate_best_thresholds(all_predicts, all_gts):
best_acc = 0
for ths_enter in tqdm(np.linspace(0, 1, 20)):
for ths_exit in np.linspace(0, 1, 20):
if ths_exit >= ths_enter:
continue
accs = []
for j, predict in enumerate(all_predicts):
predict_bool = []
is_speech = False
for i in predict:
if i >= ths_enter:
is_speech = True
predict_bool.append(1)
elif i <= ths_exit:
is_speech = False
predict_bool.append(0)
else:
val = 1 if is_speech else 0
predict_bool.append(val)
score = round(accuracy_score(all_gts[j], predict_bool), 4)
accs.append(score)
mean_acc = round(np.mean(accs), 3)
if mean_acc > best_acc:
best_acc = mean_acc
best_ths_enter = round(ths_enter, 2)
best_ths_exit = round(ths_exit, 2)
return best_ths_enter, best_ths_exit, best_acc