1 Commits

Author SHA1 Message Date
Alexander Veysov
29102cf6a3 Update README.md 2024-09-24 15:16:05 +03:00
34 changed files with 417 additions and 1411 deletions

View File

@@ -1,39 +0,0 @@
name: Test Package
on:
workflow_dispatch: # запуск вручную
jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.8","3.9","3.10","3.11","3.12","3.13"]
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build hatchling pytest soundfile
- name: Build package
run: python -m build --wheel --outdir dist
- name: Install package
run: |
import glob, subprocess, sys
whl = glob.glob("dist/*.whl")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", whl])
shell: python
- name: Run tests
run: pytest tests

View File

@@ -1,20 +0,0 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
title: "Silero VAD"
authors:
- family-names: "Silero Team"
email: "hello@silero.ai"
type: software
repository-code: "https://github.com/snakers4/silero-vad"
license: MIT
abstract: "Pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier"
preferred-citation:
type: software
authors:
- family-names: "Silero Team"
email: "hello@silero.ai"
title: "Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier"
year: 2024
publisher: "GitHub"
journal: "GitHub repository"
howpublished: "https://github.com/snakers4/silero-vad"

View File

@@ -1,6 +1,6 @@
[![Mailing list : test](http://img.shields.io/badge/Email-gray.svg?style=for-the-badge&logo=gmail)](mailto:hello@silero.ai) [![Mailing list : test](http://img.shields.io/badge/Telegram-blue.svg?style=for-the-badge&logo=telegram)](https://t.me/silero_speech) [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-MIT-lightgrey.svg?style=for-the-badge)](https://github.com/snakers4/silero-vad/blob/master/LICENSE) [![downloads](https://img.shields.io/pypi/dm/silero-vad?style=for-the-badge)](https://pypi.org/project/silero-vad/) [![Mailing list : test](http://img.shields.io/badge/Email-gray.svg?style=for-the-badge&logo=gmail)](mailto:hello@silero.ai) [![Mailing list : test](http://img.shields.io/badge/Telegram-blue.svg?style=for-the-badge&logo=telegram)](https://t.me/silero_speech) [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-MIT-lightgrey.svg?style=for-the-badge)](https://github.com/snakers4/silero-vad/blob/master/LICENSE)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) [![Test Package](https://github.com/snakers4/silero-vad/actions/workflows/test.yml/badge.svg)](https://github.com/snakers4/silero-vad/actions/workflows/test.yml) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb)
![header](https://user-images.githubusercontent.com/12515440/89997349-b3523080-dc94-11ea-9906-ca2e8bc50535.png) ![header](https://user-images.githubusercontent.com/12515440/89997349-b3523080-dc94-11ea-9906-ca2e8bc50535.png)
@@ -13,7 +13,7 @@
<br/> <br/>
<p align="center"> <p align="center">
<img src="https://github.com/user-attachments/assets/f2940867-0a51-4bdb-8c14-1129d3c44e64" /> <img src="https://github.com/snakers4/silero-vad/assets/36505480/300bd062-4da5-4f19-9736-9c144a45d7a7" />
</p> </p>
@@ -22,8 +22,6 @@
https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4 https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4
Please note, that video loads only if you are logged in your GitHub account.
</details> </details>
<br/> <br/>
@@ -66,11 +64,7 @@ If you are planning to run the VAD using solely the `onnx-runtime`, it will run
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
model = load_silero_vad() model = load_silero_vad()
wav = read_audio('path_to_audio_file') wav = read_audio('path_to_audio_file')
speech_timestamps = get_speech_timestamps( speech_timestamps = get_speech_timestamps(wav, model)
wav,
model,
return_seconds=True, # Return speech timestamps in seconds (default is samples)
)
``` ```
**Using torch.hub**: **Using torch.hub**:
@@ -82,11 +76,7 @@ model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_v
(get_speech_timestamps, _, read_audio, _, _) = utils (get_speech_timestamps, _, read_audio, _, _) = utils
wav = read_audio('path_to_audio_file') wav = read_audio('path_to_audio_file')
speech_timestamps = get_speech_timestamps( speech_timestamps = get_speech_timestamps(wav, model)
wav,
model,
return_seconds=True, # Return speech timestamps in seconds (default is samples)
)
``` ```
<br/> <br/>
@@ -175,4 +165,4 @@ Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for releva
- Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web - Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web
- [Rust](https://github.com/snakers4/silero-vad/tree/master/examples/rust-example), [Go](https://github.com/snakers4/silero-vad/tree/master/examples/go), [Java](https://github.com/snakers4/silero-vad/tree/master/examples/java-example), [C++](https://github.com/snakers4/silero-vad/tree/master/examples/cpp), [C#](https://github.com/snakers4/silero-vad/tree/master/examples/csharp) and [other](https://github.com/snakers4/silero-vad/tree/master/examples) community examples - [Rust](https://github.com/snakers4/silero-vad/tree/master/examples/rust-example), [Go](https://github.com/snakers4/silero-vad/tree/master/examples/go), [Java](https://github.com/snakers4/silero-vad/tree/master/examples/java-example) and [other](https://github.com/snakers4/silero-vad/tree/master/examples) examples

View File

@@ -17,7 +17,6 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#!apt install ffmpeg\n",
"!pip -q install pydub\n", "!pip -q install pydub\n",
"from google.colab import output\n", "from google.colab import output\n",
"from base64 import b64decode, b64encode\n", "from base64 import b64decode, b64encode\n",
@@ -38,12 +37,13 @@
" model='silero_vad',\n", " model='silero_vad',\n",
" force_reload=True)\n", " force_reload=True)\n",
"\n", "\n",
"def int2float(audio):\n", "def int2float(sound):\n",
" samples = audio.get_array_of_samples()\n", " abs_max = np.abs(sound).max()\n",
" new_sound = audio._spawn(samples)\n", " sound = sound.astype('float32')\n",
" arr = np.array(samples).astype(np.float32)\n", " if abs_max > 0:\n",
" arr = arr / np.abs(arr).max()\n", " sound *= 1/32768\n",
" return arr\n", " sound = sound.squeeze()\n",
" return sound\n",
"\n", "\n",
"AUDIO_HTML = \"\"\"\n", "AUDIO_HTML = \"\"\"\n",
"<script>\n", "<script>\n",
@@ -68,10 +68,10 @@
" //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n", " //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n",
" mimeType : 'audio/webm;codecs=opus'\n", " mimeType : 'audio/webm;codecs=opus'\n",
" //mimeType : 'audio/webm;codecs=pcm'\n", " //mimeType : 'audio/webm;codecs=pcm'\n",
" };\n", " }; \n",
" //recorder = new MediaRecorder(stream, options);\n", " //recorder = new MediaRecorder(stream, options);\n",
" recorder = new MediaRecorder(stream);\n", " recorder = new MediaRecorder(stream);\n",
" recorder.ondataavailable = function(e) {\n", " recorder.ondataavailable = function(e) { \n",
" var url = URL.createObjectURL(e.data);\n", " var url = URL.createObjectURL(e.data);\n",
" // var preview = document.createElement('audio');\n", " // var preview = document.createElement('audio');\n",
" // preview.controls = true;\n", " // preview.controls = true;\n",
@@ -79,7 +79,7 @@
" // document.body.appendChild(preview);\n", " // document.body.appendChild(preview);\n",
"\n", "\n",
" reader = new FileReader();\n", " reader = new FileReader();\n",
" reader.readAsDataURL(e.data);\n", " reader.readAsDataURL(e.data); \n",
" reader.onloadend = function() {\n", " reader.onloadend = function() {\n",
" base64data = reader.result;\n", " base64data = reader.result;\n",
" //console.log(\"Inside FileReader:\" + base64data);\n", " //console.log(\"Inside FileReader:\" + base64data);\n",
@@ -121,7 +121,7 @@
"\n", "\n",
"}\n", "}\n",
"});\n", "});\n",
"\n", " \n",
"</script>\n", "</script>\n",
"\"\"\"\n", "\"\"\"\n",
"\n", "\n",
@@ -133,8 +133,8 @@
" audio.export('test.mp3', format='mp3')\n", " audio.export('test.mp3', format='mp3')\n",
" audio = audio.set_channels(1)\n", " audio = audio.set_channels(1)\n",
" audio = audio.set_frame_rate(16000)\n", " audio = audio.set_frame_rate(16000)\n",
" audio_float = int2float(audio)\n", " audio_float = int2float(np.array(audio.get_array_of_samples()))\n",
" audio_tens = torch.tensor(audio_float)\n", " audio_tens = torch.tensor(audio_float )\n",
" return audio_tens\n", " return audio_tens\n",
"\n", "\n",
"def make_animation(probs, audio_duration, interval=40):\n", "def make_animation(probs, audio_duration, interval=40):\n",
@@ -154,18 +154,19 @@
" def animate(i):\n", " def animate(i):\n",
" x = i * interval / 1000 - 0.04\n", " x = i * interval / 1000 - 0.04\n",
" y = np.linspace(0, 1.02, 2)\n", " y = np.linspace(0, 1.02, 2)\n",
"\n", " \n",
" line.set_data(x, y)\n", " line.set_data(x, y)\n",
" line.set_color('#990000')\n", " line.set_color('#990000')\n",
" return line,\n", " return line,\n",
" anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=int(audio_duration / (interval / 1000)))\n",
"\n", "\n",
" f = r\"animation.mp4\"\n", " anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=audio_duration / (interval / 1000))\n",
" writervideo = FFMpegWriter(fps=1000/interval)\n", "\n",
" f = r\"animation.mp4\" \n",
" writervideo = FFMpegWriter(fps=1000/interval) \n",
" anim.save(f, writer=writervideo)\n", " anim.save(f, writer=writervideo)\n",
" plt.close('all')\n", " plt.close('all')\n",
"\n", "\n",
"def combine_audio(vidname, audname, outname, fps=25):\n", "def combine_audio(vidname, audname, outname, fps=25): \n",
" my_clip = mpe.VideoFileClip(vidname, verbose=False)\n", " my_clip = mpe.VideoFileClip(vidname, verbose=False)\n",
" audio_background = mpe.AudioFileClip(audname)\n", " audio_background = mpe.AudioFileClip(audname)\n",
" final_clip = my_clip.set_audio(audio_background)\n", " final_clip = my_clip.set_audio(audio_background)\n",
@@ -173,10 +174,15 @@
"\n", "\n",
"def record_make_animation():\n", "def record_make_animation():\n",
" tensor = record()\n", " tensor = record()\n",
"\n",
" print('Calculating probabilities...')\n", " print('Calculating probabilities...')\n",
" speech_probs = []\n", " speech_probs = []\n",
" window_size_samples = 512\n", " window_size_samples = 512\n",
" speech_probs = model.audio_forward(tensor, sr=16000)[0].tolist()\n", " for i in range(0, len(tensor), window_size_samples):\n",
" if len(tensor[i: i+ window_size_samples]) < window_size_samples:\n",
" break\n",
" speech_prob = model(tensor[i: i+ window_size_samples], 16000).item()\n",
" speech_probs.append(speech_prob)\n",
" model.reset_states()\n", " model.reset_states()\n",
" print('Making animation...')\n", " print('Making animation...')\n",
" make_animation(speech_probs, len(tensor) / 16000)\n", " make_animation(speech_probs, len(tensor) / 16000)\n",
@@ -190,9 +196,7 @@
" <video width=800 controls>\n", " <video width=800 controls>\n",
" <source src=\"%s\" type=\"video/mp4\">\n", " <source src=\"%s\" type=\"video/mp4\">\n",
" </video>\n", " </video>\n",
" \"\"\" % data_url))\n", " \"\"\" % data_url))"
"\n",
" return speech_probs"
] ]
}, },
{ {
@@ -212,7 +216,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"speech_probs = record_make_animation()" "record_make_animation()"
] ]
} }
], ],

View File

@@ -1,227 +1,211 @@
#ifndef _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_WARNINGS
#endif
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <cstring> #include <cstring>
#include <limits> #include <limits>
#include <chrono> #include <chrono>
#include <iomanip>
#include <memory> #include <memory>
#include <string> #include <string>
#include <stdexcept> #include <stdexcept>
#include <iostream>
#include <string>
#include "onnxruntime_cxx_api.h"
#include "wav.h"
#include <cstdio> #include <cstdio>
#include <cstdarg> #include <cstdarg>
#include <cmath> // for std::rint
#if __cplusplus < 201703L #if __cplusplus < 201703L
#include <memory> #include <memory>
#endif #endif
//#define __DEBUG_SPEECH_PROB___ //#define __DEBUG_SPEECH_PROB___
#include "onnxruntime_cxx_api.h" class timestamp_t
#include "wav.h" // For reading WAV files {
// timestamp_t class: stores the start and end (in samples) of a speech segment.
class timestamp_t {
public: public:
int start; int start;
int end; int end;
// default + parameterized constructor
timestamp_t(int start = -1, int end = -1) timestamp_t(int start = -1, int end = -1)
: start(start), end(end) { } : start(start), end(end)
{
};
timestamp_t& operator=(const timestamp_t& a) { // assignment operator modifies object, therefore non-const
timestamp_t& operator=(const timestamp_t& a)
{
start = a.start; start = a.start;
end = a.end; end = a.end;
return *this; return *this;
} };
bool operator==(const timestamp_t& a) const { // equality comparison. doesn't modify object. therefore const.
bool operator==(const timestamp_t& a) const
{
return (start == a.start && end == a.end); return (start == a.start && end == a.end);
} };
std::string c_str()
// Returns a formatted string of the timestamp. {
std::string c_str() const { //return std::format("timestamp {:08d}, {:08d}", start, end);
return format("{start:%08d, end:%08d}", start, end); return format("{start:%08d,end:%08d}", start, end);
} };
private: private:
// Helper function for formatting.
std::string format(const char* fmt, ...) const { std::string format(const char* fmt, ...)
{
char buf[256]; char buf[256];
va_list args; va_list args;
va_start(args, fmt); va_start(args, fmt);
const auto r = std::vsnprintf(buf, sizeof(buf), fmt, args); const auto r = std::vsnprintf(buf, sizeof buf, fmt, args);
va_end(args); va_end(args);
if (r < 0) if (r < 0)
// conversion failed
return {}; return {};
const size_t len = r; const size_t len = r;
if (len < sizeof(buf)) if (len < sizeof buf)
return std::string(buf, len); // we fit in the buffer
return { buf, len };
#if __cplusplus >= 201703L #if __cplusplus >= 201703L
// C++17: Create a string and write to its underlying array
std::string s(len, '\0'); std::string s(len, '\0');
va_start(args, fmt); va_start(args, fmt);
std::vsnprintf(s.data(), len + 1, fmt, args); std::vsnprintf(s.data(), len + 1, fmt, args);
va_end(args); va_end(args);
return s; return s;
#else #else
// C++11 or C++14: We need to allocate scratch memory
auto vbuf = std::unique_ptr<char[]>(new char[len + 1]); auto vbuf = std::unique_ptr<char[]>(new char[len + 1]);
va_start(args, fmt); va_start(args, fmt);
std::vsnprintf(vbuf.get(), len + 1, fmt, args); std::vsnprintf(vbuf.get(), len + 1, fmt, args);
va_end(args); va_end(args);
return std::string(vbuf.get(), len);
return { vbuf.get(), len };
#endif #endif
} };
}; };
// VadIterator class: uses ONNX Runtime to detect speech segments.
class VadIterator { class VadIterator
{
private: private:
// ONNX Runtime resources // OnnxRuntime resources
Ort::Env env; Ort::Env env;
Ort::SessionOptions session_options; Ort::SessionOptions session_options;
std::shared_ptr<Ort::Session> session = nullptr; std::shared_ptr<Ort::Session> session = nullptr;
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
// ----- Context-related additions ----- private:
const int context_samples = 64; // For 16kHz, 64 samples are added as context. void init_engine_threads(int inter_threads, int intra_threads)
std::vector<float> _context; // Holds the last 64 samples from the previous chunk (initialized to zero). {
// The method should be called in each thread/proc in multi-thread/proc work
// Original window size (e.g., 32ms corresponds to 512 samples)
int window_size_samples;
// Effective window size = window_size_samples + context_samples
int effective_window_size;
// Additional declaration: samples per millisecond
int sr_per_ms;
// ONNX Runtime input/output buffers
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_node_names = { "input", "state", "sr" };
std::vector<float> input;
unsigned int size_state = 2 * 1 * 128;
std::vector<float> _state;
std::vector<int64_t> sr;
int64_t input_node_dims[2] = {};
const int64_t state_node_dims[3] = { 2, 1, 128 };
const int64_t sr_node_dims[1] = { 1 };
std::vector<Ort::Value> ort_outputs;
std::vector<const char*> output_node_names = { "output", "stateN" };
// Model configuration parameters
int sample_rate;
float threshold;
int min_silence_samples;
int min_silence_samples_at_max_speech;
int min_speech_samples;
float max_speech_samples;
int speech_pad_samples;
int audio_length_samples;
// State management
bool triggered = false;
unsigned int temp_end = 0;
unsigned int current_sample = 0;
int prev_end;
int next_start = 0;
std::vector<timestamp_t> speeches;
timestamp_t current_speech;
// Loads the ONNX model.
void init_onnx_model(const std::wstring& model_path) {
init_engine_threads(1, 1);
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
}
// Initializes threading settings.
void init_engine_threads(int inter_threads, int intra_threads) {
session_options.SetIntraOpNumThreads(intra_threads); session_options.SetIntraOpNumThreads(intra_threads);
session_options.SetInterOpNumThreads(inter_threads); session_options.SetInterOpNumThreads(inter_threads);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
} };
// Resets internal state (_state, _context, etc.) void init_onnx_model(const std::wstring& model_path)
void reset_states() { {
std::memset(_state.data(), 0, _state.size() * sizeof(float)); // Init threads = 1 for
init_engine_threads(1, 1);
// Load model
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
};
void reset_states()
{
// Call reset before each audio start
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
triggered = false; triggered = false;
temp_end = 0; temp_end = 0;
current_sample = 0; current_sample = 0;
prev_end = next_start = 0; prev_end = next_start = 0;
speeches.clear(); speeches.clear();
current_speech = timestamp_t(); current_speech = timestamp_t();
std::fill(_context.begin(), _context.end(), 0.0f); };
}
// Inference: runs inference on one chunk of input data. void predict(const std::vector<float> &data)
// data_chunk is expected to have window_size_samples samples. {
void predict(const std::vector<float>& data_chunk) { // Infer
// Build new input: first context_samples from _context, followed by the current chunk (window_size_samples). // Create ort tensors
std::vector<float> new_data(effective_window_size, 0.0f); input.assign(data.begin(), data.end());
std::copy(_context.begin(), _context.end(), new_data.begin());
std::copy(data_chunk.begin(), data_chunk.end(), new_data.begin() + context_samples);
input = new_data;
// Create input tensor (input_node_dims[1] is already set to effective_window_size).
Ort::Value input_ort = Ort::Value::CreateTensor<float>( Ort::Value input_ort = Ort::Value::CreateTensor<float>(
memory_info, input.data(), input.size(), input_node_dims, 2); memory_info, input.data(), input.size(), input_node_dims, 2);
Ort::Value state_ort = Ort::Value::CreateTensor<float>( Ort::Value state_ort = Ort::Value::CreateTensor<float>(
memory_info, _state.data(), _state.size(), state_node_dims, 3); memory_info, _state.data(), _state.size(), state_node_dims, 3);
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>( Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
memory_info, sr.data(), sr.size(), sr_node_dims, 1); memory_info, sr.data(), sr.size(), sr_node_dims, 1);
// Clear and add inputs
ort_inputs.clear(); ort_inputs.clear();
ort_inputs.emplace_back(std::move(input_ort)); ort_inputs.emplace_back(std::move(input_ort));
ort_inputs.emplace_back(std::move(state_ort)); ort_inputs.emplace_back(std::move(state_ort));
ort_inputs.emplace_back(std::move(sr_ort)); ort_inputs.emplace_back(std::move(sr_ort));
// Run inference. // Infer
ort_outputs = session->Run( ort_outputs = session->Run(
Ort::RunOptions{ nullptr }, Ort::RunOptions{nullptr},
input_node_names.data(), ort_inputs.data(), ort_inputs.size(), input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
output_node_names.data(), output_node_names.size()); output_node_names.data(), output_node_names.size());
// Output probability & update h,c recursively
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0]; float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
float* stateN = ort_outputs[1].GetTensorMutableData<float>(); float *stateN = ort_outputs[1].GetTensorMutableData<float>();
std::memcpy(_state.data(), stateN, size_state * sizeof(float)); std::memcpy(_state.data(), stateN, size_state * sizeof(float));
current_sample += static_cast<unsigned int>(window_size_samples); // Advance by the original window size.
// If speech is detected (probability >= threshold) // Push forward sample index
if (speech_prob >= threshold) { current_sample += window_size_samples;
// Reset temp_end when > threshold
if ((speech_prob >= threshold))
{
#ifdef __DEBUG_SPEECH_PROB___ #ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples; float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0f * speech / sample_rate, speech_prob, current_sample - window_size_samples); printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample- window_size_samples);
#endif #endif //__DEBUG_SPEECH_PROB___
if (temp_end != 0) { if (temp_end != 0)
{
temp_end = 0; temp_end = 0;
if (next_start < prev_end) if (next_start < prev_end)
next_start = current_sample - window_size_samples; next_start = current_sample - window_size_samples;
} }
if (!triggered) { if (triggered == false)
{
triggered = true; triggered = true;
current_speech.start = current_sample - window_size_samples; current_speech.start = current_sample - window_size_samples;
} }
// Update context: copy the last context_samples from new_data.
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
return; return;
} }
// If the speech segment becomes too long. if (
if (triggered && ((current_sample - current_speech.start) > max_speech_samples)) { (triggered == true)
&& ((current_sample - current_speech.start) > max_speech_samples)
) {
if (prev_end > 0) { if (prev_end > 0) {
current_speech.end = prev_end; current_speech.end = prev_end;
speeches.push_back(current_speech); speeches.push_back(current_speech);
current_speech = timestamp_t(); current_speech = timestamp_t();
// previously reached silence(< neg_thres) and is still not speech(< thres)
if (next_start < prev_end) if (next_start < prev_end)
triggered = false; triggered = false;
else else{
current_speech.start = next_start; current_speech.start = next_start;
}
prev_end = 0; prev_end = 0;
next_start = 0; next_start = 0;
temp_end = 0; temp_end = 0;
} }
else { else{
current_speech.end = current_sample; current_speech.end = current_sample;
speeches.push_back(current_speech); speeches.push_back(current_speech);
current_speech = timestamp_t(); current_speech = timestamp_t();
@@ -230,29 +214,53 @@ private:
temp_end = 0; temp_end = 0;
triggered = false; triggered = false;
} }
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
return; return;
}
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) {
// When the speech probability temporarily drops but is still in speech, update context without changing state.
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
return;
} }
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold))
if (speech_prob < (threshold - 0.15)) { {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples - speech_pad_samples;
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0f * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif
if (triggered) { if (triggered) {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
}
else {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
}
return;
}
// 4) End
if ((speech_prob < (threshold - 0.15)))
{
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point.
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
if (triggered == true)
{
if (temp_end == 0) if (temp_end == 0)
{
temp_end = current_sample; temp_end = current_sample;
}
if (current_sample - temp_end > min_silence_samples_at_max_speech) if (current_sample - temp_end > min_silence_samples_at_max_speech)
prev_end = temp_end; prev_end = temp_end;
if ((current_sample - temp_end) >= min_silence_samples) { // a. silence < min_slience_samples, continue speaking
if ((current_sample - temp_end) < min_silence_samples)
{
}
// b. silence >= min_slience_samples, end speaking
else
{
current_speech.end = temp_end; current_speech.end = temp_end;
if (current_speech.end - current_speech.start > min_speech_samples) { if (current_speech.end - current_speech.start > min_speech_samples)
{
speeches.push_back(current_speech); speeches.push_back(current_speech);
current_speech = timestamp_t(); current_speech = timestamp_t();
prev_end = 0; prev_end = 0;
@@ -262,23 +270,27 @@ private:
} }
} }
} }
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin()); else {
// may first windows see end state.
}
return; return;
} }
} };
public: public:
// Process the entire audio input. void process(const std::vector<float>& input_wav)
void process(const std::vector<float>& input_wav) { {
reset_states(); reset_states();
audio_length_samples = static_cast<int>(input_wav.size());
// Process audio in chunks of window_size_samples (e.g., 512 samples) audio_length_samples = input_wav.size();
for (size_t j = 0; j < static_cast<size_t>(audio_length_samples); j += static_cast<size_t>(window_size_samples)) {
if (j + static_cast<size_t>(window_size_samples) > static_cast<size_t>(audio_length_samples)) for (int j = 0; j < audio_length_samples; j += window_size_samples)
{
if (j + window_size_samples > audio_length_samples)
break; break;
std::vector<float> chunk(&input_wav[j], &input_wav[j] + window_size_samples); std::vector<float> r{ &input_wav[0] + j, &input_wav[0] + j + window_size_samples };
predict(chunk); predict(r);
} }
if (current_speech.start >= 0) { if (current_speech.start >= 0) {
current_speech.end = audio_length_samples; current_speech.end = audio_length_samples;
speeches.push_back(current_speech); speeches.push_back(current_speech);
@@ -288,80 +300,179 @@ public:
temp_end = 0; temp_end = 0;
triggered = false; triggered = false;
} }
};
void process(const std::vector<float>& input_wav, std::vector<float>& output_wav)
{
process(input_wav);
collect_chunks(input_wav, output_wav);
} }
// Returns the detected speech timestamps. void collect_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
const std::vector<timestamp_t> get_speech_timestamps() const { {
output_wav.clear();
for (int i = 0; i < speeches.size(); i++) {
#ifdef __DEBUG_SPEECH_PROB___
std::cout << speeches[i].c_str() << std::endl;
#endif //#ifdef __DEBUG_SPEECH_PROB___
std::vector<float> slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]);
output_wav.insert(output_wav.end(),slice.begin(),slice.end());
}
};
const std::vector<timestamp_t> get_speech_timestamps() const
{
return speeches; return speeches;
} }
// Public method to reset the internal state. void drop_chunks(const std::vector<float>& input_wav, std::vector<float>& output_wav)
void reset() { {
reset_states(); output_wav.clear();
} int current_start = 0;
for (int i = 0; i < speeches.size(); i++) {
std::vector<float> slice(&input_wav[current_start],&input_wav[speeches[i].start]);
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
current_start = speeches[i].end;
}
std::vector<float> slice(&input_wav[current_start], &input_wav[input_wav.size()]);
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
};
private:
// model config
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
int sample_rate; //Assign when init support 16000 or 8000
int sr_per_ms; // Assign when init, support 8 or 16
float threshold;
int min_silence_samples; // sr_per_ms * #ms
int min_silence_samples_at_max_speech; // sr_per_ms * #98
int min_speech_samples; // sr_per_ms * #ms
float max_speech_samples;
int speech_pad_samples; // usually a
int audio_length_samples;
// model states
bool triggered = false;
unsigned int temp_end = 0;
unsigned int current_sample = 0;
// MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
int prev_end;
int next_start = 0;
//Output timestamp
std::vector<timestamp_t> speeches;
timestamp_t current_speech;
// Onnx model
// Inputs
std::vector<Ort::Value> ort_inputs;
std::vector<const char *> input_node_names = {"input", "state", "sr"};
std::vector<float> input;
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
std::vector<float> _state;
std::vector<int64_t> sr;
int64_t input_node_dims[2] = {};
const int64_t state_node_dims[3] = {2, 1, 128};
const int64_t sr_node_dims[1] = {1};
// Outputs
std::vector<Ort::Value> ort_outputs;
std::vector<const char *> output_node_names = {"output", "stateN"};
public: public:
// Constructor: sets model path, sample rate, window size (ms), and other parameters. // Construction
// The parameters are set to match the Python version.
VadIterator(const std::wstring ModelPath, VadIterator(const std::wstring ModelPath,
int Sample_rate = 16000, int windows_frame_size = 32, int Sample_rate = 16000, int windows_frame_size = 32,
float Threshold = 0.5, int min_silence_duration_ms = 100, float Threshold = 0.5, int min_silence_duration_ms = 0,
int speech_pad_ms = 30, int min_speech_duration_ms = 250, int speech_pad_ms = 32, int min_speech_duration_ms = 32,
float max_speech_duration_s = std::numeric_limits<float>::infinity()) float max_speech_duration_s = std::numeric_limits<float>::infinity())
: sample_rate(Sample_rate), threshold(Threshold), speech_pad_samples(speech_pad_ms), prev_end(0)
{ {
sr_per_ms = sample_rate / 1000; // e.g., 16000 / 1000 = 16 init_onnx_model(ModelPath);
window_size_samples = windows_frame_size * sr_per_ms; // e.g., 32ms * 16 = 512 samples threshold = Threshold;
effective_window_size = window_size_samples + context_samples; // e.g., 512 + 64 = 576 samples sample_rate = Sample_rate;
sr_per_ms = sample_rate / 1000;
window_size_samples = windows_frame_size * sr_per_ms;
min_speech_samples = sr_per_ms * min_speech_duration_ms;
speech_pad_samples = sr_per_ms * speech_pad_ms;
max_speech_samples = (
sample_rate * max_speech_duration_s
- window_size_samples
- 2 * speech_pad_samples
);
min_silence_samples = sr_per_ms * min_silence_duration_ms;
min_silence_samples_at_max_speech = sr_per_ms * 98;
input.resize(window_size_samples);
input_node_dims[0] = 1; input_node_dims[0] = 1;
input_node_dims[1] = effective_window_size; input_node_dims[1] = window_size_samples;
_state.resize(size_state); _state.resize(size_state);
sr.resize(1); sr.resize(1);
sr[0] = sample_rate; sr[0] = sample_rate;
_context.assign(context_samples, 0.0f); };
min_speech_samples = sr_per_ms * min_speech_duration_ms;
max_speech_samples = (sample_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples);
min_silence_samples = sr_per_ms * min_silence_duration_ms;
min_silence_samples_at_max_speech = sr_per_ms * 98;
init_onnx_model(ModelPath);
}
}; };
int main() { int main()
// Read the WAV file (expects 16000 Hz, mono, PCM). {
wav::WavReader wav_reader("audio/recorder.wav"); // File located in the "audio" folder. std::vector<timestamp_t> stamps;
int numSamples = wav_reader.num_samples();
std::vector<float> input_wav(static_cast<size_t>(numSamples)); // Read wav
for (size_t i = 0; i < static_cast<size_t>(numSamples); i++) { wav::WavReader wav_reader("recorder.wav"); //16000,1,32float
std::vector<float> input_wav(wav_reader.num_samples());
std::vector<float> output_wav;
for (int i = 0; i < wav_reader.num_samples(); i++)
{
input_wav[i] = static_cast<float>(*(wav_reader.data() + i)); input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
} }
// Set the ONNX model path (file located in the "model" folder).
std::wstring model_path = L"model/silero_vad.onnx";
// Initialize the VadIterator.
VadIterator vad(model_path);
// Process the audio. // ===== Test configs =====
std::wstring path = L"silero_vad.onnx";
VadIterator vad(path);
// ==============================================
// ==== = Example 1 of full function =====
// ==============================================
vad.process(input_wav); vad.process(input_wav);
// Retrieve the speech timestamps (in samples). // 1.a get_speech_timestamps
std::vector<timestamp_t> stamps = vad.get_speech_timestamps(); stamps = vad.get_speech_timestamps();
for (int i = 0; i < stamps.size(); i++) {
// Convert timestamps to seconds and round to one decimal place (for 16000 Hz). std::cout << stamps[i].c_str() << std::endl;
const float sample_rate_float = 16000.0f;
for (size_t i = 0; i < stamps.size(); i++) {
float start_sec = std::rint((stamps[i].start / sample_rate_float) * 10.0f) / 10.0f;
float end_sec = std::rint((stamps[i].end / sample_rate_float) * 10.0f) / 10.0f;
std::cout << "Speech detected from "
<< std::fixed << std::setprecision(1) << start_sec
<< " s to "
<< std::fixed << std::setprecision(1) << end_sec
<< " s" << std::endl;
} }
// Optionally, reset the internal state. // 1.b collect_chunks output wav
vad.reset(); vad.collect_chunks(input_wav, output_wav);
return 0; // 1.c drop_chunks output wav
vad.drop_chunks(input_wav, output_wav);
// ==============================================
// ===== Example 2 of simple full function =====
// ==============================================
vad.process(input_wav, output_wav);
stamps = vad.get_speech_timestamps();
for (int i = 0; i < stamps.size(); i++) {
std::cout << stamps[i].c_str() << std::endl;
}
// ==============================================
// ===== Example 3 of full function =====
// ==============================================
for(int i = 0; i<2; i++)
vad.process(input_wav, output_wav);
} }

View File

@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef FRONTEND_WAV_H_ #ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_ #define FRONTEND_WAV_H_
#include <assert.h> #include <assert.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
@@ -24,8 +24,6 @@
#include <string> #include <string>
#include <iostream>
// #include "utils/log.h" // #include "utils/log.h"
namespace wav { namespace wav {
@@ -232,6 +230,6 @@ class WavWriter {
int bits_per_sample_; int bits_per_sample_;
}; };
} // namespace wav } // namespace wenet
#endif // FRONTEND_WAV_H_ #endif // FRONTEND_WAV_H_

View File

@@ -1,45 +0,0 @@
# Silero-VAD V5 in C++ (based on LibTorch)
This is the source code for Silero-VAD V5 in C++, utilizing LibTorch. The primary implementation is CPU-based, and you should compare its results with the Python version. Only results at 16kHz have been tested.
Additionally, batch and CUDA inference options are available if you want to explore further. Note that when using batch inference, the speech probabilities may slightly differ from the standard version, likely due to differences in caching. Unlike individual input processing, batch inference may not use the cache from previous chunks. Despite this, batch inference offers significantly faster processing. For optimal performance, consider adjusting the threshold when using batch inference.
## Requirements
- GCC 11.4.0 (GCC >= 5.1)
- LibTorch 1.13.0 (other versions are also acceptable)
## Download LibTorch
```bash
-CPU Version
wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip
unzip libtorch-shared-with-deps-1.13.0+cpu.zip'
-CUDA Version
wget https://download.pytorch.org/libtorch/cu116/libtorch-shared-with-deps-1.13.0%2Bcu116.zip
unzip libtorch-shared-with-deps-1.13.0+cu116.zip
```
## Compilation
```bash
-CPU Version
g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0
-CUDA Version
g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cuda -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_GPU
```
## Optional Compilation Flags
-DUSE_BATCH: Enable batch inference
-DUSE_GPU: Use GPU for inference
## Run the Program
To run the program, use the following command:
`./silero aepyx.wav 16000 0.5`
The sample file aepyx.wav is part of the Voxconverse dataset.
File details: aepyx.wav is a 16kHz, 16-bit audio file.

Binary file not shown.

View File

@@ -1,54 +0,0 @@
#include <iostream>
#include "silero_torch.h"
#include "wav.h"
int main(int argc, char* argv[]) {
if(argc != 4){
std::cerr<<"Usage : "<<argv[0]<<" <wav.path> <SampleRate> <Threshold>"<<std::endl;
std::cerr<<"Usage : "<<argv[0]<<" sample.wav 16000 0.5"<<std::endl;
return 1;
}
std::string wav_path = argv[1];
float sample_rate = std::stof(argv[2]);
float threshold = std::stof(argv[3]);
//Load Model
std::string model_path = "../../src/silero_vad/data/silero_vad.jit";
silero::VadIterator vad(model_path);
vad.threshold=threshold; //(Default:0.5)
vad.sample_rate=sample_rate; //16000Hz,8000Hz. (Default:16000)
vad.print_as_samples=true; //if true, it prints time-stamp with samples. otherwise, in seconds
//(Default:false)
vad.SetVariables();
// Read wav
wav::WavReader wav_reader(wav_path);
std::vector<float> input_wav(wav_reader.num_samples());
for (int i = 0; i < wav_reader.num_samples(); i++)
{
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
}
vad.SpeechProbs(input_wav);
std::vector<silero::SpeechSegment> speeches = vad.GetSpeechTimestamps();
for(const auto& speech : speeches){
if(vad.print_as_samples){
std::cout<<"{'start': "<<static_cast<int>(speech.start)<<", 'end': "<<static_cast<int>(speech.end)<<"}"<<std::endl;
}
else{
std::cout<<"{'start': "<<speech.start<<", 'end': "<<speech.end<<"}"<<std::endl;
}
}
return 0;
}

Binary file not shown.

View File

@@ -1,285 +0,0 @@
//Author : Nathan Lee
//Created On : 2024-11-18
//Description : silero 5.1 system for torch-script(c++).
//Version : 1.0
#include "silero_torch.h"
namespace silero {
VadIterator::VadIterator(const std::string &model_path, float threshold, int sample_rate, int window_size_ms, int speech_pad_ms, int min_silence_duration_ms, int min_speech_duration_ms, int max_duration_merge_ms, bool print_as_samples)
:sample_rate(sample_rate), threshold(threshold), window_size_ms(window_size_ms), speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms), min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms), print_as_samples(print_as_samples)
{
init_torch_model(model_path);
//init_engine(window_size_ms);
}
VadIterator::~VadIterator(){
}
void VadIterator::SpeechProbs(std::vector<float>& input_wav){
// Set the sample rate (must match the model's expected sample rate)
// Process the waveform in chunks of 512 samples
int num_samples = input_wav.size();
int num_chunks = num_samples / window_size_samples;
int remainder_samples = num_samples % window_size_samples;
total_sample_size += num_samples;
torch::Tensor output;
std::vector<torch::Tensor> chunks;
for (int i = 0; i < num_chunks; i++) {
float* chunk_start = input_wav.data() + i *window_size_samples;
torch::Tensor chunk = torch::from_blob(chunk_start, {1,window_size_samples}, torch::kFloat32);
//std::cout<<"chunk size : "<<chunk.sizes()<<std::endl;
chunks.push_back(chunk);
if(i==num_chunks-1 && remainder_samples>0){//마지막 chunk && 나머지가 존재
int remaining_samples = num_samples - num_chunks * window_size_samples;
//std::cout<<"Remainder size : "<<remaining_samples;
float* chunk_start_remainder = input_wav.data() + num_chunks *window_size_samples;
torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1,remaining_samples},
torch::kFloat32);
// Pad the remainder chunk to match window_size_samples
torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples
- remaining_samples}, torch::kFloat32)}, 1);
//std::cout<<", padded_chunk size : "<<padded_chunk.size(1)<<std::endl;
chunks.push_back(padded_chunk);
}
}
if (!chunks.empty()) {
#ifdef USE_BATCH
torch::Tensor batched_chunks = torch::stack(chunks); // Stack all chunks into a single tensor
//batched_chunks = batched_chunks.squeeze(1);
batched_chunks = torch::cat({batched_chunks.squeeze(1)});
#ifdef USE_GPU
batched_chunks = batched_chunks.to(at::kCUDA); // Move the entire batch to GPU once
#endif
// Prepare input for model
std::vector<torch::jit::IValue> inputs;
inputs.push_back(batched_chunks); // Batch of chunks
inputs.push_back(sample_rate); // Assuming sample_rate is a valid input for the model
// Run inference on the batch
torch::NoGradGuard no_grad;
torch::Tensor output = model.forward(inputs).toTensor();
#ifdef USE_GPU
output = output.to(at::kCPU); // Move the output back to CPU once
#endif
// Collect output probabilities
for (int i = 0; i < chunks.size(); i++) {
float output_f = output[i].item<float>();
outputs_prob.push_back(output_f);
//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
}
#else
std::vector<torch::Tensor> outputs;
torch::Tensor batched_chunks = torch::stack(chunks);
#ifdef USE_GPU
batched_chunks = batched_chunks.to(at::kCUDA);
#endif
for (int i = 0; i < chunks.size(); i++) {
torch::NoGradGuard no_grad;
std::vector<torch::jit::IValue> inputs;
inputs.push_back(batched_chunks[i]);
inputs.push_back(sample_rate);
torch::Tensor output = model.forward(inputs).toTensor();
outputs.push_back(output);
}
torch::Tensor all_outputs = torch::stack(outputs);
#ifdef USE_GPU
all_outputs = all_outputs.to(at::kCPU);
#endif
for (int i = 0; i < chunks.size(); i++) {
float output_f = all_outputs[i].item<float>();
outputs_prob.push_back(output_f);
}
#endif
}
}
std::vector<SpeechSegment> VadIterator::GetSpeechTimestamps() {
std::vector<SpeechSegment> speeches = DoVad();
#ifdef USE_BATCH
//When you use BATCH inference. You would better use 'mergeSpeeches' function to arrage time stamp.
//It could be better get reasonable output because of distorted probs.
duration_merge_samples = sample_rate * max_duration_merge_ms / 1000;
std::vector<SpeechSegment> speeches_merge = mergeSpeeches(speeches, duration_merge_samples);
if(!print_as_samples){
for (auto& speech : speeches_merge) { //samples to second
speech.start /= sample_rate;
speech.end /= sample_rate;
}
}
return speeches_merge;
#else
if(!print_as_samples){
for (auto& speech : speeches) { //samples to second
speech.start /= sample_rate;
speech.end /= sample_rate;
}
}
return speeches;
#endif
}
void VadIterator::SetVariables(){
init_engine(window_size_ms);
}
void VadIterator::init_engine(int window_size_ms) {
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
speech_pad_samples = sample_rate * speech_pad_ms / 1000;
window_size_samples = sample_rate / 1000 * window_size_ms;
min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
}
void VadIterator::init_torch_model(const std::string& model_path) {
at::set_num_threads(1);
model = torch::jit::load(model_path);
#ifdef USE_GPU
if (!torch::cuda::is_available()) {
std::cout<<"CUDA is not available! Please check your GPU settings"<<std::endl;
throw std::runtime_error("CUDA is not available!");
model.to(at::Device(at::kCPU));
} else {
std::cout<<"CUDA available! Running on '0'th GPU"<<std::endl;
model.to(at::Device(at::kCUDA, 0)); //select 0'th machine
}
#endif
model.eval();
torch::NoGradGuard no_grad;
std::cout << "Model loaded successfully"<<std::endl;
}
void VadIterator::reset_states() {
triggered = false;
current_sample = 0;
temp_end = 0;
outputs_prob.clear();
model.run_method("reset_states");
total_sample_size = 0;
}
std::vector<SpeechSegment> VadIterator::DoVad() {
std::vector<SpeechSegment> speeches;
for (size_t i = 0; i < outputs_prob.size(); ++i) {
float speech_prob = outputs_prob[i];
//std::cout << speech_prob << std::endl;
//std::cout << "Chunk " << i << " Prob: " << speech_prob << "\n";
//std::cout << speech_prob << " ";
current_sample += window_size_samples;
if (speech_prob >= threshold && temp_end != 0) {
temp_end = 0;
}
if (speech_prob >= threshold && !triggered) {
triggered = true;
SpeechSegment segment;
segment.start = std::max(static_cast<int>(0), current_sample - speech_pad_samples - window_size_samples);
speeches.push_back(segment);
continue;
}
if (speech_prob < threshold - 0.15f && triggered) {
if (temp_end == 0) {
temp_end = current_sample;
}
if (current_sample - temp_end < min_silence_samples) {
continue;
} else {
SpeechSegment& segment = speeches.back();
segment.end = temp_end + speech_pad_samples - window_size_samples;
temp_end = 0;
triggered = false;
}
}
}
if (triggered) { //만약 낮은 확률을 보이다가 마지막프레임 prbos만 딱 확률이 높게 나오면 위에서 triggerd = true 메핑과 동시에 segment start가 돼서 문제가 될것 같은데? start = end 같은값? 후처리가 있으니 문제가 없으려나?
std::cout<<"when last triggered is keep working until last Probs"<<std::endl;
SpeechSegment& segment = speeches.back();
segment.end = total_sample_size; // 현재 샘플을 마지막 구간의 종료 시간으로 설정
triggered = false; // VAD 상태 초기화
}
speeches.erase(
std::remove_if(
speeches.begin(),
speeches.end(),
[this](const SpeechSegment& speech) {
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
//min_speech_samples is 4000samples(0.25sec)
//여기서 포인트!! 계산 할때는 start,end sample에'speech_pad_samples' 사이즈를 추가한후 길이를 측정함.
}
),
speeches.end()
);
//std::cout<<std::endl;
//std::cout<<"outputs_prob.size : "<<outputs_prob.size()<<std::endl;
reset_states();
return speeches;
}
std::vector<SpeechSegment> VadIterator::mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples) {
std::vector<SpeechSegment> mergedSpeeches;
if (speeches.empty()) {
return mergedSpeeches; // 빈 벡터 반환
}
// 첫 번째 구간으로 초기화
SpeechSegment currentSegment = speeches[0];
for (size_t i = 1; i < speeches.size(); ++i) { //첫번째 start,end 정보 건너뛰기. 그래서 i=1부터
// 두 구간의 차이가 threshold(duration_merge_samples)보다 작은 경우, 합침
if (speeches[i].start - currentSegment.end < duration_merge_samples) {
// 현재 구간의 끝점을 업데이트
currentSegment.end = speeches[i].end;
} else {
// 차이가 threshold(duration_merge_samples) 이상이면 현재 구간을 저장하고 새로운 구간 시작
mergedSpeeches.push_back(currentSegment);
currentSegment = speeches[i];
}
}
// 마지막 구간 추가
mergedSpeeches.push_back(currentSegment);
return mergedSpeeches;
}
}

View File

@@ -1,75 +0,0 @@
//Author : Nathan Lee
//Created On : 2024-11-18
//Description : silero 5.1 system for torch-script(c++).
//Version : 1.0
#ifndef SILERO_TORCH_H
#define SILERO_TORCH_H
#include <string>
#include <memory>
#include <stdexcept>
#include <iostream>
#include <memory>
#include <vector>
#include <fstream>
#include <chrono>
#include <torch/torch.h>
#include <torch/script.h>
namespace silero{
struct SpeechSegment{
int start;
int end;
};
class VadIterator{
public:
VadIterator(const std::string &model_path, float threshold = 0.5, int sample_rate = 16000,
int window_size_ms = 32, int speech_pad_ms = 30, int min_silence_duration_ms = 100,
int min_speech_duration_ms = 250, int max_duration_merge_ms = 300, bool print_as_samples = false);
~VadIterator();
void SpeechProbs(std::vector<float>& input_wav);
std::vector<silero::SpeechSegment> GetSpeechTimestamps();
void SetVariables();
float threshold;
int sample_rate;
int window_size_ms;
int min_speech_duration_ms;
int max_duration_merge_ms;
bool print_as_samples;
private:
torch::jit::script::Module model;
std::vector<float> outputs_prob;
int min_silence_samples;
int min_speech_samples;
int speech_pad_samples;
int window_size_samples;
int duration_merge_samples;
int current_sample = 0;
int total_sample_size=0;
int min_silence_duration_ms;
int speech_pad_ms;
bool triggered = false;
int temp_end = 0;
void init_engine(int window_size_ms);
void init_torch_model(const std::string& model_path);
void reset_states();
std::vector<SpeechSegment> DoVad();
std::vector<SpeechSegment> mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples);
};
}
#endif // SILERO_TORCH_H

View File

@@ -1,235 +0,0 @@
// Copyright (c) 2016 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
// #include "utils/log.h"
namespace wav {
struct WavHeader {
char riff[4]; // "riff"
unsigned int size;
char wav[4]; // "WAVE"
char fmt[4]; // "fmt "
unsigned int fmt_size;
uint16_t format;
uint16_t channels;
unsigned int sample_rate;
unsigned int bytes_per_second;
uint16_t block_size;
uint16_t bit;
char data[4]; // "data"
unsigned int data_size;
};
class WavReader {
public:
WavReader() : data_(nullptr) {}
explicit WavReader(const std::string& filename) { Open(filename); }
bool Open(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
if (NULL == fp) {
std::cout << "Error in read " << filename;
return false;
}
WavHeader header;
fread(&header, 1, sizeof(header), fp);
if (header.fmt_size < 16) {
printf("WaveData: expect PCM format data "
"to have fmt chunk of at least size 16.\n");
return false;
} else if (header.fmt_size > 16) {
int offset = 44 - 8 + header.fmt_size - 16;
fseek(fp, offset, SEEK_SET);
fread(header.data, 8, sizeof(char), fp);
}
// check "riff" "WAVE" "fmt " "data"
// Skip any sub-chunks between "fmt" and "data". Usually there will
// be a single "fact" sub chunk, but on Windows there can also be a
// "list" sub chunk.
while (0 != strncmp(header.data, "data", 4)) {
// We will just ignore the data in these chunks.
fseek(fp, header.data_size, SEEK_CUR);
// read next sub chunk
fread(header.data, 8, sizeof(char), fp);
}
if (header.data_size == 0) {
int offset = ftell(fp);
fseek(fp, 0, SEEK_END);
header.data_size = ftell(fp) - offset;
fseek(fp, offset, SEEK_SET);
}
num_channel_ = header.channels;
sample_rate_ = header.sample_rate;
bits_per_sample_ = header.bit;
int num_data = header.data_size / (bits_per_sample_ / 8);
data_ = new float[num_data]; // Create 1-dim array
num_samples_ = num_data / num_channel_;
std::cout << "num_channel_ :" << num_channel_ << std::endl;
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
std::cout << "num_samples :" << num_data << std::endl;
std::cout << "num_data_size :" << header.data_size << std::endl;
switch (bits_per_sample_) {
case 8: {
char sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(char), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
break;
}
case 16: {
int16_t sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(int16_t), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
break;
}
case 32:
{
if (header.format == 1) //S32
{
int sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(int), fp);
data_[i] = static_cast<float>(sample) / 32768;
}
}
else if (header.format == 3) // IEEE-float
{
float sample;
for (int i = 0; i < num_data; ++i) {
fread(&sample, 1, sizeof(float), fp);
data_[i] = static_cast<float>(sample);
}
}
else {
printf("unsupported quantization bits\n");
}
break;
}
default:
printf("unsupported quantization bits\n");
break;
}
fclose(fp);
return true;
}
int num_channel() const { return num_channel_; }
int sample_rate() const { return sample_rate_; }
int bits_per_sample() const { return bits_per_sample_; }
int num_samples() const { return num_samples_; }
~WavReader() {
delete[] data_;
}
const float* data() const { return data_; }
private:
int num_channel_;
int sample_rate_;
int bits_per_sample_;
int num_samples_; // sample points per channel
float* data_;
};
class WavWriter {
public:
WavWriter(const float* data, int num_samples, int num_channel,
int sample_rate, int bits_per_sample)
: data_(data),
num_samples_(num_samples),
num_channel_(num_channel),
sample_rate_(sample_rate),
bits_per_sample_(bits_per_sample) {}
void Write(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "w");
// init char 'riff' 'WAVE' 'fmt ' 'data'
WavHeader header;
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
memcpy(&header, wav_header, sizeof(header));
header.channels = num_channel_;
header.bit = bits_per_sample_;
header.sample_rate = sample_rate_;
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
header.size = sizeof(header) - 8 + header.data_size;
header.bytes_per_second =
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
header.block_size = num_channel_ * (bits_per_sample_ / 8);
fwrite(&header, 1, sizeof(header), fp);
for (int i = 0; i < num_samples_; ++i) {
for (int j = 0; j < num_channel_; ++j) {
switch (bits_per_sample_) {
case 8: {
char sample = static_cast<char>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 16: {
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 32: {
int sample = static_cast<int>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
}
}
}
fclose(fp);
}
private:
const float* data_;
int num_samples_; // total float points in data_
int num_channel_;
int sample_rate_;
int bits_per_sample_;
};
} // namespace wenet
#endif // FRONTEND_WAV_H_

View File

@@ -1,13 +0,0 @@
# Haskell example
To run the example, make sure you put an ``example.wav`` in this directory, and then run the following:
```bash
stack run
```
The ``example.wav`` file must have the following requirements:
- Must be 16khz sample rate.
- Must be mono channel.
- Must be 16-bit audio.
This uses the [silero-vad](https://hackage.haskell.org/package/silero-vad) package, a haskell implementation based on the C# example.

View File

@@ -1,22 +0,0 @@
module Main (main) where
import qualified Data.Vector.Storable as Vector
import Data.WAVE
import Data.Function
import Silero
main :: IO ()
main =
withModel $ \model -> do
wav <- getWAVEFile "example.wav"
let samples =
concat (waveSamples wav)
& Vector.fromList
& Vector.map (realToFrac . sampleToDouble)
let vad =
(defaultVad model)
{ startThreshold = 0.5
, endThreshold = 0.35
}
segments <- detectSegments vad samples
print segments

View File

@@ -1,23 +0,0 @@
cabal-version: 1.12
-- This file has been generated from package.yaml by hpack version 0.37.0.
--
-- see: https://github.com/sol/hpack
name: example
version: 0.1.0.0
build-type: Simple
executable example-exe
main-is: Main.hs
other-modules:
Paths_example
hs-source-dirs:
app
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
build-depends:
WAVE
, base >=4.7 && <5
, silero-vad
, vector
default-language: Haskell2010

View File

@@ -1,28 +0,0 @@
name: example
version: 0.1.0.0
dependencies:
- base >= 4.7 && < 5
- silero-vad
- WAVE
- vector
ghc-options:
- -Wall
- -Wcompat
- -Widentities
- -Wincomplete-record-updates
- -Wincomplete-uni-patterns
- -Wmissing-export-lists
- -Wmissing-home-modules
- -Wpartial-fields
- -Wredundant-constraints
executables:
example-exe:
main: Main.hs
source-dirs: app
ghc-options:
- -threaded
- -rtsopts
- -with-rtsopts=-N

View File

@@ -1,11 +0,0 @@
snapshot:
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml
packages:
- .
extra-deps:
- silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
- WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
- derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
- vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481

View File

@@ -1,41 +0,0 @@
# This file was autogenerated by Stack.
# You should not edit this file by hand.
# For more information, please see the documentation at:
# https://docs.haskellstack.org/en/stable/lock_files
packages:
- completed:
hackage: silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
pantry-tree:
sha256: a62e813f978d32c87769796fded981d25fcf2875bb2afdf60ed6279f931ccd7f
size: 1391
original:
hackage: silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
- completed:
hackage: WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
pantry-tree:
sha256: ee5ccd70fa7fe6ffc360ebd762b2e3f44ae10406aa27f3842d55b8cbd1a19498
size: 405
original:
hackage: WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
- completed:
hackage: derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
pantry-tree:
sha256: 48e35a72d1bb593173890616c8d7efd636a650a306a50bb3e1513e679939d27e
size: 902
original:
hackage: derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
- completed:
hackage: vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481
pantry-tree:
sha256: 2176fd677a02a4c47337f7dca5aeca2745dbb821a6ea5c7099b3a991ecd7f4f0
size: 4478
original:
hackage: vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481
snapshots:
- completed:
sha256: 5a59b2a405b3aba3c00188453be172b85893cab8ebc352b1ef58b0eae5d248a2
size: 650475
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml
original:
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml

View File

@@ -1,6 +1,7 @@
{ {
"cells": [ "cells": [
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
@@ -17,19 +18,17 @@
"SAMPLING_RATE = 16000\n", "SAMPLING_RATE = 16000\n",
"import torch\n", "import torch\n",
"from pprint import pprint\n", "from pprint import pprint\n",
"import time\n",
"import shutil\n",
"\n", "\n",
"torch.set_num_threads(1)\n", "torch.set_num_threads(1)\n",
"NUM_PROCESS=4 # set to the number of CPU cores in the machine\n", "NUM_PROCESS=4 # set to the number of CPU cores in the machine\n",
"NUM_COPIES=8\n", "NUM_COPIES=8\n",
"# download wav files, make multiple copies\n", "# download wav files, make multiple copies\n",
"torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example0.wav\")\n", "for idx in range(NUM_COPIES):\n",
"for idx in range(NUM_COPIES-1):\n", " torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example{idx}.wav\")\n"
" shutil.copy(f\"en_example0.wav\", f\"en_example{idx+1}.wav\")"
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
@@ -55,6 +54,7 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
@@ -99,6 +99,7 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
@@ -126,7 +127,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "diarization",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@@ -140,20 +141,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.14" "version": "3.9.15"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -8,8 +8,6 @@ Currently, the notebook consits of two examples:
- One that records audio of a predefined length from the microphone, process it with Silero-VAD, and plots it afterwards. - One that records audio of a predefined length from the microphone, process it with Silero-VAD, and plots it afterwards.
- The other one plots the speech probabilities in real-time (using jupyterplot) and records the audio until you press enter. - The other one plots the speech probabilities in real-time (using jupyterplot) and records the audio until you press enter.
This example does not work in google colab! For local usage only.
## Example Video for the Real-Time Visualization ## Example Video for the Real-Time Visualization

View File

@@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "76aa55ba", "id": "62a0cccb",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Pyaudio Microphone Streaming Examples\n", "# Pyaudio Microphone Streaming Examples\n",
@@ -12,14 +12,12 @@
"I created it as an example on how binary data from a stream could be feed into Silero VAD.\n", "I created it as an example on how binary data from a stream could be feed into Silero VAD.\n",
"\n", "\n",
"\n", "\n",
"Has been tested on Ubuntu 21.04 (x86). After you installed the dependencies below, no additional setup is required.\n", "Has been tested on Ubuntu 21.04 (x86). After you installed the dependencies below, no additional setup is required."
"\n",
"This notebook does not work in google colab! For local usage only."
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "4a4e15c2", "id": "64cbe1eb",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Dependencies\n", "## Dependencies\n",
@@ -28,27 +26,22 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"id": "24205cce", "id": "57bc2aac",
"metadata": { "metadata": {},
"ExecuteTime": {
"end_time": "2024-10-09T08:47:34.056898Z",
"start_time": "2024-10-09T08:47:34.053418Z"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"#!pip install numpy>=1.24.0\n", "#!pip install numpy==2.0.2\n",
"#!pip install torch>=1.12.0\n", "#!pip install torch==2.4.1\n",
"#!pip install matplotlib>=3.6.0\n", "#!pip install matplotlib==3.9.2\n",
"#!pip install torchaudio>=0.12.0\n", "#!pip install torchaudio==2.4.1\n",
"#!pip install soundfile==0.12.1\n", "#!pip install soundfile==0.12.1\n",
"#!apt install python3-pyaudio (linux) or pip install pyaudio (windows)" "#!pip install pyaudio==0.2.11"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "cd22818f", "id": "110de761",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Imports" "## Imports"
@@ -56,27 +49,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"id": "994d7f3a", "id": "5a647d8d",
"metadata": { "metadata": {},
"ExecuteTime": { "outputs": [],
"end_time": "2024-10-09T08:47:39.005032Z",
"start_time": "2024-10-09T08:47:36.489952Z"
}
},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'pyaudio'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpylab\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpyaudio\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pyaudio'"
]
}
],
"source": [ "source": [
"import io\n", "import io\n",
"import numpy as np\n", "import numpy as np\n",
@@ -91,7 +67,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "ac5c52f7", "id": "725d7066",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -103,7 +79,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "ad5919dc", "id": "1c0b2ea7",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -116,7 +92,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "784d1ab6", "id": "f9112603",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Helper Methods" "### Helper Methods"
@@ -125,7 +101,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "af4bca64", "id": "5abc6330",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -148,7 +124,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "ca13e514", "id": "5124095e",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Pyaudio Set-up" "## Pyaudio Set-up"
@@ -157,7 +133,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "75f99022", "id": "a845356e",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -171,7 +147,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "4da7d2ef", "id": "0b910c99",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Simple Example\n", "## Simple Example\n",
@@ -181,7 +157,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "6fe77661", "id": "9d3d2c10",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -191,7 +167,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "23f4da3e", "id": "3cb44a4a",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -231,7 +207,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "fd243e8f", "id": "a3dda982",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Real Time Visualization\n", "## Real Time Visualization\n",
@@ -244,7 +220,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "d36980c2", "id": "05ef4100",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -254,7 +230,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "5607b616", "id": "d1d4cdd6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -311,7 +287,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "dc4f0108", "id": "1e398009",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -335,7 +311,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.14" "version": "3.9.10"
}, },
"toc": { "toc": {
"base_numbering": 1, "base_numbering": 1,

View File

@@ -23,14 +23,11 @@ def versiontuple(v):
return tuple(version_list) return tuple(version_list)
def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16): def silero_vad(onnx=False, force_onnx_cpu=False):
"""Silero Voice Activity Detector """Silero Voice Activity Detector
Returns a model with a set of utils Returns a model with a set of utils
Please see https://github.com/snakers4/silero-vad for usage examples Please see https://github.com/snakers4/silero-vad for usage examples
""" """
available_ops = [15, 16]
if onnx and opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}')
if not onnx: if not onnx:
installed_version = torch.__version__ installed_version = torch.__version__
@@ -40,11 +37,7 @@ def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16):
model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data') model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data')
if onnx: if onnx:
if opset_version == 16: model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
model = OnnxWrapper(os.path.join(model_dir, model_name), force_onnx_cpu)
else: else:
model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit')) model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
utils = (get_speech_timestamps, utils = (get_speech_timestamps,

View File

@@ -3,7 +3,7 @@ requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[project] [project]
name = "silero-vad" name = "silero-vad"
version = "6.0.0" version = "5.1"
authors = [ authors = [
{name="Silero Team", email="hello@silero.ai"}, {name="Silero Team", email="hello@silero.ai"},
] ]
@@ -21,9 +21,6 @@ classifiers = [
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering",
] ]

View File

@@ -9,5 +9,4 @@ from silero_vad.utils_vad import (get_speech_timestamps,
save_audio, save_audio,
read_audio, read_audio,
VADIterator, VADIterator,
collect_chunks, collect_chunks)
drop_chunks)

Binary file not shown.

Binary file not shown.

View File

@@ -2,19 +2,8 @@ from .utils_vad import init_jit_model, OnnxWrapper
import torch import torch
torch.set_num_threads(1) torch.set_num_threads(1)
def load_silero_vad(onnx=False):
def load_silero_vad(onnx=False, opset_version=16): model_name = 'silero_vad.onnx' if onnx else 'silero_vad.jit'
available_ops = [15, 16]
if onnx and opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}')
if onnx:
if opset_version == 16:
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
else:
model_name = 'silero_vad.jit'
package_path = "silero_vad.data" package_path = "silero_vad.data"
try: try:
@@ -29,7 +18,7 @@ def load_silero_vad(onnx=False, opset_version=16):
model_file_path = str(impresources.files(package_path).joinpath(model_name)) model_file_path = str(impresources.files(package_path).joinpath(model_name))
if onnx: if onnx:
model = OnnxWrapper(str(model_file_path), force_onnx_cpu=True) model = OnnxWrapper(model_file_path, force_onnx_cpu=True)
else: else:
model = init_jit_model(model_file_path) model = init_jit_model(model_file_path)

View File

@@ -23,11 +23,7 @@ class OnnxWrapper():
self.session = onnxruntime.InferenceSession(path, sess_options=opts) self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.reset_states() self.reset_states()
if '16k' in path: self.sample_rates = [8000, 16000]
warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000]
else:
self.sample_rates = [8000, 16000]
def _validate_input(self, x, sr: int): def _validate_input(self, x, sr: int):
if x.dim() == 1: if x.dim() == 1:
@@ -197,13 +193,10 @@ def get_speech_timestamps(audio: torch.Tensor,
min_silence_duration_ms: int = 100, min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30, speech_pad_ms: int = 30,
return_seconds: bool = False, return_seconds: bool = False,
time_resolution: int = 1,
visualize_probs: bool = False, visualize_probs: bool = False,
progress_tracking_callback: Callable[[float], None] = None, progress_tracking_callback: Callable[[float], None] = None,
neg_threshold: float = None, neg_threshold: float = None,
window_size_samples: int = 512, window_size_samples: int = 512,):
min_silence_at_max_speech: float = 98,
use_max_poss_sil_at_max_speech: bool = True):
""" """
This method is used for splitting long audios into speech chunks using silero VAD This method is used for splitting long audios into speech chunks using silero VAD
@@ -239,9 +232,6 @@ def get_speech_timestamps(audio: torch.Tensor,
return_seconds: bool (default - False) return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples) whether return timestamps in seconds (default - samples)
time_resolution: bool (default - 1)
time resolution of speech coordinates when requested as seconds
visualize_probs: bool (default - False) visualize_probs: bool (default - False)
whether draw prob hist or not whether draw prob hist or not
@@ -251,12 +241,6 @@ def get_speech_timestamps(audio: torch.Tensor,
neg_threshold: float (default = threshold - 0.15) neg_threshold: float (default = threshold - 0.15)
Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH. Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH.
min_silence_at_max_speech: float (default - 98ms)
Minimum silence duration in ms which is used to avoid abrupt cuts when max_speech_duration_s is reached
use_max_poss_sil_at_max_speech: bool (default - True)
Whether to use the maximum possible silence at max_speech_duration_s or not. If not, the last silence is used.
window_size_samples: int (default - 512 samples) window_size_samples: int (default - 512 samples)
!!! DEPRECATED, DOES NOTHING !!! !!! DEPRECATED, DOES NOTHING !!!
@@ -265,6 +249,7 @@ def get_speech_timestamps(audio: torch.Tensor,
speeches: list of dicts speeches: list of dicts
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds) list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
""" """
if not torch.is_tensor(audio): if not torch.is_tensor(audio):
try: try:
audio = torch.Tensor(audio) audio = torch.Tensor(audio)
@@ -289,29 +274,25 @@ def get_speech_timestamps(audio: torch.Tensor,
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates") raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
window_size_samples = 512 if sampling_rate == 16000 else 256 window_size_samples = 512 if sampling_rate == 16000 else 256
hop_size_samples = int(window_size_samples)
model.reset_states() model.reset_states()
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000 speech_pad_samples = sampling_rate * speech_pad_ms / 1000
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
min_silence_samples_at_max_speech = sampling_rate * min_silence_at_max_speech / 1000 min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
audio_length_samples = len(audio) audio_length_samples = len(audio)
speech_probs = [] speech_probs = []
for current_start_sample in range(0, audio_length_samples, hop_size_samples): for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[current_start_sample: current_start_sample + window_size_samples] chunk = audio[current_start_sample: current_start_sample + window_size_samples]
if len(chunk) < window_size_samples: if len(chunk) < window_size_samples:
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk)))) chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
try: speech_prob = model(chunk, sampling_rate).item()
speech_prob = model(chunk, sampling_rate).item()
except Exception as e:
import ipdb; ipdb.set_trace()
speech_probs.append(speech_prob) speech_probs.append(speech_prob)
# caculate progress and seng it to callback function # caculate progress and seng it to callback function
progress = current_start_sample + hop_size_samples progress = current_start_sample + window_size_samples
if progress > audio_length_samples: if progress > audio_length_samples:
progress = audio_length_samples progress = audio_length_samples
progress_percent = (progress / audio_length_samples) * 100 progress_percent = (progress / audio_length_samples) * 100
@@ -323,59 +304,45 @@ def get_speech_timestamps(audio: torch.Tensor,
current_speech = {} current_speech = {}
if neg_threshold is None: if neg_threshold is None:
neg_threshold = max(threshold - 0.15, 0.01) neg_threshold = threshold - 0.15
temp_end = 0 # to save potential segment end (and tolerate some silence) temp_end = 0 # to save potential segment end (and tolerate some silence)
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
possible_ends = []
for i, speech_prob in enumerate(speech_probs): for i, speech_prob in enumerate(speech_probs):
if (speech_prob >= threshold) and temp_end: if (speech_prob >= threshold) and temp_end:
if temp_end != 0: temp_end = 0
sil_dur = (hop_size_samples * i) - temp_end
if sil_dur > min_silence_samples_at_max_speech:
possible_ends.append((temp_end, sil_dur))
temp_end = 0
if next_start < prev_end: if next_start < prev_end:
next_start = hop_size_samples * i next_start = window_size_samples * i
if (speech_prob >= threshold) and not triggered: if (speech_prob >= threshold) and not triggered:
triggered = True triggered = True
current_speech['start'] = hop_size_samples * i current_speech['start'] = window_size_samples * i
continue continue
if triggered and (hop_size_samples * i) - current_speech['start'] > max_speech_samples: if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples:
if possible_ends: if prev_end:
if use_max_poss_sil_at_max_speech:
prev_end, dur = max(possible_ends, key=lambda x: x[1]) # use the longest possible silence segment in the current speech chunk
else:
prev_end, dur = possible_ends[-1] # use the last possible silence segement
current_speech['end'] = prev_end current_speech['end'] = prev_end
speeches.append(current_speech) speeches.append(current_speech)
current_speech = {} current_speech = {}
next_start = prev_end + dur if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
if next_start < prev_end + hop_size_samples * i: # previously reached silence (< neg_thres) and is still not speech (< thres)
#triggered = False
current_speech['start'] = next_start
else:
triggered = False triggered = False
#current_speech['start'] = next_start else:
current_speech['start'] = next_start
prev_end = next_start = temp_end = 0 prev_end = next_start = temp_end = 0
possible_ends = []
else: else:
current_speech['end'] = hop_size_samples * i current_speech['end'] = window_size_samples * i
speeches.append(current_speech) speeches.append(current_speech)
current_speech = {} current_speech = {}
prev_end = next_start = temp_end = 0 prev_end = next_start = temp_end = 0
triggered = False triggered = False
possible_ends = []
continue continue
if (speech_prob < neg_threshold) and triggered: if (speech_prob < neg_threshold) and triggered:
if not temp_end: if not temp_end:
temp_end = hop_size_samples * i temp_end = window_size_samples * i
# if ((hop_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
# prev_end = temp_end prev_end = temp_end
if (hop_size_samples * i) - temp_end < min_silence_samples: if (window_size_samples * i) - temp_end < min_silence_samples:
continue continue
else: else:
current_speech['end'] = temp_end current_speech['end'] = temp_end
@@ -384,7 +351,6 @@ def get_speech_timestamps(audio: torch.Tensor,
current_speech = {} current_speech = {}
prev_end = next_start = temp_end = 0 prev_end = next_start = temp_end = 0
triggered = False triggered = False
possible_ends = []
continue continue
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples: if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
@@ -406,17 +372,16 @@ def get_speech_timestamps(audio: torch.Tensor,
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
if return_seconds: if return_seconds:
audio_length_seconds = audio_length_samples / sampling_rate
for speech_dict in speeches: for speech_dict in speeches:
speech_dict['start'] = max(round(speech_dict['start'] / sampling_rate, time_resolution), 0) speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1)
speech_dict['end'] = min(round(speech_dict['end'] / sampling_rate, time_resolution), audio_length_seconds) speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1)
elif step > 1: elif step > 1:
for speech_dict in speeches: for speech_dict in speeches:
speech_dict['start'] *= step speech_dict['start'] *= step
speech_dict['end'] *= step speech_dict['end'] *= step
if visualize_probs: if visualize_probs:
make_visualization(speech_probs, hop_size_samples / sampling_rate) make_visualization(speech_probs, window_size_samples / sampling_rate)
return speeches return speeches
@@ -470,16 +435,13 @@ class VADIterator:
self.current_sample = 0 self.current_sample = 0
@torch.no_grad() @torch.no_grad()
def __call__(self, x, return_seconds=False, time_resolution: int = 1): def __call__(self, x, return_seconds=False):
""" """
x: torch.Tensor x: torch.Tensor
audio chunk (see examples in repo) audio chunk (see examples in repo)
return_seconds: bool (default - False) return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples) whether return timestamps in seconds (default - samples)
time_resolution: int (default - 1)
time resolution of speech coordinates when requested as seconds
""" """
if not torch.is_tensor(x): if not torch.is_tensor(x):
@@ -499,7 +461,7 @@ class VADIterator:
if (speech_prob >= self.threshold) and not self.triggered: if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True self.triggered = True
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples) speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)} return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
if (speech_prob < self.threshold - 0.15) and self.triggered: if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end: if not self.temp_end:
@@ -510,110 +472,24 @@ class VADIterator:
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
self.temp_end = 0 self.temp_end = 0
self.triggered = False self.triggered = False
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)} return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
return None return None
def collect_chunks(tss: List[dict], def collect_chunks(tss: List[dict],
wav: torch.Tensor, wav: torch.Tensor):
seconds: bool = False, chunks = []
sampling_rate: int = None) -> torch.Tensor: for i in tss:
"""Collect audio chunks from a longer audio clip chunks.append(wav[i['start']: i['end']])
This method extracts audio chunks from an audio clip, using a list of
provided coordinates, and concatenates them together. Coordinates can be
passed either as sample numbers or in seconds, in which case the audio
sampling rate is also needed.
Parameters
----------
tss: List[dict]
Coordinate list of the clips to collect from the audio.
wav: torch.Tensor, one dimensional
One dimensional float torch.Tensor, containing the audio to clip.
seconds: bool (default - False)
Whether input coordinates are passed as seconds or samples.
sampling_rate: int (default - None)
Input audio sampling rate. Required if seconds is True.
Returns
-------
torch.Tensor, one dimensional
One dimensional float torch.Tensor of the concatenated clipped audio
chunks.
Raises
------
ValueError
Raised if sampling_rate is not provided when seconds is True.
"""
if seconds and not sampling_rate:
raise ValueError('sampling_rate must be provided when seconds is True')
chunks = list()
_tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss
for i in _tss:
chunks.append(wav[i['start']:i['end']])
return torch.cat(chunks) return torch.cat(chunks)
def drop_chunks(tss: List[dict], def drop_chunks(tss: List[dict],
wav: torch.Tensor, wav: torch.Tensor):
seconds: bool = False, chunks = []
sampling_rate: int = None) -> torch.Tensor:
"""Drop audio chunks from a longer audio clip
This method extracts audio chunks from an audio clip, using a list of
provided coordinates, and drops them. Coordinates can be passed either as
sample numbers or in seconds, in which case the audio sampling rate is also
needed.
Parameters
----------
tss: List[dict]
Coordinate list of the clips to drop from from the audio.
wav: torch.Tensor, one dimensional
One dimensional float torch.Tensor, containing the audio to clip.
seconds: bool (default - False)
Whether input coordinates are passed as seconds or samples.
sampling_rate: int (default - None)
Input audio sampling rate. Required if seconds is True.
Returns
-------
torch.Tensor, one dimensional
One dimensional float torch.Tensor of the input audio minus the dropped
chunks.
Raises
------
ValueError
Raised if sampling_rate is not provided when seconds is True.
"""
if seconds and not sampling_rate:
raise ValueError('sampling_rate must be provided when seconds is True')
chunks = list()
cur_start = 0 cur_start = 0
for i in tss:
_tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss
for i in _tss:
chunks.append((wav[cur_start: i['start']])) chunks.append((wav[cur_start: i['start']]))
cur_start = i['end'] cur_start = i['end']
return torch.cat(chunks) return torch.cat(chunks)
def _seconds_to_samples_tss(tss: List[dict], sampling_rate: int) -> List[dict]:
"""Convert coordinates expressed in seconds to sample coordinates.
"""
return [{
'start': round(crd['start']) * sampling_rate,
'end': round(crd['end']) * sampling_rate
} for crd in tss]

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,22 +0,0 @@
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
import torch
torch.set_num_threads(1)
def test_jit_model():
model = load_silero_vad(onnx=False)
for path in ["tests/data/test.wav", "tests/data/test.opus", "tests/data/test.mp3"]:
audio = read_audio(path, sampling_rate=16000)
speech_timestamps = get_speech_timestamps(audio, model, visualize_probs=False, return_seconds=True)
assert speech_timestamps is not None
out = model.audio_forward(audio, sr=16000)
assert out is not None
def test_onnx_model():
model = load_silero_vad(onnx=True)
for path in ["tests/data/test.wav", "tests/data/test.opus", "tests/data/test.mp3"]:
audio = read_audio(path, sampling_rate=16000)
speech_timestamps = get_speech_timestamps(audio, model, visualize_probs=False, return_seconds=True)
assert speech_timestamps is not None
out = model.audio_forward(audio, sr=16000)
assert out is not None