diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index 1924616..f99f83b 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -1,4 +1,5 @@ from .reply_on_pause import ReplyOnPause +from .utils import AdditionalOutputs from .webrtc import StreamHandler, WebRTC -__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"] +__all__ = ["AdditionalOutputs", "ReplyOnPause", "StreamHandler", "WebRTC"] diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index c380fbe..38f62af 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -1,9 +1,10 @@ import asyncio import fractions import logging -from typing import Callable +from typing import Any, Callable, Protocol, cast import av +import numpy as np logger = logging.getLogger(__name__) @@ -11,10 +12,38 @@ logger = logging.getLogger(__name__) AUDIO_PTIME = 0.020 +class AdditionalOutputs: + def __init__(self, *args) -> None: + self.args = args + + +class DataChannel(Protocol): + def send(self, message: str) -> None: ... + + +def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]: + if isinstance(data, tuple): + # handle the bare audio case + if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray): + return data, None + if not len(data) == 2: + raise ValueError( + "The tuple must have exactly two elements: the data and an instance of AdditionalOutputs." + ) + if not isinstance(data[-1], AdditionalOutputs): + raise ValueError( + "The last element of the tuple must be an instance of AdditionalOutputs." + ) + return data[0], cast(AdditionalOutputs, data[1]) + return data, None + + async def player_worker_decode( next_frame: Callable, queue: asyncio.Queue, thread_quit: asyncio.Event, + channel: Callable[[], DataChannel | None] | None, + set_additional_outputs: Callable | None, quit_on_none: bool = False, sample_rate: int = 48000, frame_size: int = int(48000 * AUDIO_PTIME), @@ -31,7 +60,17 @@ async def player_worker_decode( while not thread_quit.is_set(): try: # Get next frame - frame = await asyncio.wait_for(next_frame(), timeout=60) + frame, outputs = split_output( + await asyncio.wait_for(next_frame(), timeout=60) + ) + if ( + isinstance(outputs, AdditionalOutputs) + and set_additional_outputs + and channel + and channel() + ): + set_additional_outputs(outputs) + cast(DataChannel, channel()).send("change") if frame is None: if quit_on_none: @@ -65,7 +104,7 @@ async def player_worker_decode( processed_frame.time_base = audio_time_base audio_samples += processed_frame.samples await queue.put(processed_frame) - logger.debug("Queue size utils.py: %s", queue.qsize()) + logger.debug("Queue size utils.py: %s", queue.qsize()) except (TimeoutError, asyncio.TimeoutError): logger.warning( diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index 3f82197..1ac17ff 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -11,7 +11,19 @@ import traceback from abc import ABC, abstractmethod from collections.abc import Callable from copy import deepcopy -from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Concatenate, + Generator, + Iterable, + Literal, + ParamSpec, + Sequence, + TypeVar, + cast, +) import anyio.to_thread import av @@ -29,7 +41,12 @@ from gradio import wasm_utils from gradio.components.base import Component, server from gradio_client import handle_file -from gradio_webrtc.utils import player_worker_decode +from gradio_webrtc.utils import ( + AdditionalOutputs, + DataChannel, + player_worker_decode, + split_output, +) if TYPE_CHECKING: from gradio.blocks import Block @@ -55,11 +72,15 @@ class VideoCallback(VideoStreamTrack): self, track: MediaStreamTrack, event_handler: Callable, + channel: DataChannel | None = None, + set_additional_outputs: Callable | None = None, ) -> None: super().__init__() # don't forget this! self.track = track self.event_handler = event_handler self.latest_args: str | list[Any] = "not_set" + self.channel = channel + self.set_additional_outputs = set_additional_outputs def add_frame_to_payload( self, args: list[Any], frame: np.ndarray | None @@ -88,7 +109,14 @@ class VideoCallback(VideoStreamTrack): args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array) - array = self.event_handler(*args) + array, outputs = split_output(self.event_handler(*args)) + if ( + isinstance(outputs, AdditionalOutputs) + and self.set_additional_outputs + and self.channel + ): + self.set_additional_outputs(outputs) + self.channel.send("change") new_frame = self.array_to_frame(array) if frame: @@ -152,6 +180,8 @@ class AudioCallback(AudioStreamTrack): self, track: MediaStreamTrack, event_handler: StreamHandler, + channel: DataChannel | None = None, + set_additional_outputs: Callable | None = None, ) -> None: self.track = track self.event_handler = event_handler @@ -162,6 +192,8 @@ class AudioCallback(AudioStreamTrack): self._start: float | None = None self.has_started = False self.last_timestamp = 0 + self.channel = channel + self.set_additional_outputs = set_additional_outputs super().__init__() async def process_input_frames(self) -> None: @@ -189,6 +221,8 @@ class AudioCallback(AudioStreamTrack): callable, self.queue, self.thread_quit, + lambda: self.channel, + self.set_additional_outputs, False, self.event_handler.output_sample_rate, self.event_handler.output_frame_size, @@ -242,12 +276,16 @@ class ServerToClientVideo(VideoStreamTrack): def __init__( self, event_handler: Callable, + channel: DataChannel | None = None, + set_additional_outputs: Callable | None = None, ) -> None: super().__init__() # don't forget this! self.event_handler = event_handler self.args_set = asyncio.Event() self.latest_args: str | list[Any] = "not_set" self.generator: Generator[Any, None, Any] | None = None + self.channel = channel + self.set_additional_outputs = set_additional_outputs def array_to_frame(self, array: np.ndarray) -> VideoFrame: return VideoFrame.from_ndarray(array, format="bgr24") @@ -262,7 +300,14 @@ class ServerToClientVideo(VideoStreamTrack): ) try: - next_array = next(self.generator) + next_array, outputs = split_output(next(self.generator)) + if ( + isinstance(outputs, AdditionalOutputs) + and self.set_additional_outputs + and self.channel + ): + self.set_additional_outputs(outputs) + self.channel.send("change") except StopIteration: self.stop() return @@ -283,6 +328,8 @@ class ServerToClientAudio(AudioStreamTrack): def __init__( self, event_handler: Callable, + channel: DataChannel | None = None, + set_additional_outputs: Callable | None = None, ) -> None: self.generator: Generator[Any, None, Any] | None = None self.event_handler = event_handler @@ -291,6 +338,8 @@ class ServerToClientAudio(AudioStreamTrack): self.args_set = threading.Event() self.queue = asyncio.Queue() self.thread_quit = asyncio.Event() + self.channel = channel + self.set_additional_outputs = set_additional_outputs self.has_started = False self._start: float | None = None super().__init__() @@ -315,6 +364,8 @@ class ServerToClientAudio(AudioStreamTrack): callable, self.queue, self.thread_quit, + lambda: self.channel, + self.set_additional_outputs, True, ) ) @@ -353,6 +404,12 @@ class ServerToClientAudio(AudioStreamTrack): super().stop() +# For the return type +R = TypeVar("R") +# For the parameter specification +P = ParamSpec("P") + + class WebRTC(Component): """ Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output). @@ -369,8 +426,10 @@ class WebRTC(Component): connections: dict[ str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback ] = {} + data_channels: dict[str, DataChannel] = {} + additional_outputs: dict[str, AdditionalOutputs] = {} - EVENTS = ["tick"] + EVENTS = ["tick", "state_change"] def __init__( self, @@ -470,6 +529,14 @@ class WebRTC(Component): value=value, ) + def set_additional_outputs( + self, webrtc_id: str + ) -> Callable[[AdditionalOutputs], None]: + def set_outputs(outputs: AdditionalOutputs): + self.additional_outputs[webrtc_id] = outputs + + return set_outputs + def preprocess(self, payload: str) -> str: """ Parameters: @@ -498,6 +565,38 @@ class WebRTC(Component): self.connections[webrtc_id].latest_args = list(args) self.connections[webrtc_id].args_set.set() # type: ignore + def change( + self, + fn: Callable[Concatenate[P], R], + inputs: Block | Sequence[Block] | set[Block] | None = None, + outputs: Block | Sequence[Block] | set[Block] | None = None, + js: str | None = None, + concurrency_limit: int | None | Literal["default"] = "default", + concurrency_id: str | None = None, + ): + inputs = inputs or [] + if inputs and not isinstance(inputs, Iterable): + inputs = [inputs] + inputs = list(inputs) + + def handler(webrtc_id: str, *args): + if webrtc_id in self.additional_outputs: + return fn(*args, *self.additional_outputs[webrtc_id].args) # type: ignore + return ( + tuple([None for _ in range(len(outputs))]) + if isinstance(outputs, Iterable) + else None + ) + + return self.state_change( # type: ignore + fn=handler, + inputs=[self] + cast(list, inputs), + outputs=outputs, + js=js, + concurrency_limit=concurrency_limit, + concurrency_id=concurrency_id, + ) + def stream( self, fn: Callable[..., Any] | StreamHandler | None = None, @@ -599,6 +698,8 @@ class WebRTC(Component): pc = RTCPeerConnection() self.pcs.add(pc) + set_outputs = self.set_additional_outputs(body["webrtc_id"]) + @pc.on("iceconnectionstatechange") async def on_iceconnectionstatechange(): logger.debug("ICE connection state change %s", pc.iceConnectionState) @@ -627,27 +728,61 @@ class WebRTC(Component): cb = VideoCallback( relay.subscribe(track), event_handler=cast(Callable, self.event_handler), + set_additional_outputs=set_outputs, ) elif self.modality == "audio": cb = AudioCallback( relay.subscribe(track), event_handler=cast(StreamHandler, self.event_handler).copy(), + set_additional_outputs=set_outputs, ) self.connections[body["webrtc_id"]] = cb + if body["webrtc_id"] in self.data_channels: + self.connections[body["webrtc_id"]].channel = self.data_channels[ + body["webrtc_id"] + ] logger.debug("Adding track to peer connection %s", cb) pc.addTrack(cb) if self.mode == "receive": if self.modality == "video": - cb = ServerToClientVideo(cast(Callable, self.event_handler)) + cb = ServerToClientVideo( + cast(Callable, self.event_handler), + set_additional_outputs=set_outputs, + ) elif self.modality == "audio": - cb = ServerToClientAudio(cast(Callable, self.event_handler)) + cb = ServerToClientAudio( + cast(Callable, self.event_handler), + set_additional_outputs=set_outputs, + ) logger.debug("Adding track to peer connection %s", cb) pc.addTrack(cb) self.connections[body["webrtc_id"]] = cb cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None)) + @pc.on("datachannel") + def on_datachannel(channel): + print("data channel established") + logger.debug(f"Data channel established: {channel.label}") + + self.data_channels[body["webrtc_id"]] = channel + + async def set_channel(webrtc_id: str): + print("webrtc_id", webrtc_id) + while not self.connections.get(webrtc_id): + await asyncio.sleep(0.05) + print("setting channel") + self.connections[webrtc_id].channel = channel + + asyncio.create_task(set_channel(body["webrtc_id"])) + + @channel.on("message") + def on_message(message): + logger.debug(f"Received message: {message}") + if channel.readyState == "open": + channel.send(f"Server received: {message}") + # handle offer await pc.setRemoteDescription(offer) diff --git a/frontend/Index.svelte b/frontend/Index.svelte index 242841e..2ea0362 100644 --- a/frontend/Index.svelte +++ b/frontend/Index.svelte @@ -34,6 +34,10 @@ export let mode: "send-receive" | "receive" = "send-receive"; export let track_constraints: MediaTrackConstraints = {}; + const on_change_cb = () => { + gradio.dispatch("state_change"); + } + let dragging = false; $: console.log("value", value); @@ -63,6 +67,7 @@ {#if mode == "receive" && modality === "video"} gradio.dispatch("clear")} on:play={() => gradio.dispatch("play")} on:pause={() => gradio.dispatch("pause")} @@ -109,6 +116,7 @@ {:else if mode === "send-receive" && modality === "audio"} void; - $: console.log("time_limit", time_limit); + let _time_limit: number | null = null; export let server: { offer: (body: any) => Promise; @@ -41,12 +41,15 @@ const dispatch = createEventDispatcher<{ tick: undefined; + state_change: undefined; error: string play: undefined; stop: undefined; }>(); + + onMount(() => { window.setInterval(() => { if (stream_state == "open") { @@ -103,7 +106,7 @@ } if (stream == null) return; - start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => { + start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => { pc = connection; }).catch(() => { console.info("catching") diff --git a/frontend/shared/InteractiveVideo.svelte b/frontend/shared/InteractiveVideo.svelte index 5794c70..1971c97 100644 --- a/frontend/shared/InteractiveVideo.svelte +++ b/frontend/shared/InteractiveVideo.svelte @@ -21,6 +21,7 @@ }; export let rtc_configuration: Object; export let track_constraints: MediaTrackConstraints = {}; + export let on_change_cb: () => void; const dispatch = createEventDispatcher<{ change: FileData | null; @@ -50,6 +51,7 @@ {include_audio} {time_limit} {track_constraints} + {on_change_cb} on:error on:start_recording on:stop_recording diff --git a/frontend/shared/StaticAudio.svelte b/frontend/shared/StaticAudio.svelte index 6ed36cc..d2155e4 100644 --- a/frontend/shared/StaticAudio.svelte +++ b/frontend/shared/StaticAudio.svelte @@ -17,6 +17,7 @@ export let show_label = true; export let rtc_configuration: Object | null = null; export let i18n: I18nFormatter; + export let on_change_cb: () => void; export let server: { offer: (body: any) => Promise; @@ -68,7 +69,7 @@ } ) let stream = null; - start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => { + start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => { pc = connection; }).catch(() => { console.info("catching") diff --git a/frontend/shared/StaticVideo.svelte b/frontend/shared/StaticVideo.svelte index 9b5ce1a..cb1c14b 100644 --- a/frontend/shared/StaticVideo.svelte +++ b/frontend/shared/StaticVideo.svelte @@ -13,6 +13,7 @@ export let label: string | undefined = undefined; export let show_label = true; export let rtc_configuration: Object | null = null; + export let on_change_cb: () => void; export let server: { offer: (body: any) => Promise; }; @@ -59,7 +60,7 @@ } } ) - start(null, pc, video_element, server.offer, _webrtc_id).then((connection) => { + start(null, pc, video_element, server.offer, _webrtc_id, "video", on_change_cb).then((connection) => { pc = connection; }).catch(() => { console.log("catching") diff --git a/frontend/shared/Webcam.svelte b/frontend/shared/Webcam.svelte index 9c2931d..59f81f2 100644 --- a/frontend/shared/Webcam.svelte +++ b/frontend/shared/Webcam.svelte @@ -24,6 +24,7 @@ let _time_limit: number | null = null; export let time_limit: number | null = null; let stream_state: "open" | "waiting" | "closed" = "closed"; + export let on_change_cb: () => void; const _webrtc_id = Math.random().toString(36).substring(2); export const modify_stream: (state: "open" | "closed" | "waiting") => void = ( @@ -139,7 +140,7 @@ ) stream_state = "waiting" webrtc_id = Math.random().toString(36).substring(2); - start(stream, pc, video_source, server.offer, webrtc_id).then((connection) => { + start(stream, pc, video_source, server.offer, webrtc_id, "video", on_change_cb).then((connection) => { pc = connection; }).catch(() => { console.info("catching") diff --git a/frontend/shared/webrtc_utils.ts b/frontend/shared/webrtc_utils.ts index d0168a9..1079c34 100644 --- a/frontend/shared/webrtc_utils.ts +++ b/frontend/shared/webrtc_utils.ts @@ -44,8 +44,24 @@ export function createPeerConnection(pc, node) { return pc; } -export async function start(stream, pc: RTCPeerConnection, node, server_fn, webrtc_id, modality: "video" | "audio" = "video") { +export async function start(stream, pc: RTCPeerConnection, node, server_fn, webrtc_id, + modality: "video" | "audio" = "video", on_change_cb: () => void = () => {}) { pc = createPeerConnection(pc, node); + const data_channel = pc.createDataChannel("text"); + + data_channel.onopen = () => { + console.debug("Data channel is open"); + data_channel.send("handshake"); + }; + + data_channel.onmessage = (event) => { + console.debug("Received message:", event.data); + if (event.data === "change") { + console.debug("Change event received"); + on_change_cb(); + } + }; + if (stream) { stream.getTracks().forEach((track) => { console.debug("Track stream callback", track);