Add code for server to client case

This commit is contained in:
freddyaboulton
2024-10-04 17:28:26 -07:00
parent 56817f71aa
commit 9d28441995
4 changed files with 287 additions and 71 deletions

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast, Generator
from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc import RTCPeerConnection, RTCSessionDescription
@@ -22,6 +22,7 @@ from gradio.components.base import Component, server
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Timer from gradio.components import Timer
from gradio.blocks import Block from gradio.blocks import Block
from gradio.events import Dependency
if wasm_utils.IS_WASM: if wasm_utils.IS_WASM:
@@ -91,6 +92,67 @@ class VideoCallback(VideoStreamTrack):
traceback.print_exc() traceback.print_exc()
class ServerToClientVideo(VideoStreamTrack):
"""
This works for streaming input and output
"""
kind = "video"
def __init__(
self,
event_handler: Callable,
) -> None:
super().__init__() # don't forget this!
self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set"
self.generator: Generator[Any, None, Any] | None = None
def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None
) -> list[Any]:
new_args = []
for val in args:
if isinstance(val, str) and val == "__webrtc_value__":
new_args.append(frame)
else:
new_args.append(val)
return new_args
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
async def recv(self):
try:
pts, time_base = await self.next_timestamp()
if self.latest_args == "not_set":
frame = self.array_to_frame(np.zeros((480, 640, 3), dtype=np.uint8))
frame.pts = pts
frame.time_base = time_base
return frame
elif self.generator is None:
self.generator = cast(Generator[Any, None, Any], self.event_handler(*self.latest_args))
try:
next_array = next(self.generator)
except StopIteration:
print("exception")
self.stop()
return
print("pts", pts)
print("time_base", time_base)
next_frame = self.array_to_frame(next_array)
next_frame.pts = pts
next_frame.time_base = time_base
return next_frame
except Exception as e:
print(e)
import traceback
traceback.print_exc()
class WebRTC(Component): class WebRTC(Component):
""" """
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output). Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
@@ -104,7 +166,7 @@ class WebRTC(Component):
pcs: set[RTCPeerConnection] = set([]) pcs: set[RTCPeerConnection] = set([])
relay = MediaRelay() relay = MediaRelay()
connections: dict[str, VideoCallback] = {} connections: dict[str, VideoCallback | ServerToClientVideo] = {}
EVENTS = ["tick"] EVENTS = ["tick"]
@@ -129,6 +191,7 @@ class WebRTC(Component):
mirror_webcam: bool = True, mirror_webcam: bool = True,
rtc_configuration: dict[str, Any] | None = None, rtc_configuration: dict[str, Any] | None = None,
time_limit: float | None = None, time_limit: float | None = None,
mode: Literal["video-in-out", "video-out"] = "video-in-out",
): ):
""" """
Parameters: Parameters:
@@ -166,6 +229,7 @@ class WebRTC(Component):
self.mirror_webcam = mirror_webcam self.mirror_webcam = mirror_webcam
self.concurrency_limit = 1 self.concurrency_limit = 1
self.rtc_configuration = rtc_configuration self.rtc_configuration = rtc_configuration
self.mode = mode
self.event_handler: Callable | None = None self.event_handler: Callable | None = None
super().__init__( super().__init__(
label=label, label=label,
@@ -200,11 +264,14 @@ class WebRTC(Component):
Returns: Returns:
VideoData object containing the video and subtitle files. VideoData object containing the video and subtitle files.
""" """
return "__webrtc_value__" return value
def set_output(self, webrtc_id: str, *args): def set_output(self, webrtc_id: str, *args):
if webrtc_id in self.connections: if webrtc_id in self.connections:
if self.mode == "video-in-out":
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args) self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
elif self.mode == "video-out":
self.connections[webrtc_id].latest_args = list(args)
def stream( def stream(
self, self,
@@ -215,6 +282,7 @@ class WebRTC(Component):
concurrency_limit: int | None | Literal["default"] = "default", concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None, concurrency_id: str | None = None,
time_limit: float | None = None, time_limit: float | None = None,
trigger: Dependency | None = None,
): ):
from gradio.blocks import Block from gradio.blocks import Block
@@ -223,6 +291,14 @@ class WebRTC(Component):
if isinstance(outputs, Block): if isinstance(outputs, Block):
outputs = [outputs] outputs = [outputs]
self.concurrency_limit = (
1 if concurrency_limit in ["default", None] else concurrency_limit
)
self.event_handler = fn
self.time_limit = time_limit
if self.mode == "video-in-out":
if cast(list[Block], inputs)[0] != self: if cast(list[Block], inputs)[0] != self:
raise ValueError( raise ValueError(
"In the webrtc stream event, the first input component must be the WebRTC component." "In the webrtc stream event, the first input component must be the WebRTC component."
@@ -235,12 +311,6 @@ class WebRTC(Component):
raise ValueError( raise ValueError(
"In the webrtc stream event, the only output component must be the WebRTC component." "In the webrtc stream event, the only output component must be the WebRTC component."
) )
self.concurrency_limit = (
1 if concurrency_limit in ["default", None] else concurrency_limit
)
self.event_handler = fn
self.time_limit = time_limit
return self.tick( # type: ignore return self.tick( # type: ignore
self.set_output, self.set_output,
inputs=inputs, inputs=inputs,
@@ -251,6 +321,27 @@ class WebRTC(Component):
time_limit=None, time_limit=None,
js=js, js=js,
) )
elif self.mode == "video-out":
if self in cast(list[Block], inputs):
raise ValueError(
"In the video-out stream event, the WebRTC component cannot be an input."
)
if (
len(cast(list[Block], outputs)) != 1
and cast(list[Block], outputs)[0] != self
):
raise ValueError(
"In the video-out stream, the only output component must be the WebRTC component."
)
if trigger is None:
raise ValueError(
"In the video-out stream event, the trigger parameter must be provided"
)
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
self.tick(
self.set_output, inputs=[self] + inputs, outputs=None, concurrency_id=concurrency_id
)
@staticmethod @staticmethod
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float): async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
@@ -294,6 +385,12 @@ class WebRTC(Component):
self.connections[body["webrtc_id"]] = cb self.connections[body["webrtc_id"]] = cb
pc.addTrack(cb) pc.addTrack(cb)
if self.mode == "video-out":
cb = ServerToClientVideo(cast(Callable, self.event_handler))
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
# handle offer # handle offer
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)

View File

@@ -5,11 +5,12 @@
import Video from "./shared/InteractiveVideo.svelte"; import Video from "./shared/InteractiveVideo.svelte";
import { StatusTracker } from "@gradio/statustracker"; import { StatusTracker } from "@gradio/statustracker";
import type { LoadingStatus } from "@gradio/statustracker"; import type { LoadingStatus } from "@gradio/statustracker";
import StaticVideo from "./shared/StaticVideo.svelte";
export let elem_id = ""; export let elem_id = "";
export let elem_classes: string[] = []; export let elem_classes: string[] = [];
export let visible = true; export let visible = true;
export let value: string; export let value: string = "__webrtc_value__";
export let label: string; export let label: string;
export let root: string; export let root: string;
@@ -27,22 +28,7 @@
export let gradio; export let gradio;
export let rtc_configuration: Object; export let rtc_configuration: Object;
export let time_limit: number | null = null; export let time_limit: number | null = null;
// export let gradio: Gradio<{ export let mode: "video-in-out" | "video-out" = "video-in-out";
// change: never;
// clear: never;
// play: never;
// pause: never;
// upload: never;
// stop: never;
// end: never;
// start_recording: never;
// stop_recording: never;
// share: ShareData;
// error: string;
// warning: string;
// clear_status: LoadingStatus;
// tick: never;
// }>;
let dragging = false; let dragging = false;
@@ -71,13 +57,23 @@
on:clear_status={() => gradio.dispatch("clear_status", loading_status)} on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
/> />
{#if mode === "video-out"}
<StaticVideo
bind:value={value}
{label}
{show_label}
{server}
{rtc_configuration}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
{:else}
<Video <Video
bind:value={value} bind:value={value}
{label} {label}
{show_label} {show_label}
active_source={"webcam"} active_source={"webcam"}
include_audio={false} include_audio={false}
{root}
{server} {server}
{rtc_configuration} {rtc_configuration}
{time_limit} {time_limit}
@@ -96,5 +92,5 @@
> >
<UploadText i18n={gradio.i18n} type="video" /> <UploadText i18n={gradio.i18n} type="video" />
</Video> </Video>
{/if}
</Block> </Block>
<!-- {/if} -->

View File

@@ -0,0 +1,124 @@
<script lang="ts">
import { createEventDispatcher, afterUpdate, tick } from "svelte";
import {
BlockLabel,
Empty
} from "@gradio/atoms";
import { Video } from "@gradio/icons";
import { start, stop } from "./webrtc_utils";
export let value: string | null = null;
export let label: string | undefined = undefined;
export let show_label = true;
export let rtc_configuration: Object | null = null;
export let server: {
offer: (body: any) => Promise<any>;
};
let video_element: HTMLVideoElement;
let _webrtc_id = Math.random().toString(36).substring(2);
let pc: RTCPeerConnection;
const dispatch = createEventDispatcher<{
error: string;
tick: undefined;
}>();
let stream_state = "closed";
window.setInterval(() => {
if (stream_state == "open") {
dispatch("tick");
}
}, 1000);
$: console.log("static video value", value);
$: if( value === "start_webrtc_stream") {
value = _webrtc_id;
const fallback_config = {
iceServers: [
{
urls: 'stun:stun.l.google.com:19302'
}
]
};
const configuration = rtc_configuration || fallback_config;
console.log("config", configuration);
pc = new RTCPeerConnection(configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
console.log("connected");
stream_state = "open";
break;
case "disconnected":
console.log("closed");
stop(pc);
break;
default:
break;
}
}
)
start(null, pc, video_element, server.offer, _webrtc_id).then((connection) => {
pc = connection;
}).catch(() => {
console.log("catching")
dispatch("error", "Too many concurrent users. Come back later!");
});
}
</script>
<div class="wrap">
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />
{#if value === "__webrtc_value__"}
<Empty unpadded_box={true} size="large"><Video /></Empty>
{/if}
<video
class:hidden={value === "__webrtc_value__"}
bind:this={video_element}
autoplay={true}
on:loadeddata={dispatch.bind(null, "loadeddata")}
on:click={dispatch.bind(null, "click")}
on:play={dispatch.bind(null, "play")}
on:pause={dispatch.bind(null, "pause")}
on:ended={dispatch.bind(null, "ended")}
on:mouseover={dispatch.bind(null, "mouseover")}
on:mouseout={dispatch.bind(null, "mouseout")}
on:focus={dispatch.bind(null, "focus")}
on:blur={dispatch.bind(null, "blur")}
on:load
data-testid={$$props["data-testid"]}
crossorigin="anonymous"
>
<track kind="captions" />
</video>
</div>
<style>
.hidden {
display: none;
}
.wrap {
position: relative;
background-color: var(--background-fill-secondary);
height: var(--size-full);
width: var(--size-full);
border-radius: var(--radius-xl);
}
.wrap :global(video) {
height: var(--size-full);
width: var(--size-full);
}
</style>

View File

@@ -161,7 +161,6 @@
window.setInterval(() => { window.setInterval(() => {
if (stream_state == "open") { if (stream_state == "open") {
console.log("dispatching tick");
dispatch("tick"); dispatch("tick");
} }
}, stream_every * 1000); }, stream_every * 1000);