Files
gradio-webrtc/backend/fastrtc/reply_on_stopwords.py
Freddy Boulton 6ea54777af ReplyOnPause and ReplyOnStopWords can be interrupted (#119)
* Add all this code

* add code

* Fix demo

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
2025-03-03 21:47:16 -05:00

158 lines
5.3 KiB
Python

import asyncio
import logging
import re
from typing import Literal
import numpy as np
from .reply_on_pause import (
AlgoOptions,
AppState,
ReplyFnGenerator,
ReplyOnPause,
SileroVadOptions,
)
from .speech_to_text import get_stt_model
from .utils import audio_to_float32, create_message
logger = logging.getLogger(__name__)
class ReplyOnStopWordsState(AppState):
stop_word_detected: bool = False
post_stop_word_buffer: np.ndarray | None = None
started_talking_pre_stop_word: bool = False
def new(self):
return ReplyOnStopWordsState()
class ReplyOnStopWords(ReplyOnPause):
def __init__(
self,
fn: ReplyFnGenerator,
stop_words: list[str],
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
can_interrupt: bool = True,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
input_sample_rate: int = 48000,
):
super().__init__(
fn,
algo_options=algo_options,
model_options=model_options,
can_interrupt=can_interrupt,
expected_layout=expected_layout,
output_sample_rate=output_sample_rate,
output_frame_size=output_frame_size,
input_sample_rate=input_sample_rate,
)
self.stop_words = stop_words
self.state = ReplyOnStopWordsState()
self.stt_model = get_stt_model("moonshine/base")
def stop_word_detected(self, text: str) -> bool:
for stop_word in self.stop_words:
stop_word = stop_word.lower().strip().split(" ")
if bool(
re.search(
r"\b" + r"\s+".join(map(re.escape, stop_word)) + r"[.,!?]*\b",
text.lower(),
)
):
logger.debug("Stop word detected: %s", stop_word)
return True
return False
async def _send_stopword(
self,
):
if self.channel:
self.channel.send(create_message("stopword", ""))
logger.debug("Sent stopword")
def send_stopword(self):
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
def determine_pause( # type: ignore
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
) -> bool:
"""Take in the stream, determine if a pause happened"""
import librosa
duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration:
if not state.stop_word_detected:
audio_f32 = audio_to_float32((sampling_rate, audio))
audio_rs = librosa.resample(
audio_f32, orig_sr=sampling_rate, target_sr=16000
)
if state.post_stop_word_buffer is None:
state.post_stop_word_buffer = audio_rs
else:
state.post_stop_word_buffer = np.concatenate(
(state.post_stop_word_buffer, audio_rs)
)
if len(state.post_stop_word_buffer) / 16000 > 2:
state.post_stop_word_buffer = state.post_stop_word_buffer[-32000:]
dur_vad, chunks = self.model.vad(
(16000, state.post_stop_word_buffer),
self.model_options,
return_chunks=True,
)
text = self.stt_model.stt_for_chunks(
(16000, state.post_stop_word_buffer), chunks
)
logger.debug(f"STT: {text}")
state.stop_word_detected = self.stop_word_detected(text)
if state.stop_word_detected:
logger.debug("Stop word detected")
self.send_stopword()
state.buffer = None
else:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad)
if (
dur_vad > self.algo_options.started_talking_threshold
and not state.started_talking
and state.stop_word_detected
):
state.started_talking = True
logger.debug("Started talking")
if state.started_talking:
if state.stream is None:
state.stream = audio
else:
state.stream = np.concatenate((state.stream, audio))
state.buffer = None
if (
dur_vad < self.algo_options.speech_threshold
and state.started_talking
and state.stop_word_detected
):
return True
return False
def reset(self):
super().reset()
self.generator = None
self.event.clear()
self.state = ReplyOnStopWordsState()
def copy(self):
return ReplyOnStopWords(
self.fn,
self.stop_words,
self.algo_options,
self.model_options,
self.can_interrupt,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
self.input_sample_rate,
)