Files
gradio-webrtc/backend/gradio_webrtc/reply_on_pause.py
freddyaboulton e0b11c1cb0 fix frame_size
2024-10-25 17:20:31 -07:00

149 lines
4.5 KiB
Python

from dataclasses import dataclass
from functools import lru_cache
from logging import getLogger
from threading import Event
from typing import Callable, Generator, Literal, cast
import numpy as np
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
from gradio_webrtc.webrtc import StreamHandler
logger = getLogger(__name__)
counter = 0
@lru_cache
def get_vad_model() -> SileroVADModel:
"""Returns the VAD model instance."""
return SileroVADModel()
@dataclass
class AlgoOptions:
"""Algorithm options."""
audio_chunk_duration: float = 0.6
started_talking_threshold: float = 0.2
speech_threshold: float = 0.1
@dataclass
class AppState:
stream: np.ndarray | None = None
sampling_rate: int = 0
pause_detected: bool = False
started_talking: bool = False
responding: bool = False
stopped: bool = False
buffer: np.ndarray | None = None
ReplyFnGenerator = Callable[
[tuple[int, np.ndarray]],
Generator[
tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]],
None,
None,
],
]
class ReplyOnPause(StreamHandler):
def __init__(
self,
fn: ReplyFnGenerator,
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
):
super().__init__(expected_layout, output_sample_rate, output_frame_size)
self.expected_layout: Literal["mono", "stereo"] = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.model = get_vad_model()
self.fn = fn
self.event = Event()
self.state = AppState()
self.generator = None
self.model_options = model_options
self.algo_options = algo_options or AlgoOptions()
def copy(self):
return ReplyOnPause(
self.fn,
self.algo_options,
self.model_options,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
)
def determine_pause(
self, audio: np.ndarray, sampling_rate: int, state: AppState
) -> bool:
"""Take in the stream, determine if a pause happened"""
duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad)
if (
dur_vad > self.algo_options.started_talking_threshold
and not state.started_talking
):
state.started_talking = True
logger.debug("Started talking")
if state.started_talking:
if state.stream is None:
state.stream = audio
else:
state.stream = np.concatenate((state.stream, audio))
state.buffer = None
if dur_vad < self.algo_options.speech_threshold and state.started_talking:
return True
return False
def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
frame_rate, array = audio
array = np.squeeze(array)
if not state.sampling_rate:
state.sampling_rate = frame_rate
if state.buffer is None:
state.buffer = array
else:
state.buffer = np.concatenate((state.buffer, array))
pause_detected = self.determine_pause(
state.buffer, state.sampling_rate, self.state
)
state.pause_detected = pause_detected
def receive(self, frame: tuple[int, np.ndarray]) -> None:
if self.state.responding:
return
self.process_audio(frame, self.state)
if self.state.pause_detected:
self.event.set()
def reset(self):
self.generator = None
self.event.clear()
self.state = AppState()
def emit(self):
if not self.event.is_set():
return None
else:
if not self.generator:
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
self.generator = self.fn((self.state.sampling_rate, audio))
self.state.responding = True
try:
return next(self.generator)
except StopIteration:
self.reset()