From 827e86e6850023e3a23faa30c1484fa23acd093a Mon Sep 17 00:00:00 2001 From: adamnsandle Date: Mon, 19 Aug 2024 16:53:28 +0000 Subject: [PATCH] =?UTF-8?q?=D0=B4=D0=BE=D0=B1=D0=B0=D0=B2=D0=BB=D0=B5?= =?UTF-8?q?=D0=BD=20=D0=BF=D0=BE=D0=B8=D1=81=D0=BA=20=D0=BF=D0=BE=D1=80?= =?UTF-8?q?=D0=BE=D0=B3=D0=BE=D0=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- tuning/README.md | 23 ++++++++---- tuning/config.yml | 5 +-- tuning/search_thresholds.py | 36 +++++++++++++++++++ tuning/tune.py | 25 ++++++++----- tuning/utils.py | 72 +++++++++++++++++++++++++++++++++---- 6 files changed, 138 insertions(+), 25 deletions(-) create mode 100644 tuning/search_thresholds.py diff --git a/README.md b/README.md index 3d7cafa..a6f0ed9 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,7 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for releva @misc{Silero VAD, author = {Silero Team}, title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier}, - year = {2021}, + year = {2024}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/snakers4/silero-vad}}, diff --git a/tuning/README.md b/tuning/README.md index b058382..2f071eb 100644 --- a/tuning/README.md +++ b/tuning/README.md @@ -7,12 +7,12 @@ ## Зависимости Следующие зависимости используются при тюнинге VAD модели: -- `torch>=1.12.0` - `torchaudio>=0.12.0` -- `sklearn>=1.2.0` -- `tqdm` -- `pandas>=2.2.2` - `omegaconf>=2.3.0` +- `sklearn>=1.2.0` +- `torch>=1.12.0` +- `pandas>=2.2.2` +- `tqdm` ## Подготовка данных @@ -29,6 +29,7 @@ Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения: - `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`; - `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`; +- `jit_model_path` - абсолютный путь до Silero-VAD модели в формате `.jit`. Если оставить это поле пустым, то модель будет загружена из репозитория в зависимости от значения поля `use_torchhub` - `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`); - `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц; - `model_save_path` - путь сохранения добученной модели; @@ -43,17 +44,27 @@ ## Дообучение -Дообучение запускается командой `python tune.py` +Дообучение запускается командой + +`python tune.py` Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit. +## Поиск пороговых значений + +Порог на вход и порог на выход можно подобрать, используя команду + +`python search_thresholds` + +Данный скрипт использует файл конфигурации, описанный выше. Указанная в конфигурации модель будет использована для поиска оптимальных порогов на валидационном датасете. + ## Цитирование ``` @misc{Silero VAD, author = {Silero Team}, title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier}, - year = {2021}, + year = {2024}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/snakers4/silero-vad}}, diff --git a/tuning/config.yml b/tuning/config.yml index 7129a36..7b59777 100644 --- a/tuning/config.yml +++ b/tuning/config.yml @@ -1,3 +1,4 @@ +jit_model_path: '' # путь до Silero-VAD модели в формате jit, эта модель будет использована для дообучения. Если оставить поле пустым, то модель будет загружена автоматически use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True @@ -10,7 +11,7 @@ max_train_length_sec: 8 # во время тюнинга аудио длинн aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения learning_rate: 5e-4 # темп дообучения модели -batch_size: 384 # размер батча при дообучении и валидации +batch_size: 128 # размер батча при дообучении и валидации num_workers: 4 # количество потоков, используемых для даталоадеров num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных -device: 'cpu' # cpu или cuda, на чем будет производится дообучение \ No newline at end of file +device: 'cuda' # cpu или cuda, на чем будет производится дообучение \ No newline at end of file diff --git a/tuning/search_thresholds.py b/tuning/search_thresholds.py new file mode 100644 index 0000000..d83e8c1 --- /dev/null +++ b/tuning/search_thresholds.py @@ -0,0 +1,36 @@ +from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder +from omegaconf import OmegaConf +import torch +torch.set_num_threads(1) + +if __name__ == '__main__': + config = OmegaConf.load('config.yml') + + loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'), + batch_size=config.batch_size, + collate_fn=SileroVadPadder, + num_workers=config.num_workers) + + if config.jit_model_path: + print(f'Loading model from the local folder: {config.jit_model_path}') + model = init_jit_model(config.jit_model_path, device=config.device) + else: + if config.use_torchhub: + print('Loading model using torch.hub') + model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', + model='silero_vad', + onnx=False, + force_reload=True) + else: + print('Loading model using silero-vad library') + from silero_vad import load_silero_vad + model = load_silero_vad(onnx=False) + + print('Model loaded') + model.to(config.device) + + print('Making predicts...') + all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000) + print('Calculating thresholds...') + best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts) + print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}') diff --git a/tuning/tune.py b/tuning/tune.py index 2be573a..36b41c8 100644 --- a/tuning/tune.py +++ b/tuning/tune.py @@ -1,7 +1,7 @@ -from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate +from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model from omegaconf import OmegaConf -import torch import torch.nn as nn +import torch if __name__ == '__main__': @@ -19,15 +19,22 @@ if __name__ == '__main__': collate_fn=SileroVadPadder, num_workers=config.num_workers) - if config.use_torchhub: - model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', - model='silero_vad', - onnx=False, - force_reload=True) + if config.jit_model_path: + print(f'Loading model from the local folder: {config.jit_model_path}') + model = init_jit_model(config.jit_model_path, device=config.device) else: - from silero_vad import load_silero_vad - model = load_silero_vad(onnx=False) + if config.use_torchhub: + print('Loading model using torch.hub') + model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', + model='silero_vad', + onnx=False, + force_reload=True) + else: + print('Loading model using silero-vad library') + from silero_vad import load_silero_vad + model = load_silero_vad(onnx=False) + print('Model loaded') model.to(config.device) decoder = VADDecoderRNNJIT().to(config.device) decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict()) diff --git a/tuning/utils.py b/tuning/utils.py index e9a0b27..c4d58a5 100644 --- a/tuning/utils.py +++ b/tuning/utils.py @@ -1,14 +1,14 @@ -import torch -import torch.nn as nn +from sklearn.metrics import roc_auc_score, accuracy_score from torch.utils.data import Dataset -import torchaudio -import numpy as np -import random -import gc -from sklearn.metrics import roc_auc_score +import torch.nn as nn from tqdm import tqdm import pandas as pd +import numpy as np +import torchaudio import warnings +import random +import torch +import gc warnings.filterwarnings('ignore') @@ -297,3 +297,61 @@ def validate(config, gc.collect() return losses.avg, round(score, 3) + + +def init_jit_model(model_path: str, + device=torch.device('cpu')): + torch.set_grad_enabled(False) + model = torch.jit.load(model_path, map_location=device) + model.eval() + return model + + +def predict(model, loader, device, sr): + with torch.no_grad(): + all_predicts = [] + all_gts = [] + for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)): + x = x.to(device) + out = model.audio_forward(x, sr=sr) + + for i, out_chunk in enumerate(out): + predict = out_chunk[masks[i] != 0].cpu().tolist() + gt = targets[i, masks[i] != 0].cpu().tolist() + + all_predicts.append(predict) + all_gts.append(gt) + return all_predicts, all_gts + + +def calculate_best_thresholds(all_predicts, all_gts): + best_acc = 0 + for ths_enter in tqdm(np.linspace(0, 1, 20)): + for ths_exit in np.linspace(0, 1, 20): + if ths_exit >= ths_enter: + continue + + accs = [] + for j, predict in enumerate(all_predicts): + predict_bool = [] + is_speech = False + for i in predict: + if i >= ths_enter: + is_speech = True + predict_bool.append(1) + elif i <= ths_exit: + is_speech = False + predict_bool.append(0) + else: + val = 1 if is_speech else 0 + predict_bool.append(val) + + score = round(accuracy_score(all_gts[j], predict_bool), 4) + accs.append(score) + + mean_acc = round(np.mean(accs), 3) + if mean_acc > best_acc: + best_acc = mean_acc + best_ths_enter = round(ths_enter, 2) + best_ths_exit = round(ths_exit, 2) + return best_ths_enter, best_ths_exit, best_acc