diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index 2963e16..f1a72cf 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -8,6 +8,8 @@ from .reply_on_stopwords import ReplyOnStopWords from .speech_to_text import stt, stt_for_chunks from .utils import ( AdditionalOutputs, + Warning, + WebRTCError, aggregate_bytes_to_16bit, async_aggregate_bytes_to_16bit, audio_to_bytes, @@ -35,4 +37,6 @@ __all__ = [ "stt_for_chunks", "StreamHandler", "WebRTC", + "WebRTCError", + "Warning", ] diff --git a/backend/gradio_webrtc/reply_on_stopwords.py b/backend/gradio_webrtc/reply_on_stopwords.py index 2611538..0f7f9ac 100644 --- a/backend/gradio_webrtc/reply_on_stopwords.py +++ b/backend/gradio_webrtc/reply_on_stopwords.py @@ -75,6 +75,7 @@ class ReplyOnStopWords(ReplyOnPause): ) -> bool: """Take in the stream, determine if a pause happened""" import librosa + duration = len(audio) / sampling_rate if duration >= self.algo_options.audio_chunk_duration: diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index b150490..895655e 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -1,8 +1,10 @@ import asyncio import fractions import io +import json import logging import tempfile +from contextvars import ContextVar from typing import Any, Callable, Protocol, TypedDict, cast import av @@ -29,6 +31,55 @@ class DataChannel(Protocol): def send(self, message: str) -> None: ... +current_channel: ContextVar[DataChannel | None] = ContextVar( + "current_channel", default=None +) + + +def _send_log(message: str, type: str) -> None: + async def _send(channel: DataChannel) -> None: + channel.send( + json.dumps( + { + "type": type, + "message": message, + } + ) + ) + + if channel := current_channel.get(): + print("channel", channel) + try: + loop = asyncio.get_running_loop() + asyncio.run_coroutine_threadsafe(_send(channel), loop) + except RuntimeError: + asyncio.run(_send(channel)) + + +def Warning( # noqa: N802 + message: str = "Warning issued.", +): + """ + Send a warning message that is deplayed in the UI of the application. + + Parameters + ---------- + audio : str + The warning message to send + + Returns + ------- + None + """ + _send_log(message, "warning") + + +class WebRTCError(Exception): + def __init__(self, message: str) -> None: + super().__init__(message) + _send_log(message, "error") + + def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]: if isinstance(data, AdditionalOutputs): return None, data diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index 778db02..158872e 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -44,6 +44,7 @@ from gradio_client import handle_file from gradio_webrtc.utils import ( AdditionalOutputs, DataChannel, + current_channel, player_worker_decode, split_output, ) @@ -84,9 +85,12 @@ class VideoCallback(VideoStreamTrack): self.set_additional_outputs = set_additional_outputs self.thread_quit = asyncio.Event() self.mode = mode + self.channel_set = asyncio.Event() def set_channel(self, channel: DataChannel): self.channel = channel + current_channel.set(channel) + self.channel_set.set() def set_args(self, args: list[Any]): self.latest_args = ["__webrtc_value__"] + list(args) @@ -122,6 +126,12 @@ class VideoCallback(VideoStreamTrack): logger.debug("video callback stop") self.thread_quit.set() + async def wait_for_channel(self): + if not self.channel_set.is_set(): + await self.channel_set.wait() + if current_channel.get() != self.channel: + current_channel.set(self.channel) + async def recv(self): try: try: @@ -129,6 +139,8 @@ class VideoCallback(VideoStreamTrack): except MediaStreamError: self.stop() return + + await self.wait_for_channel() frame_array = frame.to_ndarray(format="bgr24") if self.latest_args == "not_set": @@ -180,6 +192,7 @@ class StreamHandlerBase(ABC): self._channel: DataChannel | None = None self._loop: asyncio.AbstractEventLoop self.args_set = asyncio.Event() + self.channel_set = asyncio.Event() @property def loop(self) -> asyncio.AbstractEventLoop: @@ -191,6 +204,7 @@ class StreamHandlerBase(ABC): def set_channel(self, channel: DataChannel): self._channel = channel + self.channel_set.set() async def fetch_args( self, @@ -203,6 +217,9 @@ class StreamHandlerBase(ABC): await self.fetch_args() await self.args_set.wait() + def wait_for_args_sync(self): + asyncio.run_coroutine_threadsafe(self.wait_for_args(), self.loop).result() + def set_args(self, args: list[Any]): logger.debug("setting args in audio callback %s", args) self.latest_args = ["__webrtc_value__"] + list(args) @@ -275,6 +292,7 @@ class AudioCallback(AudioStreamTrack): channel: DataChannel | None = None, set_additional_outputs: Callable | None = None, ) -> None: + super().__init__() self.track = track self.event_handler = event_handler self.current_timestamp = 0 @@ -286,7 +304,6 @@ class AudioCallback(AudioStreamTrack): self.last_timestamp = 0 self.channel = channel self.set_additional_outputs = set_additional_outputs - super().__init__() def set_channel(self, channel: DataChannel): self.channel = channel @@ -295,6 +312,10 @@ class AudioCallback(AudioStreamTrack): def set_args(self, args: list[Any]): self.event_handler.set_args(args) + def event_handler_receive(self, frame: tuple[int, np.ndarray]) -> None: + current_channel.set(self.event_handler.channel) + return cast(Callable, self.event_handler.receive)(frame) + async def process_input_frames(self) -> None: while not self.thread_quit.is_set(): try: @@ -307,7 +328,7 @@ class AudioCallback(AudioStreamTrack): ) else: await anyio.to_thread.run_sync( - self.event_handler.receive, (frame.sample_rate, numpy_array) + self.event_handler_receive, (frame.sample_rate, numpy_array) ) except MediaStreamError: logger.debug("MediaStreamError in process_input_frames") @@ -342,7 +363,13 @@ class AudioCallback(AudioStreamTrack): if self.readyState != "live": raise MediaStreamError + if not self.event_handler.channel_set.is_set(): + await self.event_handler.channel_set.wait() + if current_channel.get() != self.event_handler.channel: + current_channel.set(self.event_handler.channel) + self.start() + frame = await self.queue.get() logger.debug("frame %s", frame) @@ -415,7 +442,7 @@ class ServerToClientVideo(VideoStreamTrack): self.generator = cast( Generator[Any, None, Any], self.event_handler(*self.latest_args) ) - + current_channel.set(self.channel) try: next_array, outputs = split_output(next(self.generator)) if ( @@ -470,6 +497,7 @@ class ServerToClientAudio(AudioStreamTrack): def next(self) -> tuple[int, np.ndarray] | None: self.args_set.wait() + current_channel.set(self.channel) if self.generator is None: self.generator = self.event_handler(*self.latest_args) if self.generator is not None: @@ -946,6 +974,7 @@ class WebRTC(Component): answer = await pc.createAnswer() await pc.setLocalDescription(answer) # type: ignore logger.debug("done handling offer about to return") + await asyncio.sleep(0.1) return { "sdp": pc.localDescription.sdp, diff --git a/docs/additional-outputs.md b/docs/additional-outputs.md deleted file mode 100644 index e69de29..0000000 diff --git a/docs/faq.md b/docs/faq.md index 7881b6a..f3a69f6 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -23,4 +23,45 @@ You can disable this via the `track_constraints` (see [advanced configuration](. mode="send-receive", modality="audio", ) +``` + +## How to raise errors in the UI + +You can raise `WebRTCError` in order for an error message to show up in the user's screen. This is similar to how `gr.Error` works. + +Here is a simple example: + +```python +def generation(num_steps): + for _ in range(num_steps): + segment = AudioSegment.from_file( + "/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav" + ) + yield ( + segment.frame_rate, + np.array(segment.get_array_of_samples()).reshape(1, -1), + ) + time.sleep(3.5) + raise WebRTCError("This is a test error") + +with gr.Blocks() as demo: + audio = WebRTC( + label="Stream", + mode="receive", + modality="audio", + ) + num_steps = gr.Slider( + label="Number of Steps", + minimum=1, + maximum=10, + step=1, + value=5, + ) + button = gr.Button("Generate") + + audio.stream( + fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click + ) + +demo.launch() ``` \ No newline at end of file diff --git a/docs/user-guide.md b/docs/user-guide.md index bb7a61a..cb1ecb7 100644 --- a/docs/user-guide.md +++ b/docs/user-guide.md @@ -161,6 +161,145 @@ abstraction that gives you arbitrary control over how the input audio stream and 1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler. 2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`. + +### Async Stream Handlers + +It is also possible to create asynchronous stream handlers. This is very convenient for accessing async APIs from major LLM developers, like Google and OpenAI. The main difference is that `receive` and `emit` are now defined with `async def`. + +Here is a complete example of using `AsyncStreamHandler` for using the Google Gemini real time API: + +=== "Code" + ``` py title="AsyncStreamHandler" + + import asyncio + import base64 + import logging + import os + + import gradio as gr + import numpy as np + from google import genai + from gradio_webrtc import ( + AsyncStreamHandler, + WebRTC, + async_aggregate_bytes_to_16bit, + get_twilio_turn_credentials, + ) + + class GeminiHandler(AsyncStreamHandler): + def __init__( + self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480 + ) -> None: + super().__init__( + expected_layout, + output_sample_rate, + output_frame_size, + input_sample_rate=16000, + ) + self.client: genai.Client | None = None + self.input_queue = asyncio.Queue() + self.output_queue = asyncio.Queue() + self.quit = asyncio.Event() + + def copy(self) -> "GeminiHandler": + return GeminiHandler( + expected_layout=self.expected_layout, + output_sample_rate=self.output_sample_rate, + output_frame_size=self.output_frame_size, + ) + + async def stream(self): + while not self.quit.is_set(): + audio = await self.input_queue.get() + yield audio + + async def connect(self, api_key: str): + client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"}) + config = {"response_modalities": ["AUDIO"]} + async with client.aio.live.connect( + model="gemini-2.0-flash-exp", config=config + ) as session: + async for audio in session.start_stream( + stream=self.stream(), mime_type="audio/pcm" + ): + if audio.data: + yield audio.data + + async def receive(self, frame: tuple[int, np.ndarray]) -> None: + _, array = frame + array = array.squeeze() + audio_message = base64.b64encode(array.tobytes()).decode("UTF-8") + self.input_queue.put_nowait(audio_message) + + async def generator(self): + async for audio_response in async_aggregate_bytes_to_16bit( + self.connect(api_key=self.latest_args[1]) + ): + self.output_queue.put_nowait(audio_response) + + async def emit(self): + if not self.args_set.is_set(): + await self.wait_for_args() + asyncio.create_task(self.generator()) + + array = await self.output_queue.get() + return (self.output_sample_rate, array) + + def shutdown(self) -> None: + self.quit.set() + + with gr.Blocks() as demo: + gr.HTML( + """ +
+ """ + ) + with gr.Row() as api_key_row: + api_key = gr.Textbox( + label="API Key", + placeholder="Enter your API Key", + value=os.getenv("GOOGLE_API_KEY", ""), + type="password", + ) + with gr.Row(visible=False) as row: + webrtc = WebRTC( + label="Audio", + modality="audio", + mode="send-receive", + rtc_configuration=get_twilio_turn_credentials(), + pulse_color="rgb(35, 157, 225)", + icon_button_color="rgb(35, 157, 225)", + icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png", + ) + + webrtc.stream( + GeminiHandler(), + inputs=[webrtc, api_key], + outputs=[webrtc], + time_limit=90, + concurrency_limit=2, + ) + api_key.submit( + lambda: (gr.update(visible=False), gr.update(visible=True)), + None, + [api_key_row, row], + ) + + demo.launch() + ``` + +### Accessing Other Component Values from a StreamHandler + +In the gemini demo above, you'll notice that we have the user input their google API key. This is stored in a `gr.Textbox` parameter. +We can access the value of this component via the `latest_args` prop of the `StreamHandler`. The `latest_args` is a list storing the values of each component in the WebRTC `stream` event `inputs` parameter. The value of the `WebRTC` component is the 0th index and it's always the dummy string `__webrtc_value__`. + +In order to fetch the latest value from the user however, we `await self.wait_for_args()`. In a synchronous `StreamHandler`, we would call `self.wait_for_args_sync()`. + + ### Server-To-Client Only To stream only from the server to the client, implement a python generator and pass it to the component's `stream` event. The stream event must also specify a `trigger` corresponding to a UI interaction that starts the stream. In this case, it's a button click. diff --git a/frontend/Index.svelte b/frontend/Index.svelte index 6a20159..b3901ec 100644 --- a/frontend/Index.svelte +++ b/frontend/Index.svelte @@ -38,7 +38,11 @@ export let icon_button_color: string = "var(--color-accent)"; export let pulse_color: string = "var(--color-accent)"; - const on_change_cb = (msg: "change" | "tick") => { + const on_change_cb = (msg: "change" | "tick" | any) => { + if (msg?.type === "info" || msg?.type === "warning" || msg?.type === "error") { + console.log("dispatching info", msg.message); + gradio.dispatch(msg?.type === "error"? "error": "warning", msg.message); + } gradio.dispatch(msg === "change" ? "state_change" : "tick"); } @@ -93,6 +97,7 @@ i18n={gradio.i18n} on:tick={() => gradio.dispatch("tick")} on:error={({ detail }) => gradio.dispatch("error", detail)} + /> {:else if (mode === "send-receive" || mode == "send") && modality === "video"}