Changed some source.

This commit is contained in:
Nathan Lee
2024-11-22 06:21:49 +00:00
parent 05e380c1de
commit 0189ebd8af
5 changed files with 43 additions and 41 deletions

View File

@@ -2,7 +2,6 @@
//Created On : 2024-11-18
//Description : silero 5.1 system for torch-script(c++).
//Version : 1.0
//Contact : junghan4242@gmail.com
#include "silero_torch.h"
@@ -10,10 +9,10 @@
namespace silero {
VadIterator::VadIterator(const std::string &model_path, float threshold, int sample_rate, int window_size_ms, int speech_pad_ms, int min_silence_duration_ms, int min_speech_duration_ms, int max_duration_merge_ms, bool print_as_samples)
:sample_rate(sample_rate), threshold(threshold), speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms), min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms), print_as_samples(print_as_samples)
:sample_rate(sample_rate), threshold(threshold), window_size_ms(window_size_ms), speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms), min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms), print_as_samples(print_as_samples)
{
init_torch_model(model_path);
init_engine(window_size_ms);
//init_engine(window_size_ms);
}
VadIterator::~VadIterator(){
}
@@ -117,14 +116,14 @@ namespace silero {
}
std::vector<Interval> VadIterator::GetSpeechTimestamps() {
std::vector<Interval> speeches = DoVad();
std::vector<SpeechSegment> VadIterator::GetSpeechTimestamps() {
std::vector<SpeechSegment> speeches = DoVad();
#ifdef USE_BATCH
//When you use BATCH inference. You would better use 'mergeSpeeches' function to arrage time stamp.
//It could be better get reasonable output because of distorted probs.
duration_merge_samples = sample_rate * max_duration_merge_ms / 1000;
std::vector<Interval> speeches_merge = mergeSpeeches(speeches, duration_merge_samples);
std::vector<SpeechSegment> speeches_merge = mergeSpeeches(speeches, duration_merge_samples);
if(!print_as_samples){
for (auto& speech : speeches_merge) { //samples to second
speech.start /= sample_rate;
@@ -147,6 +146,9 @@ namespace silero {
#endif
}
void VadIterator::SetVariables(){
init_engine(window_size_ms);
}
void VadIterator::init_engine(int window_size_ms) {
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
@@ -186,8 +188,8 @@ namespace silero {
total_sample_size = 0;
}
std::vector<Interval> VadIterator::DoVad() {
std::vector<Interval> speeches;
std::vector<SpeechSegment> VadIterator::DoVad() {
std::vector<SpeechSegment> speeches;
for (size_t i = 0; i < outputs_prob.size(); ++i) {
float speech_prob = outputs_prob[i];
@@ -202,7 +204,7 @@ namespace silero {
if (speech_prob >= threshold && !triggered) {
triggered = true;
Interval segment;
SpeechSegment segment;
segment.start = std::max(static_cast<int>(0), current_sample - speech_pad_samples - window_size_samples);
speeches.push_back(segment);
continue;
@@ -216,7 +218,7 @@ namespace silero {
if (current_sample - temp_end < min_silence_samples) {
continue;
} else {
Interval& segment = speeches.back();
SpeechSegment& segment = speeches.back();
segment.end = temp_end + speech_pad_samples - window_size_samples;
temp_end = 0;
triggered = false;
@@ -226,7 +228,7 @@ namespace silero {
if (triggered) { //만약 낮은 확률을 보이다가 마지막프레임 prbos만 딱 확률이 높게 나오면 위에서 triggerd = true 메핑과 동시에 segment start가 돼서 문제가 될것 같은데? start = end 같은값? 후처리가 있으니 문제가 없으려나?
std::cout<<"when last triggered is keep working until last Probs"<<std::endl;
Interval& segment = speeches.back();
SpeechSegment& segment = speeches.back();
segment.end = total_sample_size; // 현재 샘플을 마지막 구간의 종료 시간으로 설정
triggered = false; // VAD 상태 초기화
}
@@ -235,7 +237,7 @@ namespace silero {
std::remove_if(
speeches.begin(),
speeches.end(),
[this](const Interval& speech) {
[this](const SpeechSegment& speech) {
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
//min_speech_samples is 4000samples(0.25sec)
//여기서 포인트!! 계산 할때는 start,end sample에'speech_pad_samples' 사이즈를 추가한후 길이를 측정함.
@@ -252,15 +254,15 @@ namespace silero {
return speeches;
}
std::vector<Interval> VadIterator::mergeSpeeches(const std::vector<Interval>& speeches, int duration_merge_samples) {
std::vector<Interval> mergedSpeeches;
std::vector<SpeechSegment> VadIterator::mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples) {
std::vector<SpeechSegment> mergedSpeeches;
if (speeches.empty()) {
return mergedSpeeches; // 빈 벡터 반환
}
// 첫 번째 구간으로 초기화
Interval currentSegment = speeches[0];
SpeechSegment currentSegment = speeches[0];
for (size_t i = 1; i < speeches.size(); ++i) { //첫번째 start,end 정보 건너뛰기. 그래서 i=1부터
// 두 구간의 차이가 threshold(duration_merge_samples)보다 작은 경우, 합침