mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add code for server to client case
This commit is contained in:
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, Generator
|
||||
|
||||
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
@@ -22,6 +22,7 @@ from gradio.components.base import Component, server
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Timer
|
||||
from gradio.blocks import Block
|
||||
from gradio.events import Dependency
|
||||
|
||||
|
||||
if wasm_utils.IS_WASM:
|
||||
@@ -91,6 +92,67 @@ class VideoCallback(VideoStreamTrack):
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class ServerToClientVideo(VideoStreamTrack):
|
||||
"""
|
||||
This works for streaming input and output
|
||||
"""
|
||||
|
||||
kind = "video"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_handler: Callable,
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
self.event_handler = event_handler
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
|
||||
def add_frame_to_payload(
|
||||
self, args: list[Any], frame: np.ndarray | None
|
||||
) -> list[Any]:
|
||||
new_args = []
|
||||
for val in args:
|
||||
if isinstance(val, str) and val == "__webrtc_value__":
|
||||
new_args.append(frame)
|
||||
else:
|
||||
new_args.append(val)
|
||||
return new_args
|
||||
|
||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||
|
||||
async def recv(self):
|
||||
try:
|
||||
|
||||
pts, time_base = await self.next_timestamp()
|
||||
if self.latest_args == "not_set":
|
||||
frame = self.array_to_frame(np.zeros((480, 640, 3), dtype=np.uint8))
|
||||
frame.pts = pts
|
||||
frame.time_base = time_base
|
||||
return frame
|
||||
elif self.generator is None:
|
||||
self.generator = cast(Generator[Any, None, Any], self.event_handler(*self.latest_args))
|
||||
|
||||
try:
|
||||
next_array = next(self.generator)
|
||||
except StopIteration:
|
||||
print("exception")
|
||||
self.stop()
|
||||
return
|
||||
|
||||
print("pts", pts)
|
||||
print("time_base", time_base)
|
||||
next_frame = self.array_to_frame(next_array)
|
||||
next_frame.pts = pts
|
||||
next_frame.time_base = time_base
|
||||
return next_frame
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class WebRTC(Component):
|
||||
"""
|
||||
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
||||
@@ -104,7 +166,7 @@ class WebRTC(Component):
|
||||
|
||||
pcs: set[RTCPeerConnection] = set([])
|
||||
relay = MediaRelay()
|
||||
connections: dict[str, VideoCallback] = {}
|
||||
connections: dict[str, VideoCallback | ServerToClientVideo] = {}
|
||||
|
||||
EVENTS = ["tick"]
|
||||
|
||||
@@ -129,6 +191,7 @@ class WebRTC(Component):
|
||||
mirror_webcam: bool = True,
|
||||
rtc_configuration: dict[str, Any] | None = None,
|
||||
time_limit: float | None = None,
|
||||
mode: Literal["video-in-out", "video-out"] = "video-in-out",
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@@ -166,6 +229,7 @@ class WebRTC(Component):
|
||||
self.mirror_webcam = mirror_webcam
|
||||
self.concurrency_limit = 1
|
||||
self.rtc_configuration = rtc_configuration
|
||||
self.mode = mode
|
||||
self.event_handler: Callable | None = None
|
||||
super().__init__(
|
||||
label=label,
|
||||
@@ -200,11 +264,14 @@ class WebRTC(Component):
|
||||
Returns:
|
||||
VideoData object containing the video and subtitle files.
|
||||
"""
|
||||
return "__webrtc_value__"
|
||||
return value
|
||||
|
||||
def set_output(self, webrtc_id: str, *args):
|
||||
if webrtc_id in self.connections:
|
||||
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
|
||||
if self.mode == "video-in-out":
|
||||
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
|
||||
elif self.mode == "video-out":
|
||||
self.connections[webrtc_id].latest_args = list(args)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
@@ -215,6 +282,7 @@ class WebRTC(Component):
|
||||
concurrency_limit: int | None | Literal["default"] = "default",
|
||||
concurrency_id: str | None = None,
|
||||
time_limit: float | None = None,
|
||||
trigger: Dependency | None = None,
|
||||
):
|
||||
from gradio.blocks import Block
|
||||
|
||||
@@ -223,34 +291,57 @@ class WebRTC(Component):
|
||||
if isinstance(outputs, Block):
|
||||
outputs = [outputs]
|
||||
|
||||
if cast(list[Block], inputs)[0] != self:
|
||||
raise ValueError(
|
||||
"In the webrtc stream event, the first input component must be the WebRTC component."
|
||||
)
|
||||
|
||||
if (
|
||||
len(cast(list[Block], outputs)) != 1
|
||||
and cast(list[Block], outputs)[0] != self
|
||||
):
|
||||
raise ValueError(
|
||||
"In the webrtc stream event, the only output component must be the WebRTC component."
|
||||
)
|
||||
|
||||
self.concurrency_limit = (
|
||||
1 if concurrency_limit in ["default", None] else concurrency_limit
|
||||
)
|
||||
self.event_handler = fn
|
||||
self.time_limit = time_limit
|
||||
return self.tick( # type: ignore
|
||||
self.set_output,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
concurrency_limit=None,
|
||||
stream_every=0.5,
|
||||
time_limit=None,
|
||||
js=js,
|
||||
)
|
||||
|
||||
if self.mode == "video-in-out":
|
||||
|
||||
if cast(list[Block], inputs)[0] != self:
|
||||
raise ValueError(
|
||||
"In the webrtc stream event, the first input component must be the WebRTC component."
|
||||
)
|
||||
|
||||
if (
|
||||
len(cast(list[Block], outputs)) != 1
|
||||
and cast(list[Block], outputs)[0] != self
|
||||
):
|
||||
raise ValueError(
|
||||
"In the webrtc stream event, the only output component must be the WebRTC component."
|
||||
)
|
||||
return self.tick( # type: ignore
|
||||
self.set_output,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
concurrency_limit=None,
|
||||
stream_every=0.5,
|
||||
time_limit=None,
|
||||
js=js,
|
||||
)
|
||||
elif self.mode == "video-out":
|
||||
if self in cast(list[Block], inputs):
|
||||
raise ValueError(
|
||||
"In the video-out stream event, the WebRTC component cannot be an input."
|
||||
)
|
||||
if (
|
||||
len(cast(list[Block], outputs)) != 1
|
||||
and cast(list[Block], outputs)[0] != self
|
||||
):
|
||||
raise ValueError(
|
||||
"In the video-out stream, the only output component must be the WebRTC component."
|
||||
)
|
||||
if trigger is None:
|
||||
raise ValueError(
|
||||
"In the video-out stream event, the trigger parameter must be provided"
|
||||
)
|
||||
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
|
||||
self.tick(
|
||||
self.set_output, inputs=[self] + inputs, outputs=None, concurrency_id=concurrency_id
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
||||
@@ -293,6 +384,12 @@ class WebRTC(Component):
|
||||
)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
pc.addTrack(cb)
|
||||
|
||||
if self.mode == "video-out":
|
||||
cb = ServerToClientVideo(cast(Callable, self.event_handler))
|
||||
pc.addTrack(cb)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
|
||||
|
||||
# handle offer
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
Reference in New Issue
Block a user