From 1fc6b72ac926b294021cd744896ebaea3d384062 Mon Sep 17 00:00:00 2001 From: adamnsandle Date: Fri, 27 Aug 2021 10:10:14 +0000 Subject: [PATCH] additional vad utils --- utils_vad.py | 15 ++++++++--- utils_vad_additional.py | 56 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 4 deletions(-) create mode 100644 utils_vad_additional.py diff --git a/utils_vad.py b/utils_vad.py index f65a089..ffb3c3a 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -86,8 +86,11 @@ def get_speech_ts(wav: torch.Tensor, min_speech_samples: int = 10000, #samples min_silence_samples: int = 500, run_function=validate, - visualize_probs=False): + visualize_probs=False, + smoothed_prob_func='mean', + device='cpu'): + assert smoothed_prob_func in ['mean', 'max'], 'smoothed_prob_func not in ["max", "mean"]' num_samples = num_samples_per_window assert num_samples % num_steps == 0 step = int(num_samples / num_steps) # stride / hop @@ -99,13 +102,13 @@ def get_speech_ts(wav: torch.Tensor, chunk = F.pad(chunk, (0, num_samples - len(chunk))) to_concat.append(chunk.unsqueeze(0)) if len(to_concat) >= batch_size: - chunks = torch.Tensor(torch.cat(to_concat, dim=0)) + chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device) out = run_function(model, chunks) outs.append(out) to_concat = [] if to_concat: - chunks = torch.Tensor(torch.cat(to_concat, dim=0)) + chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device) out = run_function(model, chunks) outs.append(out) @@ -123,7 +126,11 @@ def get_speech_ts(wav: torch.Tensor, temp_end = 0 for i, predict in enumerate(speech_probs): # add name buffer.append(predict) - smoothed_prob = (sum(buffer) / len(buffer)) + if smoothed_prob_func == 'mean': + smoothed_prob = (sum(buffer) / len(buffer)) + elif smoothed_prob_func == 'max': + smoothed_prob = max(buffer) + if visualize_probs: smoothed_probs.append(float(smoothed_prob)) if (smoothed_prob >= trig_sum) and temp_end: diff --git a/utils_vad_additional.py b/utils_vad_additional.py new file mode 100644 index 0000000..92bf33f --- /dev/null +++ b/utils_vad_additional.py @@ -0,0 +1,56 @@ +from utils_vad import * +import sys +import os +from pathlib import Path +sys.path.append('/home/keras/notebook/nvme_raid/adamnsandle/silero_mono/pipelines/align/bin/') +from align_utils import load_audio_norm +import torch +import pandas as pd +import numpy as np +sys.path.append('/home/keras/notebook/nvme_raid/adamnsandle/silero_mono/utils/') +from open_stt import soundfile_opus as sf + +def split_save_audio_chunks(audio_path, model_path, save_path=None, device='cpu', absolute=True, max_duration=10, adaptive=False, **kwargs): + + if not save_path: + save_path = str(Path(audio_path).with_name('after_vad')) + print(f'No save path specified! Using {save_path} to save audio chunks!') + + SAMPLE_RATE = 16000 + if type(model_path) == str: + #print('Loading model...') + model = init_jit_model(model_path, device) + else: + #print('Using loaded model') + model = model_path + save_name = Path(audio_path).stem + audio, sr = load_audio_norm(audio_path) + wav = torch.tensor(audio) + if adaptive: + speech_timestamps = get_speech_ts_adaptive(wav, model, device=device, **kwargs) + else: + speech_timestamps = get_speech_ts(wav, model, device=device, **kwargs) + + full_save_path = Path(save_path, save_name) + if not os.path.exists(full_save_path): + os.makedirs(full_save_path, exist_ok=True) + + chunks = [] + if not speech_timestamps: + return pd.DataFrame() + for ts in speech_timestamps: + start_ts = int(ts['start']) + end_ts = int(ts['end']) + + for i in range(start_ts, end_ts, max_duration * SAMPLE_RATE): + new_start = i + new_end = min(end_ts, i + max_duration * SAMPLE_RATE) + duration = round((new_end - new_start) / SAMPLE_RATE, 2) + chunk_path = Path(full_save_path, f'{save_name}_{new_start}-{new_end}.opus') + chunk_path = chunk_path.absolute() if absolute else chunk_path + sf.write(str(chunk_path), audio[new_start: new_end], 16000, format='OGG', subtype='OPUS') + chunks.append({'audio_path': chunk_path, + 'text': '', + 'duration': duration, + 'domain': ''}) + return pd.DataFrame(chunks) \ No newline at end of file