diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index 9ab41d1..9cfbd18 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast, Generator from aiortc import RTCPeerConnection, RTCSessionDescription @@ -22,6 +22,7 @@ from gradio.components.base import Component, server if TYPE_CHECKING: from gradio.components import Timer from gradio.blocks import Block + from gradio.events import Dependency if wasm_utils.IS_WASM: @@ -91,6 +92,67 @@ class VideoCallback(VideoStreamTrack): traceback.print_exc() +class ServerToClientVideo(VideoStreamTrack): + """ + This works for streaming input and output + """ + + kind = "video" + + def __init__( + self, + event_handler: Callable, + ) -> None: + super().__init__() # don't forget this! + self.event_handler = event_handler + self.latest_args: str | list[Any] = "not_set" + self.generator: Generator[Any, None, Any] | None = None + + def add_frame_to_payload( + self, args: list[Any], frame: np.ndarray | None + ) -> list[Any]: + new_args = [] + for val in args: + if isinstance(val, str) and val == "__webrtc_value__": + new_args.append(frame) + else: + new_args.append(val) + return new_args + + def array_to_frame(self, array: np.ndarray) -> VideoFrame: + return VideoFrame.from_ndarray(array, format="bgr24") + + async def recv(self): + try: + + pts, time_base = await self.next_timestamp() + if self.latest_args == "not_set": + frame = self.array_to_frame(np.zeros((480, 640, 3), dtype=np.uint8)) + frame.pts = pts + frame.time_base = time_base + return frame + elif self.generator is None: + self.generator = cast(Generator[Any, None, Any], self.event_handler(*self.latest_args)) + + try: + next_array = next(self.generator) + except StopIteration: + print("exception") + self.stop() + return + + print("pts", pts) + print("time_base", time_base) + next_frame = self.array_to_frame(next_array) + next_frame.pts = pts + next_frame.time_base = time_base + return next_frame + except Exception as e: + print(e) + import traceback + traceback.print_exc() + + class WebRTC(Component): """ Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output). @@ -104,7 +166,7 @@ class WebRTC(Component): pcs: set[RTCPeerConnection] = set([]) relay = MediaRelay() - connections: dict[str, VideoCallback] = {} + connections: dict[str, VideoCallback | ServerToClientVideo] = {} EVENTS = ["tick"] @@ -129,6 +191,7 @@ class WebRTC(Component): mirror_webcam: bool = True, rtc_configuration: dict[str, Any] | None = None, time_limit: float | None = None, + mode: Literal["video-in-out", "video-out"] = "video-in-out", ): """ Parameters: @@ -166,6 +229,7 @@ class WebRTC(Component): self.mirror_webcam = mirror_webcam self.concurrency_limit = 1 self.rtc_configuration = rtc_configuration + self.mode = mode self.event_handler: Callable | None = None super().__init__( label=label, @@ -200,11 +264,14 @@ class WebRTC(Component): Returns: VideoData object containing the video and subtitle files. """ - return "__webrtc_value__" + return value def set_output(self, webrtc_id: str, *args): if webrtc_id in self.connections: - self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args) + if self.mode == "video-in-out": + self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args) + elif self.mode == "video-out": + self.connections[webrtc_id].latest_args = list(args) def stream( self, @@ -215,6 +282,7 @@ class WebRTC(Component): concurrency_limit: int | None | Literal["default"] = "default", concurrency_id: str | None = None, time_limit: float | None = None, + trigger: Dependency | None = None, ): from gradio.blocks import Block @@ -223,34 +291,57 @@ class WebRTC(Component): if isinstance(outputs, Block): outputs = [outputs] - if cast(list[Block], inputs)[0] != self: - raise ValueError( - "In the webrtc stream event, the first input component must be the WebRTC component." - ) - - if ( - len(cast(list[Block], outputs)) != 1 - and cast(list[Block], outputs)[0] != self - ): - raise ValueError( - "In the webrtc stream event, the only output component must be the WebRTC component." - ) - self.concurrency_limit = ( 1 if concurrency_limit in ["default", None] else concurrency_limit ) self.event_handler = fn self.time_limit = time_limit - return self.tick( # type: ignore - self.set_output, - inputs=inputs, - outputs=None, - concurrency_id=concurrency_id, - concurrency_limit=None, - stream_every=0.5, - time_limit=None, - js=js, - ) + + if self.mode == "video-in-out": + + if cast(list[Block], inputs)[0] != self: + raise ValueError( + "In the webrtc stream event, the first input component must be the WebRTC component." + ) + + if ( + len(cast(list[Block], outputs)) != 1 + and cast(list[Block], outputs)[0] != self + ): + raise ValueError( + "In the webrtc stream event, the only output component must be the WebRTC component." + ) + return self.tick( # type: ignore + self.set_output, + inputs=inputs, + outputs=None, + concurrency_id=concurrency_id, + concurrency_limit=None, + stream_every=0.5, + time_limit=None, + js=js, + ) + elif self.mode == "video-out": + if self in cast(list[Block], inputs): + raise ValueError( + "In the video-out stream event, the WebRTC component cannot be an input." + ) + if ( + len(cast(list[Block], outputs)) != 1 + and cast(list[Block], outputs)[0] != self + ): + raise ValueError( + "In the video-out stream, the only output component must be the WebRTC component." + ) + if trigger is None: + raise ValueError( + "In the video-out stream event, the trigger parameter must be provided" + ) + trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self) + self.tick( + self.set_output, inputs=[self] + inputs, outputs=None, concurrency_id=concurrency_id + ) + @staticmethod async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float): @@ -293,6 +384,12 @@ class WebRTC(Component): ) self.connections[body["webrtc_id"]] = cb pc.addTrack(cb) + + if self.mode == "video-out": + cb = ServerToClientVideo(cast(Callable, self.event_handler)) + pc.addTrack(cb) + self.connections[body["webrtc_id"]] = cb + # handle offer await pc.setRemoteDescription(offer) diff --git a/frontend/Index.svelte b/frontend/Index.svelte index d6e9b89..0dcbeaf 100644 --- a/frontend/Index.svelte +++ b/frontend/Index.svelte @@ -5,11 +5,12 @@ import Video from "./shared/InteractiveVideo.svelte"; import { StatusTracker } from "@gradio/statustracker"; import type { LoadingStatus } from "@gradio/statustracker"; + import StaticVideo from "./shared/StaticVideo.svelte"; export let elem_id = ""; export let elem_classes: string[] = []; export let visible = true; - export let value: string; + export let value: string = "__webrtc_value__"; export let label: string; export let root: string; @@ -27,22 +28,7 @@ export let gradio; export let rtc_configuration: Object; export let time_limit: number | null = null; - // export let gradio: Gradio<{ - // change: never; - // clear: never; - // play: never; - // pause: never; - // upload: never; - // stop: never; - // end: never; - // start_recording: never; - // stop_recording: never; - // share: ShareData; - // error: string; - // warning: string; - // clear_status: LoadingStatus; - // tick: never; - // }>; + export let mode: "video-in-out" | "video-out" = "video-in-out"; let dragging = false; @@ -71,30 +57,40 @@ on:clear_status={() => gradio.dispatch("clear_status", loading_status)} /> - + {#if mode === "video-out"} + gradio.dispatch("tick")} + on:error={({ detail }) => gradio.dispatch("error", detail)} + /> + {:else} + + {/if} - diff --git a/frontend/shared/StaticVideo.svelte b/frontend/shared/StaticVideo.svelte new file mode 100644 index 0000000..1fd3db5 --- /dev/null +++ b/frontend/shared/StaticVideo.svelte @@ -0,0 +1,124 @@ + + +
+ + {#if value === "__webrtc_value__"} + + {/if} + +
+ + + + diff --git a/frontend/shared/Webcam.svelte b/frontend/shared/Webcam.svelte index 6c58007..11b5fbc 100644 --- a/frontend/shared/Webcam.svelte +++ b/frontend/shared/Webcam.svelte @@ -161,7 +161,6 @@ window.setInterval(() => { if (stream_state == "open") { - console.log("dispatching tick"); dispatch("tick"); } }, stream_every * 1000);