diff --git a/tuning/README.md b/tuning/README.md new file mode 100644 index 0000000..ba38aa8 --- /dev/null +++ b/tuning/README.md @@ -0,0 +1,51 @@ +# Тюнинг Silero-VAD модели + +> Код тюнинга создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный +интеллект» национальной программы «Цифровая экономика Российской Федерации». + +Тюнинг используется для улучшения качества детекции речи Silero-VAD модели на кастомных данных. + +## Подготовка данных + +Датафреймы для тюнинга должны быть подготовлены и сохранены в формате `.feather`. Следующие колонки в `.feather` файлах тренировки и валидации являются обязательными: +- **audio_path** - абсолютный путь до аудиофайла в дисковой системе. Аудиофайлы должны представлять собой `PCM` данные, предпочтительно в форматах `.wav` или `.opus` (иные популярные форматы аудио тоже поддерживаются). Для ускорения темпа дообучения рекомендуется предварительно выполнить ресемплинг аудиофайлов (изменить частоту дискретизации) до 16000 Гц; +- **speech_ts** - разметка для соответствующего аудиофайла. Список, состоящий из словарей формата `{'start': START_SEC, 'end': 'END_SEC'}`, где `START_SEC` и `END_SEC` - время начало и конца речевого отрезка в секундах соответственно. + +Пример `.feather` датафрейма можно посмотреть в файле `example_dataframe.feather` + +## Файл конфигурации `config.yml` + +Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения: +- `train_dataset_path` - абсолютный путь до тренирового датафрема в формате `.feather`, Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрема можно посмотреть в `example_dataframe.feather`; +- `val_dataset_path` - абсолютный путь до валидационного датафрема в формате `.feather`, Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрема можно посмотреть в `example_dataframe.feather`; +- `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`); +- `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц; +- `model_save_path` - путь сохранения добученной модели; +- `noise_loss` - коэффициент лосса, применяемый для неречевых окон аудио; +- `max_train_length_sec` - максимальная длина аудио в секундах на этапе дообучения. Более длительные аудио будут обрезаны до этого показателя; +- `aug_prob` - вероятность применения аугментаций к аудиофайлу на этапе дообучения; +- `learning_rate` - темп дообучения; +- `batch_size` - размер батча при дообучении и валидации; +- `num_workers` - количество потоков, используемых для загрузки данных; +- `num_epochs` - количество эпох дообучения. За одну эпоху прогоняются все тренировочные данные; +- `device` - `cpu` или `cuda`. + +## Дообучение + +Дообучение запускается командой `python tune.py` + +Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit. + +## Цитирование + +``` +@misc{Silero VAD Dataset, + author = {Silero Team}, + title = {Silero-VAD Dataset: a large public Internet-scale dataset for voice activity detection for 6000+ languages}, + year = {2024}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/snakers4/silero-vad/datasets/README.md}}, + email = {hello@silero.ai} +} +``` \ No newline at end of file diff --git a/tuning/__init__.py b/tuning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tuning/config.yml b/tuning/config.yml new file mode 100644 index 0000000..7129a36 --- /dev/null +++ b/tuning/config.yml @@ -0,0 +1,16 @@ +use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False + +tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True +train_dataset_path: 'train_dataset_path.feather' # путь до датасета в формате feather для дообучения, подробности в README +val_dataset_path: 'val_dataset_path.feather' # путь до датасета в формате feather для валидации, подробности в README +model_save_path: 'model_save_path.jit' # путь сохранения дообученной модели + +noise_loss: 0.5 # коэффициент, применяемый к лоссу на неречевых окнах +max_train_length_sec: 8 # во время тюнинга аудио длиннее будут обрезаны до данного значения +aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения + +learning_rate: 5e-4 # темп дообучения модели +batch_size: 384 # размер батча при дообучении и валидации +num_workers: 4 # количество потоков, используемых для даталоадеров +num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных +device: 'cpu' # cpu или cuda, на чем будет производится дообучение \ No newline at end of file diff --git a/tuning/example_dataframe.feather b/tuning/example_dataframe.feather new file mode 100644 index 0000000..d8b8592 Binary files /dev/null and b/tuning/example_dataframe.feather differ diff --git a/tuning/tune.py b/tuning/tune.py new file mode 100644 index 0000000..2be573a --- /dev/null +++ b/tuning/tune.py @@ -0,0 +1,58 @@ +from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate +from omegaconf import OmegaConf +import torch +import torch.nn as nn + + +if __name__ == '__main__': + config = OmegaConf.load('config.yml') + + train_dataset = SileroVadDataset(config, mode='train') + train_loader = torch.utils.data.DataLoader(train_dataset, + batch_size=config.batch_size, + collate_fn=SileroVadPadder, + num_workers=config.num_workers) + + val_dataset = SileroVadDataset(config, mode='val') + val_loader = torch.utils.data.DataLoader(val_dataset, + batch_size=config.batch_size, + collate_fn=SileroVadPadder, + num_workers=config.num_workers) + + if config.use_torchhub: + model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', + model='silero_vad', + onnx=False, + force_reload=True) + else: + from silero_vad import load_silero_vad + model = load_silero_vad(onnx=False) + + model.to(config.device) + decoder = VADDecoderRNNJIT().to(config.device) + decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict()) + decoder.train() + params = decoder.parameters() + optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params), + lr=config.learning_rate) + criterion = nn.BCELoss(reduction='none') + + best_val_roc = 0 + for i in range(config.num_epochs): + print(f'Starting epoch {i + 1}') + train_loss = train(config, train_loader, model, decoder, criterion, optimizer, config.device) + val_loss, val_roc = validate(config, val_loader, model, decoder, criterion, config.device) + print(f'Metrics after epoch {i + 1}:\n' + f'\tTrain loss: {round(train_loss, 3)}\n', + f'\tValidation loss: {round(val_loss, 3)}\n' + f'\tValidation ROC-AUC: {round(val_roc, 3)}') + + if val_roc > best_val_roc: + print('New best ROC-AUC, saving model') + best_val_roc = val_roc + if config.tune_8k: + model._model_8k.decoder.load_state_dict(decoder.state_dict()) + else: + model._model.decoder.load_state_dict(decoder.state_dict()) + torch.jit.save(model, config.model_save_path) + print('Done') diff --git a/tuning/utils.py b/tuning/utils.py new file mode 100644 index 0000000..e9a0b27 --- /dev/null +++ b/tuning/utils.py @@ -0,0 +1,299 @@ +import torch +import torch.nn as nn +from torch.utils.data import Dataset +import torchaudio +import numpy as np +import random +import gc +from sklearn.metrics import roc_auc_score +from tqdm import tqdm +import pandas as pd +import warnings +warnings.filterwarnings('ignore') + + +def read_audio(path: str, + sampling_rate: int = 16000, + normalize=False): + + wav, sr = torchaudio.load(path) + + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + + if sampling_rate: + if sr != sampling_rate: + transform = torchaudio.transforms.Resample(orig_freq=sr, + new_freq=sampling_rate) + wav = transform(wav) + sr = sampling_rate + + if normalize and wav.abs().max() != 0: + wav = wav / wav.abs().max() + + return wav.squeeze(0) + + +def build_audiomentations_augs(p): + from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \ + LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \ + Aliasing, AddGaussianNoise + transforms = [Aliasing(p=1), + AddGaussianNoise(p=1), + AirAbsorption(p=1), + BandPassFilter(p=1), + BandStopFilter(p=1), + ClippingDistortion(p=1), + HighPassFilter(p=1), + HighShelfFilter(p=1), + LowPassFilter(p=1), + LowShelfFilter(p=1), + Mp3Compression(p=1), + PeakingFilter(p=1), + PitchShift(p=1), + RoomSimulator(p=1, leave_length_unchanged=True), + SevenBandParametricEQ(p=1)] + tr = SomeOf((1, 3), transforms=transforms, p=p) + return tr + + +class SileroVadDataset(Dataset): + def __init__(self, + config, + mode='train'): + + self.num_samples = 512 # constant, do not change + self.sr = 16000 # constant, do not change + + self.resample_to_8k = config.tune_8k + self.noise_loss = config.noise_loss + self.max_train_length_sec = config.max_train_length_sec + self.max_train_length_samples = config.max_train_length_sec * self.sr + + assert self.max_train_length_samples % self.num_samples == 0 + assert mode in ['train', 'val'] + + dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path + self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True) + self.index_dict = self.dataframe.to_dict('index') + self.mode = mode + print(f'DATASET SIZE : {len(self.dataframe)}') + + if mode == 'train': + self.augs = build_audiomentations_augs(p=config.aug_prob) + else: + self.augs = None + + def __getitem__(self, idx): + idx = None if self.mode == 'train' else idx + wav, gt, mask = self.load_speech_sample(idx) + + if self.mode == 'train': + wav = self.add_augs(wav) + if len(wav) > self.max_train_length_samples: + wav = wav[:self.max_train_length_samples] + gt = gt[:int(self.max_train_length_samples / self.num_samples)] + mask = mask[:int(self.max_train_length_samples / self.num_samples)] + + wav = torch.FloatTensor(wav) + if self.resample_to_8k: + transform = torchaudio.transforms.Resample(orig_freq=self.sr, + new_freq=8000) + wav = transform(wav) + return wav, torch.FloatTensor(gt), torch.from_numpy(mask) + + def __len__(self): + return len(self.index_dict) + + def load_speech_sample(self, idx=None): + if idx is None: + idx = random.randint(0, len(self.index_dict) - 1) + wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy() + + if len(wav) % self.num_samples != 0: + pad_num = self.num_samples - (len(wav) % (self.num_samples)) + wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0) + + gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav)) + + assert len(gt) == len(wav) / self.num_samples + + mask[gt == 0] + + return wav, gt, mask + + def get_ground_truth_annotated(self, annotation, audio_length_samples): + gt = np.zeros(audio_length_samples) + + for i in annotation: + gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1 + + squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1) + squeezed_predicts = (squeezed_predicts > 0.5).astype(int) + mask = np.ones(len(squeezed_predicts)) + mask[squeezed_predicts == 0] = self.noise_loss + return squeezed_predicts, mask + + def add_augs(self, wav): + while True: + try: + wav_aug = self.augs(wav, self.sr) + if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()): + return wav + return wav_aug + except Exception as e: + continue + + +def SileroVadPadder(batch): + wavs = [batch[i][0] for i in range(len(batch))] + labels = [batch[i][1] for i in range(len(batch))] + masks = [batch[i][2] for i in range(len(batch))] + + wavs = torch.nn.utils.rnn.pad_sequence( + wavs, batch_first=True, padding_value=0) + + labels = torch.nn.utils.rnn.pad_sequence( + labels, batch_first=True, padding_value=0) + + masks = torch.nn.utils.rnn.pad_sequence( + masks, batch_first=True, padding_value=0) + + return wavs, labels, masks + + +class VADDecoderRNNJIT(nn.Module): + + def __init__(self): + super(VADDecoderRNNJIT, self).__init__() + + self.rnn = nn.LSTMCell(128, 128) + self.decoder = nn.Sequential(nn.Dropout(0.1), + nn.ReLU(), + nn.Conv1d(128, 1, kernel_size=1), + nn.Sigmoid()) + + def forward(self, x, state=torch.zeros(0)): + x = x.squeeze(-1) + if len(state): + h, c = self.rnn(x, (state[0], state[1])) + else: + h, c = self.rnn(x) + + x = h.unsqueeze(-1).float() + state = torch.stack([h, c]) + x = self.decoder(x) + return x, state + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def train(config, + loader, + jit_model, + decoder, + criterion, + optimizer, + device): + + losses = AverageMeter() + decoder.train() + + context_size = 32 if config.tune_8k else 64 + num_samples = 256 if config.tune_8k else 512 + stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft + encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder + + with torch.enable_grad(): + for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)): + targets = targets.to(device) + x = x.to(device) + masks = masks.to(device) + x = torch.nn.functional.pad(x, (context_size, 0)) + + outs = [] + state = torch.zeros(0) + for i in range(context_size, x.shape[1], num_samples): + input_ = x[:, i-context_size:i+num_samples] + out = stft_layer(input_) + out = encoder_layer(out) + out, state = decoder(out, state) + outs.append(out) + stacked = torch.cat(outs, dim=2).squeeze(1) + + loss = criterion(stacked, targets) + loss = (loss * masks).mean() + loss.backward() + optimizer.step() + losses.update(loss.item(), masks.numel()) + + torch.cuda.empty_cache() + gc.collect() + + return losses.avg + + +def validate(config, + loader, + jit_model, + decoder, + criterion, + device): + + losses = AverageMeter() + decoder.eval() + + predicts = [] + gts = [] + + context_size = 32 if config.tune_8k else 64 + num_samples = 256 if config.tune_8k else 512 + stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft + encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder + + with torch.no_grad(): + for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)): + targets = targets.to(device) + x = x.to(device) + masks = masks.to(device) + x = torch.nn.functional.pad(x, (context_size, 0)) + + outs = [] + state = torch.zeros(0) + for i in range(context_size, x.shape[1], num_samples): + input_ = x[:, i-context_size:i+num_samples] + out = stft_layer(input_) + out = encoder_layer(out) + out, state = decoder(out, state) + outs.append(out) + stacked = torch.cat(outs, dim=2).squeeze(1) + + predicts.extend(stacked[masks != 0].tolist()) + gts.extend(targets[masks != 0].tolist()) + + loss = criterion(stacked, targets) + loss = (loss * masks).mean() + losses.update(loss.item(), masks.numel()) + score = roc_auc_score(gts, predicts) + + torch.cuda.empty_cache() + gc.collect() + + return losses.avg, round(score, 3)