Make sure channel is always set, be able to raise UI errors with WebRTCError (#45)

* Code

* test

* code

* user guide
This commit is contained in:
Freddy Boulton
2024-12-23 15:21:10 -05:00
committed by GitHub
parent e057fc1502
commit 5812fd5aeb
11 changed files with 289 additions and 7 deletions

View File

@@ -23,4 +23,45 @@ You can disable this via the `track_constraints` (see [advanced configuration](.
mode="send-receive",
modality="audio",
)
```
## How to raise errors in the UI
You can raise `WebRTCError` in order for an error message to show up in the user's screen. This is similar to how `gr.Error` works.
Here is a simple example:
```python
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file(
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
)
yield (
segment.frame_rate,
np.array(segment.get_array_of_samples()).reshape(1, -1),
)
time.sleep(3.5)
raise WebRTCError("This is a test error")
with gr.Blocks() as demo:
audio = WebRTC(
label="Stream",
mode="receive",
modality="audio",
)
num_steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=10,
step=1,
value=5,
)
button = gr.Button("Generate")
audio.stream(
fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click
)
demo.launch()
```

View File

@@ -161,6 +161,145 @@ abstraction that gives you arbitrary control over how the input audio stream and
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
### Async Stream Handlers
It is also possible to create asynchronous stream handlers. This is very convenient for accessing async APIs from major LLM developers, like Google and OpenAI. The main difference is that `receive` and `emit` are now defined with `async def`.
Here is a complete example of using `AsyncStreamHandler` for using the Google Gemini real time API:
=== "Code"
``` py title="AsyncStreamHandler"
import asyncio
import base64
import logging
import os
import gradio as gr
import numpy as np
from google import genai
from gradio_webrtc import (
AsyncStreamHandler,
WebRTC,
async_aggregate_bytes_to_16bit,
get_twilio_turn_credentials,
)
class GeminiHandler(AsyncStreamHandler):
def __init__(
self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=16000,
)
self.client: genai.Client | None = None
self.input_queue = asyncio.Queue()
self.output_queue = asyncio.Queue()
self.quit = asyncio.Event()
def copy(self) -> "GeminiHandler":
return GeminiHandler(
expected_layout=self.expected_layout,
output_sample_rate=self.output_sample_rate,
output_frame_size=self.output_frame_size,
)
async def stream(self):
while not self.quit.is_set():
audio = await self.input_queue.get()
yield audio
async def connect(self, api_key: str):
client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
config = {"response_modalities": ["AUDIO"]}
async with client.aio.live.connect(
model="gemini-2.0-flash-exp", config=config
) as session:
async for audio in session.start_stream(
stream=self.stream(), mime_type="audio/pcm"
):
if audio.data:
yield audio.data
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
_, array = frame
array = array.squeeze()
audio_message = base64.b64encode(array.tobytes()).decode("UTF-8")
self.input_queue.put_nowait(audio_message)
async def generator(self):
async for audio_response in async_aggregate_bytes_to_16bit(
self.connect(api_key=self.latest_args[1])
):
self.output_queue.put_nowait(audio_response)
async def emit(self):
if not self.args_set.is_set():
await self.wait_for_args()
asyncio.create_task(self.generator())
array = await self.output_queue.get()
return (self.output_sample_rate, array)
def shutdown(self) -> None:
self.quit.set()
with gr.Blocks() as demo:
gr.HTML(
"""
<div style='text-align: center'>
<h1>Gen AI SDK Voice Chat</h1>
<p>Speak with Gemini using real-time audio streaming</p>
<p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
</div>
"""
)
with gr.Row() as api_key_row:
api_key = gr.Textbox(
label="API Key",
placeholder="Enter your API Key",
value=os.getenv("GOOGLE_API_KEY", ""),
type="password",
)
with gr.Row(visible=False) as row:
webrtc = WebRTC(
label="Audio",
modality="audio",
mode="send-receive",
rtc_configuration=get_twilio_turn_credentials(),
pulse_color="rgb(35, 157, 225)",
icon_button_color="rgb(35, 157, 225)",
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
)
webrtc.stream(
GeminiHandler(),
inputs=[webrtc, api_key],
outputs=[webrtc],
time_limit=90,
concurrency_limit=2,
)
api_key.submit(
lambda: (gr.update(visible=False), gr.update(visible=True)),
None,
[api_key_row, row],
)
demo.launch()
```
### Accessing Other Component Values from a StreamHandler
In the gemini demo above, you'll notice that we have the user input their google API key. This is stored in a `gr.Textbox` parameter.
We can access the value of this component via the `latest_args` prop of the `StreamHandler`. The `latest_args` is a list storing the values of each component in the WebRTC `stream` event `inputs` parameter. The value of the `WebRTC` component is the 0th index and it's always the dummy string `__webrtc_value__`.
In order to fetch the latest value from the user however, we `await self.wait_for_args()`. In a synchronous `StreamHandler`, we would call `self.wait_for_args_sync()`.
### Server-To-Client Only
To stream only from the server to the client, implement a python generator and pass it to the component's `stream` event. The stream event must also specify a `trigger` corresponding to a UI interaction that starts the stream. In this case, it's a button click.