mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
добавлен поиск порогов
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sklearn.metrics import roc_auc_score, accuracy_score
|
||||
from torch.utils.data import Dataset
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
import random
|
||||
import gc
|
||||
from sklearn.metrics import roc_auc_score
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
import warnings
|
||||
import random
|
||||
import torch
|
||||
import gc
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
@@ -297,3 +297,61 @@ def validate(config,
|
||||
gc.collect()
|
||||
|
||||
return losses.avg, round(score, 3)
|
||||
|
||||
|
||||
def init_jit_model(model_path: str,
|
||||
device=torch.device('cpu')):
|
||||
torch.set_grad_enabled(False)
|
||||
model = torch.jit.load(model_path, map_location=device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def predict(model, loader, device, sr):
|
||||
with torch.no_grad():
|
||||
all_predicts = []
|
||||
all_gts = []
|
||||
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
|
||||
x = x.to(device)
|
||||
out = model.audio_forward(x, sr=sr)
|
||||
|
||||
for i, out_chunk in enumerate(out):
|
||||
predict = out_chunk[masks[i] != 0].cpu().tolist()
|
||||
gt = targets[i, masks[i] != 0].cpu().tolist()
|
||||
|
||||
all_predicts.append(predict)
|
||||
all_gts.append(gt)
|
||||
return all_predicts, all_gts
|
||||
|
||||
|
||||
def calculate_best_thresholds(all_predicts, all_gts):
|
||||
best_acc = 0
|
||||
for ths_enter in tqdm(np.linspace(0, 1, 20)):
|
||||
for ths_exit in np.linspace(0, 1, 20):
|
||||
if ths_exit >= ths_enter:
|
||||
continue
|
||||
|
||||
accs = []
|
||||
for j, predict in enumerate(all_predicts):
|
||||
predict_bool = []
|
||||
is_speech = False
|
||||
for i in predict:
|
||||
if i >= ths_enter:
|
||||
is_speech = True
|
||||
predict_bool.append(1)
|
||||
elif i <= ths_exit:
|
||||
is_speech = False
|
||||
predict_bool.append(0)
|
||||
else:
|
||||
val = 1 if is_speech else 0
|
||||
predict_bool.append(val)
|
||||
|
||||
score = round(accuracy_score(all_gts[j], predict_bool), 4)
|
||||
accs.append(score)
|
||||
|
||||
mean_acc = round(np.mean(accs), 3)
|
||||
if mean_acc > best_acc:
|
||||
best_acc = mean_acc
|
||||
best_ths_enter = round(ths_enter, 2)
|
||||
best_ths_exit = round(ths_exit, 2)
|
||||
return best_ths_enter, best_ths_exit, best_acc
|
||||
|
||||
Reference in New Issue
Block a user