mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add API Reference and llms.txt (#256)
* stream api reference * docs * Add code * Add code * code
This commit is contained in:
@@ -65,6 +65,34 @@ async def iterate(generator: Generator) -> Any:
|
||||
|
||||
|
||||
class ReplyOnPause(StreamHandler):
|
||||
"""
|
||||
A stream handler that processes incoming audio, detects pauses,
|
||||
and triggers a reply function (`fn`) when a pause is detected.
|
||||
|
||||
This handler accumulates audio chunks, uses a Voice Activity Detection (VAD)
|
||||
model to determine speech segments, and identifies pauses based on configurable
|
||||
thresholds. Once a pause is detected after speech has started, it calls the
|
||||
provided generator function `fn` with the accumulated audio.
|
||||
|
||||
It can optionally run a `startup_fn` at the beginning and supports interruption
|
||||
of the reply function if new audio arrives.
|
||||
|
||||
Attributes:
|
||||
fn (ReplyFnGenerator): The generator function to call when a pause is detected.
|
||||
startup_fn (Callable | None): An optional function to run at startup.
|
||||
algo_options (AlgoOptions): Configuration for the pause detection algorithm.
|
||||
model_options (ModelOptions | None): Configuration for the VAD model.
|
||||
can_interrupt (bool): Whether incoming audio can interrupt the `fn` execution.
|
||||
expected_layout (Literal["mono", "stereo"]): Expected audio channel layout.
|
||||
output_sample_rate (int): Sample rate for the output audio from `fn`.
|
||||
input_sample_rate (int): Expected sample rate of the input audio.
|
||||
model (PauseDetectionModel): The VAD model instance.
|
||||
state (AppState): The current state of the pause detection logic.
|
||||
generator (Generator | AsyncGenerator | None): The active generator instance from `fn`.
|
||||
event (Event): Threading event used to signal pause detection.
|
||||
loop (asyncio.AbstractEventLoop): The asyncio event loop.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn: ReplyFnGenerator,
|
||||
@@ -78,6 +106,23 @@ class ReplyOnPause(StreamHandler):
|
||||
input_sample_rate: int = 48000,
|
||||
model: PauseDetectionModel | None = None,
|
||||
):
|
||||
"""
|
||||
Initializes the ReplyOnPause handler.
|
||||
|
||||
Args:
|
||||
fn: The generator function to execute upon pause detection.
|
||||
It receives `(sample_rate, audio_array)` and optionally `*args`.
|
||||
startup_fn: An optional function to run once at the beginning.
|
||||
algo_options: Options for the pause detection algorithm.
|
||||
model_options: Options for the VAD model.
|
||||
can_interrupt: If True, incoming audio during `fn` execution
|
||||
will stop the generator and process the new audio.
|
||||
expected_layout: Expected input audio layout ('mono' or 'stereo').
|
||||
output_sample_rate: The sample rate expected for audio yielded by `fn`.
|
||||
output_frame_size: Deprecated.
|
||||
input_sample_rate: The expected sample rate of incoming audio.
|
||||
model: An optional pre-initialized VAD model instance.
|
||||
"""
|
||||
super().__init__(
|
||||
expected_layout,
|
||||
output_sample_rate,
|
||||
@@ -100,9 +145,16 @@ class ReplyOnPause(StreamHandler):
|
||||
|
||||
@property
|
||||
def _needs_additional_inputs(self) -> bool:
|
||||
"""Checks if the reply function `fn` expects additional arguments."""
|
||||
return len(inspect.signature(self.fn).parameters) > 1
|
||||
|
||||
def start_up(self):
|
||||
"""
|
||||
Executes the startup function `startup_fn` if provided.
|
||||
|
||||
Waits for additional arguments if `_needs_additional_inputs` is True
|
||||
before calling `startup_fn`. Sets the `event` after completion.
|
||||
"""
|
||||
if self.startup_fn:
|
||||
if self._needs_additional_inputs:
|
||||
self.wait_for_args_sync()
|
||||
@@ -113,6 +165,7 @@ class ReplyOnPause(StreamHandler):
|
||||
self.event.set()
|
||||
|
||||
def copy(self):
|
||||
"""Creates a new instance of ReplyOnPause with the same configuration."""
|
||||
return ReplyOnPause(
|
||||
self.fn,
|
||||
self.startup_fn,
|
||||
@@ -129,7 +182,22 @@ class ReplyOnPause(StreamHandler):
|
||||
def determine_pause(
|
||||
self, audio: np.ndarray, sampling_rate: int, state: AppState
|
||||
) -> bool:
|
||||
"""Take in the stream, determine if a pause happened"""
|
||||
"""
|
||||
Analyzes an audio chunk to detect if a significant pause occurred after speech.
|
||||
|
||||
Uses the VAD model to measure speech duration within the chunk. Updates the
|
||||
application state (`state`) regarding whether talking has started and
|
||||
accumulates speech segments.
|
||||
|
||||
Args:
|
||||
audio: The numpy array containing the audio chunk.
|
||||
sampling_rate: The sample rate of the audio chunk.
|
||||
state: The current application state.
|
||||
|
||||
Returns:
|
||||
True if a pause satisfying the configured thresholds is detected
|
||||
after speech has started, False otherwise.
|
||||
"""
|
||||
duration = len(audio) / sampling_rate
|
||||
|
||||
if duration >= self.algo_options.audio_chunk_duration:
|
||||
@@ -152,6 +220,16 @@ class ReplyOnPause(StreamHandler):
|
||||
return False
|
||||
|
||||
def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
|
||||
"""
|
||||
Processes an incoming audio frame.
|
||||
|
||||
Appends the frame to the buffer, runs pause detection on the buffer,
|
||||
and updates the application state.
|
||||
|
||||
Args:
|
||||
audio: A tuple containing the sample rate and the audio frame data.
|
||||
state: The current application state to update.
|
||||
"""
|
||||
frame_rate, array = audio
|
||||
array = np.squeeze(array)
|
||||
if not state.sampling_rate:
|
||||
@@ -167,6 +245,16 @@ class ReplyOnPause(StreamHandler):
|
||||
state.pause_detected = pause_detected
|
||||
|
||||
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
"""
|
||||
Receives an audio frame from the stream.
|
||||
|
||||
Processes the audio frame using `process_audio`. If a pause is detected,
|
||||
it sets the `event`. If interruption is enabled and a reply is ongoing,
|
||||
it closes the current generator and clears the processing queue.
|
||||
|
||||
Args:
|
||||
frame: A tuple containing the sample rate and the audio frame data.
|
||||
"""
|
||||
if self.state.responding and not self.can_interrupt:
|
||||
return
|
||||
self.process_audio(frame, self.state)
|
||||
@@ -179,7 +267,13 @@ class ReplyOnPause(StreamHandler):
|
||||
self.clear_queue()
|
||||
|
||||
def _close_generator(self):
|
||||
"""Properly close the generator to ensure resources are released."""
|
||||
"""
|
||||
Safely closes the active reply generator (`self.generator`).
|
||||
|
||||
Handles both synchronous and asynchronous generators, ensuring proper
|
||||
resource cleanup (e.g., calling `aclose()` or `close()`).
|
||||
Logs any errors during closure.
|
||||
"""
|
||||
if self.generator is None:
|
||||
return
|
||||
|
||||
@@ -199,6 +293,12 @@ class ReplyOnPause(StreamHandler):
|
||||
logger.debug(f"Error closing generator: {e}")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the handler state to its initial condition.
|
||||
|
||||
Clears accumulated audio, resets state flags, closes any active generator,
|
||||
and clears the event flag. Also handles resetting argument state for phone mode.
|
||||
"""
|
||||
super().reset()
|
||||
if self.phone_mode:
|
||||
self.args_set.set()
|
||||
@@ -207,14 +307,37 @@ class ReplyOnPause(StreamHandler):
|
||||
self.state = AppState()
|
||||
|
||||
def trigger_response(self):
|
||||
"""
|
||||
Manually triggers the response generation process.
|
||||
|
||||
Sets the event flag, effectively simulating a pause detection.
|
||||
Initializes the stream buffer if it's empty.
|
||||
"""
|
||||
self.event.set()
|
||||
if self.state.stream is None:
|
||||
self.state.stream = np.array([], dtype=np.int16)
|
||||
|
||||
async def async_iterate(self, generator) -> EmitType:
|
||||
"""Helper function to get the next item from an async generator."""
|
||||
return await anext(generator)
|
||||
|
||||
def emit(self):
|
||||
"""
|
||||
Produces the next output chunk from the reply generator (`fn`).
|
||||
|
||||
This method is called repeatedly after a pause is detected (event is set).
|
||||
If the generator is not already running, it initializes it by calling `fn`
|
||||
with the accumulated audio and any required additional arguments.
|
||||
It then yields the next item from the generator. Handles both sync and
|
||||
async generators. Resets the state upon generator completion or error.
|
||||
|
||||
Returns:
|
||||
The next output item from the generator, or None if no pause event
|
||||
has occurred or the generator is exhausted.
|
||||
|
||||
Raises:
|
||||
Exception: Re-raises exceptions occurring within the `fn` generator.
|
||||
"""
|
||||
if not self.event.is_set():
|
||||
return None
|
||||
else:
|
||||
|
||||
@@ -20,15 +20,33 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplyOnStopWordsState(AppState):
|
||||
"""Extends AppState to include state specific to stop word detection."""
|
||||
|
||||
stop_word_detected: bool = False
|
||||
post_stop_word_buffer: np.ndarray | None = None
|
||||
started_talking_pre_stop_word: bool = False
|
||||
|
||||
def new(self):
|
||||
"""Creates a new instance of ReplyOnStopWordsState."""
|
||||
return ReplyOnStopWordsState()
|
||||
|
||||
|
||||
class ReplyOnStopWords(ReplyOnPause):
|
||||
"""
|
||||
A stream handler that extends ReplyOnPause to trigger based on stop words
|
||||
followed by a pause.
|
||||
|
||||
This handler listens to the incoming audio stream, performs Speech-to-Text (STT)
|
||||
to detect predefined stop words. Once a stop word is detected, it waits for a
|
||||
subsequent pause in speech (using the VAD model) before triggering the reply
|
||||
function (`fn`) with the audio recorded *after* the stop word.
|
||||
|
||||
Attributes:
|
||||
stop_words (list[str]): A list of words or phrases that trigger the pause detection.
|
||||
state (ReplyOnStopWordsState): The current state of the stop word and pause detection logic.
|
||||
stt_model: The Speech-to-Text model instance used for detecting stop words.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn: ReplyFnGenerator,
|
||||
@@ -43,6 +61,25 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
input_sample_rate: int = 48000,
|
||||
model: PauseDetectionModel | None = None,
|
||||
):
|
||||
"""
|
||||
Initializes the ReplyOnStopWords handler.
|
||||
|
||||
Args:
|
||||
fn: The generator function to execute upon stop word and pause detection.
|
||||
It receives `(sample_rate, audio_array)` and optionally `*args`.
|
||||
stop_words: A list of strings (words or phrases) to listen for.
|
||||
Detection is case-insensitive and ignores punctuation.
|
||||
startup_fn: An optional function to run once at the beginning.
|
||||
algo_options: Options for the pause detection algorithm (used after stop word).
|
||||
model_options: Options for the VAD model.
|
||||
can_interrupt: If True, incoming audio during `fn` execution
|
||||
will stop the generator and process the new audio.
|
||||
expected_layout: Expected input audio layout ('mono' or 'stereo').
|
||||
output_sample_rate: The sample rate expected for audio yielded by `fn`.
|
||||
output_frame_size: Deprecated.
|
||||
input_sample_rate: The expected sample rate of incoming audio.
|
||||
model: An optional pre-initialized VAD model instance.
|
||||
"""
|
||||
super().__init__(
|
||||
fn,
|
||||
algo_options=algo_options,
|
||||
@@ -60,6 +97,18 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
self.stt_model = get_stt_model("moonshine/base")
|
||||
|
||||
def stop_word_detected(self, text: str) -> bool:
|
||||
"""
|
||||
Checks if any of the configured stop words are present in the text.
|
||||
|
||||
Performs a case-insensitive search, treating multi-word stop phrases
|
||||
correctly and ignoring basic punctuation.
|
||||
|
||||
Args:
|
||||
text: The text transcribed from the audio.
|
||||
|
||||
Returns:
|
||||
True if a stop word is found, False otherwise.
|
||||
"""
|
||||
for stop_word in self.stop_words:
|
||||
stop_word = stop_word.lower().strip().split(" ")
|
||||
if bool(
|
||||
@@ -75,17 +124,36 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
async def _send_stopword(
|
||||
self,
|
||||
):
|
||||
"""Internal async method to send a 'stopword' message via the channel."""
|
||||
if self.channel:
|
||||
self.channel.send(create_message("stopword", ""))
|
||||
logger.debug("Sent stopword")
|
||||
|
||||
def send_stopword(self):
|
||||
"""Sends a 'stopword' message asynchronously via the communication channel."""
|
||||
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
|
||||
|
||||
def determine_pause( # type: ignore
|
||||
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
|
||||
) -> bool:
|
||||
"""Take in the stream, determine if a pause happened"""
|
||||
"""
|
||||
Analyzes an audio chunk to detect stop words and subsequent pauses.
|
||||
|
||||
Overrides the `ReplyOnPause.determine_pause` method.
|
||||
First, it performs STT on the audio buffer to detect stop words.
|
||||
Once a stop word is detected (`state.stop_word_detected` is True), it then
|
||||
uses the VAD model (similar to `ReplyOnPause`) to detect a pause in the
|
||||
audio *following* the stop word.
|
||||
|
||||
Args:
|
||||
audio: The numpy array containing the audio chunk.
|
||||
sampling_rate: The sample rate of the audio chunk.
|
||||
state: The current application state (ReplyOnStopWordsState).
|
||||
|
||||
Returns:
|
||||
True if a stop word has been detected and a subsequent pause
|
||||
satisfying the configured thresholds is detected, False otherwise.
|
||||
"""
|
||||
import librosa
|
||||
|
||||
duration = len(audio) / sampling_rate
|
||||
@@ -142,12 +210,19 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
return False
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the handler state to its initial condition.
|
||||
|
||||
Clears accumulated audio, resets state flags (including stop word state),
|
||||
closes any active generator, and clears the event flag.
|
||||
"""
|
||||
super().reset()
|
||||
self.generator = None
|
||||
self.event.clear()
|
||||
self.state = ReplyOnStopWordsState()
|
||||
|
||||
def copy(self):
|
||||
"""Creates a new instance of ReplyOnStopWords with the same configuration."""
|
||||
return ReplyOnStopWords(
|
||||
self.fn,
|
||||
self.stop_words,
|
||||
|
||||
@@ -36,6 +36,10 @@ class Body(BaseModel):
|
||||
|
||||
|
||||
class UIArgs(TypedDict):
|
||||
"""
|
||||
UI customization arguments for the Gradio Blocks UI of the Stream class
|
||||
"""
|
||||
|
||||
title: NotRequired[str]
|
||||
"""Title of the demo"""
|
||||
subtitle: NotRequired[str]
|
||||
@@ -56,6 +60,34 @@ class UIArgs(TypedDict):
|
||||
|
||||
|
||||
class Stream(WebRTCConnectionMixin):
|
||||
"""
|
||||
Define an audio or video stream with a built-in UI, mountable on a FastAPI app.
|
||||
|
||||
This class encapsulates the logic for handling real-time communication (WebRTC)
|
||||
streams, including setting up peer connections, managing tracks, generating
|
||||
a Gradio user interface, and integrating with FastAPI for API endpoints.
|
||||
It supports different modes (send, receive, send-receive) and modalities
|
||||
(audio, video, audio-video), and can optionally handle additional Gradio
|
||||
input/output components alongside the stream. It also provides functionality
|
||||
for telephone integration via the FastPhone method.
|
||||
|
||||
Attributes:
|
||||
mode (Literal["send-receive", "receive", "send"]): The direction of the stream.
|
||||
modality (Literal["video", "audio", "audio-video"]): The type of media stream.
|
||||
rtp_params (dict[str, Any] | None): Parameters for RTP encoding.
|
||||
event_handler (HandlerType): The main function to process stream data.
|
||||
concurrency_limit (int): The maximum number of concurrent connections allowed.
|
||||
time_limit (float | None): Time limit in seconds for the event handler execution.
|
||||
allow_extra_tracks (bool): Whether to allow extra tracks beyond the specified modality.
|
||||
additional_output_components (list[Component] | None): Extra Gradio output components.
|
||||
additional_input_components (list[Component] | None): Extra Gradio input components.
|
||||
additional_outputs_handler (Callable | None): Handler for additional outputs.
|
||||
track_constraints (dict[str, Any] | None): Constraints for media tracks (e.g., resolution).
|
||||
webrtc_component (WebRTC): The underlying Gradio WebRTC component instance.
|
||||
rtc_configuration (dict[str, Any] | None): Configuration for the RTCPeerConnection (e.g., ICE servers).
|
||||
_ui (Blocks): The Gradio Blocks UI instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler: HandlerType,
|
||||
@@ -73,6 +105,28 @@ class Stream(WebRTCConnectionMixin):
|
||||
additional_outputs: list[Component] | None = None,
|
||||
ui_args: UIArgs | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Stream instance.
|
||||
|
||||
Args:
|
||||
handler: The function to handle incoming stream data and return output data.
|
||||
additional_outputs_handler: An optional function to handle updates to additional output components.
|
||||
mode: The direction of the stream ('send', 'receive', or 'send-receive').
|
||||
modality: The type of media ('video', 'audio', or 'audio-video').
|
||||
concurrency_limit: Maximum number of concurrent connections. 'default' maps to 1.
|
||||
time_limit: Maximum execution time for the handler function in seconds.
|
||||
allow_extra_tracks: If True, allows connections with tracks not matching the modality.
|
||||
rtp_params: Optional dictionary of RTP encoding parameters.
|
||||
rtc_configuration: Optional dictionary for RTCPeerConnection configuration (e.g., ICE servers).
|
||||
Required when deploying on Colab or Spaces.
|
||||
track_constraints: Optional dictionary of constraints for media tracks (e.g., resolution, frame rate).
|
||||
additional_inputs: Optional list of extra Gradio input components.
|
||||
additional_outputs: Optional list of extra Gradio output components. Requires `additional_outputs_handler`.
|
||||
ui_args: Optional dictionary to customize the default UI appearance (title, subtitle, icon, etc.).
|
||||
|
||||
Raises:
|
||||
ValueError: If `additional_outputs` are provided without `additional_outputs_handler`.
|
||||
"""
|
||||
WebRTCConnectionMixin.__init__(self)
|
||||
self.mode = mode
|
||||
self.modality = modality
|
||||
@@ -97,6 +151,18 @@ class Stream(WebRTCConnectionMixin):
|
||||
self._ui.launch = self._wrap_gradio_launch(self._ui.launch)
|
||||
|
||||
def mount(self, app: FastAPI, path: str = ""):
|
||||
"""
|
||||
Mount the stream's API endpoints onto a FastAPI application.
|
||||
|
||||
This method adds the necessary routes (`/webrtc/offer`, `/telephone/handler`,
|
||||
`/telephone/incoming`, `/websocket/offer`) to the provided FastAPI app,
|
||||
prefixed with the optional `path`. It also injects a startup message
|
||||
into the app's lifespan.
|
||||
|
||||
Args:
|
||||
app: The FastAPI application instance.
|
||||
path: An optional URL prefix for the mounted routes.
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(prefix=path)
|
||||
@@ -109,7 +175,18 @@ class Stream(WebRTCConnectionMixin):
|
||||
app.include_router(router)
|
||||
|
||||
@staticmethod
|
||||
def print_error(env: Literal["colab", "spaces"]):
|
||||
def _print_error(env: Literal["colab", "spaces"]):
|
||||
"""
|
||||
Print an error message and raise RuntimeError for missing rtc_configuration.
|
||||
|
||||
Used internally when running in Colab or Spaces without necessary WebRTC setup.
|
||||
|
||||
Args:
|
||||
env: The environment ('colab' or 'spaces') where the error occurred.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Always raised after printing the error message.
|
||||
"""
|
||||
import click
|
||||
|
||||
print(
|
||||
@@ -125,14 +202,34 @@ class Stream(WebRTCConnectionMixin):
|
||||
)
|
||||
|
||||
def _check_colab_or_spaces(self):
|
||||
"""
|
||||
Check if running in Colab or Spaces and if rtc_configuration is missing.
|
||||
|
||||
Calls `_print_error` if the conditions are met.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If running in Colab/Spaces without `rtc_configuration`.
|
||||
"""
|
||||
from gradio.utils import colab_check, get_space
|
||||
|
||||
if colab_check() and not self.rtc_configuration:
|
||||
self.print_error("colab")
|
||||
self._print_error("colab")
|
||||
if get_space() and not self.rtc_configuration:
|
||||
self.print_error("spaces")
|
||||
self._print_error("spaces")
|
||||
|
||||
def _wrap_gradio_launch(self, callable):
|
||||
"""
|
||||
Wrap the Gradio launch method to inject environment checks.
|
||||
|
||||
Ensures that `_check_colab_or_spaces` is called during the application
|
||||
lifespan when `Blocks.launch()` is invoked.
|
||||
|
||||
Args:
|
||||
callable: The original `gradio.Blocks.launch` method.
|
||||
|
||||
Returns:
|
||||
A wrapped version of the launch method.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -158,6 +255,15 @@ class Stream(WebRTCConnectionMixin):
|
||||
def _inject_startup_message(
|
||||
self, lifespan: Callable[[FastAPI], AsyncContextManager] | None = None
|
||||
):
|
||||
"""
|
||||
Create a FastAPI lifespan context manager to print startup messages and check environment.
|
||||
|
||||
Args:
|
||||
lifespan: An optional existing lifespan context manager to wrap.
|
||||
|
||||
Returns:
|
||||
An async context manager function suitable for `FastAPI(lifespan=...)`.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
import click
|
||||
@@ -186,7 +292,26 @@ class Stream(WebRTCConnectionMixin):
|
||||
def _generate_default_ui(
|
||||
self,
|
||||
ui_args: UIArgs | None = None,
|
||||
):
|
||||
) -> Blocks:
|
||||
"""
|
||||
Generate the default Gradio UI based on mode, modality, and arguments.
|
||||
|
||||
Constructs a `gradio.Blocks` interface with the appropriate WebRTC component
|
||||
and any specified additional input/output components.
|
||||
|
||||
Args:
|
||||
ui_args: Optional dictionary containing UI customization arguments
|
||||
(title, subtitle, icon, etc.).
|
||||
|
||||
Returns:
|
||||
A `gradio.Blocks` instance representing the generated UI.
|
||||
|
||||
Raises:
|
||||
ValueError: If `additional_outputs` are provided without
|
||||
`additional_outputs_handler`.
|
||||
ValueError: If the combination of `mode` and `modality` is invalid
|
||||
or not supported for UI generation.
|
||||
"""
|
||||
ui_args = ui_args or {}
|
||||
same_components = []
|
||||
additional_input_components = self.additional_input_components or []
|
||||
@@ -590,18 +715,55 @@ class Stream(WebRTCConnectionMixin):
|
||||
|
||||
@property
|
||||
def ui(self) -> Blocks:
|
||||
"""
|
||||
Get the Gradio Blocks UI instance associated with this stream.
|
||||
|
||||
Returns:
|
||||
The `gradio.Blocks` UI instance.
|
||||
"""
|
||||
return self._ui
|
||||
|
||||
@ui.setter
|
||||
def ui(self, blocks: Blocks):
|
||||
"""
|
||||
Set a custom Gradio Blocks UI for this stream.
|
||||
|
||||
Args:
|
||||
blocks: The `gradio.Blocks` instance to use as the UI.
|
||||
"""
|
||||
self._ui = blocks
|
||||
|
||||
async def offer(self, body: Body):
|
||||
"""
|
||||
Handle an incoming WebRTC offer via HTTP POST.
|
||||
|
||||
Processes the SDP offer and ICE candidates from the client to establish
|
||||
a WebRTC connection.
|
||||
|
||||
Args:
|
||||
body: A Pydantic model containing the SDP offer, optional ICE candidate,
|
||||
type ('offer'), and a unique WebRTC ID.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the SDP answer generated by the server.
|
||||
"""
|
||||
return await self.handle_offer(
|
||||
body.model_dump(), set_outputs=self.set_additional_outputs(body.webrtc_id)
|
||||
)
|
||||
|
||||
async def handle_incoming_call(self, request: Request):
|
||||
"""
|
||||
Handle incoming telephone calls (e.g., via Twilio).
|
||||
|
||||
Generates TwiML instructions to connect the incoming call to the
|
||||
WebSocket handler (`/telephone/handler`) for audio streaming.
|
||||
|
||||
Args:
|
||||
request: The FastAPI Request object for the incoming call webhook.
|
||||
|
||||
Returns:
|
||||
An HTMLResponse containing the TwiML instructions as XML.
|
||||
"""
|
||||
from twilio.twiml.voice_response import Connect, VoiceResponse
|
||||
|
||||
response = VoiceResponse()
|
||||
@@ -613,6 +775,12 @@ class Stream(WebRTCConnectionMixin):
|
||||
return HTMLResponse(content=str(response), media_type="application/xml")
|
||||
|
||||
async def telephone_handler(self, websocket: WebSocket):
|
||||
"""
|
||||
The websocket endpoint for streaming audio over Twilio phone.
|
||||
|
||||
Args:
|
||||
websocket: The incoming WebSocket connection object.
|
||||
"""
|
||||
handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore
|
||||
handler.phone_mode = True
|
||||
|
||||
@@ -636,6 +804,15 @@ class Stream(WebRTCConnectionMixin):
|
||||
await ws.handle_websocket(websocket)
|
||||
|
||||
async def websocket_offer(self, websocket: WebSocket):
|
||||
"""
|
||||
Handle WebRTC signaling over a WebSocket connection.
|
||||
|
||||
Provides an alternative to the HTTP POST `/webrtc/offer` endpoint for
|
||||
exchanging SDP offers/answers and ICE candidates via WebSocket messages.
|
||||
|
||||
Args:
|
||||
websocket: The incoming WebSocket connection object.
|
||||
"""
|
||||
handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore
|
||||
handler.phone_mode = False
|
||||
|
||||
@@ -670,6 +847,25 @@ class Stream(WebRTCConnectionMixin):
|
||||
port: int = 8000,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Launch the FastPhone service for telephone integration.
|
||||
|
||||
Starts a local FastAPI server, mounts the stream, creates a public tunnel
|
||||
(using Gradio's tunneling), registers the tunnel URL with the FastPhone
|
||||
backend service, and prints the assigned phone number and access code.
|
||||
This allows users to call the phone number and interact with the stream handler.
|
||||
|
||||
Args:
|
||||
token: Optional Hugging Face Hub token for authentication with the
|
||||
FastPhone service. If None, attempts to find one automatically.
|
||||
host: The local host address to bind the server to.
|
||||
port: The local port to bind the server to.
|
||||
**kwargs: Additional keyword arguments passed to `uvicorn.run`.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If registration with the FastPhone service fails.
|
||||
RuntimeError: If running in Colab/Spaces without `rtc_configuration`.
|
||||
"""
|
||||
import atexit
|
||||
import inspect
|
||||
import secrets
|
||||
|
||||
@@ -236,14 +236,48 @@ class VideoCallback(VideoStreamTrack):
|
||||
|
||||
|
||||
class StreamHandlerBase(ABC):
|
||||
"""
|
||||
Base class for handling media streams in FastRTC.
|
||||
|
||||
Provides common attributes and methods for managing stream state,
|
||||
communication channels, and basic configuration. This class is intended
|
||||
to be subclassed by concrete stream handlers like `StreamHandler` or
|
||||
`AsyncStreamHandler`.
|
||||
|
||||
Attributes:
|
||||
expected_layout (Literal["mono", "stereo"]): The expected channel layout
|
||||
of the input audio ('mono' or 'stereo').
|
||||
output_sample_rate (int): The target sample rate for the output audio.
|
||||
output_frame_size (int): The desired number of samples per output audio frame.
|
||||
input_sample_rate (int): The expected sample rate of the input audio.
|
||||
channel (DataChannel | None): The WebRTC data channel for communication.
|
||||
channel_set (asyncio.Event): Event indicating if the data channel is set.
|
||||
args_set (asyncio.Event): Event indicating if additional arguments are set.
|
||||
latest_args (str | list[Any]): Stores the latest arguments received.
|
||||
loop (asyncio.AbstractEventLoop): The asyncio event loop.
|
||||
_resampler (av.AudioResampler | None): Internal audio resampler instance.
|
||||
_clear_queue (Callable | None): Callback to clear the processing queue.
|
||||
phone_mode (bool): Flag indicating if operating in telephone mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int | None = None,
|
||||
output_frame_size: int | None = None, # Deprecated
|
||||
input_sample_rate: int = 48000,
|
||||
fps: int = 30,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the StreamHandlerBase.
|
||||
|
||||
Args:
|
||||
expected_layout: Expected input audio layout ('mono' or 'stereo').
|
||||
output_sample_rate: Target output audio sample rate.
|
||||
output_frame_size: Deprecated. Frame size is now derived from sample rate.
|
||||
input_sample_rate: Expected input audio sample rate.
|
||||
fps: The desired frame rate for the output audio.
|
||||
"""
|
||||
self.expected_layout = expected_layout
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.input_sample_rate = input_sample_rate
|
||||
@@ -302,6 +336,12 @@ class StreamHandlerBase(ABC):
|
||||
self._phone_mode = value
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
"""
|
||||
Sets the data channel for communication and signals readiness.
|
||||
|
||||
Args:
|
||||
channel: The WebRTC DataChannel instance.
|
||||
"""
|
||||
self._channel = channel
|
||||
self.channel_set.set()
|
||||
|
||||
@@ -328,11 +368,25 @@ class StreamHandlerBase(ABC):
|
||||
traceback.print_exc()
|
||||
|
||||
async def send_message(self, msg: str):
|
||||
"""
|
||||
Asynchronously sends a message over the data channel.
|
||||
|
||||
Args:
|
||||
msg: The string message to send.
|
||||
"""
|
||||
if self.channel:
|
||||
self.channel.send(msg)
|
||||
logger.debug("Sent msg %s", msg)
|
||||
|
||||
def send_message_sync(self, msg: str):
|
||||
"""
|
||||
Synchronously sends a message over the data channel.
|
||||
|
||||
Runs the async `send_message` in the event loop and waits for completion.
|
||||
|
||||
Args:
|
||||
msg: The string message to send.
|
||||
"""
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
|
||||
logger.debug("Sent msg %s", msg)
|
||||
@@ -340,17 +394,36 @@ class StreamHandlerBase(ABC):
|
||||
logger.debug("Exception sending msg %s", e)
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
"""
|
||||
Sets additional arguments received (e.g., from UI components).
|
||||
|
||||
Args:
|
||||
args: A list of arguments.
|
||||
"""
|
||||
logger.debug("setting args in audio callback %s", args)
|
||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||
self.args_set.set()
|
||||
|
||||
def reset(self):
|
||||
"""Resets the argument set event."""
|
||||
self.args_set.clear()
|
||||
|
||||
def shutdown(self):
|
||||
"""Placeholder for shutdown logic. Subclasses can override."""
|
||||
pass
|
||||
|
||||
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
|
||||
"""
|
||||
Resamples an incoming audio frame to the target format and sample rate.
|
||||
|
||||
Initializes the resampler on the first call.
|
||||
|
||||
Args:
|
||||
frame: The input AudioFrame.
|
||||
|
||||
Yields:
|
||||
Resampled AudioFrame(s).
|
||||
"""
|
||||
if self._resampler is None:
|
||||
self._resampler = av.AudioResampler( # type: ignore
|
||||
format="s16",
|
||||
@@ -372,36 +445,102 @@ AudioEmitType = EmitType
|
||||
|
||||
|
||||
class StreamHandler(StreamHandlerBase):
|
||||
"""
|
||||
Abstract base class for synchronous stream handlers.
|
||||
|
||||
Inherits from `StreamHandlerBase` and defines the core synchronous interface
|
||||
for processing audio streams. Subclasses must implement `receive`, `emit`,
|
||||
and `copy`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def receive(self, frame: tuple[int, npt.NDArray[np.int16]]) -> None:
|
||||
"""
|
||||
Process an incoming audio frame synchronously.
|
||||
|
||||
Args:
|
||||
frame: A tuple containing the sample rate (int) and the audio data
|
||||
as a numpy array (int16).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def emit(self) -> EmitType:
|
||||
"""
|
||||
Produce the next output chunk synchronously.
|
||||
|
||||
This method is called to generate the output to be sent back over the stream.
|
||||
|
||||
Returns:
|
||||
An output item conforming to `EmitType`, which could be audio data,
|
||||
additional outputs, control signals (like `CloseStream`), or None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def copy(self) -> StreamHandler:
|
||||
"""
|
||||
Create a copy of this stream handler instance.
|
||||
|
||||
Used to create a new handler for each connection.
|
||||
|
||||
Returns:
|
||||
A new instance of the concrete StreamHandler subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
def start_up(self):
|
||||
"""Optional synchronous startup logic. Can be overridden by subclasses."""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncStreamHandler(StreamHandlerBase):
|
||||
"""
|
||||
Abstract base class for asynchronous stream handlers.
|
||||
|
||||
Inherits from `StreamHandlerBase` and defines the core asynchronous interface
|
||||
for processing audio streams using `async`/`await`. Subclasses must implement
|
||||
`receive`, `emit`, and `copy`. The `start_up` method must also be a coroutine.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self, frame: tuple[int, npt.NDArray[np.int16]]) -> None:
|
||||
"""
|
||||
Process an incoming audio frame asynchronously.
|
||||
|
||||
Args:
|
||||
frame: A tuple containing the sample rate (int) and the audio data
|
||||
as a numpy array (int16).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def emit(self) -> EmitType:
|
||||
"""
|
||||
Produce the next output chunk asynchronously.
|
||||
|
||||
This coroutine is called to generate the output to be sent back over the stream.
|
||||
|
||||
Returns:
|
||||
An output item conforming to `EmitType`, which could be audio data,
|
||||
additional outputs, control signals (like `CloseStream`), or None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def copy(self) -> AsyncStreamHandler:
|
||||
"""
|
||||
Create a copy of this asynchronous stream handler instance.
|
||||
|
||||
Used to create a new handler for each connection.
|
||||
|
||||
Returns:
|
||||
A new instance of the concrete AsyncStreamHandler subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def start_up(self):
|
||||
"""Optional asynchronous startup logic. Must be a coroutine (async def)."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -409,30 +548,88 @@ StreamHandlerImpl = StreamHandler | AsyncStreamHandler
|
||||
|
||||
|
||||
class AudioVideoStreamHandler(StreamHandler):
|
||||
"""
|
||||
Abstract base class for synchronous handlers processing both audio and video.
|
||||
|
||||
Inherits from `StreamHandler` (synchronous audio) and adds abstract methods
|
||||
for handling video frames synchronously. Subclasses must implement the audio
|
||||
methods (`receive`, `emit`) and the video methods (`video_receive`, `video_emit`),
|
||||
as well as `copy`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def video_receive(self, frame: VideoFrame) -> None:
|
||||
"""
|
||||
Process an incoming video frame synchronously.
|
||||
|
||||
Args:
|
||||
frame: The incoming aiortc `VideoFrame`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def video_emit(self) -> VideoEmitType:
|
||||
"""
|
||||
Produce the next output video frame synchronously.
|
||||
|
||||
Returns:
|
||||
An output item conforming to `VideoEmitType`, typically a numpy array
|
||||
representing the video frame, or None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def copy(self) -> AudioVideoStreamHandler:
|
||||
"""
|
||||
Create a copy of this audio-video stream handler instance.
|
||||
|
||||
Returns:
|
||||
A new instance of the concrete AudioVideoStreamHandler subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncAudioVideoStreamHandler(AsyncStreamHandler):
|
||||
"""
|
||||
Abstract base class for asynchronous handlers processing both audio and video.
|
||||
|
||||
Inherits from `AsyncStreamHandler` (asynchronous audio) and adds abstract
|
||||
coroutines for handling video frames asynchronously. Subclasses must implement
|
||||
the async audio methods (`receive`, `emit`, `start_up`) and the async video
|
||||
methods (`video_receive`, `video_emit`), as well as `copy`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def video_receive(self, frame: npt.NDArray[np.float32]) -> None:
|
||||
"""
|
||||
Process an incoming video frame asynchronously.
|
||||
|
||||
Args:
|
||||
frame: The video frame data as a numpy array (float32).
|
||||
Note: The type hint differs from the synchronous version.
|
||||
Consider standardizing if possible.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def video_emit(self) -> VideoEmitType:
|
||||
"""
|
||||
Produce the next output video frame asynchronously.
|
||||
|
||||
Returns:
|
||||
An output item conforming to `VideoEmitType`, typically a numpy array
|
||||
representing the video frame, or None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def copy(self) -> AsyncAudioVideoStreamHandler:
|
||||
"""
|
||||
Create a copy of this asynchronous audio-video stream handler instance.
|
||||
|
||||
Returns:
|
||||
A new instance of the concrete AsyncAudioVideoStreamHandler subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user