mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
add c++ inference based on libtorch
This commit is contained in:
35
examples/cpp_libtorch/ReadMe
Normal file
35
examples/cpp_libtorch/ReadMe
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
This is the source code for Silero-VAD 5.1 in C++, based on LibTorch.
|
||||||
|
The primary implementation is the CPU version, and you should compare its results with the Python version.
|
||||||
|
|
||||||
|
In addition, Batch and CUDA inference options are also available if you want to explore further.
|
||||||
|
Note that when using batch inference, the speech probabilities might slightly differ from the standard version, likely due to differences in caching.
|
||||||
|
Unlike processing individual inputs, batch inference may not be able to use the cache from previous chunks.
|
||||||
|
Nevertheless, batch inference provides significantly faster processing.
|
||||||
|
For optimal performance, carefully adjust the threshold when using batch inference.
|
||||||
|
|
||||||
|
#Requirements:
|
||||||
|
GCC 11.4.0 (GCC >= 5.1)
|
||||||
|
LibTorch 1.13.0(Other versions are also acceptable)
|
||||||
|
|
||||||
|
#Download Libtorch:
|
||||||
|
#cpu
|
||||||
|
$wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip
|
||||||
|
$unzip libtorch-shared-with-deps-1.13.0+cpu.zip
|
||||||
|
|
||||||
|
#cuda
|
||||||
|
$wget https://download.pytorch.org/libtorch/cu116/libtorch-shared-with-deps-1.13.0%2Bcu116.zip
|
||||||
|
$unzip libtorch-shared-with-deps-1.13.0+cu116.zip
|
||||||
|
|
||||||
|
#complie:
|
||||||
|
#cpu
|
||||||
|
$g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0
|
||||||
|
#cuda
|
||||||
|
$g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cuda -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_GPU
|
||||||
|
|
||||||
|
#option to add
|
||||||
|
-DUSE_BATCH
|
||||||
|
-DUSE_GPU
|
||||||
|
|
||||||
|
# Run:
|
||||||
|
./silero aepyx.wav 0.5 #The sample file 'aepyx.wav' is part of the Voxconverse dataset.
|
||||||
|
#aepyx.wav : 16kHz, 16-bit
|
||||||
BIN
examples/cpp_libtorch/aepyx.wav
Normal file
BIN
examples/cpp_libtorch/aepyx.wav
Normal file
Binary file not shown.
51
examples/cpp_libtorch/main.cc
Normal file
51
examples/cpp_libtorch/main.cc
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include "silero_torch.h"
|
||||||
|
#include "wav.h"
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
|
||||||
|
if(argc != 3){
|
||||||
|
std::cerr<<"Usage : "<<argv[0]<<" <wav.path> threshold"<<std::endl;
|
||||||
|
std::cerr<<"Usage : "<<argv[0]<<" sample.wav 0.38"<<std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string wav_path = argv[1];
|
||||||
|
float threshold = std::stof(argv[2]);
|
||||||
|
|
||||||
|
|
||||||
|
//Load Model
|
||||||
|
std::string model_path = "../../src/silero_vad/data/silero_vad.jit";
|
||||||
|
silero::VadIterator vad(model_path);
|
||||||
|
vad.threshold=threshold;
|
||||||
|
vad.min_speech_duration_ms=255;
|
||||||
|
vad.max_duration_merge_ms=300;
|
||||||
|
vad.print_as_samples=true; //if true, it prints time-stamp with sample numbers.
|
||||||
|
//(Default:false)
|
||||||
|
|
||||||
|
// Read wav
|
||||||
|
wav::WavReader wav_reader(wav_path);
|
||||||
|
std::vector<float> input_wav(wav_reader.num_samples());
|
||||||
|
|
||||||
|
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||||
|
{
|
||||||
|
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
|
||||||
|
}
|
||||||
|
|
||||||
|
vad.SpeechProbs(input_wav);
|
||||||
|
|
||||||
|
std::vector<silero::Interval> speeches = vad.GetSpeechTimestamps();
|
||||||
|
for(const auto& speech : speeches){
|
||||||
|
if(vad.print_as_samples){
|
||||||
|
std::cout<<"{'start': "<<static_cast<int>(speech.start)<<", 'end': "<<static_cast<int>(speech.end)<<"}"<<std::endl;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
std::cout<<"{'start': "<<speech.start<<", 'end': "<<speech.end<<"}"<<std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
283
examples/cpp_libtorch/silero_torch.cc
Normal file
283
examples/cpp_libtorch/silero_torch.cc
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
//Author : Nathan Lee
|
||||||
|
//Created On : 2024-11-18
|
||||||
|
//Description : silero 5.1 system for torch-script(c++).
|
||||||
|
//Version : 1.0
|
||||||
|
//Contact : junghan4242@gmail.com
|
||||||
|
|
||||||
|
|
||||||
|
#include "silero_torch.h"
|
||||||
|
|
||||||
|
namespace silero {
|
||||||
|
|
||||||
|
VadIterator::VadIterator(const std::string &model_path, float threshold, int sample_rate, int window_size_ms, int speech_pad_ms, int min_silence_duration_ms, int min_speech_duration_ms, int max_duration_merge_ms, bool print_as_samples)
|
||||||
|
:sample_rate(sample_rate), threshold(threshold), speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms), min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms), print_as_samples(print_as_samples)
|
||||||
|
{
|
||||||
|
init_torch_model(model_path);
|
||||||
|
init_engine(window_size_ms);
|
||||||
|
}
|
||||||
|
VadIterator::~VadIterator(){
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void VadIterator::SpeechProbs(std::vector<float>& input_wav){
|
||||||
|
// Set the sample rate (must match the model's expected sample rate)
|
||||||
|
// Process the waveform in chunks of 512 samples
|
||||||
|
int num_samples = input_wav.size();
|
||||||
|
int num_chunks = num_samples / window_size_samples;
|
||||||
|
int remainder_samples = num_samples % window_size_samples;
|
||||||
|
|
||||||
|
total_sample_size += num_samples;
|
||||||
|
|
||||||
|
torch::Tensor output;
|
||||||
|
std::vector<torch::Tensor> chunks;
|
||||||
|
|
||||||
|
for (int i = 0; i < num_chunks; i++) {
|
||||||
|
|
||||||
|
float* chunk_start = input_wav.data() + i *window_size_samples;
|
||||||
|
torch::Tensor chunk = torch::from_blob(chunk_start, {1,window_size_samples}, torch::kFloat32);
|
||||||
|
//std::cout<<"chunk size : "<<chunk.sizes()<<std::endl;
|
||||||
|
chunks.push_back(chunk);
|
||||||
|
|
||||||
|
|
||||||
|
if(i==num_chunks-1 && remainder_samples>0){//마지막 chunk && 나머지가 존재
|
||||||
|
int remaining_samples = num_samples - num_chunks * window_size_samples;
|
||||||
|
//std::cout<<"Remainder size : "<<remaining_samples;
|
||||||
|
float* chunk_start_remainder = input_wav.data() + num_chunks *window_size_samples;
|
||||||
|
|
||||||
|
torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1,remaining_samples},
|
||||||
|
torch::kFloat32);
|
||||||
|
// Pad the remainder chunk to match window_size_samples
|
||||||
|
torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples
|
||||||
|
- remaining_samples}, torch::kFloat32)}, 1);
|
||||||
|
//std::cout<<", padded_chunk size : "<<padded_chunk.size(1)<<std::endl;
|
||||||
|
|
||||||
|
chunks.push_back(padded_chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!chunks.empty()) {
|
||||||
|
|
||||||
|
#ifdef USE_BATCH
|
||||||
|
torch::Tensor batched_chunks = torch::stack(chunks); // Stack all chunks into a single tensor
|
||||||
|
//batched_chunks = batched_chunks.squeeze(1);
|
||||||
|
batched_chunks = torch::cat({batched_chunks.squeeze(1)});
|
||||||
|
|
||||||
|
#ifdef USE_GPU
|
||||||
|
batched_chunks = batched_chunks.to(at::kCUDA); // Move the entire batch to GPU once
|
||||||
|
#endif
|
||||||
|
// Prepare input for model
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
inputs.push_back(batched_chunks); // Batch of chunks
|
||||||
|
inputs.push_back(sample_rate); // Assuming sample_rate is a valid input for the model
|
||||||
|
|
||||||
|
// Run inference on the batch
|
||||||
|
torch::NoGradGuard no_grad;
|
||||||
|
torch::Tensor output = model.forward(inputs).toTensor();
|
||||||
|
#ifdef USE_GPU
|
||||||
|
output = output.to(at::kCPU); // Move the output back to CPU once
|
||||||
|
#endif
|
||||||
|
// Collect output probabilities
|
||||||
|
for (int i = 0; i < chunks.size(); i++) {
|
||||||
|
float output_f = output[i].item<float>();
|
||||||
|
outputs_prob.push_back(output_f);
|
||||||
|
//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> outputs;
|
||||||
|
torch::Tensor batched_chunks = torch::stack(chunks);
|
||||||
|
#ifdef USE_GPU
|
||||||
|
batched_chunks = batched_chunks.to(at::kCUDA);
|
||||||
|
#endif
|
||||||
|
for (int i = 0; i < chunks.size(); i++) {
|
||||||
|
torch::NoGradGuard no_grad;
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
inputs.push_back(batched_chunks[i]);
|
||||||
|
inputs.push_back(sample_rate);
|
||||||
|
|
||||||
|
torch::Tensor output = model.forward(inputs).toTensor();
|
||||||
|
outputs.push_back(output);
|
||||||
|
}
|
||||||
|
torch::Tensor all_outputs = torch::stack(outputs);
|
||||||
|
#ifdef USE_GPU
|
||||||
|
all_outputs = all_outputs.to(at::kCPU);
|
||||||
|
#endif
|
||||||
|
for (int i = 0; i < chunks.size(); i++) {
|
||||||
|
float output_f = all_outputs[i].item<float>();
|
||||||
|
outputs_prob.push_back(output_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<Interval> VadIterator::GetSpeechTimestamps() {
|
||||||
|
std::vector<Interval> speeches = DoVad();
|
||||||
|
|
||||||
|
#ifdef USE_BATCH
|
||||||
|
//When you use BATCH inference. You would better use 'mergeSpeeches' function to arrage time stamp.
|
||||||
|
//It could be better get reasonable output because of distorted probs.
|
||||||
|
duration_merge_samples = sample_rate * max_duration_merge_ms / 1000;
|
||||||
|
std::vector<Interval> speeches_merge = mergeSpeeches(speeches, duration_merge_samples);
|
||||||
|
if(!print_as_samples){
|
||||||
|
for (auto& speech : speeches_merge) { //samples to second
|
||||||
|
speech.start /= sample_rate;
|
||||||
|
speech.end /= sample_rate;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return speeches_merge;
|
||||||
|
#else
|
||||||
|
|
||||||
|
if(!print_as_samples){
|
||||||
|
for (auto& speech : speeches) { //samples to second
|
||||||
|
speech.start /= sample_rate;
|
||||||
|
speech.end /= sample_rate;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return speeches;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::init_engine(int window_size_ms) {
|
||||||
|
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
|
||||||
|
speech_pad_samples = sample_rate * speech_pad_ms / 1000;
|
||||||
|
window_size_samples = sample_rate / 1000 * window_size_ms;
|
||||||
|
min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::init_torch_model(const std::string& model_path) {
|
||||||
|
at::set_num_threads(1);
|
||||||
|
model = torch::jit::load(model_path);
|
||||||
|
|
||||||
|
#ifdef USE_GPU
|
||||||
|
if (!torch::cuda::is_available()) {
|
||||||
|
std::cout<<"CUDA is not available! Please check your GPU settings"<<std::endl;
|
||||||
|
throw std::runtime_error("CUDA is not available!");
|
||||||
|
model.to(at::Device(at::kCPU));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
std::cout<<"CUDA available! Running on '0'th GPU"<<std::endl;
|
||||||
|
model.to(at::Device(at::kCUDA, 0)); //select 0'th machine
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
model.eval();
|
||||||
|
torch::NoGradGuard no_grad;
|
||||||
|
std::cout << "Model loaded successfully"<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void VadIterator::reset_states() {
|
||||||
|
triggered = false;
|
||||||
|
current_sample = 0;
|
||||||
|
temp_end = 0;
|
||||||
|
outputs_prob.clear();
|
||||||
|
model.run_method("reset_states");
|
||||||
|
total_sample_size = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Interval> VadIterator::DoVad() {
|
||||||
|
std::vector<Interval> speeches;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < outputs_prob.size(); ++i) {
|
||||||
|
float speech_prob = outputs_prob[i];
|
||||||
|
//std::cout << speech_prob << std::endl;
|
||||||
|
//std::cout << "Chunk " << i << " Prob: " << speech_prob << "\n";
|
||||||
|
//std::cout << speech_prob << " ";
|
||||||
|
current_sample += window_size_samples;
|
||||||
|
|
||||||
|
if (speech_prob >= threshold && temp_end != 0) {
|
||||||
|
temp_end = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speech_prob >= threshold && !triggered) {
|
||||||
|
triggered = true;
|
||||||
|
Interval segment;
|
||||||
|
segment.start = std::max(static_cast<int>(0), current_sample - speech_pad_samples - window_size_samples);
|
||||||
|
speeches.push_back(segment);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (speech_prob < threshold - 0.15f && triggered) {
|
||||||
|
if (temp_end == 0) {
|
||||||
|
temp_end = current_sample;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (current_sample - temp_end < min_silence_samples) {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
Interval& segment = speeches.back();
|
||||||
|
segment.end = temp_end + speech_pad_samples - window_size_samples;
|
||||||
|
temp_end = 0;
|
||||||
|
triggered = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (triggered) { //만약 낮은 확률을 보이다가 마지막프레임 prbos만 딱 확률이 높게 나오면 위에서 triggerd = true 메핑과 동시에 segment start가 돼서 문제가 될것 같은데? start = end 같은값? 후처리가 있으니 문제가 없으려나?
|
||||||
|
std::cout<<"when last triggered is keep working until last Probs"<<std::endl;
|
||||||
|
Interval& segment = speeches.back();
|
||||||
|
segment.end = total_sample_size; // 현재 샘플을 마지막 구간의 종료 시간으로 설정
|
||||||
|
triggered = false; // VAD 상태 초기화
|
||||||
|
}
|
||||||
|
|
||||||
|
speeches.erase(
|
||||||
|
std::remove_if(
|
||||||
|
speeches.begin(),
|
||||||
|
speeches.end(),
|
||||||
|
[this](const Interval& speech) {
|
||||||
|
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
|
||||||
|
//min_speech_samples is 4000samples(0.25sec)
|
||||||
|
//여기서 포인트!! 계산 할때는 start,end sample에'speech_pad_samples' 사이즈를 추가한후 길이를 측정함.
|
||||||
|
}
|
||||||
|
),
|
||||||
|
speeches.end()
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
//std::cout<<std::endl;
|
||||||
|
//std::cout<<"outputs_prob.size : "<<outputs_prob.size()<<std::endl;
|
||||||
|
|
||||||
|
reset_states();
|
||||||
|
return speeches;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Interval> VadIterator::mergeSpeeches(const std::vector<Interval>& speeches, int duration_merge_samples) {
|
||||||
|
std::vector<Interval> mergedSpeeches;
|
||||||
|
|
||||||
|
if (speeches.empty()) {
|
||||||
|
return mergedSpeeches; // 빈 벡터 반환
|
||||||
|
}
|
||||||
|
|
||||||
|
// 첫 번째 구간으로 초기화
|
||||||
|
Interval currentSegment = speeches[0];
|
||||||
|
|
||||||
|
for (size_t i = 1; i < speeches.size(); ++i) { //첫번째 start,end 정보 건너뛰기. 그래서 i=1부터
|
||||||
|
// 두 구간의 차이가 threshold(duration_merge_samples)보다 작은 경우, 합침
|
||||||
|
if (speeches[i].start - currentSegment.end < duration_merge_samples) {
|
||||||
|
// 현재 구간의 끝점을 업데이트
|
||||||
|
currentSegment.end = speeches[i].end;
|
||||||
|
} else {
|
||||||
|
// 차이가 threshold(duration_merge_samples) 이상이면 현재 구간을 저장하고 새로운 구간 시작
|
||||||
|
mergedSpeeches.push_back(currentSegment);
|
||||||
|
currentSegment = speeches[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 마지막 구간 추가
|
||||||
|
mergedSpeeches.push_back(currentSegment);
|
||||||
|
|
||||||
|
return mergedSpeeches;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
79
examples/cpp_libtorch/silero_torch.h
Normal file
79
examples/cpp_libtorch/silero_torch.h
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
//Author : Nathan Lee
|
||||||
|
//Created On : 2024-11-18
|
||||||
|
//Description : silero 5.1 system for torch-script(c++).
|
||||||
|
//Version : 1.0
|
||||||
|
//Contact : junghan4242@gmail.com
|
||||||
|
|
||||||
|
#ifndef SILERO_TORCH_H
|
||||||
|
#define SILERO_TORCH_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <torch/script.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace silero{
|
||||||
|
|
||||||
|
struct SpeechSegment{
|
||||||
|
int start;
|
||||||
|
int end;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Interval {
|
||||||
|
float start;
|
||||||
|
float end;
|
||||||
|
};
|
||||||
|
|
||||||
|
class VadIterator{
|
||||||
|
public:
|
||||||
|
|
||||||
|
VadIterator(const std::string &model_path, float threshold = 0.5, int sample_rate = 16000,
|
||||||
|
int window_size_ms = 32, int speech_pad_ms = 30, int min_silence_duration_ms = 100,
|
||||||
|
int min_speech_duration_ms = 250, int max_duration_merge_ms = 300, bool print_as_samples = false);
|
||||||
|
~VadIterator();
|
||||||
|
|
||||||
|
|
||||||
|
void SpeechProbs(std::vector<float>& input_wav);
|
||||||
|
std::vector<silero::Interval> GetSpeechTimestamps();
|
||||||
|
|
||||||
|
float threshold;
|
||||||
|
int sample_rate;
|
||||||
|
int min_speech_duration_ms;
|
||||||
|
int max_duration_merge_ms;
|
||||||
|
bool print_as_samples;
|
||||||
|
|
||||||
|
private:
|
||||||
|
torch::jit::script::Module model;
|
||||||
|
std::vector<float> outputs_prob;
|
||||||
|
int min_silence_samples;
|
||||||
|
int min_speech_samples;
|
||||||
|
int speech_pad_samples;
|
||||||
|
int window_size_samples;
|
||||||
|
int duration_merge_samples;
|
||||||
|
int current_sample = 0;
|
||||||
|
|
||||||
|
int total_sample_size=0;
|
||||||
|
|
||||||
|
int min_silence_duration_ms;
|
||||||
|
int speech_pad_ms;
|
||||||
|
bool triggered = false;
|
||||||
|
int temp_end = 0;
|
||||||
|
|
||||||
|
void init_engine(int window_size_ms);
|
||||||
|
void init_torch_model(const std::string& model_path);
|
||||||
|
void reset_states();
|
||||||
|
std::vector<Interval> DoVad();
|
||||||
|
std::vector<Interval> mergeSpeeches(const std::vector<Interval>& speeches, int duration_merge_samples);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
#endif // SILERO_TORCH_H
|
||||||
235
examples/cpp_libtorch/wav.h
Normal file
235
examples/cpp_libtorch/wav.h
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
// Copyright (c) 2016 Personal (Binbin Zhang)
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef FRONTEND_WAV_H_
|
||||||
|
#define FRONTEND_WAV_H_
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
// #include "utils/log.h"
|
||||||
|
|
||||||
|
namespace wav {
|
||||||
|
|
||||||
|
struct WavHeader {
|
||||||
|
char riff[4]; // "riff"
|
||||||
|
unsigned int size;
|
||||||
|
char wav[4]; // "WAVE"
|
||||||
|
char fmt[4]; // "fmt "
|
||||||
|
unsigned int fmt_size;
|
||||||
|
uint16_t format;
|
||||||
|
uint16_t channels;
|
||||||
|
unsigned int sample_rate;
|
||||||
|
unsigned int bytes_per_second;
|
||||||
|
uint16_t block_size;
|
||||||
|
uint16_t bit;
|
||||||
|
char data[4]; // "data"
|
||||||
|
unsigned int data_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WavReader {
|
||||||
|
public:
|
||||||
|
WavReader() : data_(nullptr) {}
|
||||||
|
explicit WavReader(const std::string& filename) { Open(filename); }
|
||||||
|
|
||||||
|
bool Open(const std::string& filename) {
|
||||||
|
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
|
||||||
|
if (NULL == fp) {
|
||||||
|
std::cout << "Error in read " << filename;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
WavHeader header;
|
||||||
|
fread(&header, 1, sizeof(header), fp);
|
||||||
|
if (header.fmt_size < 16) {
|
||||||
|
printf("WaveData: expect PCM format data "
|
||||||
|
"to have fmt chunk of at least size 16.\n");
|
||||||
|
return false;
|
||||||
|
} else if (header.fmt_size > 16) {
|
||||||
|
int offset = 44 - 8 + header.fmt_size - 16;
|
||||||
|
fseek(fp, offset, SEEK_SET);
|
||||||
|
fread(header.data, 8, sizeof(char), fp);
|
||||||
|
}
|
||||||
|
// check "riff" "WAVE" "fmt " "data"
|
||||||
|
|
||||||
|
// Skip any sub-chunks between "fmt" and "data". Usually there will
|
||||||
|
// be a single "fact" sub chunk, but on Windows there can also be a
|
||||||
|
// "list" sub chunk.
|
||||||
|
while (0 != strncmp(header.data, "data", 4)) {
|
||||||
|
// We will just ignore the data in these chunks.
|
||||||
|
fseek(fp, header.data_size, SEEK_CUR);
|
||||||
|
// read next sub chunk
|
||||||
|
fread(header.data, 8, sizeof(char), fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (header.data_size == 0) {
|
||||||
|
int offset = ftell(fp);
|
||||||
|
fseek(fp, 0, SEEK_END);
|
||||||
|
header.data_size = ftell(fp) - offset;
|
||||||
|
fseek(fp, offset, SEEK_SET);
|
||||||
|
}
|
||||||
|
|
||||||
|
num_channel_ = header.channels;
|
||||||
|
sample_rate_ = header.sample_rate;
|
||||||
|
bits_per_sample_ = header.bit;
|
||||||
|
int num_data = header.data_size / (bits_per_sample_ / 8);
|
||||||
|
data_ = new float[num_data]; // Create 1-dim array
|
||||||
|
num_samples_ = num_data / num_channel_;
|
||||||
|
|
||||||
|
std::cout << "num_channel_ :" << num_channel_ << std::endl;
|
||||||
|
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
|
||||||
|
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
|
||||||
|
std::cout << "num_samples :" << num_data << std::endl;
|
||||||
|
std::cout << "num_data_size :" << header.data_size << std::endl;
|
||||||
|
|
||||||
|
switch (bits_per_sample_) {
|
||||||
|
case 8: {
|
||||||
|
char sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(char), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 16: {
|
||||||
|
int16_t sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(int16_t), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 32:
|
||||||
|
{
|
||||||
|
if (header.format == 1) //S32
|
||||||
|
{
|
||||||
|
int sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(int), fp);
|
||||||
|
data_[i] = static_cast<float>(sample) / 32768;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (header.format == 3) // IEEE-float
|
||||||
|
{
|
||||||
|
float sample;
|
||||||
|
for (int i = 0; i < num_data; ++i) {
|
||||||
|
fread(&sample, 1, sizeof(float), fp);
|
||||||
|
data_[i] = static_cast<float>(sample);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
printf("unsupported quantization bits\n");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
printf("unsupported quantization bits\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
fclose(fp);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_channel() const { return num_channel_; }
|
||||||
|
int sample_rate() const { return sample_rate_; }
|
||||||
|
int bits_per_sample() const { return bits_per_sample_; }
|
||||||
|
int num_samples() const { return num_samples_; }
|
||||||
|
|
||||||
|
~WavReader() {
|
||||||
|
delete[] data_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float* data() const { return data_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_channel_;
|
||||||
|
int sample_rate_;
|
||||||
|
int bits_per_sample_;
|
||||||
|
int num_samples_; // sample points per channel
|
||||||
|
float* data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WavWriter {
|
||||||
|
public:
|
||||||
|
WavWriter(const float* data, int num_samples, int num_channel,
|
||||||
|
int sample_rate, int bits_per_sample)
|
||||||
|
: data_(data),
|
||||||
|
num_samples_(num_samples),
|
||||||
|
num_channel_(num_channel),
|
||||||
|
sample_rate_(sample_rate),
|
||||||
|
bits_per_sample_(bits_per_sample) {}
|
||||||
|
|
||||||
|
void Write(const std::string& filename) {
|
||||||
|
FILE* fp = fopen(filename.c_str(), "w");
|
||||||
|
// init char 'riff' 'WAVE' 'fmt ' 'data'
|
||||||
|
WavHeader header;
|
||||||
|
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
|
||||||
|
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
|
||||||
|
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
|
||||||
|
memcpy(&header, wav_header, sizeof(header));
|
||||||
|
header.channels = num_channel_;
|
||||||
|
header.bit = bits_per_sample_;
|
||||||
|
header.sample_rate = sample_rate_;
|
||||||
|
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
header.size = sizeof(header) - 8 + header.data_size;
|
||||||
|
header.bytes_per_second =
|
||||||
|
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
header.block_size = num_channel_ * (bits_per_sample_ / 8);
|
||||||
|
|
||||||
|
fwrite(&header, 1, sizeof(header), fp);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_samples_; ++i) {
|
||||||
|
for (int j = 0; j < num_channel_; ++j) {
|
||||||
|
switch (bits_per_sample_) {
|
||||||
|
case 8: {
|
||||||
|
char sample = static_cast<char>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 16: {
|
||||||
|
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 32: {
|
||||||
|
int sample = static_cast<int>(data_[i * num_channel_ + j]);
|
||||||
|
fwrite(&sample, 1, sizeof(sample), fp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fclose(fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const float* data_;
|
||||||
|
int num_samples_; // total float points in data_
|
||||||
|
int num_channel_;
|
||||||
|
int sample_rate_;
|
||||||
|
int bits_per_sample_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace wenet
|
||||||
|
|
||||||
|
#endif // FRONTEND_WAV_H_
|
||||||
Reference in New Issue
Block a user