mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
add c++ inference based on libtorch
This commit is contained in:
35
examples/cpp_libtorch/ReadMe
Normal file
35
examples/cpp_libtorch/ReadMe
Normal file
@@ -0,0 +1,35 @@
|
||||
This is the source code for Silero-VAD 5.1 in C++, based on LibTorch.
|
||||
The primary implementation is the CPU version, and you should compare its results with the Python version.
|
||||
|
||||
In addition, Batch and CUDA inference options are also available if you want to explore further.
|
||||
Note that when using batch inference, the speech probabilities might slightly differ from the standard version, likely due to differences in caching.
|
||||
Unlike processing individual inputs, batch inference may not be able to use the cache from previous chunks.
|
||||
Nevertheless, batch inference provides significantly faster processing.
|
||||
For optimal performance, carefully adjust the threshold when using batch inference.
|
||||
|
||||
#Requirements:
|
||||
GCC 11.4.0 (GCC >= 5.1)
|
||||
LibTorch 1.13.0(Other versions are also acceptable)
|
||||
|
||||
#Download Libtorch:
|
||||
#cpu
|
||||
$wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip
|
||||
$unzip libtorch-shared-with-deps-1.13.0+cpu.zip
|
||||
|
||||
#cuda
|
||||
$wget https://download.pytorch.org/libtorch/cu116/libtorch-shared-with-deps-1.13.0%2Bcu116.zip
|
||||
$unzip libtorch-shared-with-deps-1.13.0+cu116.zip
|
||||
|
||||
#complie:
|
||||
#cpu
|
||||
$g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0
|
||||
#cuda
|
||||
$g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cuda -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_GPU
|
||||
|
||||
#option to add
|
||||
-DUSE_BATCH
|
||||
-DUSE_GPU
|
||||
|
||||
# Run:
|
||||
./silero aepyx.wav 0.5 #The sample file 'aepyx.wav' is part of the Voxconverse dataset.
|
||||
#aepyx.wav : 16kHz, 16-bit
|
||||
BIN
examples/cpp_libtorch/aepyx.wav
Normal file
BIN
examples/cpp_libtorch/aepyx.wav
Normal file
Binary file not shown.
51
examples/cpp_libtorch/main.cc
Normal file
51
examples/cpp_libtorch/main.cc
Normal file
@@ -0,0 +1,51 @@
|
||||
#include <iostream>
|
||||
#include "silero_torch.h"
|
||||
#include "wav.h"
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
|
||||
if(argc != 3){
|
||||
std::cerr<<"Usage : "<<argv[0]<<" <wav.path> threshold"<<std::endl;
|
||||
std::cerr<<"Usage : "<<argv[0]<<" sample.wav 0.38"<<std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string wav_path = argv[1];
|
||||
float threshold = std::stof(argv[2]);
|
||||
|
||||
|
||||
//Load Model
|
||||
std::string model_path = "../../src/silero_vad/data/silero_vad.jit";
|
||||
silero::VadIterator vad(model_path);
|
||||
vad.threshold=threshold;
|
||||
vad.min_speech_duration_ms=255;
|
||||
vad.max_duration_merge_ms=300;
|
||||
vad.print_as_samples=true; //if true, it prints time-stamp with sample numbers.
|
||||
//(Default:false)
|
||||
|
||||
// Read wav
|
||||
wav::WavReader wav_reader(wav_path);
|
||||
std::vector<float> input_wav(wav_reader.num_samples());
|
||||
|
||||
for (int i = 0; i < wav_reader.num_samples(); i++)
|
||||
{
|
||||
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
|
||||
}
|
||||
|
||||
vad.SpeechProbs(input_wav);
|
||||
|
||||
std::vector<silero::Interval> speeches = vad.GetSpeechTimestamps();
|
||||
for(const auto& speech : speeches){
|
||||
if(vad.print_as_samples){
|
||||
std::cout<<"{'start': "<<static_cast<int>(speech.start)<<", 'end': "<<static_cast<int>(speech.end)<<"}"<<std::endl;
|
||||
}
|
||||
else{
|
||||
std::cout<<"{'start': "<<speech.start<<", 'end': "<<speech.end<<"}"<<std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
283
examples/cpp_libtorch/silero_torch.cc
Normal file
283
examples/cpp_libtorch/silero_torch.cc
Normal file
@@ -0,0 +1,283 @@
|
||||
//Author : Nathan Lee
|
||||
//Created On : 2024-11-18
|
||||
//Description : silero 5.1 system for torch-script(c++).
|
||||
//Version : 1.0
|
||||
//Contact : junghan4242@gmail.com
|
||||
|
||||
|
||||
#include "silero_torch.h"
|
||||
|
||||
namespace silero {
|
||||
|
||||
VadIterator::VadIterator(const std::string &model_path, float threshold, int sample_rate, int window_size_ms, int speech_pad_ms, int min_silence_duration_ms, int min_speech_duration_ms, int max_duration_merge_ms, bool print_as_samples)
|
||||
:sample_rate(sample_rate), threshold(threshold), speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms), min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms), print_as_samples(print_as_samples)
|
||||
{
|
||||
init_torch_model(model_path);
|
||||
init_engine(window_size_ms);
|
||||
}
|
||||
VadIterator::~VadIterator(){
|
||||
}
|
||||
|
||||
|
||||
void VadIterator::SpeechProbs(std::vector<float>& input_wav){
|
||||
// Set the sample rate (must match the model's expected sample rate)
|
||||
// Process the waveform in chunks of 512 samples
|
||||
int num_samples = input_wav.size();
|
||||
int num_chunks = num_samples / window_size_samples;
|
||||
int remainder_samples = num_samples % window_size_samples;
|
||||
|
||||
total_sample_size += num_samples;
|
||||
|
||||
torch::Tensor output;
|
||||
std::vector<torch::Tensor> chunks;
|
||||
|
||||
for (int i = 0; i < num_chunks; i++) {
|
||||
|
||||
float* chunk_start = input_wav.data() + i *window_size_samples;
|
||||
torch::Tensor chunk = torch::from_blob(chunk_start, {1,window_size_samples}, torch::kFloat32);
|
||||
//std::cout<<"chunk size : "<<chunk.sizes()<<std::endl;
|
||||
chunks.push_back(chunk);
|
||||
|
||||
|
||||
if(i==num_chunks-1 && remainder_samples>0){//마지막 chunk && 나머지가 존재
|
||||
int remaining_samples = num_samples - num_chunks * window_size_samples;
|
||||
//std::cout<<"Remainder size : "<<remaining_samples;
|
||||
float* chunk_start_remainder = input_wav.data() + num_chunks *window_size_samples;
|
||||
|
||||
torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1,remaining_samples},
|
||||
torch::kFloat32);
|
||||
// Pad the remainder chunk to match window_size_samples
|
||||
torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples
|
||||
- remaining_samples}, torch::kFloat32)}, 1);
|
||||
//std::cout<<", padded_chunk size : "<<padded_chunk.size(1)<<std::endl;
|
||||
|
||||
chunks.push_back(padded_chunk);
|
||||
}
|
||||
}
|
||||
|
||||
if (!chunks.empty()) {
|
||||
|
||||
#ifdef USE_BATCH
|
||||
torch::Tensor batched_chunks = torch::stack(chunks); // Stack all chunks into a single tensor
|
||||
//batched_chunks = batched_chunks.squeeze(1);
|
||||
batched_chunks = torch::cat({batched_chunks.squeeze(1)});
|
||||
|
||||
#ifdef USE_GPU
|
||||
batched_chunks = batched_chunks.to(at::kCUDA); // Move the entire batch to GPU once
|
||||
#endif
|
||||
// Prepare input for model
|
||||
std::vector<torch::jit::IValue> inputs;
|
||||
inputs.push_back(batched_chunks); // Batch of chunks
|
||||
inputs.push_back(sample_rate); // Assuming sample_rate is a valid input for the model
|
||||
|
||||
// Run inference on the batch
|
||||
torch::NoGradGuard no_grad;
|
||||
torch::Tensor output = model.forward(inputs).toTensor();
|
||||
#ifdef USE_GPU
|
||||
output = output.to(at::kCPU); // Move the output back to CPU once
|
||||
#endif
|
||||
// Collect output probabilities
|
||||
for (int i = 0; i < chunks.size(); i++) {
|
||||
float output_f = output[i].item<float>();
|
||||
outputs_prob.push_back(output_f);
|
||||
//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
|
||||
}
|
||||
#else
|
||||
|
||||
std::vector<torch::Tensor> outputs;
|
||||
torch::Tensor batched_chunks = torch::stack(chunks);
|
||||
#ifdef USE_GPU
|
||||
batched_chunks = batched_chunks.to(at::kCUDA);
|
||||
#endif
|
||||
for (int i = 0; i < chunks.size(); i++) {
|
||||
torch::NoGradGuard no_grad;
|
||||
std::vector<torch::jit::IValue> inputs;
|
||||
inputs.push_back(batched_chunks[i]);
|
||||
inputs.push_back(sample_rate);
|
||||
|
||||
torch::Tensor output = model.forward(inputs).toTensor();
|
||||
outputs.push_back(output);
|
||||
}
|
||||
torch::Tensor all_outputs = torch::stack(outputs);
|
||||
#ifdef USE_GPU
|
||||
all_outputs = all_outputs.to(at::kCPU);
|
||||
#endif
|
||||
for (int i = 0; i < chunks.size(); i++) {
|
||||
float output_f = all_outputs[i].item<float>();
|
||||
outputs_prob.push_back(output_f);
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
std::vector<Interval> VadIterator::GetSpeechTimestamps() {
|
||||
std::vector<Interval> speeches = DoVad();
|
||||
|
||||
#ifdef USE_BATCH
|
||||
//When you use BATCH inference. You would better use 'mergeSpeeches' function to arrage time stamp.
|
||||
//It could be better get reasonable output because of distorted probs.
|
||||
duration_merge_samples = sample_rate * max_duration_merge_ms / 1000;
|
||||
std::vector<Interval> speeches_merge = mergeSpeeches(speeches, duration_merge_samples);
|
||||
if(!print_as_samples){
|
||||
for (auto& speech : speeches_merge) { //samples to second
|
||||
speech.start /= sample_rate;
|
||||
speech.end /= sample_rate;
|
||||
}
|
||||
}
|
||||
|
||||
return speeches_merge;
|
||||
#else
|
||||
|
||||
if(!print_as_samples){
|
||||
for (auto& speech : speeches) { //samples to second
|
||||
speech.start /= sample_rate;
|
||||
speech.end /= sample_rate;
|
||||
}
|
||||
}
|
||||
|
||||
return speeches;
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
void VadIterator::init_engine(int window_size_ms) {
|
||||
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
|
||||
speech_pad_samples = sample_rate * speech_pad_ms / 1000;
|
||||
window_size_samples = sample_rate / 1000 * window_size_ms;
|
||||
min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
|
||||
}
|
||||
|
||||
void VadIterator::init_torch_model(const std::string& model_path) {
|
||||
at::set_num_threads(1);
|
||||
model = torch::jit::load(model_path);
|
||||
|
||||
#ifdef USE_GPU
|
||||
if (!torch::cuda::is_available()) {
|
||||
std::cout<<"CUDA is not available! Please check your GPU settings"<<std::endl;
|
||||
throw std::runtime_error("CUDA is not available!");
|
||||
model.to(at::Device(at::kCPU));
|
||||
|
||||
} else {
|
||||
std::cout<<"CUDA available! Running on '0'th GPU"<<std::endl;
|
||||
model.to(at::Device(at::kCUDA, 0)); //select 0'th machine
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
model.eval();
|
||||
torch::NoGradGuard no_grad;
|
||||
std::cout << "Model loaded successfully"<<std::endl;
|
||||
}
|
||||
|
||||
void VadIterator::reset_states() {
|
||||
triggered = false;
|
||||
current_sample = 0;
|
||||
temp_end = 0;
|
||||
outputs_prob.clear();
|
||||
model.run_method("reset_states");
|
||||
total_sample_size = 0;
|
||||
}
|
||||
|
||||
std::vector<Interval> VadIterator::DoVad() {
|
||||
std::vector<Interval> speeches;
|
||||
|
||||
for (size_t i = 0; i < outputs_prob.size(); ++i) {
|
||||
float speech_prob = outputs_prob[i];
|
||||
//std::cout << speech_prob << std::endl;
|
||||
//std::cout << "Chunk " << i << " Prob: " << speech_prob << "\n";
|
||||
//std::cout << speech_prob << " ";
|
||||
current_sample += window_size_samples;
|
||||
|
||||
if (speech_prob >= threshold && temp_end != 0) {
|
||||
temp_end = 0;
|
||||
}
|
||||
|
||||
if (speech_prob >= threshold && !triggered) {
|
||||
triggered = true;
|
||||
Interval segment;
|
||||
segment.start = std::max(static_cast<int>(0), current_sample - speech_pad_samples - window_size_samples);
|
||||
speeches.push_back(segment);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (speech_prob < threshold - 0.15f && triggered) {
|
||||
if (temp_end == 0) {
|
||||
temp_end = current_sample;
|
||||
}
|
||||
|
||||
if (current_sample - temp_end < min_silence_samples) {
|
||||
continue;
|
||||
} else {
|
||||
Interval& segment = speeches.back();
|
||||
segment.end = temp_end + speech_pad_samples - window_size_samples;
|
||||
temp_end = 0;
|
||||
triggered = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (triggered) { //만약 낮은 확률을 보이다가 마지막프레임 prbos만 딱 확률이 높게 나오면 위에서 triggerd = true 메핑과 동시에 segment start가 돼서 문제가 될것 같은데? start = end 같은값? 후처리가 있으니 문제가 없으려나?
|
||||
std::cout<<"when last triggered is keep working until last Probs"<<std::endl;
|
||||
Interval& segment = speeches.back();
|
||||
segment.end = total_sample_size; // 현재 샘플을 마지막 구간의 종료 시간으로 설정
|
||||
triggered = false; // VAD 상태 초기화
|
||||
}
|
||||
|
||||
speeches.erase(
|
||||
std::remove_if(
|
||||
speeches.begin(),
|
||||
speeches.end(),
|
||||
[this](const Interval& speech) {
|
||||
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
|
||||
//min_speech_samples is 4000samples(0.25sec)
|
||||
//여기서 포인트!! 계산 할때는 start,end sample에'speech_pad_samples' 사이즈를 추가한후 길이를 측정함.
|
||||
}
|
||||
),
|
||||
speeches.end()
|
||||
);
|
||||
|
||||
|
||||
//std::cout<<std::endl;
|
||||
//std::cout<<"outputs_prob.size : "<<outputs_prob.size()<<std::endl;
|
||||
|
||||
reset_states();
|
||||
return speeches;
|
||||
}
|
||||
|
||||
std::vector<Interval> VadIterator::mergeSpeeches(const std::vector<Interval>& speeches, int duration_merge_samples) {
|
||||
std::vector<Interval> mergedSpeeches;
|
||||
|
||||
if (speeches.empty()) {
|
||||
return mergedSpeeches; // 빈 벡터 반환
|
||||
}
|
||||
|
||||
// 첫 번째 구간으로 초기화
|
||||
Interval currentSegment = speeches[0];
|
||||
|
||||
for (size_t i = 1; i < speeches.size(); ++i) { //첫번째 start,end 정보 건너뛰기. 그래서 i=1부터
|
||||
// 두 구간의 차이가 threshold(duration_merge_samples)보다 작은 경우, 합침
|
||||
if (speeches[i].start - currentSegment.end < duration_merge_samples) {
|
||||
// 현재 구간의 끝점을 업데이트
|
||||
currentSegment.end = speeches[i].end;
|
||||
} else {
|
||||
// 차이가 threshold(duration_merge_samples) 이상이면 현재 구간을 저장하고 새로운 구간 시작
|
||||
mergedSpeeches.push_back(currentSegment);
|
||||
currentSegment = speeches[i];
|
||||
}
|
||||
}
|
||||
|
||||
// 마지막 구간 추가
|
||||
mergedSpeeches.push_back(currentSegment);
|
||||
|
||||
return mergedSpeeches;
|
||||
}
|
||||
|
||||
}
|
||||
79
examples/cpp_libtorch/silero_torch.h
Normal file
79
examples/cpp_libtorch/silero_torch.h
Normal file
@@ -0,0 +1,79 @@
|
||||
//Author : Nathan Lee
|
||||
//Created On : 2024-11-18
|
||||
//Description : silero 5.1 system for torch-script(c++).
|
||||
//Version : 1.0
|
||||
//Contact : junghan4242@gmail.com
|
||||
|
||||
#ifndef SILERO_TORCH_H
|
||||
#define SILERO_TORCH_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <chrono>
|
||||
|
||||
#include <torch/torch.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
|
||||
namespace silero{
|
||||
|
||||
struct SpeechSegment{
|
||||
int start;
|
||||
int end;
|
||||
};
|
||||
|
||||
struct Interval {
|
||||
float start;
|
||||
float end;
|
||||
};
|
||||
|
||||
class VadIterator{
|
||||
public:
|
||||
|
||||
VadIterator(const std::string &model_path, float threshold = 0.5, int sample_rate = 16000,
|
||||
int window_size_ms = 32, int speech_pad_ms = 30, int min_silence_duration_ms = 100,
|
||||
int min_speech_duration_ms = 250, int max_duration_merge_ms = 300, bool print_as_samples = false);
|
||||
~VadIterator();
|
||||
|
||||
|
||||
void SpeechProbs(std::vector<float>& input_wav);
|
||||
std::vector<silero::Interval> GetSpeechTimestamps();
|
||||
|
||||
float threshold;
|
||||
int sample_rate;
|
||||
int min_speech_duration_ms;
|
||||
int max_duration_merge_ms;
|
||||
bool print_as_samples;
|
||||
|
||||
private:
|
||||
torch::jit::script::Module model;
|
||||
std::vector<float> outputs_prob;
|
||||
int min_silence_samples;
|
||||
int min_speech_samples;
|
||||
int speech_pad_samples;
|
||||
int window_size_samples;
|
||||
int duration_merge_samples;
|
||||
int current_sample = 0;
|
||||
|
||||
int total_sample_size=0;
|
||||
|
||||
int min_silence_duration_ms;
|
||||
int speech_pad_ms;
|
||||
bool triggered = false;
|
||||
int temp_end = 0;
|
||||
|
||||
void init_engine(int window_size_ms);
|
||||
void init_torch_model(const std::string& model_path);
|
||||
void reset_states();
|
||||
std::vector<Interval> DoVad();
|
||||
std::vector<Interval> mergeSpeeches(const std::vector<Interval>& speeches, int duration_merge_samples);
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
#endif // SILERO_TORCH_H
|
||||
235
examples/cpp_libtorch/wav.h
Normal file
235
examples/cpp_libtorch/wav.h
Normal file
@@ -0,0 +1,235 @@
|
||||
// Copyright (c) 2016 Personal (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef FRONTEND_WAV_H_
|
||||
#define FRONTEND_WAV_H_
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
// #include "utils/log.h"
|
||||
|
||||
namespace wav {
|
||||
|
||||
struct WavHeader {
|
||||
char riff[4]; // "riff"
|
||||
unsigned int size;
|
||||
char wav[4]; // "WAVE"
|
||||
char fmt[4]; // "fmt "
|
||||
unsigned int fmt_size;
|
||||
uint16_t format;
|
||||
uint16_t channels;
|
||||
unsigned int sample_rate;
|
||||
unsigned int bytes_per_second;
|
||||
uint16_t block_size;
|
||||
uint16_t bit;
|
||||
char data[4]; // "data"
|
||||
unsigned int data_size;
|
||||
};
|
||||
|
||||
class WavReader {
|
||||
public:
|
||||
WavReader() : data_(nullptr) {}
|
||||
explicit WavReader(const std::string& filename) { Open(filename); }
|
||||
|
||||
bool Open(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
|
||||
if (NULL == fp) {
|
||||
std::cout << "Error in read " << filename;
|
||||
return false;
|
||||
}
|
||||
|
||||
WavHeader header;
|
||||
fread(&header, 1, sizeof(header), fp);
|
||||
if (header.fmt_size < 16) {
|
||||
printf("WaveData: expect PCM format data "
|
||||
"to have fmt chunk of at least size 16.\n");
|
||||
return false;
|
||||
} else if (header.fmt_size > 16) {
|
||||
int offset = 44 - 8 + header.fmt_size - 16;
|
||||
fseek(fp, offset, SEEK_SET);
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
// check "riff" "WAVE" "fmt " "data"
|
||||
|
||||
// Skip any sub-chunks between "fmt" and "data". Usually there will
|
||||
// be a single "fact" sub chunk, but on Windows there can also be a
|
||||
// "list" sub chunk.
|
||||
while (0 != strncmp(header.data, "data", 4)) {
|
||||
// We will just ignore the data in these chunks.
|
||||
fseek(fp, header.data_size, SEEK_CUR);
|
||||
// read next sub chunk
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
|
||||
if (header.data_size == 0) {
|
||||
int offset = ftell(fp);
|
||||
fseek(fp, 0, SEEK_END);
|
||||
header.data_size = ftell(fp) - offset;
|
||||
fseek(fp, offset, SEEK_SET);
|
||||
}
|
||||
|
||||
num_channel_ = header.channels;
|
||||
sample_rate_ = header.sample_rate;
|
||||
bits_per_sample_ = header.bit;
|
||||
int num_data = header.data_size / (bits_per_sample_ / 8);
|
||||
data_ = new float[num_data]; // Create 1-dim array
|
||||
num_samples_ = num_data / num_channel_;
|
||||
|
||||
std::cout << "num_channel_ :" << num_channel_ << std::endl;
|
||||
std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
|
||||
std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
|
||||
std::cout << "num_samples :" << num_data << std::endl;
|
||||
std::cout << "num_data_size :" << header.data_size << std::endl;
|
||||
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(char), fp);
|
||||
data_[i] = static_cast<float>(sample) / 32768;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(int16_t), fp);
|
||||
data_[i] = static_cast<float>(sample) / 32768;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 32:
|
||||
{
|
||||
if (header.format == 1) //S32
|
||||
{
|
||||
int sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(int), fp);
|
||||
data_[i] = static_cast<float>(sample) / 32768;
|
||||
}
|
||||
}
|
||||
else if (header.format == 3) // IEEE-float
|
||||
{
|
||||
float sample;
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
fread(&sample, 1, sizeof(float), fp);
|
||||
data_[i] = static_cast<float>(sample);
|
||||
}
|
||||
}
|
||||
else {
|
||||
printf("unsupported quantization bits\n");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
printf("unsupported quantization bits\n");
|
||||
break;
|
||||
}
|
||||
|
||||
fclose(fp);
|
||||
return true;
|
||||
}
|
||||
|
||||
int num_channel() const { return num_channel_; }
|
||||
int sample_rate() const { return sample_rate_; }
|
||||
int bits_per_sample() const { return bits_per_sample_; }
|
||||
int num_samples() const { return num_samples_; }
|
||||
|
||||
~WavReader() {
|
||||
delete[] data_;
|
||||
}
|
||||
|
||||
const float* data() const { return data_; }
|
||||
|
||||
private:
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
int num_samples_; // sample points per channel
|
||||
float* data_;
|
||||
};
|
||||
|
||||
class WavWriter {
|
||||
public:
|
||||
WavWriter(const float* data, int num_samples, int num_channel,
|
||||
int sample_rate, int bits_per_sample)
|
||||
: data_(data),
|
||||
num_samples_(num_samples),
|
||||
num_channel_(num_channel),
|
||||
sample_rate_(sample_rate),
|
||||
bits_per_sample_(bits_per_sample) {}
|
||||
|
||||
void Write(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "w");
|
||||
// init char 'riff' 'WAVE' 'fmt ' 'data'
|
||||
WavHeader header;
|
||||
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
|
||||
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
|
||||
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
|
||||
memcpy(&header, wav_header, sizeof(header));
|
||||
header.channels = num_channel_;
|
||||
header.bit = bits_per_sample_;
|
||||
header.sample_rate = sample_rate_;
|
||||
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.size = sizeof(header) - 8 + header.data_size;
|
||||
header.bytes_per_second =
|
||||
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.block_size = num_channel_ * (bits_per_sample_ / 8);
|
||||
|
||||
fwrite(&header, 1, sizeof(header), fp);
|
||||
|
||||
for (int i = 0; i < num_samples_; ++i) {
|
||||
for (int j = 0; j < num_channel_; ++j) {
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample = static_cast<char>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
int sample = static_cast<int>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fclose(fp);
|
||||
}
|
||||
|
||||
private:
|
||||
const float* data_;
|
||||
int num_samples_; // total float points in data_
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
};
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // FRONTEND_WAV_H_
|
||||
Reference in New Issue
Block a user