mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 09:29:22 +08:00
код для тюнинга
This commit is contained in:
51
tuning/README.md
Normal file
51
tuning/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Тюнинг Silero-VAD модели
|
||||
|
||||
> Код тюнинга создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный
|
||||
интеллект» национальной программы «Цифровая экономика Российской Федерации».
|
||||
|
||||
Тюнинг используется для улучшения качества детекции речи Silero-VAD модели на кастомных данных.
|
||||
|
||||
## Подготовка данных
|
||||
|
||||
Датафреймы для тюнинга должны быть подготовлены и сохранены в формате `.feather`. Следующие колонки в `.feather` файлах тренировки и валидации являются обязательными:
|
||||
- **audio_path** - абсолютный путь до аудиофайла в дисковой системе. Аудиофайлы должны представлять собой `PCM` данные, предпочтительно в форматах `.wav` или `.opus` (иные популярные форматы аудио тоже поддерживаются). Для ускорения темпа дообучения рекомендуется предварительно выполнить ресемплинг аудиофайлов (изменить частоту дискретизации) до 16000 Гц;
|
||||
- **speech_ts** - разметка для соответствующего аудиофайла. Список, состоящий из словарей формата `{'start': START_SEC, 'end': 'END_SEC'}`, где `START_SEC` и `END_SEC` - время начало и конца речевого отрезка в секундах соответственно.
|
||||
|
||||
Пример `.feather` датафрейма можно посмотреть в файле `example_dataframe.feather`
|
||||
|
||||
## Файл конфигурации `config.yml`
|
||||
|
||||
Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения:
|
||||
- `train_dataset_path` - абсолютный путь до тренирового датафрема в формате `.feather`, Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрема можно посмотреть в `example_dataframe.feather`;
|
||||
- `val_dataset_path` - абсолютный путь до валидационного датафрема в формате `.feather`, Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрема можно посмотреть в `example_dataframe.feather`;
|
||||
- `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`);
|
||||
- `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц;
|
||||
- `model_save_path` - путь сохранения добученной модели;
|
||||
- `noise_loss` - коэффициент лосса, применяемый для неречевых окон аудио;
|
||||
- `max_train_length_sec` - максимальная длина аудио в секундах на этапе дообучения. Более длительные аудио будут обрезаны до этого показателя;
|
||||
- `aug_prob` - вероятность применения аугментаций к аудиофайлу на этапе дообучения;
|
||||
- `learning_rate` - темп дообучения;
|
||||
- `batch_size` - размер батча при дообучении и валидации;
|
||||
- `num_workers` - количество потоков, используемых для загрузки данных;
|
||||
- `num_epochs` - количество эпох дообучения. За одну эпоху прогоняются все тренировочные данные;
|
||||
- `device` - `cpu` или `cuda`.
|
||||
|
||||
## Дообучение
|
||||
|
||||
Дообучение запускается командой `python tune.py`
|
||||
|
||||
Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit.
|
||||
|
||||
## Цитирование
|
||||
|
||||
```
|
||||
@misc{Silero VAD Dataset,
|
||||
author = {Silero Team},
|
||||
title = {Silero-VAD Dataset: a large public Internet-scale dataset for voice activity detection for 6000+ languages},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/snakers4/silero-vad/datasets/README.md}},
|
||||
email = {hello@silero.ai}
|
||||
}
|
||||
```
|
||||
0
tuning/__init__.py
Normal file
0
tuning/__init__.py
Normal file
16
tuning/config.yml
Normal file
16
tuning/config.yml
Normal file
@@ -0,0 +1,16 @@
|
||||
use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False
|
||||
|
||||
tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True
|
||||
train_dataset_path: 'train_dataset_path.feather' # путь до датасета в формате feather для дообучения, подробности в README
|
||||
val_dataset_path: 'val_dataset_path.feather' # путь до датасета в формате feather для валидации, подробности в README
|
||||
model_save_path: 'model_save_path.jit' # путь сохранения дообученной модели
|
||||
|
||||
noise_loss: 0.5 # коэффициент, применяемый к лоссу на неречевых окнах
|
||||
max_train_length_sec: 8 # во время тюнинга аудио длиннее будут обрезаны до данного значения
|
||||
aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения
|
||||
|
||||
learning_rate: 5e-4 # темп дообучения модели
|
||||
batch_size: 384 # размер батча при дообучении и валидации
|
||||
num_workers: 4 # количество потоков, используемых для даталоадеров
|
||||
num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных
|
||||
device: 'cpu' # cpu или cuda, на чем будет производится дообучение
|
||||
BIN
tuning/example_dataframe.feather
Normal file
BIN
tuning/example_dataframe.feather
Normal file
Binary file not shown.
58
tuning/tune.py
Normal file
58
tuning/tune.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = OmegaConf.load('config.yml')
|
||||
|
||||
train_dataset = SileroVadDataset(config, mode='train')
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=SileroVadPadder,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
val_dataset = SileroVadDataset(config, mode='val')
|
||||
val_loader = torch.utils.data.DataLoader(val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=SileroVadPadder,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
if config.use_torchhub:
|
||||
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
onnx=False,
|
||||
force_reload=True)
|
||||
else:
|
||||
from silero_vad import load_silero_vad
|
||||
model = load_silero_vad(onnx=False)
|
||||
|
||||
model.to(config.device)
|
||||
decoder = VADDecoderRNNJIT().to(config.device)
|
||||
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())
|
||||
decoder.train()
|
||||
params = decoder.parameters()
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params),
|
||||
lr=config.learning_rate)
|
||||
criterion = nn.BCELoss(reduction='none')
|
||||
|
||||
best_val_roc = 0
|
||||
for i in range(config.num_epochs):
|
||||
print(f'Starting epoch {i + 1}')
|
||||
train_loss = train(config, train_loader, model, decoder, criterion, optimizer, config.device)
|
||||
val_loss, val_roc = validate(config, val_loader, model, decoder, criterion, config.device)
|
||||
print(f'Metrics after epoch {i + 1}:\n'
|
||||
f'\tTrain loss: {round(train_loss, 3)}\n',
|
||||
f'\tValidation loss: {round(val_loss, 3)}\n'
|
||||
f'\tValidation ROC-AUC: {round(val_roc, 3)}')
|
||||
|
||||
if val_roc > best_val_roc:
|
||||
print('New best ROC-AUC, saving model')
|
||||
best_val_roc = val_roc
|
||||
if config.tune_8k:
|
||||
model._model_8k.decoder.load_state_dict(decoder.state_dict())
|
||||
else:
|
||||
model._model.decoder.load_state_dict(decoder.state_dict())
|
||||
torch.jit.save(model, config.model_save_path)
|
||||
print('Done')
|
||||
299
tuning/utils.py
Normal file
299
tuning/utils.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
import random
|
||||
import gc
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def read_audio(path: str,
|
||||
sampling_rate: int = 16000,
|
||||
normalize=False):
|
||||
|
||||
wav, sr = torchaudio.load(path)
|
||||
|
||||
if wav.size(0) > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
if sampling_rate:
|
||||
if sr != sampling_rate:
|
||||
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
||||
new_freq=sampling_rate)
|
||||
wav = transform(wav)
|
||||
sr = sampling_rate
|
||||
|
||||
if normalize and wav.abs().max() != 0:
|
||||
wav = wav / wav.abs().max()
|
||||
|
||||
return wav.squeeze(0)
|
||||
|
||||
|
||||
def build_audiomentations_augs(p):
|
||||
from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \
|
||||
LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \
|
||||
Aliasing, AddGaussianNoise
|
||||
transforms = [Aliasing(p=1),
|
||||
AddGaussianNoise(p=1),
|
||||
AirAbsorption(p=1),
|
||||
BandPassFilter(p=1),
|
||||
BandStopFilter(p=1),
|
||||
ClippingDistortion(p=1),
|
||||
HighPassFilter(p=1),
|
||||
HighShelfFilter(p=1),
|
||||
LowPassFilter(p=1),
|
||||
LowShelfFilter(p=1),
|
||||
Mp3Compression(p=1),
|
||||
PeakingFilter(p=1),
|
||||
PitchShift(p=1),
|
||||
RoomSimulator(p=1, leave_length_unchanged=True),
|
||||
SevenBandParametricEQ(p=1)]
|
||||
tr = SomeOf((1, 3), transforms=transforms, p=p)
|
||||
return tr
|
||||
|
||||
|
||||
class SileroVadDataset(Dataset):
|
||||
def __init__(self,
|
||||
config,
|
||||
mode='train'):
|
||||
|
||||
self.num_samples = 512 # constant, do not change
|
||||
self.sr = 16000 # constant, do not change
|
||||
|
||||
self.resample_to_8k = config.tune_8k
|
||||
self.noise_loss = config.noise_loss
|
||||
self.max_train_length_sec = config.max_train_length_sec
|
||||
self.max_train_length_samples = config.max_train_length_sec * self.sr
|
||||
|
||||
assert self.max_train_length_samples % self.num_samples == 0
|
||||
assert mode in ['train', 'val']
|
||||
|
||||
dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path
|
||||
self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True)
|
||||
self.index_dict = self.dataframe.to_dict('index')
|
||||
self.mode = mode
|
||||
print(f'DATASET SIZE : {len(self.dataframe)}')
|
||||
|
||||
if mode == 'train':
|
||||
self.augs = build_audiomentations_augs(p=config.aug_prob)
|
||||
else:
|
||||
self.augs = None
|
||||
|
||||
def __getitem__(self, idx):
|
||||
idx = None if self.mode == 'train' else idx
|
||||
wav, gt, mask = self.load_speech_sample(idx)
|
||||
|
||||
if self.mode == 'train':
|
||||
wav = self.add_augs(wav)
|
||||
if len(wav) > self.max_train_length_samples:
|
||||
wav = wav[:self.max_train_length_samples]
|
||||
gt = gt[:int(self.max_train_length_samples / self.num_samples)]
|
||||
mask = mask[:int(self.max_train_length_samples / self.num_samples)]
|
||||
|
||||
wav = torch.FloatTensor(wav)
|
||||
if self.resample_to_8k:
|
||||
transform = torchaudio.transforms.Resample(orig_freq=self.sr,
|
||||
new_freq=8000)
|
||||
wav = transform(wav)
|
||||
return wav, torch.FloatTensor(gt), torch.from_numpy(mask)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index_dict)
|
||||
|
||||
def load_speech_sample(self, idx=None):
|
||||
if idx is None:
|
||||
idx = random.randint(0, len(self.index_dict) - 1)
|
||||
wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy()
|
||||
|
||||
if len(wav) % self.num_samples != 0:
|
||||
pad_num = self.num_samples - (len(wav) % (self.num_samples))
|
||||
wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0)
|
||||
|
||||
gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav))
|
||||
|
||||
assert len(gt) == len(wav) / self.num_samples
|
||||
|
||||
mask[gt == 0]
|
||||
|
||||
return wav, gt, mask
|
||||
|
||||
def get_ground_truth_annotated(self, annotation, audio_length_samples):
|
||||
gt = np.zeros(audio_length_samples)
|
||||
|
||||
for i in annotation:
|
||||
gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1
|
||||
|
||||
squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1)
|
||||
squeezed_predicts = (squeezed_predicts > 0.5).astype(int)
|
||||
mask = np.ones(len(squeezed_predicts))
|
||||
mask[squeezed_predicts == 0] = self.noise_loss
|
||||
return squeezed_predicts, mask
|
||||
|
||||
def add_augs(self, wav):
|
||||
while True:
|
||||
try:
|
||||
wav_aug = self.augs(wav, self.sr)
|
||||
if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()):
|
||||
return wav
|
||||
return wav_aug
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
|
||||
def SileroVadPadder(batch):
|
||||
wavs = [batch[i][0] for i in range(len(batch))]
|
||||
labels = [batch[i][1] for i in range(len(batch))]
|
||||
masks = [batch[i][2] for i in range(len(batch))]
|
||||
|
||||
wavs = torch.nn.utils.rnn.pad_sequence(
|
||||
wavs, batch_first=True, padding_value=0)
|
||||
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
labels, batch_first=True, padding_value=0)
|
||||
|
||||
masks = torch.nn.utils.rnn.pad_sequence(
|
||||
masks, batch_first=True, padding_value=0)
|
||||
|
||||
return wavs, labels, masks
|
||||
|
||||
|
||||
class VADDecoderRNNJIT(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(VADDecoderRNNJIT, self).__init__()
|
||||
|
||||
self.rnn = nn.LSTMCell(128, 128)
|
||||
self.decoder = nn.Sequential(nn.Dropout(0.1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(128, 1, kernel_size=1),
|
||||
nn.Sigmoid())
|
||||
|
||||
def forward(self, x, state=torch.zeros(0)):
|
||||
x = x.squeeze(-1)
|
||||
if len(state):
|
||||
h, c = self.rnn(x, (state[0], state[1]))
|
||||
else:
|
||||
h, c = self.rnn(x)
|
||||
|
||||
x = h.unsqueeze(-1).float()
|
||||
state = torch.stack([h, c])
|
||||
x = self.decoder(x)
|
||||
return x, state
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def train(config,
|
||||
loader,
|
||||
jit_model,
|
||||
decoder,
|
||||
criterion,
|
||||
optimizer,
|
||||
device):
|
||||
|
||||
losses = AverageMeter()
|
||||
decoder.train()
|
||||
|
||||
context_size = 32 if config.tune_8k else 64
|
||||
num_samples = 256 if config.tune_8k else 512
|
||||
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
|
||||
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
|
||||
|
||||
with torch.enable_grad():
|
||||
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||
targets = targets.to(device)
|
||||
x = x.to(device)
|
||||
masks = masks.to(device)
|
||||
x = torch.nn.functional.pad(x, (context_size, 0))
|
||||
|
||||
outs = []
|
||||
state = torch.zeros(0)
|
||||
for i in range(context_size, x.shape[1], num_samples):
|
||||
input_ = x[:, i-context_size:i+num_samples]
|
||||
out = stft_layer(input_)
|
||||
out = encoder_layer(out)
|
||||
out, state = decoder(out, state)
|
||||
outs.append(out)
|
||||
stacked = torch.cat(outs, dim=2).squeeze(1)
|
||||
|
||||
loss = criterion(stacked, targets)
|
||||
loss = (loss * masks).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
losses.update(loss.item(), masks.numel())
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return losses.avg
|
||||
|
||||
|
||||
def validate(config,
|
||||
loader,
|
||||
jit_model,
|
||||
decoder,
|
||||
criterion,
|
||||
device):
|
||||
|
||||
losses = AverageMeter()
|
||||
decoder.eval()
|
||||
|
||||
predicts = []
|
||||
gts = []
|
||||
|
||||
context_size = 32 if config.tune_8k else 64
|
||||
num_samples = 256 if config.tune_8k else 512
|
||||
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
|
||||
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
|
||||
|
||||
with torch.no_grad():
|
||||
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||
targets = targets.to(device)
|
||||
x = x.to(device)
|
||||
masks = masks.to(device)
|
||||
x = torch.nn.functional.pad(x, (context_size, 0))
|
||||
|
||||
outs = []
|
||||
state = torch.zeros(0)
|
||||
for i in range(context_size, x.shape[1], num_samples):
|
||||
input_ = x[:, i-context_size:i+num_samples]
|
||||
out = stft_layer(input_)
|
||||
out = encoder_layer(out)
|
||||
out, state = decoder(out, state)
|
||||
outs.append(out)
|
||||
stacked = torch.cat(outs, dim=2).squeeze(1)
|
||||
|
||||
predicts.extend(stacked[masks != 0].tolist())
|
||||
gts.extend(targets[masks != 0].tolist())
|
||||
|
||||
loss = criterion(stacked, targets)
|
||||
loss = (loss * masks).mean()
|
||||
losses.update(loss.item(), masks.numel())
|
||||
score = roc_auc_score(gts, predicts)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return losses.avg, round(score, 3)
|
||||
Reference in New Issue
Block a user