mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 01:49:22 +08:00
Add C++ folder for supporting ONNX & LibTorch
This commit is contained in:
49
examples/c++/README.md
Normal file
49
examples/c++/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Silero-VAD V5 in C++ (based on LibTorch)
|
||||
|
||||
This is the source code for Silero-VAD V6 in C++, utilizing LibTorch & Onnxruntime.
|
||||
You should compare its results with the Python version.
|
||||
Results at 16 and 8kHz have been tested. Batch and CUDA inference options are deprecated.
|
||||
|
||||
|
||||
## Requirements
|
||||
- GCC 11.4.0 (GCC >= 5.1)
|
||||
- Onnxruntime 1.11.0 (other versions are also acceptable)
|
||||
- LibTorch 1.13.0 (other versions are also acceptable)
|
||||
|
||||
## Download LibTorch
|
||||
|
||||
```bash
|
||||
-Onnxruntime
|
||||
$wget https://github.com/microsoft/onnxruntime/releases/download/v1.11.1/onnxruntime-linux-x64-1.11.1.tgz
|
||||
$tar -xvf onnxruntime-linux-x64-1.11.1.tgz
|
||||
$ln -s onnxruntime-linux-x64-1.11.1 onnxruntime-linux #soft-link
|
||||
|
||||
-Libtorch
|
||||
$wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip
|
||||
$unzip libtorch-shared-with-deps-1.13.0+cpu.zip
|
||||
```
|
||||
|
||||
## Compilation
|
||||
|
||||
```bash
|
||||
-ONNX-build
|
||||
$g++ main.cc silero.cc -I ./onnxruntime-linux/include/ -L ./onnxruntime-linux/lib/ -lonnxruntime -Wl,-rpath,./onnxruntime-linux/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_ONNX
|
||||
|
||||
-TORCH-build
|
||||
$g++ main.cc silero.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_TORCH
|
||||
```
|
||||
|
||||
## Optional Compilation Flags
|
||||
-DUSE_TORCH
|
||||
-DUSE_ONNX
|
||||
|
||||
## Run the Program
|
||||
To run the program, use the following command:
|
||||
|
||||
`./silero <sample.wav> <SampleRate> <threshold>`
|
||||
`./silero aepyx.wav 16000 0.5`
|
||||
`./silero aepyx_8k.wav 8000 0.5`
|
||||
|
||||
The sample file aepyx.wav is part of the Voxconverse dataset.
|
||||
File details: aepyx.wav is a 16kHz, 16-bit audio file.
|
||||
File details: aepyx_8k.wav is a 8kHz, 16-bit audio file.
|
||||
BIN
examples/c++/aepyx.wav
Normal file
BIN
examples/c++/aepyx.wav
Normal file
Binary file not shown.
BIN
examples/c++/aepyx_8k.wav
Normal file
BIN
examples/c++/aepyx_8k.wav
Normal file
Binary file not shown.
61
examples/c++/main.cc
Normal file
61
examples/c++/main.cc
Normal file
@@ -0,0 +1,61 @@
|
||||
#include <iostream>
|
||||
#include "silero.h"
|
||||
#include "wav.h"
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
|
||||
if(argc != 4){
|
||||
std::cerr<<"Usage : "<<argv[0]<<" <wav.path> <SampleRate> <Threshold>"<<std::endl;
|
||||
std::cerr<<"Usage : "<<argv[0]<<" sample.wav 16000 0.5"<<std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string wav_path = argv[1];
|
||||
float sample_rate = std::stof(argv[2]);
|
||||
float threshold = std::stof(argv[3]);
|
||||
|
||||
if (sample_rate != 16000 && sample_rate != 8000) {
|
||||
std::cout<<"Unsupported sample rate (only 16000 or 8000)."<<std::endl;
|
||||
exit (0);
|
||||
}
|
||||
|
||||
//Load Model
|
||||
#ifdef USE_TORCH
|
||||
std::string model_path = "../../src/silero_vad/data/silero_vad.jit";
|
||||
#elif USE_ONNX
|
||||
std::string model_path = "../../src/silero_vad/data/silero_vad.onnx";
|
||||
#endif
|
||||
silero::VadIterator vad(model_path);
|
||||
|
||||
vad.threshold=threshold; //(Default:0.5)
|
||||
vad.sample_rate=sample_rate; //16000Hz,8000Hz. (Default:16000)
|
||||
vad.print_as_samples=false; //if true, it prints time-stamp with samples. otherwise, in seconds
|
||||
//(Default:false)
|
||||
|
||||
vad.SetVariables();
|
||||
|
||||
// Read wav
|
||||
wav::WavReader wav_reader(wav_path);
|
||||
std::vector<float> input_wav(wav_reader.num_samples());
|
||||
|
||||
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||
{
|
||||
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
|
||||
}
|
||||
|
||||
vad.SpeechProbs(input_wav);
|
||||
|
||||
std::vector<silero::Interval> speeches = vad.GetSpeechTimestamps();
|
||||
for(const auto& speech : speeches){
|
||||
if(vad.print_as_samples){
|
||||
std::cout<<"{'start': "<<static_cast<int>(speech.start)<<", 'end': "<<static_cast<int>(speech.end)<<"}"<<std::endl;
|
||||
}
|
||||
else{
|
||||
std::cout<<"{'start': "<<speech.start<<", 'end': "<<speech.end<<"}"<<std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
273
examples/c++/silero.cc
Normal file
273
examples/c++/silero.cc
Normal file
@@ -0,0 +1,273 @@
|
||||
// silero.cc
|
||||
// Author : NathanJHLee
|
||||
// Created On : 2025-11-10
|
||||
// Description : silero 6.2 system for onnx-runtime(c++) and torch-script(c++)
|
||||
// Version : 1.3
|
||||
|
||||
#include "silero.h"
|
||||
|
||||
|
||||
namespace silero {
|
||||
|
||||
#ifdef USE_TORCH
|
||||
VadIterator::VadIterator(const std::string &model_path,
|
||||
float threshold,
|
||||
int sample_rate,
|
||||
int window_size_ms,
|
||||
int speech_pad_ms,
|
||||
int min_silence_duration_ms,
|
||||
int min_speech_duration_ms,
|
||||
int max_duration_merge_ms,
|
||||
bool print_as_samples)
|
||||
: threshold(threshold), sample_rate(sample_rate), window_size_ms(window_size_ms),
|
||||
speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms),
|
||||
min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms),
|
||||
print_as_samples(print_as_samples)
|
||||
{
|
||||
init_torch_model(model_path);
|
||||
}
|
||||
|
||||
VadIterator::~VadIterator(){
|
||||
}
|
||||
|
||||
|
||||
void VadIterator::init_torch_model(const std::string& model_path) {
|
||||
at::set_num_threads(1);
|
||||
model = torch::jit::load(model_path);
|
||||
|
||||
model.eval();
|
||||
torch::NoGradGuard no_grad;
|
||||
std::cout<<"Silero libtorch-Model loaded successfully"<<std::endl;
|
||||
}
|
||||
|
||||
void VadIterator::SpeechProbs(std::vector<float>& input_wav) {
|
||||
int num_samples = input_wav.size();
|
||||
int num_chunks = num_samples / window_size_samples;
|
||||
int remainder_samples = num_samples % window_size_samples;
|
||||
total_sample_size += num_samples;
|
||||
|
||||
std::vector<torch::Tensor> chunks;
|
||||
|
||||
for (int i = 0; i < num_chunks; i++) {
|
||||
float* chunk_start = input_wav.data() + i * window_size_samples;
|
||||
torch::Tensor chunk = torch::from_blob(chunk_start, {1, window_size_samples}, torch::kFloat32);
|
||||
chunks.push_back(chunk);
|
||||
|
||||
if (i == num_chunks - 1 && remainder_samples > 0) {
|
||||
int remaining_samples = num_samples - num_chunks * window_size_samples;
|
||||
float* chunk_start_remainder = input_wav.data() + num_chunks * window_size_samples;
|
||||
torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1, remaining_samples}, torch::kFloat32);
|
||||
torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples - remaining_samples}, torch::kFloat32)}, 1);
|
||||
chunks.push_back(padded_chunk);
|
||||
}
|
||||
}
|
||||
|
||||
if (!chunks.empty()) {
|
||||
std::vector<torch::Tensor> outputs;
|
||||
torch::Tensor batched_chunks = torch::stack(chunks);
|
||||
for (size_t i = 0; i < chunks.size(); i++) {
|
||||
torch::NoGradGuard no_grad;
|
||||
std::vector<torch::jit::IValue> inputs;
|
||||
inputs.push_back(batched_chunks[i]);
|
||||
inputs.push_back(sample_rate);
|
||||
torch::Tensor output = model.forward(inputs).toTensor();
|
||||
outputs.push_back(output);
|
||||
}
|
||||
torch::Tensor all_outputs = torch::stack(outputs);
|
||||
for (size_t i = 0; i < chunks.size(); i++) {
|
||||
float output_f = all_outputs[i].item<float>();
|
||||
outputs_prob.push_back(output_f);
|
||||
//////To print Probs by libtorch
|
||||
//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#elif USE_ONNX
|
||||
|
||||
VadIterator::VadIterator(const std::string &model_path,
|
||||
float threshold,
|
||||
int sample_rate,
|
||||
int window_size_ms,
|
||||
int speech_pad_ms,
|
||||
int min_silence_duration_ms,
|
||||
int min_speech_duration_ms,
|
||||
int max_duration_merge_ms,
|
||||
bool print_as_samples)
|
||||
:sample_rate(sample_rate), threshold(threshold), window_size_ms(window_size_ms),
|
||||
speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms),
|
||||
min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms),
|
||||
print_as_samples(print_as_samples),
|
||||
env(ORT_LOGGING_LEVEL_ERROR, "Vad"), session_options(), session(nullptr), allocator(),
|
||||
memory_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU)), context_samples(64),
|
||||
_context(64, 0.0f), current_sample(0), size_state(2 * 1 * 128),
|
||||
input_node_names({"input", "state", "sr"}), output_node_names({"output", "stateN"}),
|
||||
state_node_dims{2, 1, 128}, sr_node_dims{1}
|
||||
|
||||
{
|
||||
init_onnx_model(model_path);
|
||||
}
|
||||
VadIterator::~VadIterator(){
|
||||
}
|
||||
|
||||
void VadIterator::init_onnx_model(const std::string& model_path) {
|
||||
int inter_threads=1;
|
||||
int intra_threads=1;
|
||||
session_options.SetIntraOpNumThreads(intra_threads);
|
||||
session_options.SetInterOpNumThreads(inter_threads);
|
||||
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
|
||||
std::cout<<"Silero onnx-Model loaded successfully"<<std::endl;
|
||||
}
|
||||
|
||||
float VadIterator::predict(const std::vector<float>& data_chunk) {
|
||||
// _context와 현재 청크를 결합하여 입력 데이터 구성
|
||||
std::vector<float> new_data(effective_window_size, 0.0f);
|
||||
std::copy(_context.begin(), _context.end(), new_data.begin());
|
||||
std::copy(data_chunk.begin(), data_chunk.end(), new_data.begin() + context_samples);
|
||||
input = new_data;
|
||||
|
||||
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, input.data(), input.size(), input_node_dims, 2);
|
||||
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, _state.data(), _state.size(), state_node_dims, 3);
|
||||
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
|
||||
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
|
||||
ort_inputs.clear();
|
||||
ort_inputs.push_back(std::move(input_ort));
|
||||
ort_inputs.push_back(std::move(state_ort));
|
||||
ort_inputs.push_back(std::move(sr_ort));
|
||||
|
||||
ort_outputs = session->Run(
|
||||
Ort::RunOptions{ nullptr },
|
||||
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
|
||||
output_node_names.data(), output_node_names.size());
|
||||
|
||||
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0]; // ONNX 출력: 첫 번째 값이 음성 확률
|
||||
|
||||
float* stateN = ort_outputs[1].GetTensorMutableData<float>(); // 두 번째 출력값: 상태 업데이트
|
||||
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
|
||||
|
||||
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
|
||||
// _context 업데이트: new_data의 마지막 context_samples 유지
|
||||
|
||||
return speech_prob;
|
||||
}
|
||||
void VadIterator::SpeechProbs(std::vector<float>& input_wav) {
|
||||
reset_states();
|
||||
total_sample_size = static_cast<int>(input_wav.size());
|
||||
for (size_t j = 0; j < static_cast<size_t>(total_sample_size); j += window_size_samples) {
|
||||
if (j + window_size_samples > static_cast<size_t>(total_sample_size))
|
||||
break;
|
||||
std::vector<float> chunk(input_wav.begin() + j, input_wav.begin() + j + window_size_samples);
|
||||
float speech_prob = predict(chunk);
|
||||
outputs_prob.push_back(speech_prob);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void VadIterator::reset_states() {
|
||||
triggered = false;
|
||||
current_sample = 0;
|
||||
temp_end = 0;
|
||||
outputs_prob.clear();
|
||||
total_sample_size = 0;
|
||||
|
||||
#ifdef USE_TORCH
|
||||
model.run_method("reset_states"); // Reset model states if applicable
|
||||
#elif USE_ONNX
|
||||
std::memset(_state.data(), 0, _state.size() * sizeof(float));
|
||||
std::fill(_context.begin(), _context.end(), 0.0f);
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<Interval> VadIterator::GetSpeechTimestamps() {
|
||||
std::vector<Interval> speeches = DoVad();
|
||||
if(!print_as_samples){
|
||||
for (auto& speech : speeches) {
|
||||
speech.start /= sample_rate;
|
||||
speech.end /= sample_rate;
|
||||
}
|
||||
}
|
||||
return speeches;
|
||||
}
|
||||
|
||||
void VadIterator::SetVariables(){
|
||||
// Initialize internal engine parameters
|
||||
init_engine(window_size_ms);
|
||||
}
|
||||
|
||||
void VadIterator::init_engine(int window_size_ms) {
|
||||
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
|
||||
speech_pad_samples = sample_rate * speech_pad_ms / 1000;
|
||||
window_size_samples = sample_rate / 1000 * window_size_ms;
|
||||
min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
|
||||
#ifdef USE_ONNX
|
||||
//for ONNX
|
||||
context_samples=window_size_samples / 8;
|
||||
_context.assign(context_samples, 0.0f);
|
||||
|
||||
effective_window_size = window_size_samples + context_samples; // 예: 512 + 64 = 576 samples
|
||||
input_node_dims[0] = 1;
|
||||
input_node_dims[1] = effective_window_size;
|
||||
_state.resize(size_state);
|
||||
sr.resize(1);
|
||||
sr[0] = sample_rate;
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<Interval> VadIterator::DoVad() {
|
||||
std::vector<Interval> speeches;
|
||||
for (size_t i = 0; i < outputs_prob.size(); ++i) {
|
||||
float speech_prob = outputs_prob[i];
|
||||
current_sample += window_size_samples;
|
||||
if (speech_prob >= threshold && temp_end != 0) {
|
||||
temp_end = 0;
|
||||
}
|
||||
|
||||
if (speech_prob >= threshold) {
|
||||
if (!triggered) {
|
||||
triggered = true;
|
||||
Interval segment;
|
||||
segment.start = std::max(0, current_sample - speech_pad_samples - window_size_samples);
|
||||
speeches.push_back(segment);
|
||||
}
|
||||
}else {
|
||||
if (triggered) {
|
||||
if (speech_prob < threshold - 0.15f) {
|
||||
if (temp_end == 0) {
|
||||
temp_end = current_sample;
|
||||
}
|
||||
if (current_sample - temp_end >= min_silence_samples) {
|
||||
Interval& segment = speeches.back();
|
||||
segment.end = temp_end + speech_pad_samples - window_size_samples;
|
||||
temp_end = 0;
|
||||
triggered = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
if (triggered) {
|
||||
std::cout<<"Finalizing active speech segment at stream end."<<std::endl;
|
||||
Interval& segment = speeches.back();
|
||||
segment.end = total_sample_size;
|
||||
triggered = false;
|
||||
}
|
||||
speeches.erase(std::remove_if(speeches.begin(), speeches.end(),
|
||||
[this](const Interval& speech) {
|
||||
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
|
||||
}), speeches.end());
|
||||
|
||||
reset_states();
|
||||
return speeches;
|
||||
}
|
||||
|
||||
|
||||
} // namespace silero
|
||||
|
||||
123
examples/c++/silero.h
Normal file
123
examples/c++/silero.h
Normal file
@@ -0,0 +1,123 @@
|
||||
#ifndef SILERO_H
|
||||
#define SILERO_H
|
||||
|
||||
// silero.h
|
||||
// Author : NathanJHLee
|
||||
// Created On : 2025-11-10
|
||||
// Description : silero 6.2 system for onnx-runtime(c++) and torch-script(c++)
|
||||
// Version : 1.3
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <chrono>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
#ifdef USE_TORCH
|
||||
#include <torch/torch.h>
|
||||
#include <torch/script.h>
|
||||
#elif USE_ONNX
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#endif
|
||||
|
||||
namespace silero {
|
||||
|
||||
struct Interval {
|
||||
float start;
|
||||
float end;
|
||||
int numberOfSubseg;
|
||||
|
||||
void initialize() {
|
||||
start = 0;
|
||||
end = 0;
|
||||
numberOfSubseg = 0;
|
||||
}
|
||||
};
|
||||
|
||||
class VadIterator {
|
||||
public:
|
||||
VadIterator(const std::string &model_path,
|
||||
float threshold = 0.5,
|
||||
int sample_rate = 16000,
|
||||
int window_size_ms = 32,
|
||||
int speech_pad_ms = 30,
|
||||
int min_silence_duration_ms = 100,
|
||||
int min_speech_duration_ms = 250,
|
||||
int max_duration_merge_ms = 300,
|
||||
bool print_as_samples = false);
|
||||
~VadIterator();
|
||||
|
||||
// Batch (non-streaming) interface (for backward compatibility)
|
||||
void SpeechProbs(std::vector<float>& input_wav);
|
||||
std::vector<Interval> GetSpeechTimestamps();
|
||||
void SetVariables();
|
||||
|
||||
// Public parameters (can be modified by user)
|
||||
float threshold;
|
||||
int sample_rate;
|
||||
int window_size_ms;
|
||||
int min_speech_duration_ms;
|
||||
int max_duration_merge_ms;
|
||||
bool print_as_samples;
|
||||
|
||||
private:
|
||||
#ifdef USE_TORCH
|
||||
torch::jit::script::Module model;
|
||||
void init_torch_model(const std::string& model_path);
|
||||
#elif USE_ONNX
|
||||
Ort::Env env; // 환경 객체
|
||||
Ort::SessionOptions session_options; // 세션 옵션
|
||||
std::shared_ptr<Ort::Session> session; // ONNX 세션
|
||||
Ort::AllocatorWithDefaultOptions allocator; // 기본 할당자
|
||||
Ort::MemoryInfo memory_info; // 메모리 정보 (CPU)
|
||||
|
||||
void init_onnx_model(const std::string& model_path);
|
||||
float predict(const std::vector<float>& data_chunk);
|
||||
|
||||
//const int context_samples; // 예: 64 samples
|
||||
int context_samples; // 예: 64 samples
|
||||
std::vector<float> _context; // 초기값 모두 0
|
||||
int effective_window_size;
|
||||
|
||||
// ONNX 입력/출력 관련 버퍼 및 노드 이름들
|
||||
std::vector<Ort::Value> ort_inputs;
|
||||
std::vector<const char*> input_node_names;
|
||||
std::vector<float> input;
|
||||
unsigned int size_state; // 고정값: 2*1*128
|
||||
std::vector<float> _state;
|
||||
std::vector<int64_t> sr;
|
||||
int64_t input_node_dims[2]; // [1, effective_window_size]
|
||||
const int64_t state_node_dims[3]; // [ 2, 1, 128 ]
|
||||
const int64_t sr_node_dims[1]; // [ 1 ]
|
||||
std::vector<Ort::Value> ort_outputs;
|
||||
std::vector<const char*> output_node_names; // 기본값: [ "output", "stateN" ]
|
||||
#endif
|
||||
std::vector<float> outputs_prob; // used in batch mode
|
||||
int min_silence_samples;
|
||||
int min_speech_samples;
|
||||
int speech_pad_samples;
|
||||
int window_size_samples;
|
||||
int duration_merge_samples;
|
||||
int current_sample = 0;
|
||||
int total_sample_size = 0;
|
||||
int min_silence_duration_ms;
|
||||
int speech_pad_ms;
|
||||
bool triggered = false;
|
||||
int temp_end = 0;
|
||||
int global_end = 0;
|
||||
int erase_tail_count = 0;
|
||||
|
||||
|
||||
void init_engine(int window_size_ms);
|
||||
void reset_states();
|
||||
std::vector<Interval> DoVad();
|
||||
|
||||
|
||||
};
|
||||
|
||||
} // namespace silero
|
||||
|
||||
#endif // SILERO_H
|
||||
|
||||
237
examples/c++/wav.h
Normal file
237
examples/c++/wav.h
Normal file
@@ -0,0 +1,237 @@
|
||||
// Copyright (c) 2016 Personal (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef FRONTEND_WAV_H_
|
||||
#define FRONTEND_WAV_H_
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
// #include "utils/log.h"
|
||||
|
||||
namespace wav {
|
||||
|
||||
struct WavHeader {
|
||||
char riff[4]; // "riff"
|
||||
unsigned int size;
|
||||
char wav[4]; // "WAVE"
|
||||
char fmt[4]; // "fmt "
|
||||
unsigned int fmt_size;
|
||||
uint16_t format;
|
||||
uint16_t channels;
|
||||
unsigned int sample_rate;
|
||||
unsigned int bytes_per_second;
|
||||
uint16_t block_size;
|
||||
uint16_t bit;
|
||||
char data[4]; // "data"
|
||||
unsigned int data_size;
|
||||
};
|
||||
|
||||
class WavReader {
|
||||
public:
|
||||
WavReader() : data_(nullptr) {}
|
||||
explicit WavReader(const std::string& filename) { Open(filename); }
|
||||
|
||||
bool Open(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
|
||||
if (NULL == fp) {
|
||||
std::cout << "Error in read " << filename;
|
||||
return false;
|
||||
}
|
||||
|
||||
WavHeader header;
|
||||
fread(&header, 1, sizeof(header), fp);
|
||||
if (header.fmt_size < 16) {
|
||||
printf("WaveData: expect PCM format data "
|
||||
"to have fmt chunk of at least size 16.\n");
|
||||
return false;
|
||||
} else if (header.fmt_size > 16) {
|
||||
int offset = 44 - 8 + header.fmt_size - 16;
|
||||
fseek(fp, offset, SEEK_SET);
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
// check "riff" "WAVE" "fmt " "data"
|
||||
|
||||
// Skip any sub-chunks between "fmt" and "data". Usually there will
|
||||
// be a single "fact" sub chunk, but on Windows there can also be a
|
||||
// "list" sub chunk.
|
||||
while (0 != strncmp(header.data, "data", 4)) {
|
||||
// We will just ignore the data in these chunks.
|
||||
fseek(fp, header.data_size, SEEK_CUR);
|
||||
// read next sub chunk
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
|
||||
if (header.data_size == 0) {
|
||||
int offset = ftell(fp);
|
||||
fseek(fp, 0, SEEK_END);
|
||||
header.data_size = ftell(fp) - offset;
|
||||
fseek(fp, offset, SEEK_SET);
|
||||
}
|
||||
|
||||
num_channel_ = header.channels;
|
||||
sample_rate_ = header.sample_rate;
|
||||
bits_per_sample_ = header.bit;
|
||||
int num_data = header.data_size / (bits_per_sample_ / 8);
|
||||
data_ = new float[num_data]; // Create 1-dim array
|
||||
num_samples_ = num_data / num_channel_;
|
||||
|
||||
std::cout << "num_channel_ :" << num_channel_ << std::endl;
|
||||
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
|
||||
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
|
||||
std::cout << "num_samples :" << num_data << std::endl;
|
||||
std::cout << "num_data_size :" << header.data_size << std::endl;
|
||||
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(char), fp);
|
||||
data_[i] = static_cast<float>(sample) / 32768;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(int16_t), fp);
|
||||
data_[i] = static_cast<float>(sample) / 32768;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 32:
|
||||
{
|
||||
if (header.format == 1) //S32
|
||||
{
|
||||
int sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(int), fp);
|
||||
data_[i] = static_cast<float>(sample) / 32768;
|
||||
}
|
||||
}
|
||||
else if (header.format == 3) // IEEE-float
|
||||
{
|
||||
float sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(float), fp);
|
||||
data_[i] = static_cast<float>(sample);
|
||||
}
|
||||
}
|
||||
else {
|
||||
printf("unsupported quantization bits\n");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
printf("unsupported quantization bits\n");
|
||||
break;
|
||||
}
|
||||
|
||||
fclose(fp);
|
||||
return true;
|
||||
}
|
||||
|
||||
int num_channel() const { return num_channel_; }
|
||||
int sample_rate() const { return sample_rate_; }
|
||||
int bits_per_sample() const { return bits_per_sample_; }
|
||||
int num_samples() const { return num_samples_; }
|
||||
|
||||
~WavReader() {
|
||||
delete[] data_;
|
||||
}
|
||||
|
||||
const float* data() const { return data_; }
|
||||
|
||||
private:
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
int num_samples_; // sample points per channel
|
||||
float* data_;
|
||||
};
|
||||
|
||||
class WavWriter {
|
||||
public:
|
||||
WavWriter(const float* data, int num_samples, int num_channel,
|
||||
int sample_rate, int bits_per_sample)
|
||||
: data_(data),
|
||||
num_samples_(num_samples),
|
||||
num_channel_(num_channel),
|
||||
sample_rate_(sample_rate),
|
||||
bits_per_sample_(bits_per_sample) {}
|
||||
|
||||
void Write(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "w");
|
||||
// init char 'riff' 'WAVE' 'fmt ' 'data'
|
||||
WavHeader header;
|
||||
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
|
||||
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
|
||||
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
|
||||
memcpy(&header, wav_header, sizeof(header));
|
||||
header.channels = num_channel_;
|
||||
header.bit = bits_per_sample_;
|
||||
header.sample_rate = sample_rate_;
|
||||
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.size = sizeof(header) - 8 + header.data_size;
|
||||
header.bytes_per_second =
|
||||
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.block_size = num_channel_ * (bits_per_sample_ / 8);
|
||||
|
||||
fwrite(&header, 1, sizeof(header), fp);
|
||||
|
||||
for (int i = 0; i < num_samples_; ++i) {
|
||||
for (int j = 0; j < num_channel_; ++j) {
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample = static_cast<char>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
int sample = static_cast<int>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fclose(fp);
|
||||
}
|
||||
|
||||
private:
|
||||
const float* data_;
|
||||
int num_samples_; // total float points in data_
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
};
|
||||
|
||||
} // namespace wav
|
||||
|
||||
#endif // FRONTEND_WAV_H_
|
||||
|
||||
|
||||
Reference in New Issue
Block a user