mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 01:49:23 +08:00
* Pretty good spot * Working draft * Fix other mode * Add js to git * Working * Add code * fix * Fix * Add code * Fix submit race condition * demo * fix * Fix * Fix
420 lines
16 KiB
Python
420 lines
16 KiB
Python
import asyncio
|
|
import inspect
|
|
from collections.abc import AsyncGenerator, Callable, Generator
|
|
from dataclasses import dataclass, field
|
|
from logging import getLogger
|
|
from threading import Event
|
|
from typing import Any, Literal, cast
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model
|
|
from .tracks import EmitType, StreamHandler
|
|
from .utils import AdditionalOutputs, WebRTCData, create_message, split_output
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class AlgoOptions:
|
|
"""
|
|
Algorithm options.
|
|
|
|
Attributes:
|
|
- audio_chunk_duration: Duration in seconds of audio chunks passed to the VAD model.
|
|
- started_talking_threshold: If the chunk has more than started_talking_threshold seconds of speech, the user started talking.
|
|
- speech_threshold: If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking.
|
|
- max_continuous_speech_s: Max duration of speech chunks before the handler is triggered, even if a pause is not detected by the VAD model.
|
|
"""
|
|
|
|
audio_chunk_duration: float = 0.6
|
|
started_talking_threshold: float = 0.2
|
|
speech_threshold: float = 0.1
|
|
max_continuous_speech_s: float = float("inf")
|
|
|
|
|
|
@dataclass
|
|
class AppState:
|
|
stream: np.ndarray | None = None
|
|
sampling_rate: int = 0
|
|
pause_detected: bool = False
|
|
started_talking: bool = False
|
|
responding: bool = False
|
|
stopped: bool = False
|
|
buffer: np.ndarray | None = None
|
|
responded_audio: bool = False
|
|
interrupted: asyncio.Event = field(default_factory=asyncio.Event)
|
|
|
|
def new(self):
|
|
return AppState()
|
|
|
|
|
|
ReplyFnGenerator = (
|
|
Callable[
|
|
[tuple[int, NDArray[np.int16]], Any],
|
|
Generator[EmitType, None, None],
|
|
]
|
|
| Callable[
|
|
[tuple[int, NDArray[np.int16]]],
|
|
Generator[EmitType, None, None],
|
|
]
|
|
| Callable[
|
|
[tuple[int, NDArray[np.int16]]],
|
|
AsyncGenerator[EmitType, None],
|
|
]
|
|
| Callable[
|
|
[tuple[int, NDArray[np.int16]], Any],
|
|
AsyncGenerator[EmitType, None],
|
|
]
|
|
| Callable[
|
|
[WebRTCData],
|
|
Generator[EmitType, None, None],
|
|
]
|
|
| Callable[
|
|
[WebRTCData, Any],
|
|
AsyncGenerator[EmitType, None],
|
|
]
|
|
)
|
|
|
|
|
|
async def iterate(generator: Generator) -> Any:
|
|
return next(generator)
|
|
|
|
|
|
class ReplyOnPause(StreamHandler):
|
|
"""
|
|
A stream handler that processes incoming audio, detects pauses,
|
|
and triggers a reply function (`fn`) when a pause is detected.
|
|
|
|
This handler accumulates audio chunks, uses a Voice Activity Detection (VAD)
|
|
model to determine speech segments, and identifies pauses based on configurable
|
|
thresholds. Once a pause is detected after speech has started, it calls the
|
|
provided generator function `fn` with the accumulated audio.
|
|
|
|
It can optionally run a `startup_fn` at the beginning and supports interruption
|
|
of the reply function if new audio arrives.
|
|
|
|
Attributes:
|
|
fn (ReplyFnGenerator): The generator function to call when a pause is detected.
|
|
startup_fn (Callable | None): An optional function to run at startup.
|
|
algo_options (AlgoOptions): Configuration for the pause detection algorithm.
|
|
model_options (ModelOptions | None): Configuration for the VAD model.
|
|
can_interrupt (bool): Whether incoming audio can interrupt the `fn` execution.
|
|
expected_layout (Literal["mono", "stereo"]): Expected audio channel layout.
|
|
output_sample_rate (int): Sample rate for the output audio from `fn`.
|
|
input_sample_rate (int): Expected sample rate of the input audio.
|
|
model (PauseDetectionModel): The VAD model instance.
|
|
state (AppState): The current state of the pause detection logic.
|
|
generator (Generator | AsyncGenerator | None): The active generator instance from `fn`.
|
|
event (Event): Threading event used to signal pause detection.
|
|
loop (asyncio.AbstractEventLoop): The asyncio event loop.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fn: ReplyFnGenerator,
|
|
startup_fn: Callable | None = None,
|
|
algo_options: AlgoOptions | None = None,
|
|
model_options: ModelOptions | None = None,
|
|
can_interrupt: bool = True,
|
|
expected_layout: Literal["mono", "stereo"] = "mono",
|
|
output_sample_rate: int = 24000,
|
|
output_frame_size: int | None = None, # Deprecated
|
|
input_sample_rate: int = 48000,
|
|
model: PauseDetectionModel | None = None,
|
|
needs_args: bool = False,
|
|
):
|
|
"""
|
|
Initializes the ReplyOnPause handler.
|
|
|
|
Args:
|
|
fn: The generator function to execute upon pause detection.
|
|
It receives `(sample_rate, audio_array)` and optionally `*args`.
|
|
startup_fn: An optional function to run once at the beginning.
|
|
algo_options: Options for the pause detection algorithm.
|
|
model_options: Options for the VAD model.
|
|
can_interrupt: If True, incoming audio during `fn` execution
|
|
will stop the generator and process the new audio.
|
|
expected_layout: Expected input audio layout ('mono' or 'stereo').
|
|
output_sample_rate: The sample rate expected for audio yielded by `fn`.
|
|
output_frame_size: Deprecated.
|
|
input_sample_rate: The expected sample rate of incoming audio.
|
|
model: An optional pre-initialized VAD model instance.
|
|
needs_args: Whether the reply function expects additional arguments.
|
|
"""
|
|
super().__init__(
|
|
expected_layout,
|
|
output_sample_rate,
|
|
output_frame_size,
|
|
input_sample_rate=input_sample_rate,
|
|
)
|
|
self.can_interrupt = can_interrupt
|
|
self.expected_layout: Literal["mono", "stereo"] = expected_layout
|
|
self.model = model or get_silero_model()
|
|
self.fn = fn
|
|
self.is_async = inspect.isasyncgenfunction(fn)
|
|
self.event = Event()
|
|
self.state = AppState()
|
|
self.generator: (
|
|
Generator[EmitType, None, None] | AsyncGenerator[EmitType, None] | None
|
|
) = None
|
|
self.model_options = model_options
|
|
self.algo_options = algo_options or AlgoOptions()
|
|
self.startup_fn = startup_fn
|
|
self.needs_args = needs_args
|
|
|
|
@property
|
|
def _needs_additional_inputs(self) -> bool:
|
|
"""Checks if the reply function `fn` expects additional arguments."""
|
|
return len(inspect.signature(self.fn).parameters) > 1 or self.needs_args
|
|
|
|
def start_up(self):
|
|
"""
|
|
Executes the startup function `startup_fn` if provided.
|
|
|
|
Waits for additional arguments if `_needs_additional_inputs` is True
|
|
before calling `startup_fn`. Sets the `event` after completion.
|
|
"""
|
|
if self.startup_fn:
|
|
if self._needs_additional_inputs:
|
|
self.wait_for_args_sync()
|
|
args = self.latest_args[1:]
|
|
else:
|
|
args = ()
|
|
self.generator = self.startup_fn(*args)
|
|
self.event.set()
|
|
|
|
def copy(self):
|
|
"""Creates a new instance of ReplyOnPause with the same configuration."""
|
|
return ReplyOnPause(
|
|
self.fn,
|
|
self.startup_fn,
|
|
self.algo_options,
|
|
self.model_options,
|
|
self.can_interrupt,
|
|
self.expected_layout,
|
|
self.output_sample_rate,
|
|
self.output_frame_size,
|
|
self.input_sample_rate,
|
|
self.model,
|
|
self.needs_args,
|
|
)
|
|
|
|
def determine_pause(
|
|
self, audio: np.ndarray, sampling_rate: int, state: AppState
|
|
) -> bool:
|
|
"""
|
|
Analyzes an audio chunk to detect if a significant pause occurred after speech.
|
|
|
|
Uses the VAD model to measure speech duration within the chunk. Updates the
|
|
application state (`state`) regarding whether talking has started and
|
|
accumulates speech segments.
|
|
|
|
Args:
|
|
audio: The numpy array containing the audio chunk.
|
|
sampling_rate: The sample rate of the audio chunk.
|
|
state: The current application state.
|
|
|
|
Returns:
|
|
True if a pause satisfying the configured thresholds is detected
|
|
after speech has started, False otherwise.
|
|
"""
|
|
duration = len(audio) / sampling_rate
|
|
|
|
if duration >= self.algo_options.audio_chunk_duration:
|
|
dur_vad, _ = self.model.vad((sampling_rate, audio), self.model_options)
|
|
logger.debug("VAD duration: %s", dur_vad)
|
|
if (
|
|
dur_vad > self.algo_options.started_talking_threshold
|
|
and not state.started_talking
|
|
):
|
|
state.started_talking = True
|
|
logger.debug("Started talking")
|
|
self.send_message_sync(create_message("log", "started_talking"))
|
|
if state.started_talking:
|
|
if state.stream is None:
|
|
state.stream = audio
|
|
else:
|
|
state.stream = np.concatenate((state.stream, audio))
|
|
|
|
# Check if continuous speech limit has been reached
|
|
current_duration = len(state.stream) / sampling_rate
|
|
if current_duration >= self.algo_options.max_continuous_speech_s:
|
|
return True
|
|
state.buffer = None
|
|
|
|
# Check if a pause has been detected by the VAD model
|
|
if dur_vad < self.algo_options.speech_threshold and state.started_talking:
|
|
return True
|
|
return False
|
|
|
|
def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
|
|
"""
|
|
Processes an incoming audio frame.
|
|
|
|
Appends the frame to the buffer, runs pause detection on the buffer,
|
|
and updates the application state.
|
|
|
|
Args:
|
|
audio: A tuple containing the sample rate and the audio frame data.
|
|
state: The current application state to update.
|
|
"""
|
|
frame_rate, array = audio
|
|
array = np.squeeze(array)
|
|
if not state.sampling_rate:
|
|
state.sampling_rate = frame_rate
|
|
if state.buffer is None:
|
|
state.buffer = array
|
|
else:
|
|
state.buffer = np.concatenate((state.buffer, array))
|
|
|
|
pause_detected = self.determine_pause(
|
|
state.buffer, state.sampling_rate, self.state
|
|
)
|
|
state.pause_detected = pause_detected
|
|
|
|
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
|
"""
|
|
Receives an audio frame from the stream.
|
|
|
|
Processes the audio frame using `process_audio`. If a pause is detected,
|
|
it sets the `event`. If interruption is enabled and a reply is ongoing,
|
|
it closes the current generator and clears the processing queue.
|
|
|
|
Args:
|
|
frame: A tuple containing the sample rate and the audio frame data.
|
|
"""
|
|
if self.state.responding and not self.can_interrupt:
|
|
return
|
|
self.process_audio(frame, self.state)
|
|
if self.state.pause_detected:
|
|
self.event.set()
|
|
if self.can_interrupt and self.state.responding:
|
|
self._close_generator()
|
|
self.generator = None
|
|
if self.can_interrupt:
|
|
self.clear_queue()
|
|
|
|
def _close_generator(self):
|
|
"""
|
|
Safely closes the active reply generator (`self.generator`).
|
|
|
|
Handles both synchronous and asynchronous generators, ensuring proper
|
|
resource cleanup (e.g., calling `aclose()` or `close()`).
|
|
Logs any errors during closure.
|
|
"""
|
|
if self.generator is None:
|
|
return
|
|
|
|
try:
|
|
if self.is_async:
|
|
# For async generators, we need to call aclose()
|
|
if hasattr(self.generator, "aclose"):
|
|
asyncio.run_coroutine_threadsafe(
|
|
cast(AsyncGenerator[EmitType, None], self.generator).aclose(),
|
|
self.loop,
|
|
).result(timeout=1.0) # Add timeout to prevent blocking
|
|
else:
|
|
# For sync generators, we can just exhaust it or close it
|
|
if hasattr(self.generator, "close"):
|
|
cast(Generator[EmitType, None, None], self.generator).close()
|
|
except Exception as e:
|
|
logger.debug(f"Error closing generator: {e}")
|
|
|
|
def reset(self):
|
|
"""
|
|
Resets the handler state to its initial condition.
|
|
|
|
Clears accumulated audio, resets state flags, closes any active generator,
|
|
and clears the event flag. Also handles resetting argument state for phone mode.
|
|
"""
|
|
super().reset()
|
|
if self.phone_mode:
|
|
self.args_set.set()
|
|
self.generator = None
|
|
self.event.clear()
|
|
self.state = AppState()
|
|
|
|
def trigger_response(self):
|
|
"""
|
|
Manually triggers the response generation process.
|
|
|
|
Sets the event flag, effectively simulating a pause detection.
|
|
Initializes the stream buffer if it's empty.
|
|
"""
|
|
self.event.set()
|
|
if self.state.stream is None:
|
|
self.state.stream = np.array([], dtype=np.int16)
|
|
|
|
async def async_iterate(self, generator) -> EmitType:
|
|
"""Helper function to get the next item from an async generator."""
|
|
return await anext(generator)
|
|
|
|
def emit(self):
|
|
"""
|
|
Produces the next output chunk from the reply generator (`fn`).
|
|
|
|
This method is called repeatedly after a pause is detected (event is set).
|
|
If the generator is not already running, it initializes it by calling `fn`
|
|
with the accumulated audio and any required additional arguments.
|
|
It then yields the next item from the generator. Handles both sync and
|
|
async generators. Resets the state upon generator completion or error.
|
|
|
|
Returns:
|
|
The next output item from the generator, or None if no pause event
|
|
has occurred or the generator is exhausted.
|
|
|
|
Raises:
|
|
Exception: Re-raises exceptions occurring within the `fn` generator.
|
|
"""
|
|
if not self.event.is_set():
|
|
return None
|
|
else:
|
|
if not self.generator:
|
|
self.send_message_sync(create_message("log", "pause_detected"))
|
|
if self._needs_additional_inputs and not self.phone_mode:
|
|
self.wait_for_args_sync()
|
|
else:
|
|
self.latest_args = [None]
|
|
self.args_set.set()
|
|
logger.debug("Creating generator")
|
|
if self.state.stream is not None and self.state.stream.size > 0:
|
|
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
|
|
else:
|
|
audio = np.array([[]], dtype=np.int16)
|
|
if isinstance(self.latest_args[0], WebRTCData):
|
|
self.latest_args[0].audio = (self.state.sampling_rate, audio)
|
|
else:
|
|
self.latest_args[0] = (self.state.sampling_rate, audio)
|
|
self.generator = self.fn(*self.latest_args) # type: ignore
|
|
logger.debug("Latest args: %s", self.latest_args)
|
|
self.state = self.state.new()
|
|
self.state.responding = True
|
|
try:
|
|
if self.is_async:
|
|
output = asyncio.run_coroutine_threadsafe(
|
|
self.async_iterate(self.generator), self.loop
|
|
).result()
|
|
else:
|
|
output = next(self.generator) # type: ignore
|
|
audio, additional_outputs = split_output(output)
|
|
if audio is not None:
|
|
self.send_message_sync(create_message("log", "response_starting"))
|
|
self.state.responded_audio = True
|
|
if self.phone_mode:
|
|
if isinstance(additional_outputs, AdditionalOutputs):
|
|
self.latest_args = [None] + list(additional_outputs.args)
|
|
return output
|
|
except (StopIteration, StopAsyncIteration):
|
|
if not self.state.responded_audio:
|
|
self.send_message_sync(create_message("log", "response_starting"))
|
|
self.reset()
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
logger.debug("Error in ReplyOnPause: %s", e)
|
|
self.reset()
|
|
raise e
|