Make sure channel is always set, be able to raise UI errors with WebRTCError (#45)

* Code

* test

* code

* user guide
This commit is contained in:
Freddy Boulton
2024-12-23 15:21:10 -05:00
committed by GitHub
parent e057fc1502
commit 5812fd5aeb
11 changed files with 289 additions and 7 deletions

View File

@@ -8,6 +8,8 @@ from .reply_on_stopwords import ReplyOnStopWords
from .speech_to_text import stt, stt_for_chunks
from .utils import (
AdditionalOutputs,
Warning,
WebRTCError,
aggregate_bytes_to_16bit,
async_aggregate_bytes_to_16bit,
audio_to_bytes,
@@ -35,4 +37,6 @@ __all__ = [
"stt_for_chunks",
"StreamHandler",
"WebRTC",
"WebRTCError",
"Warning",
]

View File

@@ -75,6 +75,7 @@ class ReplyOnStopWords(ReplyOnPause):
) -> bool:
"""Take in the stream, determine if a pause happened"""
import librosa
duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration:

View File

@@ -1,8 +1,10 @@
import asyncio
import fractions
import io
import json
import logging
import tempfile
from contextvars import ContextVar
from typing import Any, Callable, Protocol, TypedDict, cast
import av
@@ -29,6 +31,55 @@ class DataChannel(Protocol):
def send(self, message: str) -> None: ...
current_channel: ContextVar[DataChannel | None] = ContextVar(
"current_channel", default=None
)
def _send_log(message: str, type: str) -> None:
async def _send(channel: DataChannel) -> None:
channel.send(
json.dumps(
{
"type": type,
"message": message,
}
)
)
if channel := current_channel.get():
print("channel", channel)
try:
loop = asyncio.get_running_loop()
asyncio.run_coroutine_threadsafe(_send(channel), loop)
except RuntimeError:
asyncio.run(_send(channel))
def Warning( # noqa: N802
message: str = "Warning issued.",
):
"""
Send a warning message that is deplayed in the UI of the application.
Parameters
----------
audio : str
The warning message to send
Returns
-------
None
"""
_send_log(message, "warning")
class WebRTCError(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)
_send_log(message, "error")
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
if isinstance(data, AdditionalOutputs):
return None, data

View File

@@ -44,6 +44,7 @@ from gradio_client import handle_file
from gradio_webrtc.utils import (
AdditionalOutputs,
DataChannel,
current_channel,
player_worker_decode,
split_output,
)
@@ -84,9 +85,12 @@ class VideoCallback(VideoStreamTrack):
self.set_additional_outputs = set_additional_outputs
self.thread_quit = asyncio.Event()
self.mode = mode
self.channel_set = asyncio.Event()
def set_channel(self, channel: DataChannel):
self.channel = channel
current_channel.set(channel)
self.channel_set.set()
def set_args(self, args: list[Any]):
self.latest_args = ["__webrtc_value__"] + list(args)
@@ -122,6 +126,12 @@ class VideoCallback(VideoStreamTrack):
logger.debug("video callback stop")
self.thread_quit.set()
async def wait_for_channel(self):
if not self.channel_set.is_set():
await self.channel_set.wait()
if current_channel.get() != self.channel:
current_channel.set(self.channel)
async def recv(self):
try:
try:
@@ -129,6 +139,8 @@ class VideoCallback(VideoStreamTrack):
except MediaStreamError:
self.stop()
return
await self.wait_for_channel()
frame_array = frame.to_ndarray(format="bgr24")
if self.latest_args == "not_set":
@@ -180,6 +192,7 @@ class StreamHandlerBase(ABC):
self._channel: DataChannel | None = None
self._loop: asyncio.AbstractEventLoop
self.args_set = asyncio.Event()
self.channel_set = asyncio.Event()
@property
def loop(self) -> asyncio.AbstractEventLoop:
@@ -191,6 +204,7 @@ class StreamHandlerBase(ABC):
def set_channel(self, channel: DataChannel):
self._channel = channel
self.channel_set.set()
async def fetch_args(
self,
@@ -203,6 +217,9 @@ class StreamHandlerBase(ABC):
await self.fetch_args()
await self.args_set.wait()
def wait_for_args_sync(self):
asyncio.run_coroutine_threadsafe(self.wait_for_args(), self.loop).result()
def set_args(self, args: list[Any]):
logger.debug("setting args in audio callback %s", args)
self.latest_args = ["__webrtc_value__"] + list(args)
@@ -275,6 +292,7 @@ class AudioCallback(AudioStreamTrack):
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
super().__init__()
self.track = track
self.event_handler = event_handler
self.current_timestamp = 0
@@ -286,7 +304,6 @@ class AudioCallback(AudioStreamTrack):
self.last_timestamp = 0
self.channel = channel
self.set_additional_outputs = set_additional_outputs
super().__init__()
def set_channel(self, channel: DataChannel):
self.channel = channel
@@ -295,6 +312,10 @@ class AudioCallback(AudioStreamTrack):
def set_args(self, args: list[Any]):
self.event_handler.set_args(args)
def event_handler_receive(self, frame: tuple[int, np.ndarray]) -> None:
current_channel.set(self.event_handler.channel)
return cast(Callable, self.event_handler.receive)(frame)
async def process_input_frames(self) -> None:
while not self.thread_quit.is_set():
try:
@@ -307,7 +328,7 @@ class AudioCallback(AudioStreamTrack):
)
else:
await anyio.to_thread.run_sync(
self.event_handler.receive, (frame.sample_rate, numpy_array)
self.event_handler_receive, (frame.sample_rate, numpy_array)
)
except MediaStreamError:
logger.debug("MediaStreamError in process_input_frames")
@@ -342,7 +363,13 @@ class AudioCallback(AudioStreamTrack):
if self.readyState != "live":
raise MediaStreamError
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
self.start()
frame = await self.queue.get()
logger.debug("frame %s", frame)
@@ -415,7 +442,7 @@ class ServerToClientVideo(VideoStreamTrack):
self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args)
)
current_channel.set(self.channel)
try:
next_array, outputs = split_output(next(self.generator))
if (
@@ -470,6 +497,7 @@ class ServerToClientAudio(AudioStreamTrack):
def next(self) -> tuple[int, np.ndarray] | None:
self.args_set.wait()
current_channel.set(self.channel)
if self.generator is None:
self.generator = self.event_handler(*self.latest_args)
if self.generator is not None:
@@ -946,6 +974,7 @@ class WebRTC(Component):
answer = await pc.createAnswer()
await pc.setLocalDescription(answer) # type: ignore
logger.debug("done handling offer about to return")
await asyncio.sleep(0.1)
return {
"sdp": pc.localDescription.sdp,

View File

@@ -23,4 +23,45 @@ You can disable this via the `track_constraints` (see [advanced configuration](.
mode="send-receive",
modality="audio",
)
```
## How to raise errors in the UI
You can raise `WebRTCError` in order for an error message to show up in the user's screen. This is similar to how `gr.Error` works.
Here is a simple example:
```python
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file(
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
)
yield (
segment.frame_rate,
np.array(segment.get_array_of_samples()).reshape(1, -1),
)
time.sleep(3.5)
raise WebRTCError("This is a test error")
with gr.Blocks() as demo:
audio = WebRTC(
label="Stream",
mode="receive",
modality="audio",
)
num_steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=10,
step=1,
value=5,
)
button = gr.Button("Generate")
audio.stream(
fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click
)
demo.launch()
```

View File

@@ -161,6 +161,145 @@ abstraction that gives you arbitrary control over how the input audio stream and
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
### Async Stream Handlers
It is also possible to create asynchronous stream handlers. This is very convenient for accessing async APIs from major LLM developers, like Google and OpenAI. The main difference is that `receive` and `emit` are now defined with `async def`.
Here is a complete example of using `AsyncStreamHandler` for using the Google Gemini real time API:
=== "Code"
``` py title="AsyncStreamHandler"
import asyncio
import base64
import logging
import os
import gradio as gr
import numpy as np
from google import genai
from gradio_webrtc import (
AsyncStreamHandler,
WebRTC,
async_aggregate_bytes_to_16bit,
get_twilio_turn_credentials,
)
class GeminiHandler(AsyncStreamHandler):
def __init__(
self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=16000,
)
self.client: genai.Client | None = None
self.input_queue = asyncio.Queue()
self.output_queue = asyncio.Queue()
self.quit = asyncio.Event()
def copy(self) -> "GeminiHandler":
return GeminiHandler(
expected_layout=self.expected_layout,
output_sample_rate=self.output_sample_rate,
output_frame_size=self.output_frame_size,
)
async def stream(self):
while not self.quit.is_set():
audio = await self.input_queue.get()
yield audio
async def connect(self, api_key: str):
client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
config = {"response_modalities": ["AUDIO"]}
async with client.aio.live.connect(
model="gemini-2.0-flash-exp", config=config
) as session:
async for audio in session.start_stream(
stream=self.stream(), mime_type="audio/pcm"
):
if audio.data:
yield audio.data
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
_, array = frame
array = array.squeeze()
audio_message = base64.b64encode(array.tobytes()).decode("UTF-8")
self.input_queue.put_nowait(audio_message)
async def generator(self):
async for audio_response in async_aggregate_bytes_to_16bit(
self.connect(api_key=self.latest_args[1])
):
self.output_queue.put_nowait(audio_response)
async def emit(self):
if not self.args_set.is_set():
await self.wait_for_args()
asyncio.create_task(self.generator())
array = await self.output_queue.get()
return (self.output_sample_rate, array)
def shutdown(self) -> None:
self.quit.set()
with gr.Blocks() as demo:
gr.HTML(
"""
<div style='text-align: center'>
<h1>Gen AI SDK Voice Chat</h1>
<p>Speak with Gemini using real-time audio streaming</p>
<p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
</div>
"""
)
with gr.Row() as api_key_row:
api_key = gr.Textbox(
label="API Key",
placeholder="Enter your API Key",
value=os.getenv("GOOGLE_API_KEY", ""),
type="password",
)
with gr.Row(visible=False) as row:
webrtc = WebRTC(
label="Audio",
modality="audio",
mode="send-receive",
rtc_configuration=get_twilio_turn_credentials(),
pulse_color="rgb(35, 157, 225)",
icon_button_color="rgb(35, 157, 225)",
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
)
webrtc.stream(
GeminiHandler(),
inputs=[webrtc, api_key],
outputs=[webrtc],
time_limit=90,
concurrency_limit=2,
)
api_key.submit(
lambda: (gr.update(visible=False), gr.update(visible=True)),
None,
[api_key_row, row],
)
demo.launch()
```
### Accessing Other Component Values from a StreamHandler
In the gemini demo above, you'll notice that we have the user input their google API key. This is stored in a `gr.Textbox` parameter.
We can access the value of this component via the `latest_args` prop of the `StreamHandler`. The `latest_args` is a list storing the values of each component in the WebRTC `stream` event `inputs` parameter. The value of the `WebRTC` component is the 0th index and it's always the dummy string `__webrtc_value__`.
In order to fetch the latest value from the user however, we `await self.wait_for_args()`. In a synchronous `StreamHandler`, we would call `self.wait_for_args_sync()`.
### Server-To-Client Only
To stream only from the server to the client, implement a python generator and pass it to the component's `stream` event. The stream event must also specify a `trigger` corresponding to a UI interaction that starts the stream. In this case, it's a button click.

View File

@@ -38,7 +38,11 @@
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
const on_change_cb = (msg: "change" | "tick") => {
const on_change_cb = (msg: "change" | "tick" | any) => {
if (msg?.type === "info" || msg?.type === "warning" || msg?.type === "error") {
console.log("dispatching info", msg.message);
gradio.dispatch(msg?.type === "error"? "error": "warning", msg.message);
}
gradio.dispatch(msg === "change" ? "state_change" : "tick");
}
@@ -93,6 +97,7 @@
i18n={gradio.i18n}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
{:else if (mode === "send-receive" || mode == "send") && modality === "video"}
<Video
@@ -141,6 +146,7 @@
{pulse_color}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
on:warning={({ detail }) => gradio.dispatch("warning", detail)}
/>
{/if}
</Block>

View File

@@ -46,6 +46,7 @@
});
let _on_change_cb = (msg: "change" | "tick" | "stopword") => {
console.log("msg", msg);
if (msg === "stopword") {
console.log("stopword recognized");
stopword_recognized = true;
@@ -53,6 +54,7 @@
stopword_recognized = false;
}, 3000);
} else {
console.log("calling on_change_cb with msg", msg);
on_change_cb(msg);
}
};

View File

@@ -64,13 +64,22 @@ export async function start(
data_channel.onmessage = (event) => {
console.debug("Received message:", event.data);
let event_json;
try {
event_json = JSON.parse(event.data);
} catch (e) {
console.debug("Error parsing JSON")
}
console.log("event_json", event_json);
if (
event.data === "change" ||
event.data === "tick" ||
event.data === "stopword"
event.data === "stopword" ||
event_json?.type === "warning" ||
event_json?.type === "error"
) {
console.debug(`${event.data} event received`);
on_change_cb(event.data);
on_change_cb(event_json ?? event.data);
}
};

View File

@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
[project]
name = "gradio_webrtc"
version = "0.0.25"
version = "0.0.27"
description = "Stream images in realtime with webrtc"
readme = "README.md"
license = "apache-2.0"