This commit is contained in:
freddyaboulton
2024-09-26 12:38:42 -04:00
parent 11c828edb5
commit d5f5db5f9b
10 changed files with 626 additions and 126 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
@@ -142,6 +143,7 @@ class WebRTC(Component):
min_length: int | None = None,
max_length: int | None = None,
rtc_configuration: dict[str, Any] | None = None,
time_limit: float | None = None,
):
"""
Parameters:
@@ -173,6 +175,7 @@ class WebRTC(Component):
streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding.
watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png.
"""
self.time_limit = time_limit
self.height = height
self.width = width
self.mirror_webcam = mirror_webcam
@@ -227,37 +230,52 @@ class WebRTC(Component):
if webrtc_id in self.connections:
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
def webrtc_stream(
def stream(
self,
fn: Callable[..., Any] | None = None,
inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | set[Block] | None = None,
js: str | None = None,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
stream_every: float = 0.5,
time_limit: float | None = None):
if inputs[0] != self:
raise ValueError("In the webrtc_stream event, the first input component must be the WebRTC component.")
from gradio.blocks import Block
if isinstance(inputs, Block):
inputs = [inputs]
if isinstance(outputs, Block):
outputs = [outputs]
if cast(list[Block], inputs)[0] != self:
raise ValueError("In the webrtc stream event, the first input component must be the WebRTC component.")
if len(cast(list[Block], outputs)) != 1 and cast(list[Block], outputs)[0] != self:
raise ValueError("In the webrtc stream event, the only output component must be the WebRTC component.")
self.concurrency_limit = 1 if concurrency_limit in ["default", None] else concurrency_limit
self.event_handler = fn
self.time_limit = time_limit
return self.tick(self.set_output,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
concurrency_limit=None,
stream_every=stream_every,
stream_every=0.5,
time_limit=None,
js=js
)
@staticmethod
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
await asyncio.sleep(time_limit)
await pc.close()
@server
async def offer(self, body):
if len(self.connections) >= self.concurrency_limit:
if len(self.connections) >= cast(int, self.concurrency_limit):
return {"status": "failed"}
offer = RTCSessionDescription(sdp=body['sdp'], type=body['type'])
@@ -278,6 +296,9 @@ class WebRTC(Component):
await pc.close()
self.connections.pop(body['webrtc_id'], None)
self.pcs.discard(pc)
if pc.connectionState == "connected":
if self.time_limit is not None:
asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit))
@pc.on("track")
def on_track(track):