mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 01:49:23 +08:00
* Add all this code * add code * Fix demo --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
158 lines
5.3 KiB
Python
158 lines
5.3 KiB
Python
import asyncio
|
|
import logging
|
|
import re
|
|
from typing import Literal
|
|
|
|
import numpy as np
|
|
|
|
from .reply_on_pause import (
|
|
AlgoOptions,
|
|
AppState,
|
|
ReplyFnGenerator,
|
|
ReplyOnPause,
|
|
SileroVadOptions,
|
|
)
|
|
from .speech_to_text import get_stt_model
|
|
from .utils import audio_to_float32, create_message
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ReplyOnStopWordsState(AppState):
|
|
stop_word_detected: bool = False
|
|
post_stop_word_buffer: np.ndarray | None = None
|
|
started_talking_pre_stop_word: bool = False
|
|
|
|
def new(self):
|
|
return ReplyOnStopWordsState()
|
|
|
|
|
|
class ReplyOnStopWords(ReplyOnPause):
|
|
def __init__(
|
|
self,
|
|
fn: ReplyFnGenerator,
|
|
stop_words: list[str],
|
|
algo_options: AlgoOptions | None = None,
|
|
model_options: SileroVadOptions | None = None,
|
|
can_interrupt: bool = True,
|
|
expected_layout: Literal["mono", "stereo"] = "mono",
|
|
output_sample_rate: int = 24000,
|
|
output_frame_size: int = 480,
|
|
input_sample_rate: int = 48000,
|
|
):
|
|
super().__init__(
|
|
fn,
|
|
algo_options=algo_options,
|
|
model_options=model_options,
|
|
can_interrupt=can_interrupt,
|
|
expected_layout=expected_layout,
|
|
output_sample_rate=output_sample_rate,
|
|
output_frame_size=output_frame_size,
|
|
input_sample_rate=input_sample_rate,
|
|
)
|
|
self.stop_words = stop_words
|
|
self.state = ReplyOnStopWordsState()
|
|
self.stt_model = get_stt_model("moonshine/base")
|
|
|
|
def stop_word_detected(self, text: str) -> bool:
|
|
for stop_word in self.stop_words:
|
|
stop_word = stop_word.lower().strip().split(" ")
|
|
if bool(
|
|
re.search(
|
|
r"\b" + r"\s+".join(map(re.escape, stop_word)) + r"[.,!?]*\b",
|
|
text.lower(),
|
|
)
|
|
):
|
|
logger.debug("Stop word detected: %s", stop_word)
|
|
return True
|
|
return False
|
|
|
|
async def _send_stopword(
|
|
self,
|
|
):
|
|
if self.channel:
|
|
self.channel.send(create_message("stopword", ""))
|
|
logger.debug("Sent stopword")
|
|
|
|
def send_stopword(self):
|
|
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
|
|
|
|
def determine_pause( # type: ignore
|
|
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
|
|
) -> bool:
|
|
"""Take in the stream, determine if a pause happened"""
|
|
import librosa
|
|
|
|
duration = len(audio) / sampling_rate
|
|
|
|
if duration >= self.algo_options.audio_chunk_duration:
|
|
if not state.stop_word_detected:
|
|
audio_f32 = audio_to_float32((sampling_rate, audio))
|
|
audio_rs = librosa.resample(
|
|
audio_f32, orig_sr=sampling_rate, target_sr=16000
|
|
)
|
|
if state.post_stop_word_buffer is None:
|
|
state.post_stop_word_buffer = audio_rs
|
|
else:
|
|
state.post_stop_word_buffer = np.concatenate(
|
|
(state.post_stop_word_buffer, audio_rs)
|
|
)
|
|
if len(state.post_stop_word_buffer) / 16000 > 2:
|
|
state.post_stop_word_buffer = state.post_stop_word_buffer[-32000:]
|
|
dur_vad, chunks = self.model.vad(
|
|
(16000, state.post_stop_word_buffer),
|
|
self.model_options,
|
|
return_chunks=True,
|
|
)
|
|
text = self.stt_model.stt_for_chunks(
|
|
(16000, state.post_stop_word_buffer), chunks
|
|
)
|
|
logger.debug(f"STT: {text}")
|
|
state.stop_word_detected = self.stop_word_detected(text)
|
|
if state.stop_word_detected:
|
|
logger.debug("Stop word detected")
|
|
self.send_stopword()
|
|
state.buffer = None
|
|
else:
|
|
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
|
|
logger.debug("VAD duration: %s", dur_vad)
|
|
if (
|
|
dur_vad > self.algo_options.started_talking_threshold
|
|
and not state.started_talking
|
|
and state.stop_word_detected
|
|
):
|
|
state.started_talking = True
|
|
logger.debug("Started talking")
|
|
if state.started_talking:
|
|
if state.stream is None:
|
|
state.stream = audio
|
|
else:
|
|
state.stream = np.concatenate((state.stream, audio))
|
|
state.buffer = None
|
|
if (
|
|
dur_vad < self.algo_options.speech_threshold
|
|
and state.started_talking
|
|
and state.stop_word_detected
|
|
):
|
|
return True
|
|
return False
|
|
|
|
def reset(self):
|
|
super().reset()
|
|
self.generator = None
|
|
self.event.clear()
|
|
self.state = ReplyOnStopWordsState()
|
|
|
|
def copy(self):
|
|
return ReplyOnStopWords(
|
|
self.fn,
|
|
self.stop_words,
|
|
self.algo_options,
|
|
self.model_options,
|
|
self.can_interrupt,
|
|
self.expected_layout,
|
|
self.output_sample_rate,
|
|
self.output_frame_size,
|
|
self.input_sample_rate,
|
|
)
|