mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Make sure channel is always set, be able to raise UI errors with WebRTCError (#45)
* Code * test * code * user guide
This commit is contained in:
@@ -8,6 +8,8 @@ from .reply_on_stopwords import ReplyOnStopWords
|
|||||||
from .speech_to_text import stt, stt_for_chunks
|
from .speech_to_text import stt, stt_for_chunks
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
|
Warning,
|
||||||
|
WebRTCError,
|
||||||
aggregate_bytes_to_16bit,
|
aggregate_bytes_to_16bit,
|
||||||
async_aggregate_bytes_to_16bit,
|
async_aggregate_bytes_to_16bit,
|
||||||
audio_to_bytes,
|
audio_to_bytes,
|
||||||
@@ -35,4 +37,6 @@ __all__ = [
|
|||||||
"stt_for_chunks",
|
"stt_for_chunks",
|
||||||
"StreamHandler",
|
"StreamHandler",
|
||||||
"WebRTC",
|
"WebRTC",
|
||||||
|
"WebRTCError",
|
||||||
|
"Warning",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""Take in the stream, determine if a pause happened"""
|
"""Take in the stream, determine if a pause happened"""
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
duration = len(audio) / sampling_rate
|
duration = len(audio) / sampling_rate
|
||||||
|
|
||||||
if duration >= self.algo_options.audio_chunk_duration:
|
if duration >= self.algo_options.audio_chunk_duration:
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import fractions
|
import fractions
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from contextvars import ContextVar
|
||||||
from typing import Any, Callable, Protocol, TypedDict, cast
|
from typing import Any, Callable, Protocol, TypedDict, cast
|
||||||
|
|
||||||
import av
|
import av
|
||||||
@@ -29,6 +31,55 @@ class DataChannel(Protocol):
|
|||||||
def send(self, message: str) -> None: ...
|
def send(self, message: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
current_channel: ContextVar[DataChannel | None] = ContextVar(
|
||||||
|
"current_channel", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _send_log(message: str, type: str) -> None:
|
||||||
|
async def _send(channel: DataChannel) -> None:
|
||||||
|
channel.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": type,
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel := current_channel.get():
|
||||||
|
print("channel", channel)
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
asyncio.run_coroutine_threadsafe(_send(channel), loop)
|
||||||
|
except RuntimeError:
|
||||||
|
asyncio.run(_send(channel))
|
||||||
|
|
||||||
|
|
||||||
|
def Warning( # noqa: N802
|
||||||
|
message: str = "Warning issued.",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Send a warning message that is deplayed in the UI of the application.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio : str
|
||||||
|
The warning message to send
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
_send_log(message, "warning")
|
||||||
|
|
||||||
|
|
||||||
|
class WebRTCError(Exception):
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
_send_log(message, "error")
|
||||||
|
|
||||||
|
|
||||||
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||||
if isinstance(data, AdditionalOutputs):
|
if isinstance(data, AdditionalOutputs):
|
||||||
return None, data
|
return None, data
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from gradio_client import handle_file
|
|||||||
from gradio_webrtc.utils import (
|
from gradio_webrtc.utils import (
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
DataChannel,
|
DataChannel,
|
||||||
|
current_channel,
|
||||||
player_worker_decode,
|
player_worker_decode,
|
||||||
split_output,
|
split_output,
|
||||||
)
|
)
|
||||||
@@ -84,9 +85,12 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
self.set_additional_outputs = set_additional_outputs
|
self.set_additional_outputs = set_additional_outputs
|
||||||
self.thread_quit = asyncio.Event()
|
self.thread_quit = asyncio.Event()
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
self.channel_set = asyncio.Event()
|
||||||
|
|
||||||
def set_channel(self, channel: DataChannel):
|
def set_channel(self, channel: DataChannel):
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
|
current_channel.set(channel)
|
||||||
|
self.channel_set.set()
|
||||||
|
|
||||||
def set_args(self, args: list[Any]):
|
def set_args(self, args: list[Any]):
|
||||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||||
@@ -122,6 +126,12 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
logger.debug("video callback stop")
|
logger.debug("video callback stop")
|
||||||
self.thread_quit.set()
|
self.thread_quit.set()
|
||||||
|
|
||||||
|
async def wait_for_channel(self):
|
||||||
|
if not self.channel_set.is_set():
|
||||||
|
await self.channel_set.wait()
|
||||||
|
if current_channel.get() != self.channel:
|
||||||
|
current_channel.set(self.channel)
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
@@ -129,6 +139,8 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
except MediaStreamError:
|
except MediaStreamError:
|
||||||
self.stop()
|
self.stop()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
await self.wait_for_channel()
|
||||||
frame_array = frame.to_ndarray(format="bgr24")
|
frame_array = frame.to_ndarray(format="bgr24")
|
||||||
|
|
||||||
if self.latest_args == "not_set":
|
if self.latest_args == "not_set":
|
||||||
@@ -180,6 +192,7 @@ class StreamHandlerBase(ABC):
|
|||||||
self._channel: DataChannel | None = None
|
self._channel: DataChannel | None = None
|
||||||
self._loop: asyncio.AbstractEventLoop
|
self._loop: asyncio.AbstractEventLoop
|
||||||
self.args_set = asyncio.Event()
|
self.args_set = asyncio.Event()
|
||||||
|
self.channel_set = asyncio.Event()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loop(self) -> asyncio.AbstractEventLoop:
|
def loop(self) -> asyncio.AbstractEventLoop:
|
||||||
@@ -191,6 +204,7 @@ class StreamHandlerBase(ABC):
|
|||||||
|
|
||||||
def set_channel(self, channel: DataChannel):
|
def set_channel(self, channel: DataChannel):
|
||||||
self._channel = channel
|
self._channel = channel
|
||||||
|
self.channel_set.set()
|
||||||
|
|
||||||
async def fetch_args(
|
async def fetch_args(
|
||||||
self,
|
self,
|
||||||
@@ -203,6 +217,9 @@ class StreamHandlerBase(ABC):
|
|||||||
await self.fetch_args()
|
await self.fetch_args()
|
||||||
await self.args_set.wait()
|
await self.args_set.wait()
|
||||||
|
|
||||||
|
def wait_for_args_sync(self):
|
||||||
|
asyncio.run_coroutine_threadsafe(self.wait_for_args(), self.loop).result()
|
||||||
|
|
||||||
def set_args(self, args: list[Any]):
|
def set_args(self, args: list[Any]):
|
||||||
logger.debug("setting args in audio callback %s", args)
|
logger.debug("setting args in audio callback %s", args)
|
||||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||||
@@ -275,6 +292,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
channel: DataChannel | None = None,
|
channel: DataChannel | None = None,
|
||||||
set_additional_outputs: Callable | None = None,
|
set_additional_outputs: Callable | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
self.track = track
|
self.track = track
|
||||||
self.event_handler = event_handler
|
self.event_handler = event_handler
|
||||||
self.current_timestamp = 0
|
self.current_timestamp = 0
|
||||||
@@ -286,7 +304,6 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
self.last_timestamp = 0
|
self.last_timestamp = 0
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.set_additional_outputs = set_additional_outputs
|
self.set_additional_outputs = set_additional_outputs
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def set_channel(self, channel: DataChannel):
|
def set_channel(self, channel: DataChannel):
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
@@ -295,6 +312,10 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
def set_args(self, args: list[Any]):
|
def set_args(self, args: list[Any]):
|
||||||
self.event_handler.set_args(args)
|
self.event_handler.set_args(args)
|
||||||
|
|
||||||
|
def event_handler_receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||||
|
current_channel.set(self.event_handler.channel)
|
||||||
|
return cast(Callable, self.event_handler.receive)(frame)
|
||||||
|
|
||||||
async def process_input_frames(self) -> None:
|
async def process_input_frames(self) -> None:
|
||||||
while not self.thread_quit.is_set():
|
while not self.thread_quit.is_set():
|
||||||
try:
|
try:
|
||||||
@@ -307,7 +328,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await anyio.to_thread.run_sync(
|
await anyio.to_thread.run_sync(
|
||||||
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
self.event_handler_receive, (frame.sample_rate, numpy_array)
|
||||||
)
|
)
|
||||||
except MediaStreamError:
|
except MediaStreamError:
|
||||||
logger.debug("MediaStreamError in process_input_frames")
|
logger.debug("MediaStreamError in process_input_frames")
|
||||||
@@ -342,7 +363,13 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
if self.readyState != "live":
|
if self.readyState != "live":
|
||||||
raise MediaStreamError
|
raise MediaStreamError
|
||||||
|
|
||||||
|
if not self.event_handler.channel_set.is_set():
|
||||||
|
await self.event_handler.channel_set.wait()
|
||||||
|
if current_channel.get() != self.event_handler.channel:
|
||||||
|
current_channel.set(self.event_handler.channel)
|
||||||
|
|
||||||
self.start()
|
self.start()
|
||||||
|
|
||||||
frame = await self.queue.get()
|
frame = await self.queue.get()
|
||||||
logger.debug("frame %s", frame)
|
logger.debug("frame %s", frame)
|
||||||
|
|
||||||
@@ -415,7 +442,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
self.generator = cast(
|
self.generator = cast(
|
||||||
Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
||||||
)
|
)
|
||||||
|
current_channel.set(self.channel)
|
||||||
try:
|
try:
|
||||||
next_array, outputs = split_output(next(self.generator))
|
next_array, outputs = split_output(next(self.generator))
|
||||||
if (
|
if (
|
||||||
@@ -470,6 +497,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
|
|
||||||
def next(self) -> tuple[int, np.ndarray] | None:
|
def next(self) -> tuple[int, np.ndarray] | None:
|
||||||
self.args_set.wait()
|
self.args_set.wait()
|
||||||
|
current_channel.set(self.channel)
|
||||||
if self.generator is None:
|
if self.generator is None:
|
||||||
self.generator = self.event_handler(*self.latest_args)
|
self.generator = self.event_handler(*self.latest_args)
|
||||||
if self.generator is not None:
|
if self.generator is not None:
|
||||||
@@ -946,6 +974,7 @@ class WebRTC(Component):
|
|||||||
answer = await pc.createAnswer()
|
answer = await pc.createAnswer()
|
||||||
await pc.setLocalDescription(answer) # type: ignore
|
await pc.setLocalDescription(answer) # type: ignore
|
||||||
logger.debug("done handling offer about to return")
|
logger.debug("done handling offer about to return")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"sdp": pc.localDescription.sdp,
|
"sdp": pc.localDescription.sdp,
|
||||||
|
|||||||
41
docs/faq.md
41
docs/faq.md
@@ -23,4 +23,45 @@ You can disable this via the `track_constraints` (see [advanced configuration](.
|
|||||||
mode="send-receive",
|
mode="send-receive",
|
||||||
modality="audio",
|
modality="audio",
|
||||||
)
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## How to raise errors in the UI
|
||||||
|
|
||||||
|
You can raise `WebRTCError` in order for an error message to show up in the user's screen. This is similar to how `gr.Error` works.
|
||||||
|
|
||||||
|
Here is a simple example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def generation(num_steps):
|
||||||
|
for _ in range(num_steps):
|
||||||
|
segment = AudioSegment.from_file(
|
||||||
|
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
|
||||||
|
)
|
||||||
|
yield (
|
||||||
|
segment.frame_rate,
|
||||||
|
np.array(segment.get_array_of_samples()).reshape(1, -1),
|
||||||
|
)
|
||||||
|
time.sleep(3.5)
|
||||||
|
raise WebRTCError("This is a test error")
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
audio = WebRTC(
|
||||||
|
label="Stream",
|
||||||
|
mode="receive",
|
||||||
|
modality="audio",
|
||||||
|
)
|
||||||
|
num_steps = gr.Slider(
|
||||||
|
label="Number of Steps",
|
||||||
|
minimum=1,
|
||||||
|
maximum=10,
|
||||||
|
step=1,
|
||||||
|
value=5,
|
||||||
|
)
|
||||||
|
button = gr.Button("Generate")
|
||||||
|
|
||||||
|
audio.stream(
|
||||||
|
fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click
|
||||||
|
)
|
||||||
|
|
||||||
|
demo.launch()
|
||||||
```
|
```
|
||||||
@@ -161,6 +161,145 @@ abstraction that gives you arbitrary control over how the input audio stream and
|
|||||||
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
|
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
|
||||||
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
|
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
|
||||||
|
|
||||||
|
|
||||||
|
### Async Stream Handlers
|
||||||
|
|
||||||
|
It is also possible to create asynchronous stream handlers. This is very convenient for accessing async APIs from major LLM developers, like Google and OpenAI. The main difference is that `receive` and `emit` are now defined with `async def`.
|
||||||
|
|
||||||
|
Here is a complete example of using `AsyncStreamHandler` for using the Google Gemini real time API:
|
||||||
|
|
||||||
|
=== "Code"
|
||||||
|
``` py title="AsyncStreamHandler"
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import numpy as np
|
||||||
|
from google import genai
|
||||||
|
from gradio_webrtc import (
|
||||||
|
AsyncStreamHandler,
|
||||||
|
WebRTC,
|
||||||
|
async_aggregate_bytes_to_16bit,
|
||||||
|
get_twilio_turn_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
class GeminiHandler(AsyncStreamHandler):
|
||||||
|
def __init__(
|
||||||
|
self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
expected_layout,
|
||||||
|
output_sample_rate,
|
||||||
|
output_frame_size,
|
||||||
|
input_sample_rate=16000,
|
||||||
|
)
|
||||||
|
self.client: genai.Client | None = None
|
||||||
|
self.input_queue = asyncio.Queue()
|
||||||
|
self.output_queue = asyncio.Queue()
|
||||||
|
self.quit = asyncio.Event()
|
||||||
|
|
||||||
|
def copy(self) -> "GeminiHandler":
|
||||||
|
return GeminiHandler(
|
||||||
|
expected_layout=self.expected_layout,
|
||||||
|
output_sample_rate=self.output_sample_rate,
|
||||||
|
output_frame_size=self.output_frame_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stream(self):
|
||||||
|
while not self.quit.is_set():
|
||||||
|
audio = await self.input_queue.get()
|
||||||
|
yield audio
|
||||||
|
|
||||||
|
async def connect(self, api_key: str):
|
||||||
|
client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
|
||||||
|
config = {"response_modalities": ["AUDIO"]}
|
||||||
|
async with client.aio.live.connect(
|
||||||
|
model="gemini-2.0-flash-exp", config=config
|
||||||
|
) as session:
|
||||||
|
async for audio in session.start_stream(
|
||||||
|
stream=self.stream(), mime_type="audio/pcm"
|
||||||
|
):
|
||||||
|
if audio.data:
|
||||||
|
yield audio.data
|
||||||
|
|
||||||
|
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||||
|
_, array = frame
|
||||||
|
array = array.squeeze()
|
||||||
|
audio_message = base64.b64encode(array.tobytes()).decode("UTF-8")
|
||||||
|
self.input_queue.put_nowait(audio_message)
|
||||||
|
|
||||||
|
async def generator(self):
|
||||||
|
async for audio_response in async_aggregate_bytes_to_16bit(
|
||||||
|
self.connect(api_key=self.latest_args[1])
|
||||||
|
):
|
||||||
|
self.output_queue.put_nowait(audio_response)
|
||||||
|
|
||||||
|
async def emit(self):
|
||||||
|
if not self.args_set.is_set():
|
||||||
|
await self.wait_for_args()
|
||||||
|
asyncio.create_task(self.generator())
|
||||||
|
|
||||||
|
array = await self.output_queue.get()
|
||||||
|
return (self.output_sample_rate, array)
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
self.quit.set()
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.HTML(
|
||||||
|
"""
|
||||||
|
<div style='text-align: center'>
|
||||||
|
<h1>Gen AI SDK Voice Chat</h1>
|
||||||
|
<p>Speak with Gemini using real-time audio streaming</p>
|
||||||
|
<p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with gr.Row() as api_key_row:
|
||||||
|
api_key = gr.Textbox(
|
||||||
|
label="API Key",
|
||||||
|
placeholder="Enter your API Key",
|
||||||
|
value=os.getenv("GOOGLE_API_KEY", ""),
|
||||||
|
type="password",
|
||||||
|
)
|
||||||
|
with gr.Row(visible=False) as row:
|
||||||
|
webrtc = WebRTC(
|
||||||
|
label="Audio",
|
||||||
|
modality="audio",
|
||||||
|
mode="send-receive",
|
||||||
|
rtc_configuration=get_twilio_turn_credentials(),
|
||||||
|
pulse_color="rgb(35, 157, 225)",
|
||||||
|
icon_button_color="rgb(35, 157, 225)",
|
||||||
|
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
webrtc.stream(
|
||||||
|
GeminiHandler(),
|
||||||
|
inputs=[webrtc, api_key],
|
||||||
|
outputs=[webrtc],
|
||||||
|
time_limit=90,
|
||||||
|
concurrency_limit=2,
|
||||||
|
)
|
||||||
|
api_key.submit(
|
||||||
|
lambda: (gr.update(visible=False), gr.update(visible=True)),
|
||||||
|
None,
|
||||||
|
[api_key_row, row],
|
||||||
|
)
|
||||||
|
|
||||||
|
demo.launch()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Accessing Other Component Values from a StreamHandler
|
||||||
|
|
||||||
|
In the gemini demo above, you'll notice that we have the user input their google API key. This is stored in a `gr.Textbox` parameter.
|
||||||
|
We can access the value of this component via the `latest_args` prop of the `StreamHandler`. The `latest_args` is a list storing the values of each component in the WebRTC `stream` event `inputs` parameter. The value of the `WebRTC` component is the 0th index and it's always the dummy string `__webrtc_value__`.
|
||||||
|
|
||||||
|
In order to fetch the latest value from the user however, we `await self.wait_for_args()`. In a synchronous `StreamHandler`, we would call `self.wait_for_args_sync()`.
|
||||||
|
|
||||||
|
|
||||||
### Server-To-Client Only
|
### Server-To-Client Only
|
||||||
|
|
||||||
To stream only from the server to the client, implement a python generator and pass it to the component's `stream` event. The stream event must also specify a `trigger` corresponding to a UI interaction that starts the stream. In this case, it's a button click.
|
To stream only from the server to the client, implement a python generator and pass it to the component's `stream` event. The stream event must also specify a `trigger` corresponding to a UI interaction that starts the stream. In this case, it's a button click.
|
||||||
|
|||||||
@@ -38,7 +38,11 @@
|
|||||||
export let icon_button_color: string = "var(--color-accent)";
|
export let icon_button_color: string = "var(--color-accent)";
|
||||||
export let pulse_color: string = "var(--color-accent)";
|
export let pulse_color: string = "var(--color-accent)";
|
||||||
|
|
||||||
const on_change_cb = (msg: "change" | "tick") => {
|
const on_change_cb = (msg: "change" | "tick" | any) => {
|
||||||
|
if (msg?.type === "info" || msg?.type === "warning" || msg?.type === "error") {
|
||||||
|
console.log("dispatching info", msg.message);
|
||||||
|
gradio.dispatch(msg?.type === "error"? "error": "warning", msg.message);
|
||||||
|
}
|
||||||
gradio.dispatch(msg === "change" ? "state_change" : "tick");
|
gradio.dispatch(msg === "change" ? "state_change" : "tick");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,6 +97,7 @@
|
|||||||
i18n={gradio.i18n}
|
i18n={gradio.i18n}
|
||||||
on:tick={() => gradio.dispatch("tick")}
|
on:tick={() => gradio.dispatch("tick")}
|
||||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||||
|
|
||||||
/>
|
/>
|
||||||
{:else if (mode === "send-receive" || mode == "send") && modality === "video"}
|
{:else if (mode === "send-receive" || mode == "send") && modality === "video"}
|
||||||
<Video
|
<Video
|
||||||
@@ -141,6 +146,7 @@
|
|||||||
{pulse_color}
|
{pulse_color}
|
||||||
on:tick={() => gradio.dispatch("tick")}
|
on:tick={() => gradio.dispatch("tick")}
|
||||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||||
|
on:warning={({ detail }) => gradio.dispatch("warning", detail)}
|
||||||
/>
|
/>
|
||||||
{/if}
|
{/if}
|
||||||
</Block>
|
</Block>
|
||||||
|
|||||||
@@ -46,6 +46,7 @@
|
|||||||
});
|
});
|
||||||
|
|
||||||
let _on_change_cb = (msg: "change" | "tick" | "stopword") => {
|
let _on_change_cb = (msg: "change" | "tick" | "stopword") => {
|
||||||
|
console.log("msg", msg);
|
||||||
if (msg === "stopword") {
|
if (msg === "stopword") {
|
||||||
console.log("stopword recognized");
|
console.log("stopword recognized");
|
||||||
stopword_recognized = true;
|
stopword_recognized = true;
|
||||||
@@ -53,6 +54,7 @@
|
|||||||
stopword_recognized = false;
|
stopword_recognized = false;
|
||||||
}, 3000);
|
}, 3000);
|
||||||
} else {
|
} else {
|
||||||
|
console.log("calling on_change_cb with msg", msg);
|
||||||
on_change_cb(msg);
|
on_change_cb(msg);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -64,13 +64,22 @@ export async function start(
|
|||||||
|
|
||||||
data_channel.onmessage = (event) => {
|
data_channel.onmessage = (event) => {
|
||||||
console.debug("Received message:", event.data);
|
console.debug("Received message:", event.data);
|
||||||
|
let event_json;
|
||||||
|
try {
|
||||||
|
event_json = JSON.parse(event.data);
|
||||||
|
} catch (e) {
|
||||||
|
console.debug("Error parsing JSON")
|
||||||
|
}
|
||||||
|
console.log("event_json", event_json);
|
||||||
if (
|
if (
|
||||||
event.data === "change" ||
|
event.data === "change" ||
|
||||||
event.data === "tick" ||
|
event.data === "tick" ||
|
||||||
event.data === "stopword"
|
event.data === "stopword" ||
|
||||||
|
event_json?.type === "warning" ||
|
||||||
|
event_json?.type === "error"
|
||||||
) {
|
) {
|
||||||
console.debug(`${event.data} event received`);
|
console.debug(`${event.data} event received`);
|
||||||
on_change_cb(event.data);
|
on_change_cb(event_json ?? event.data);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "gradio_webrtc"
|
name = "gradio_webrtc"
|
||||||
version = "0.0.25"
|
version = "0.0.27"
|
||||||
description = "Stream images in realtime with webrtc"
|
description = "Stream images in realtime with webrtc"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "apache-2.0"
|
license = "apache-2.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user