mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
Make sure channel is always set, be able to raise UI errors with WebRTCError (#45)
* Code * test * code * user guide
This commit is contained in:
@@ -8,6 +8,8 @@ from .reply_on_stopwords import ReplyOnStopWords
|
||||
from .speech_to_text import stt, stt_for_chunks
|
||||
from .utils import (
|
||||
AdditionalOutputs,
|
||||
Warning,
|
||||
WebRTCError,
|
||||
aggregate_bytes_to_16bit,
|
||||
async_aggregate_bytes_to_16bit,
|
||||
audio_to_bytes,
|
||||
@@ -35,4 +37,6 @@ __all__ = [
|
||||
"stt_for_chunks",
|
||||
"StreamHandler",
|
||||
"WebRTC",
|
||||
"WebRTCError",
|
||||
"Warning",
|
||||
]
|
||||
|
||||
@@ -75,6 +75,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
) -> bool:
|
||||
"""Take in the stream, determine if a pause happened"""
|
||||
import librosa
|
||||
|
||||
duration = len(audio) / sampling_rate
|
||||
|
||||
if duration >= self.algo_options.audio_chunk_duration:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import fractions
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Protocol, TypedDict, cast
|
||||
|
||||
import av
|
||||
@@ -29,6 +31,55 @@ class DataChannel(Protocol):
|
||||
def send(self, message: str) -> None: ...
|
||||
|
||||
|
||||
current_channel: ContextVar[DataChannel | None] = ContextVar(
|
||||
"current_channel", default=None
|
||||
)
|
||||
|
||||
|
||||
def _send_log(message: str, type: str) -> None:
|
||||
async def _send(channel: DataChannel) -> None:
|
||||
channel.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": type,
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if channel := current_channel.get():
|
||||
print("channel", channel)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
asyncio.run_coroutine_threadsafe(_send(channel), loop)
|
||||
except RuntimeError:
|
||||
asyncio.run(_send(channel))
|
||||
|
||||
|
||||
def Warning( # noqa: N802
|
||||
message: str = "Warning issued.",
|
||||
):
|
||||
"""
|
||||
Send a warning message that is deplayed in the UI of the application.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : str
|
||||
The warning message to send
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
_send_log(message, "warning")
|
||||
|
||||
|
||||
class WebRTCError(Exception):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
_send_log(message, "error")
|
||||
|
||||
|
||||
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||
if isinstance(data, AdditionalOutputs):
|
||||
return None, data
|
||||
|
||||
@@ -44,6 +44,7 @@ from gradio_client import handle_file
|
||||
from gradio_webrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
DataChannel,
|
||||
current_channel,
|
||||
player_worker_decode,
|
||||
split_output,
|
||||
)
|
||||
@@ -84,9 +85,12 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.thread_quit = asyncio.Event()
|
||||
self.mode = mode
|
||||
self.channel_set = asyncio.Event()
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
current_channel.set(channel)
|
||||
self.channel_set.set()
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||
@@ -122,6 +126,12 @@ class VideoCallback(VideoStreamTrack):
|
||||
logger.debug("video callback stop")
|
||||
self.thread_quit.set()
|
||||
|
||||
async def wait_for_channel(self):
|
||||
if not self.channel_set.is_set():
|
||||
await self.channel_set.wait()
|
||||
if current_channel.get() != self.channel:
|
||||
current_channel.set(self.channel)
|
||||
|
||||
async def recv(self):
|
||||
try:
|
||||
try:
|
||||
@@ -129,6 +139,8 @@ class VideoCallback(VideoStreamTrack):
|
||||
except MediaStreamError:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
await self.wait_for_channel()
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
|
||||
if self.latest_args == "not_set":
|
||||
@@ -180,6 +192,7 @@ class StreamHandlerBase(ABC):
|
||||
self._channel: DataChannel | None = None
|
||||
self._loop: asyncio.AbstractEventLoop
|
||||
self.args_set = asyncio.Event()
|
||||
self.channel_set = asyncio.Event()
|
||||
|
||||
@property
|
||||
def loop(self) -> asyncio.AbstractEventLoop:
|
||||
@@ -191,6 +204,7 @@ class StreamHandlerBase(ABC):
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self._channel = channel
|
||||
self.channel_set.set()
|
||||
|
||||
async def fetch_args(
|
||||
self,
|
||||
@@ -203,6 +217,9 @@ class StreamHandlerBase(ABC):
|
||||
await self.fetch_args()
|
||||
await self.args_set.wait()
|
||||
|
||||
def wait_for_args_sync(self):
|
||||
asyncio.run_coroutine_threadsafe(self.wait_for_args(), self.loop).result()
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
logger.debug("setting args in audio callback %s", args)
|
||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||
@@ -275,6 +292,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
self.current_timestamp = 0
|
||||
@@ -286,7 +304,6 @@ class AudioCallback(AudioStreamTrack):
|
||||
self.last_timestamp = 0
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
super().__init__()
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
@@ -295,6 +312,10 @@ class AudioCallback(AudioStreamTrack):
|
||||
def set_args(self, args: list[Any]):
|
||||
self.event_handler.set_args(args)
|
||||
|
||||
def event_handler_receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
return cast(Callable, self.event_handler.receive)(frame)
|
||||
|
||||
async def process_input_frames(self) -> None:
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
@@ -307,7 +328,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
)
|
||||
else:
|
||||
await anyio.to_thread.run_sync(
|
||||
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
||||
self.event_handler_receive, (frame.sample_rate, numpy_array)
|
||||
)
|
||||
except MediaStreamError:
|
||||
logger.debug("MediaStreamError in process_input_frames")
|
||||
@@ -342,7 +363,13 @@ class AudioCallback(AudioStreamTrack):
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
if not self.event_handler.channel_set.is_set():
|
||||
await self.event_handler.channel_set.wait()
|
||||
if current_channel.get() != self.event_handler.channel:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
|
||||
self.start()
|
||||
|
||||
frame = await self.queue.get()
|
||||
logger.debug("frame %s", frame)
|
||||
|
||||
@@ -415,7 +442,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.generator = cast(
|
||||
Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
||||
)
|
||||
|
||||
current_channel.set(self.channel)
|
||||
try:
|
||||
next_array, outputs = split_output(next(self.generator))
|
||||
if (
|
||||
@@ -470,6 +497,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
|
||||
def next(self) -> tuple[int, np.ndarray] | None:
|
||||
self.args_set.wait()
|
||||
current_channel.set(self.channel)
|
||||
if self.generator is None:
|
||||
self.generator = self.event_handler(*self.latest_args)
|
||||
if self.generator is not None:
|
||||
@@ -946,6 +974,7 @@ class WebRTC(Component):
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer) # type: ignore
|
||||
logger.debug("done handling offer about to return")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return {
|
||||
"sdp": pc.localDescription.sdp,
|
||||
|
||||
41
docs/faq.md
41
docs/faq.md
@@ -23,4 +23,45 @@ You can disable this via the `track_constraints` (see [advanced configuration](.
|
||||
mode="send-receive",
|
||||
modality="audio",
|
||||
)
|
||||
```
|
||||
|
||||
## How to raise errors in the UI
|
||||
|
||||
You can raise `WebRTCError` in order for an error message to show up in the user's screen. This is similar to how `gr.Error` works.
|
||||
|
||||
Here is a simple example:
|
||||
|
||||
```python
|
||||
def generation(num_steps):
|
||||
for _ in range(num_steps):
|
||||
segment = AudioSegment.from_file(
|
||||
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
|
||||
)
|
||||
yield (
|
||||
segment.frame_rate,
|
||||
np.array(segment.get_array_of_samples()).reshape(1, -1),
|
||||
)
|
||||
time.sleep(3.5)
|
||||
raise WebRTCError("This is a test error")
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
mode="receive",
|
||||
modality="audio",
|
||||
)
|
||||
num_steps = gr.Slider(
|
||||
label="Number of Steps",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=1,
|
||||
value=5,
|
||||
)
|
||||
button = gr.Button("Generate")
|
||||
|
||||
audio.stream(
|
||||
fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
@@ -161,6 +161,145 @@ abstraction that gives you arbitrary control over how the input audio stream and
|
||||
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
|
||||
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
|
||||
|
||||
|
||||
### Async Stream Handlers
|
||||
|
||||
It is also possible to create asynchronous stream handlers. This is very convenient for accessing async APIs from major LLM developers, like Google and OpenAI. The main difference is that `receive` and `emit` are now defined with `async def`.
|
||||
|
||||
Here is a complete example of using `AsyncStreamHandler` for using the Google Gemini real time API:
|
||||
|
||||
=== "Code"
|
||||
``` py title="AsyncStreamHandler"
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from google import genai
|
||||
from gradio_webrtc import (
|
||||
AsyncStreamHandler,
|
||||
WebRTC,
|
||||
async_aggregate_bytes_to_16bit,
|
||||
get_twilio_turn_credentials,
|
||||
)
|
||||
|
||||
class GeminiHandler(AsyncStreamHandler):
|
||||
def __init__(
|
||||
self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
|
||||
) -> None:
|
||||
super().__init__(
|
||||
expected_layout,
|
||||
output_sample_rate,
|
||||
output_frame_size,
|
||||
input_sample_rate=16000,
|
||||
)
|
||||
self.client: genai.Client | None = None
|
||||
self.input_queue = asyncio.Queue()
|
||||
self.output_queue = asyncio.Queue()
|
||||
self.quit = asyncio.Event()
|
||||
|
||||
def copy(self) -> "GeminiHandler":
|
||||
return GeminiHandler(
|
||||
expected_layout=self.expected_layout,
|
||||
output_sample_rate=self.output_sample_rate,
|
||||
output_frame_size=self.output_frame_size,
|
||||
)
|
||||
|
||||
async def stream(self):
|
||||
while not self.quit.is_set():
|
||||
audio = await self.input_queue.get()
|
||||
yield audio
|
||||
|
||||
async def connect(self, api_key: str):
|
||||
client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
|
||||
config = {"response_modalities": ["AUDIO"]}
|
||||
async with client.aio.live.connect(
|
||||
model="gemini-2.0-flash-exp", config=config
|
||||
) as session:
|
||||
async for audio in session.start_stream(
|
||||
stream=self.stream(), mime_type="audio/pcm"
|
||||
):
|
||||
if audio.data:
|
||||
yield audio.data
|
||||
|
||||
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
_, array = frame
|
||||
array = array.squeeze()
|
||||
audio_message = base64.b64encode(array.tobytes()).decode("UTF-8")
|
||||
self.input_queue.put_nowait(audio_message)
|
||||
|
||||
async def generator(self):
|
||||
async for audio_response in async_aggregate_bytes_to_16bit(
|
||||
self.connect(api_key=self.latest_args[1])
|
||||
):
|
||||
self.output_queue.put_nowait(audio_response)
|
||||
|
||||
async def emit(self):
|
||||
if not self.args_set.is_set():
|
||||
await self.wait_for_args()
|
||||
asyncio.create_task(self.generator())
|
||||
|
||||
array = await self.output_queue.get()
|
||||
return (self.output_sample_rate, array)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.quit.set()
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<div style='text-align: center'>
|
||||
<h1>Gen AI SDK Voice Chat</h1>
|
||||
<p>Speak with Gemini using real-time audio streaming</p>
|
||||
<p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
|
||||
</div>
|
||||
"""
|
||||
)
|
||||
with gr.Row() as api_key_row:
|
||||
api_key = gr.Textbox(
|
||||
label="API Key",
|
||||
placeholder="Enter your API Key",
|
||||
value=os.getenv("GOOGLE_API_KEY", ""),
|
||||
type="password",
|
||||
)
|
||||
with gr.Row(visible=False) as row:
|
||||
webrtc = WebRTC(
|
||||
label="Audio",
|
||||
modality="audio",
|
||||
mode="send-receive",
|
||||
rtc_configuration=get_twilio_turn_credentials(),
|
||||
pulse_color="rgb(35, 157, 225)",
|
||||
icon_button_color="rgb(35, 157, 225)",
|
||||
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
|
||||
)
|
||||
|
||||
webrtc.stream(
|
||||
GeminiHandler(),
|
||||
inputs=[webrtc, api_key],
|
||||
outputs=[webrtc],
|
||||
time_limit=90,
|
||||
concurrency_limit=2,
|
||||
)
|
||||
api_key.submit(
|
||||
lambda: (gr.update(visible=False), gr.update(visible=True)),
|
||||
None,
|
||||
[api_key_row, row],
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
### Accessing Other Component Values from a StreamHandler
|
||||
|
||||
In the gemini demo above, you'll notice that we have the user input their google API key. This is stored in a `gr.Textbox` parameter.
|
||||
We can access the value of this component via the `latest_args` prop of the `StreamHandler`. The `latest_args` is a list storing the values of each component in the WebRTC `stream` event `inputs` parameter. The value of the `WebRTC` component is the 0th index and it's always the dummy string `__webrtc_value__`.
|
||||
|
||||
In order to fetch the latest value from the user however, we `await self.wait_for_args()`. In a synchronous `StreamHandler`, we would call `self.wait_for_args_sync()`.
|
||||
|
||||
|
||||
### Server-To-Client Only
|
||||
|
||||
To stream only from the server to the client, implement a python generator and pass it to the component's `stream` event. The stream event must also specify a `trigger` corresponding to a UI interaction that starts the stream. In this case, it's a button click.
|
||||
|
||||
@@ -38,7 +38,11 @@
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
|
||||
const on_change_cb = (msg: "change" | "tick") => {
|
||||
const on_change_cb = (msg: "change" | "tick" | any) => {
|
||||
if (msg?.type === "info" || msg?.type === "warning" || msg?.type === "error") {
|
||||
console.log("dispatching info", msg.message);
|
||||
gradio.dispatch(msg?.type === "error"? "error": "warning", msg.message);
|
||||
}
|
||||
gradio.dispatch(msg === "change" ? "state_change" : "tick");
|
||||
}
|
||||
|
||||
@@ -93,6 +97,7 @@
|
||||
i18n={gradio.i18n}
|
||||
on:tick={() => gradio.dispatch("tick")}
|
||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||
|
||||
/>
|
||||
{:else if (mode === "send-receive" || mode == "send") && modality === "video"}
|
||||
<Video
|
||||
@@ -141,6 +146,7 @@
|
||||
{pulse_color}
|
||||
on:tick={() => gradio.dispatch("tick")}
|
||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||
on:warning={({ detail }) => gradio.dispatch("warning", detail)}
|
||||
/>
|
||||
{/if}
|
||||
</Block>
|
||||
|
||||
@@ -46,6 +46,7 @@
|
||||
});
|
||||
|
||||
let _on_change_cb = (msg: "change" | "tick" | "stopword") => {
|
||||
console.log("msg", msg);
|
||||
if (msg === "stopword") {
|
||||
console.log("stopword recognized");
|
||||
stopword_recognized = true;
|
||||
@@ -53,6 +54,7 @@
|
||||
stopword_recognized = false;
|
||||
}, 3000);
|
||||
} else {
|
||||
console.log("calling on_change_cb with msg", msg);
|
||||
on_change_cb(msg);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -64,13 +64,22 @@ export async function start(
|
||||
|
||||
data_channel.onmessage = (event) => {
|
||||
console.debug("Received message:", event.data);
|
||||
let event_json;
|
||||
try {
|
||||
event_json = JSON.parse(event.data);
|
||||
} catch (e) {
|
||||
console.debug("Error parsing JSON")
|
||||
}
|
||||
console.log("event_json", event_json);
|
||||
if (
|
||||
event.data === "change" ||
|
||||
event.data === "tick" ||
|
||||
event.data === "stopword"
|
||||
event.data === "stopword" ||
|
||||
event_json?.type === "warning" ||
|
||||
event_json?.type === "error"
|
||||
) {
|
||||
console.debug(`${event.data} event received`);
|
||||
on_change_cb(event.data);
|
||||
on_change_cb(event_json ?? event.data);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "gradio_webrtc"
|
||||
version = "0.0.25"
|
||||
version = "0.0.27"
|
||||
description = "Stream images in realtime with webrtc"
|
||||
readme = "README.md"
|
||||
license = "apache-2.0"
|
||||
|
||||
Reference in New Issue
Block a user