mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add code for server to client case
This commit is contained in:
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast, Generator
|
||||||
|
|
||||||
|
|
||||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||||
@@ -22,6 +22,7 @@ from gradio.components.base import Component, server
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Timer
|
from gradio.components import Timer
|
||||||
from gradio.blocks import Block
|
from gradio.blocks import Block
|
||||||
|
from gradio.events import Dependency
|
||||||
|
|
||||||
|
|
||||||
if wasm_utils.IS_WASM:
|
if wasm_utils.IS_WASM:
|
||||||
@@ -91,6 +92,67 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
class ServerToClientVideo(VideoStreamTrack):
|
||||||
|
"""
|
||||||
|
This works for streaming input and output
|
||||||
|
"""
|
||||||
|
|
||||||
|
kind = "video"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
event_handler: Callable,
|
||||||
|
) -> None:
|
||||||
|
super().__init__() # don't forget this!
|
||||||
|
self.event_handler = event_handler
|
||||||
|
self.latest_args: str | list[Any] = "not_set"
|
||||||
|
self.generator: Generator[Any, None, Any] | None = None
|
||||||
|
|
||||||
|
def add_frame_to_payload(
|
||||||
|
self, args: list[Any], frame: np.ndarray | None
|
||||||
|
) -> list[Any]:
|
||||||
|
new_args = []
|
||||||
|
for val in args:
|
||||||
|
if isinstance(val, str) and val == "__webrtc_value__":
|
||||||
|
new_args.append(frame)
|
||||||
|
else:
|
||||||
|
new_args.append(val)
|
||||||
|
return new_args
|
||||||
|
|
||||||
|
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||||
|
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
try:
|
||||||
|
|
||||||
|
pts, time_base = await self.next_timestamp()
|
||||||
|
if self.latest_args == "not_set":
|
||||||
|
frame = self.array_to_frame(np.zeros((480, 640, 3), dtype=np.uint8))
|
||||||
|
frame.pts = pts
|
||||||
|
frame.time_base = time_base
|
||||||
|
return frame
|
||||||
|
elif self.generator is None:
|
||||||
|
self.generator = cast(Generator[Any, None, Any], self.event_handler(*self.latest_args))
|
||||||
|
|
||||||
|
try:
|
||||||
|
next_array = next(self.generator)
|
||||||
|
except StopIteration:
|
||||||
|
print("exception")
|
||||||
|
self.stop()
|
||||||
|
return
|
||||||
|
|
||||||
|
print("pts", pts)
|
||||||
|
print("time_base", time_base)
|
||||||
|
next_frame = self.array_to_frame(next_array)
|
||||||
|
next_frame.pts = pts
|
||||||
|
next_frame.time_base = time_base
|
||||||
|
return next_frame
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
class WebRTC(Component):
|
class WebRTC(Component):
|
||||||
"""
|
"""
|
||||||
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
||||||
@@ -104,7 +166,7 @@ class WebRTC(Component):
|
|||||||
|
|
||||||
pcs: set[RTCPeerConnection] = set([])
|
pcs: set[RTCPeerConnection] = set([])
|
||||||
relay = MediaRelay()
|
relay = MediaRelay()
|
||||||
connections: dict[str, VideoCallback] = {}
|
connections: dict[str, VideoCallback | ServerToClientVideo] = {}
|
||||||
|
|
||||||
EVENTS = ["tick"]
|
EVENTS = ["tick"]
|
||||||
|
|
||||||
@@ -129,6 +191,7 @@ class WebRTC(Component):
|
|||||||
mirror_webcam: bool = True,
|
mirror_webcam: bool = True,
|
||||||
rtc_configuration: dict[str, Any] | None = None,
|
rtc_configuration: dict[str, Any] | None = None,
|
||||||
time_limit: float | None = None,
|
time_limit: float | None = None,
|
||||||
|
mode: Literal["video-in-out", "video-out"] = "video-in-out",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
@@ -166,6 +229,7 @@ class WebRTC(Component):
|
|||||||
self.mirror_webcam = mirror_webcam
|
self.mirror_webcam = mirror_webcam
|
||||||
self.concurrency_limit = 1
|
self.concurrency_limit = 1
|
||||||
self.rtc_configuration = rtc_configuration
|
self.rtc_configuration = rtc_configuration
|
||||||
|
self.mode = mode
|
||||||
self.event_handler: Callable | None = None
|
self.event_handler: Callable | None = None
|
||||||
super().__init__(
|
super().__init__(
|
||||||
label=label,
|
label=label,
|
||||||
@@ -200,11 +264,14 @@ class WebRTC(Component):
|
|||||||
Returns:
|
Returns:
|
||||||
VideoData object containing the video and subtitle files.
|
VideoData object containing the video and subtitle files.
|
||||||
"""
|
"""
|
||||||
return "__webrtc_value__"
|
return value
|
||||||
|
|
||||||
def set_output(self, webrtc_id: str, *args):
|
def set_output(self, webrtc_id: str, *args):
|
||||||
if webrtc_id in self.connections:
|
if webrtc_id in self.connections:
|
||||||
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
|
if self.mode == "video-in-out":
|
||||||
|
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
|
||||||
|
elif self.mode == "video-out":
|
||||||
|
self.connections[webrtc_id].latest_args = list(args)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
@@ -215,6 +282,7 @@ class WebRTC(Component):
|
|||||||
concurrency_limit: int | None | Literal["default"] = "default",
|
concurrency_limit: int | None | Literal["default"] = "default",
|
||||||
concurrency_id: str | None = None,
|
concurrency_id: str | None = None,
|
||||||
time_limit: float | None = None,
|
time_limit: float | None = None,
|
||||||
|
trigger: Dependency | None = None,
|
||||||
):
|
):
|
||||||
from gradio.blocks import Block
|
from gradio.blocks import Block
|
||||||
|
|
||||||
@@ -223,34 +291,57 @@ class WebRTC(Component):
|
|||||||
if isinstance(outputs, Block):
|
if isinstance(outputs, Block):
|
||||||
outputs = [outputs]
|
outputs = [outputs]
|
||||||
|
|
||||||
if cast(list[Block], inputs)[0] != self:
|
|
||||||
raise ValueError(
|
|
||||||
"In the webrtc stream event, the first input component must be the WebRTC component."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
len(cast(list[Block], outputs)) != 1
|
|
||||||
and cast(list[Block], outputs)[0] != self
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"In the webrtc stream event, the only output component must be the WebRTC component."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.concurrency_limit = (
|
self.concurrency_limit = (
|
||||||
1 if concurrency_limit in ["default", None] else concurrency_limit
|
1 if concurrency_limit in ["default", None] else concurrency_limit
|
||||||
)
|
)
|
||||||
self.event_handler = fn
|
self.event_handler = fn
|
||||||
self.time_limit = time_limit
|
self.time_limit = time_limit
|
||||||
return self.tick( # type: ignore
|
|
||||||
self.set_output,
|
if self.mode == "video-in-out":
|
||||||
inputs=inputs,
|
|
||||||
outputs=None,
|
if cast(list[Block], inputs)[0] != self:
|
||||||
concurrency_id=concurrency_id,
|
raise ValueError(
|
||||||
concurrency_limit=None,
|
"In the webrtc stream event, the first input component must be the WebRTC component."
|
||||||
stream_every=0.5,
|
)
|
||||||
time_limit=None,
|
|
||||||
js=js,
|
if (
|
||||||
)
|
len(cast(list[Block], outputs)) != 1
|
||||||
|
and cast(list[Block], outputs)[0] != self
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"In the webrtc stream event, the only output component must be the WebRTC component."
|
||||||
|
)
|
||||||
|
return self.tick( # type: ignore
|
||||||
|
self.set_output,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=None,
|
||||||
|
concurrency_id=concurrency_id,
|
||||||
|
concurrency_limit=None,
|
||||||
|
stream_every=0.5,
|
||||||
|
time_limit=None,
|
||||||
|
js=js,
|
||||||
|
)
|
||||||
|
elif self.mode == "video-out":
|
||||||
|
if self in cast(list[Block], inputs):
|
||||||
|
raise ValueError(
|
||||||
|
"In the video-out stream event, the WebRTC component cannot be an input."
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
len(cast(list[Block], outputs)) != 1
|
||||||
|
and cast(list[Block], outputs)[0] != self
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"In the video-out stream, the only output component must be the WebRTC component."
|
||||||
|
)
|
||||||
|
if trigger is None:
|
||||||
|
raise ValueError(
|
||||||
|
"In the video-out stream event, the trigger parameter must be provided"
|
||||||
|
)
|
||||||
|
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
|
||||||
|
self.tick(
|
||||||
|
self.set_output, inputs=[self] + inputs, outputs=None, concurrency_id=concurrency_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
||||||
@@ -293,6 +384,12 @@ class WebRTC(Component):
|
|||||||
)
|
)
|
||||||
self.connections[body["webrtc_id"]] = cb
|
self.connections[body["webrtc_id"]] = cb
|
||||||
pc.addTrack(cb)
|
pc.addTrack(cb)
|
||||||
|
|
||||||
|
if self.mode == "video-out":
|
||||||
|
cb = ServerToClientVideo(cast(Callable, self.event_handler))
|
||||||
|
pc.addTrack(cb)
|
||||||
|
self.connections[body["webrtc_id"]] = cb
|
||||||
|
|
||||||
|
|
||||||
# handle offer
|
# handle offer
|
||||||
await pc.setRemoteDescription(offer)
|
await pc.setRemoteDescription(offer)
|
||||||
|
|||||||
@@ -5,11 +5,12 @@
|
|||||||
import Video from "./shared/InteractiveVideo.svelte";
|
import Video from "./shared/InteractiveVideo.svelte";
|
||||||
import { StatusTracker } from "@gradio/statustracker";
|
import { StatusTracker } from "@gradio/statustracker";
|
||||||
import type { LoadingStatus } from "@gradio/statustracker";
|
import type { LoadingStatus } from "@gradio/statustracker";
|
||||||
|
import StaticVideo from "./shared/StaticVideo.svelte";
|
||||||
|
|
||||||
export let elem_id = "";
|
export let elem_id = "";
|
||||||
export let elem_classes: string[] = [];
|
export let elem_classes: string[] = [];
|
||||||
export let visible = true;
|
export let visible = true;
|
||||||
export let value: string;
|
export let value: string = "__webrtc_value__";
|
||||||
|
|
||||||
export let label: string;
|
export let label: string;
|
||||||
export let root: string;
|
export let root: string;
|
||||||
@@ -27,22 +28,7 @@
|
|||||||
export let gradio;
|
export let gradio;
|
||||||
export let rtc_configuration: Object;
|
export let rtc_configuration: Object;
|
||||||
export let time_limit: number | null = null;
|
export let time_limit: number | null = null;
|
||||||
// export let gradio: Gradio<{
|
export let mode: "video-in-out" | "video-out" = "video-in-out";
|
||||||
// change: never;
|
|
||||||
// clear: never;
|
|
||||||
// play: never;
|
|
||||||
// pause: never;
|
|
||||||
// upload: never;
|
|
||||||
// stop: never;
|
|
||||||
// end: never;
|
|
||||||
// start_recording: never;
|
|
||||||
// stop_recording: never;
|
|
||||||
// share: ShareData;
|
|
||||||
// error: string;
|
|
||||||
// warning: string;
|
|
||||||
// clear_status: LoadingStatus;
|
|
||||||
// tick: never;
|
|
||||||
// }>;
|
|
||||||
|
|
||||||
let dragging = false;
|
let dragging = false;
|
||||||
|
|
||||||
@@ -71,30 +57,40 @@
|
|||||||
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
|
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<Video
|
{#if mode === "video-out"}
|
||||||
bind:value={value}
|
<StaticVideo
|
||||||
{label}
|
bind:value={value}
|
||||||
{show_label}
|
{label}
|
||||||
active_source={"webcam"}
|
{show_label}
|
||||||
include_audio={false}
|
{server}
|
||||||
{root}
|
{rtc_configuration}
|
||||||
{server}
|
on:tick={() => gradio.dispatch("tick")}
|
||||||
{rtc_configuration}
|
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||||
{time_limit}
|
/>
|
||||||
on:clear={() => gradio.dispatch("clear")}
|
{:else}
|
||||||
on:play={() => gradio.dispatch("play")}
|
<Video
|
||||||
on:pause={() => gradio.dispatch("pause")}
|
bind:value={value}
|
||||||
on:upload={() => gradio.dispatch("upload")}
|
{label}
|
||||||
on:stop={() => gradio.dispatch("stop")}
|
{show_label}
|
||||||
on:end={() => gradio.dispatch("end")}
|
active_source={"webcam"}
|
||||||
on:start_recording={() => gradio.dispatch("start_recording")}
|
include_audio={false}
|
||||||
on:stop_recording={() => gradio.dispatch("stop_recording")}
|
{server}
|
||||||
on:tick={() => gradio.dispatch("tick")}
|
{rtc_configuration}
|
||||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
{time_limit}
|
||||||
i18n={gradio.i18n}
|
on:clear={() => gradio.dispatch("clear")}
|
||||||
stream_handler={(...args) => gradio.client.stream(...args)}
|
on:play={() => gradio.dispatch("play")}
|
||||||
>
|
on:pause={() => gradio.dispatch("pause")}
|
||||||
<UploadText i18n={gradio.i18n} type="video" />
|
on:upload={() => gradio.dispatch("upload")}
|
||||||
</Video>
|
on:stop={() => gradio.dispatch("stop")}
|
||||||
|
on:end={() => gradio.dispatch("end")}
|
||||||
|
on:start_recording={() => gradio.dispatch("start_recording")}
|
||||||
|
on:stop_recording={() => gradio.dispatch("stop_recording")}
|
||||||
|
on:tick={() => gradio.dispatch("tick")}
|
||||||
|
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||||
|
i18n={gradio.i18n}
|
||||||
|
stream_handler={(...args) => gradio.client.stream(...args)}
|
||||||
|
>
|
||||||
|
<UploadText i18n={gradio.i18n} type="video" />
|
||||||
|
</Video>
|
||||||
|
{/if}
|
||||||
</Block>
|
</Block>
|
||||||
<!-- {/if} -->
|
|
||||||
|
|||||||
124
frontend/shared/StaticVideo.svelte
Normal file
124
frontend/shared/StaticVideo.svelte
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { createEventDispatcher, afterUpdate, tick } from "svelte";
|
||||||
|
import {
|
||||||
|
BlockLabel,
|
||||||
|
Empty
|
||||||
|
} from "@gradio/atoms";
|
||||||
|
import { Video } from "@gradio/icons";
|
||||||
|
|
||||||
|
import { start, stop } from "./webrtc_utils";
|
||||||
|
|
||||||
|
|
||||||
|
export let value: string | null = null;
|
||||||
|
export let label: string | undefined = undefined;
|
||||||
|
export let show_label = true;
|
||||||
|
export let rtc_configuration: Object | null = null;
|
||||||
|
export let server: {
|
||||||
|
offer: (body: any) => Promise<any>;
|
||||||
|
};
|
||||||
|
|
||||||
|
let video_element: HTMLVideoElement;
|
||||||
|
|
||||||
|
let _webrtc_id = Math.random().toString(36).substring(2);
|
||||||
|
|
||||||
|
let pc: RTCPeerConnection;
|
||||||
|
|
||||||
|
const dispatch = createEventDispatcher<{
|
||||||
|
error: string;
|
||||||
|
tick: undefined;
|
||||||
|
}>();
|
||||||
|
|
||||||
|
let stream_state = "closed";
|
||||||
|
window.setInterval(() => {
|
||||||
|
if (stream_state == "open") {
|
||||||
|
dispatch("tick");
|
||||||
|
}
|
||||||
|
}, 1000);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
$: console.log("static video value", value);
|
||||||
|
$: if( value === "start_webrtc_stream") {
|
||||||
|
value = _webrtc_id;
|
||||||
|
const fallback_config = {
|
||||||
|
iceServers: [
|
||||||
|
{
|
||||||
|
urls: 'stun:stun.l.google.com:19302'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
};
|
||||||
|
const configuration = rtc_configuration || fallback_config;
|
||||||
|
console.log("config", configuration);
|
||||||
|
pc = new RTCPeerConnection(configuration);
|
||||||
|
pc.addEventListener("connectionstatechange",
|
||||||
|
async (event) => {
|
||||||
|
switch(pc.connectionState) {
|
||||||
|
case "connected":
|
||||||
|
console.log("connected");
|
||||||
|
stream_state = "open";
|
||||||
|
break;
|
||||||
|
case "disconnected":
|
||||||
|
console.log("closed");
|
||||||
|
stop(pc);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
start(null, pc, video_element, server.offer, _webrtc_id).then((connection) => {
|
||||||
|
pc = connection;
|
||||||
|
}).catch(() => {
|
||||||
|
console.log("catching")
|
||||||
|
dispatch("error", "Too many concurrent users. Come back later!");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="wrap">
|
||||||
|
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />
|
||||||
|
{#if value === "__webrtc_value__"}
|
||||||
|
<Empty unpadded_box={true} size="large"><Video /></Empty>
|
||||||
|
{/if}
|
||||||
|
<video
|
||||||
|
class:hidden={value === "__webrtc_value__"}
|
||||||
|
bind:this={video_element}
|
||||||
|
autoplay={true}
|
||||||
|
on:loadeddata={dispatch.bind(null, "loadeddata")}
|
||||||
|
on:click={dispatch.bind(null, "click")}
|
||||||
|
on:play={dispatch.bind(null, "play")}
|
||||||
|
on:pause={dispatch.bind(null, "pause")}
|
||||||
|
on:ended={dispatch.bind(null, "ended")}
|
||||||
|
on:mouseover={dispatch.bind(null, "mouseover")}
|
||||||
|
on:mouseout={dispatch.bind(null, "mouseout")}
|
||||||
|
on:focus={dispatch.bind(null, "focus")}
|
||||||
|
on:blur={dispatch.bind(null, "blur")}
|
||||||
|
on:load
|
||||||
|
data-testid={$$props["data-testid"]}
|
||||||
|
crossorigin="anonymous"
|
||||||
|
>
|
||||||
|
<track kind="captions" />
|
||||||
|
</video>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<style>
|
||||||
|
.hidden {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.wrap {
|
||||||
|
position: relative;
|
||||||
|
background-color: var(--background-fill-secondary);
|
||||||
|
height: var(--size-full);
|
||||||
|
width: var(--size-full);
|
||||||
|
border-radius: var(--radius-xl);
|
||||||
|
}
|
||||||
|
.wrap :global(video) {
|
||||||
|
height: var(--size-full);
|
||||||
|
width: var(--size-full);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -161,7 +161,6 @@
|
|||||||
|
|
||||||
window.setInterval(() => {
|
window.setInterval(() => {
|
||||||
if (stream_state == "open") {
|
if (stream_state == "open") {
|
||||||
console.log("dispatching tick");
|
|
||||||
dispatch("tick");
|
dispatch("tick");
|
||||||
}
|
}
|
||||||
}, stream_every * 1000);
|
}, stream_every * 1000);
|
||||||
|
|||||||
Reference in New Issue
Block a user