mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
357 lines
12 KiB
Python
357 lines
12 KiB
Python
from sklearn.metrics import roc_auc_score, accuracy_score
|
|
from torch.utils.data import Dataset
|
|
import torch.nn as nn
|
|
from tqdm import tqdm
|
|
import pandas as pd
|
|
import numpy as np
|
|
import torchaudio
|
|
import warnings
|
|
import random
|
|
import torch
|
|
import gc
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
def read_audio(path: str,
|
|
sampling_rate: int = 16000,
|
|
normalize=False):
|
|
|
|
wav, sr = torchaudio.load(path)
|
|
|
|
if wav.size(0) > 1:
|
|
wav = wav.mean(dim=0, keepdim=True)
|
|
|
|
if sampling_rate:
|
|
if sr != sampling_rate:
|
|
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
|
new_freq=sampling_rate)
|
|
wav = transform(wav)
|
|
sr = sampling_rate
|
|
|
|
if normalize and wav.abs().max() != 0:
|
|
wav = wav / wav.abs().max()
|
|
|
|
return wav.squeeze(0)
|
|
|
|
|
|
def build_audiomentations_augs(p):
|
|
from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \
|
|
LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \
|
|
Aliasing, AddGaussianNoise
|
|
transforms = [Aliasing(p=1),
|
|
AddGaussianNoise(p=1),
|
|
AirAbsorption(p=1),
|
|
BandPassFilter(p=1),
|
|
BandStopFilter(p=1),
|
|
ClippingDistortion(p=1),
|
|
HighPassFilter(p=1),
|
|
HighShelfFilter(p=1),
|
|
LowPassFilter(p=1),
|
|
LowShelfFilter(p=1),
|
|
Mp3Compression(p=1),
|
|
PeakingFilter(p=1),
|
|
PitchShift(p=1),
|
|
RoomSimulator(p=1, leave_length_unchanged=True),
|
|
SevenBandParametricEQ(p=1)]
|
|
tr = SomeOf((1, 3), transforms=transforms, p=p)
|
|
return tr
|
|
|
|
|
|
class SileroVadDataset(Dataset):
|
|
def __init__(self,
|
|
config,
|
|
mode='train'):
|
|
|
|
self.num_samples = 512 # constant, do not change
|
|
self.sr = 16000 # constant, do not change
|
|
|
|
self.resample_to_8k = config.tune_8k
|
|
self.noise_loss = config.noise_loss
|
|
self.max_train_length_sec = config.max_train_length_sec
|
|
self.max_train_length_samples = config.max_train_length_sec * self.sr
|
|
|
|
assert self.max_train_length_samples % self.num_samples == 0
|
|
assert mode in ['train', 'val']
|
|
|
|
dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path
|
|
self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True)
|
|
self.index_dict = self.dataframe.to_dict('index')
|
|
self.mode = mode
|
|
print(f'DATASET SIZE : {len(self.dataframe)}')
|
|
|
|
if mode == 'train':
|
|
self.augs = build_audiomentations_augs(p=config.aug_prob)
|
|
else:
|
|
self.augs = None
|
|
|
|
def __getitem__(self, idx):
|
|
idx = None if self.mode == 'train' else idx
|
|
wav, gt, mask = self.load_speech_sample(idx)
|
|
|
|
if self.mode == 'train':
|
|
wav = self.add_augs(wav)
|
|
if len(wav) > self.max_train_length_samples:
|
|
wav = wav[:self.max_train_length_samples]
|
|
gt = gt[:int(self.max_train_length_samples / self.num_samples)]
|
|
mask = mask[:int(self.max_train_length_samples / self.num_samples)]
|
|
|
|
wav = torch.FloatTensor(wav)
|
|
if self.resample_to_8k:
|
|
transform = torchaudio.transforms.Resample(orig_freq=self.sr,
|
|
new_freq=8000)
|
|
wav = transform(wav)
|
|
return wav, torch.FloatTensor(gt), torch.from_numpy(mask)
|
|
|
|
def __len__(self):
|
|
return len(self.index_dict)
|
|
|
|
def load_speech_sample(self, idx=None):
|
|
if idx is None:
|
|
idx = random.randint(0, len(self.index_dict) - 1)
|
|
wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy()
|
|
|
|
if len(wav) % self.num_samples != 0:
|
|
pad_num = self.num_samples - (len(wav) % (self.num_samples))
|
|
wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0)
|
|
|
|
gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav))
|
|
|
|
assert len(gt) == len(wav) / self.num_samples
|
|
|
|
return wav, gt, mask
|
|
|
|
def get_ground_truth_annotated(self, annotation, audio_length_samples):
|
|
gt = np.zeros(audio_length_samples)
|
|
|
|
for i in annotation:
|
|
gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1
|
|
|
|
squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1)
|
|
squeezed_predicts = (squeezed_predicts > 0.5).astype(int)
|
|
mask = np.ones(len(squeezed_predicts))
|
|
mask[squeezed_predicts == 0] = self.noise_loss
|
|
return squeezed_predicts, mask
|
|
|
|
def add_augs(self, wav):
|
|
while True:
|
|
try:
|
|
wav_aug = self.augs(wav, self.sr)
|
|
if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()):
|
|
return wav
|
|
return wav_aug
|
|
except Exception as e:
|
|
continue
|
|
|
|
|
|
def SileroVadPadder(batch):
|
|
wavs = [batch[i][0] for i in range(len(batch))]
|
|
labels = [batch[i][1] for i in range(len(batch))]
|
|
masks = [batch[i][2] for i in range(len(batch))]
|
|
|
|
wavs = torch.nn.utils.rnn.pad_sequence(
|
|
wavs, batch_first=True, padding_value=0)
|
|
|
|
labels = torch.nn.utils.rnn.pad_sequence(
|
|
labels, batch_first=True, padding_value=0)
|
|
|
|
masks = torch.nn.utils.rnn.pad_sequence(
|
|
masks, batch_first=True, padding_value=0)
|
|
|
|
return wavs, labels, masks
|
|
|
|
|
|
class VADDecoderRNNJIT(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(VADDecoderRNNJIT, self).__init__()
|
|
|
|
self.rnn = nn.LSTMCell(128, 128)
|
|
self.decoder = nn.Sequential(nn.Dropout(0.1),
|
|
nn.ReLU(),
|
|
nn.Conv1d(128, 1, kernel_size=1),
|
|
nn.Sigmoid())
|
|
|
|
def forward(self, x, state=torch.zeros(0)):
|
|
x = x.squeeze(-1)
|
|
if len(state):
|
|
h, c = self.rnn(x, (state[0], state[1]))
|
|
else:
|
|
h, c = self.rnn(x)
|
|
|
|
x = h.unsqueeze(-1).float()
|
|
state = torch.stack([h, c])
|
|
x = self.decoder(x)
|
|
return x, state
|
|
|
|
|
|
class AverageMeter(object):
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
def train(config,
|
|
loader,
|
|
jit_model,
|
|
decoder,
|
|
criterion,
|
|
optimizer,
|
|
device):
|
|
|
|
losses = AverageMeter()
|
|
decoder.train()
|
|
|
|
context_size = 32 if config.tune_8k else 64
|
|
num_samples = 256 if config.tune_8k else 512
|
|
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
|
|
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
|
|
|
|
with torch.enable_grad():
|
|
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
|
targets = targets.to(device)
|
|
x = x.to(device)
|
|
masks = masks.to(device)
|
|
x = torch.nn.functional.pad(x, (context_size, 0))
|
|
|
|
outs = []
|
|
state = torch.zeros(0)
|
|
for i in range(context_size, x.shape[1], num_samples):
|
|
input_ = x[:, i-context_size:i+num_samples]
|
|
out = stft_layer(input_)
|
|
out = encoder_layer(out)
|
|
out, state = decoder(out, state)
|
|
outs.append(out)
|
|
stacked = torch.cat(outs, dim=2).squeeze(1)
|
|
|
|
loss = criterion(stacked, targets)
|
|
loss = (loss * masks).mean()
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
losses.update(loss.item(), masks.numel())
|
|
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
return losses.avg
|
|
|
|
|
|
def validate(config,
|
|
loader,
|
|
jit_model,
|
|
decoder,
|
|
criterion,
|
|
device):
|
|
|
|
losses = AverageMeter()
|
|
decoder.eval()
|
|
|
|
predicts = []
|
|
gts = []
|
|
|
|
context_size = 32 if config.tune_8k else 64
|
|
num_samples = 256 if config.tune_8k else 512
|
|
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
|
|
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
|
|
|
|
with torch.no_grad():
|
|
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
|
targets = targets.to(device)
|
|
x = x.to(device)
|
|
masks = masks.to(device)
|
|
x = torch.nn.functional.pad(x, (context_size, 0))
|
|
|
|
outs = []
|
|
state = torch.zeros(0)
|
|
for i in range(context_size, x.shape[1], num_samples):
|
|
input_ = x[:, i-context_size:i+num_samples]
|
|
out = stft_layer(input_)
|
|
out = encoder_layer(out)
|
|
out, state = decoder(out, state)
|
|
outs.append(out)
|
|
stacked = torch.cat(outs, dim=2).squeeze(1)
|
|
|
|
predicts.extend(stacked[masks != 0].tolist())
|
|
gts.extend(targets[masks != 0].tolist())
|
|
|
|
loss = criterion(stacked, targets)
|
|
loss = (loss * masks).mean()
|
|
losses.update(loss.item(), masks.numel())
|
|
score = roc_auc_score(gts, predicts)
|
|
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
return losses.avg, round(score, 3)
|
|
|
|
|
|
def init_jit_model(model_path: str,
|
|
device=torch.device('cpu')):
|
|
torch.set_grad_enabled(False)
|
|
model = torch.jit.load(model_path, map_location=device)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def predict(model, loader, device, sr):
|
|
with torch.no_grad():
|
|
all_predicts = []
|
|
all_gts = []
|
|
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
|
x = x.to(device)
|
|
out = model.audio_forward(x, sr=sr)
|
|
|
|
for i, out_chunk in enumerate(out):
|
|
predict = out_chunk[masks[i] != 0].cpu().tolist()
|
|
gt = targets[i, masks[i] != 0].cpu().tolist()
|
|
|
|
all_predicts.append(predict)
|
|
all_gts.append(gt)
|
|
return all_predicts, all_gts
|
|
|
|
|
|
def calculate_best_thresholds(all_predicts, all_gts):
|
|
best_acc = 0
|
|
for ths_enter in tqdm(np.linspace(0, 1, 20)):
|
|
for ths_exit in np.linspace(0, 1, 20):
|
|
if ths_exit >= ths_enter:
|
|
continue
|
|
|
|
accs = []
|
|
for j, predict in enumerate(all_predicts):
|
|
predict_bool = []
|
|
is_speech = False
|
|
for i in predict:
|
|
if i >= ths_enter:
|
|
is_speech = True
|
|
predict_bool.append(1)
|
|
elif i <= ths_exit:
|
|
is_speech = False
|
|
predict_bool.append(0)
|
|
else:
|
|
val = 1 if is_speech else 0
|
|
predict_bool.append(val)
|
|
|
|
score = round(accuracy_score(all_gts[j], predict_bool), 4)
|
|
accs.append(score)
|
|
|
|
mean_acc = round(np.mean(accs), 3)
|
|
if mean_acc > best_acc:
|
|
best_acc = mean_acc
|
|
best_ths_enter = round(ths_enter, 2)
|
|
best_ths_exit = round(ths_exit, 2)
|
|
return best_ths_enter, best_ths_exit, best_acc
|