mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 01:49:23 +08:00
format
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
from .webrtc import WebRTC
|
||||
|
||||
__all__ = ['WebRTC']
|
||||
__all__ = ["WebRTC"]
|
||||
|
||||
@@ -12,7 +12,7 @@ from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
from aiortc.contrib.media import MediaRelay
|
||||
from aiortc import VideoStreamTrack
|
||||
from aiortc.mediastreams import MediaStreamError
|
||||
from aiortc.contrib.media import VideoFrame # type: ignore
|
||||
from aiortc.contrib.media import VideoFrame # type: ignore
|
||||
from gradio_client import handle_file
|
||||
import numpy as np
|
||||
|
||||
@@ -29,7 +29,6 @@ if wasm_utils.IS_WASM:
|
||||
raise ValueError("Not supported in gradio-lite!")
|
||||
|
||||
|
||||
|
||||
class VideoCallback(VideoStreamTrack):
|
||||
"""
|
||||
This works for streaming input and output
|
||||
@@ -47,7 +46,9 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.event_handler = event_handler
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
|
||||
def add_frame_to_payload(self, args: list[Any], frame: np.ndarray | None) -> list[Any]:
|
||||
def add_frame_to_payload(
|
||||
self, args: list[Any], frame: np.ndarray | None
|
||||
) -> list[Any]:
|
||||
new_args = []
|
||||
for val in args:
|
||||
if isinstance(val, str) and val == "__webrtc_value__":
|
||||
@@ -55,14 +56,12 @@ class VideoCallback(VideoStreamTrack):
|
||||
else:
|
||||
new_args.append(val)
|
||||
return new_args
|
||||
|
||||
|
||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||
|
||||
async def recv(self):
|
||||
try:
|
||||
|
||||
try:
|
||||
frame = await self.track.recv()
|
||||
except MediaStreamError:
|
||||
@@ -70,12 +69,10 @@ class VideoCallback(VideoStreamTrack):
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
|
||||
if self.latest_args == "not_set":
|
||||
print("args not set")
|
||||
return frame
|
||||
|
||||
|
||||
args = self.add_frame_to_payload(self.latest_args, frame_array)
|
||||
|
||||
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
||||
|
||||
array = self.event_handler(*args)
|
||||
|
||||
new_frame = self.array_to_frame(array)
|
||||
@@ -110,18 +107,11 @@ class WebRTC(Component):
|
||||
relay = MediaRelay()
|
||||
connections: dict[str, VideoCallback] = {}
|
||||
|
||||
EVENTS = [
|
||||
"tick"
|
||||
]
|
||||
EVENTS = ["tick"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: str
|
||||
| Path
|
||||
| tuple[str | Path, str | Path | None]
|
||||
| Callable
|
||||
| None = None,
|
||||
*,
|
||||
value: None = None,
|
||||
height: int | str | None = None,
|
||||
width: int | str | None = None,
|
||||
label: str | None = None,
|
||||
@@ -138,10 +128,6 @@ class WebRTC(Component):
|
||||
render: bool = True,
|
||||
key: int | str | None = None,
|
||||
mirror_webcam: bool = True,
|
||||
show_share_button: bool | None = None,
|
||||
show_download_button: bool | None = None,
|
||||
min_length: int | None = None,
|
||||
max_length: int | None = None,
|
||||
rtc_configuration: dict[str, Any] | None = None,
|
||||
time_limit: float | None = None,
|
||||
):
|
||||
@@ -180,14 +166,6 @@ class WebRTC(Component):
|
||||
self.width = width
|
||||
self.mirror_webcam = mirror_webcam
|
||||
self.concurrency_limit = 1
|
||||
self.show_share_button = (
|
||||
(utils.get_space() is not None)
|
||||
if show_share_button is None
|
||||
else show_share_button
|
||||
)
|
||||
self.show_download_button = show_download_button
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
self.rtc_configuration = rtc_configuration
|
||||
self.event_handler: Callable | None = None
|
||||
super().__init__(
|
||||
@@ -216,7 +194,6 @@ class WebRTC(Component):
|
||||
"""
|
||||
return payload
|
||||
|
||||
|
||||
def postprocess(self, value: Any) -> str:
|
||||
"""
|
||||
Parameters:
|
||||
@@ -229,7 +206,7 @@ class WebRTC(Component):
|
||||
def set_output(self, webrtc_id: str, *args):
|
||||
if webrtc_id in self.connections:
|
||||
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
|
||||
|
||||
|
||||
def stream(
|
||||
self,
|
||||
fn: Callable[..., Any] | None = None,
|
||||
@@ -238,8 +215,8 @@ class WebRTC(Component):
|
||||
js: str | None = None,
|
||||
concurrency_limit: int | None | Literal["default"] = "default",
|
||||
concurrency_id: str | None = None,
|
||||
time_limit: float | None = None):
|
||||
|
||||
time_limit: float | None = None,
|
||||
):
|
||||
from gradio.blocks import Block
|
||||
|
||||
if isinstance(inputs, Block):
|
||||
@@ -248,23 +225,33 @@ class WebRTC(Component):
|
||||
outputs = [outputs]
|
||||
|
||||
if cast(list[Block], inputs)[0] != self:
|
||||
raise ValueError("In the webrtc stream event, the first input component must be the WebRTC component.")
|
||||
raise ValueError(
|
||||
"In the webrtc stream event, the first input component must be the WebRTC component."
|
||||
)
|
||||
|
||||
if len(cast(list[Block], outputs)) != 1 and cast(list[Block], outputs)[0] != self:
|
||||
raise ValueError("In the webrtc stream event, the only output component must be the WebRTC component.")
|
||||
if (
|
||||
len(cast(list[Block], outputs)) != 1
|
||||
and cast(list[Block], outputs)[0] != self
|
||||
):
|
||||
raise ValueError(
|
||||
"In the webrtc stream event, the only output component must be the WebRTC component."
|
||||
)
|
||||
|
||||
self.concurrency_limit = 1 if concurrency_limit in ["default", None] else concurrency_limit
|
||||
self.concurrency_limit = (
|
||||
1 if concurrency_limit in ["default", None] else concurrency_limit
|
||||
)
|
||||
self.event_handler = fn
|
||||
self.time_limit = time_limit
|
||||
return self.tick(self.set_output,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
concurrency_limit=None,
|
||||
stream_every=0.5,
|
||||
time_limit=None,
|
||||
js=js
|
||||
)
|
||||
return self.tick( # type: ignore
|
||||
self.set_output,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
concurrency_limit=None,
|
||||
stream_every=0.5,
|
||||
time_limit=None,
|
||||
js=js,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
||||
@@ -273,11 +260,10 @@ class WebRTC(Component):
|
||||
|
||||
@server
|
||||
async def offer(self, body):
|
||||
|
||||
if len(self.connections) >= cast(int, self.concurrency_limit):
|
||||
return {"status": "failed"}
|
||||
|
||||
offer = RTCSessionDescription(sdp=body['sdp'], type=body['type'])
|
||||
|
||||
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
|
||||
|
||||
pc = RTCPeerConnection()
|
||||
self.pcs.add(pc)
|
||||
@@ -287,14 +273,14 @@ class WebRTC(Component):
|
||||
print(pc.iceConnectionState)
|
||||
if pc.iceConnectionState == "failed":
|
||||
await pc.close()
|
||||
self.connections.pop(body['webrtc_id'], None)
|
||||
self.connections.pop(body["webrtc_id"], None)
|
||||
self.pcs.discard(pc)
|
||||
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange():
|
||||
if pc.connectionState in ["failed", "closed"]:
|
||||
await pc.close()
|
||||
self.connections.pop(body['webrtc_id'], None)
|
||||
self.connections.pop(body["webrtc_id"], None)
|
||||
self.pcs.discard(pc)
|
||||
if pc.connectionState == "connected":
|
||||
if self.time_limit is not None:
|
||||
@@ -303,12 +289,12 @@ class WebRTC(Component):
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
cb = VideoCallback(
|
||||
self.relay.subscribe(track),
|
||||
event_handler=cast(Callable, self.event_handler)
|
||||
)
|
||||
self.connections[body['webrtc_id']] = cb
|
||||
self.relay.subscribe(track),
|
||||
event_handler=cast(Callable, self.event_handler),
|
||||
)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
pc.addTrack(cb)
|
||||
|
||||
|
||||
# handle offer
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
@@ -317,9 +303,9 @@ class WebRTC(Component):
|
||||
await pc.setLocalDescription(answer) # type: ignore
|
||||
|
||||
return {
|
||||
"sdp": pc.localDescription.sdp,
|
||||
"type": pc.localDescription.type,
|
||||
}
|
||||
"sdp": pc.localDescription.sdp,
|
||||
"type": pc.localDescription.type,
|
||||
}
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return {
|
||||
@@ -331,7 +317,5 @@ class WebRTC(Component):
|
||||
def example_value(self) -> Any:
|
||||
return "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
|
||||
|
||||
|
||||
def api_info(self) -> Any:
|
||||
return {"type": "number"}
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ else:
|
||||
|
||||
rtc_configuration = None
|
||||
|
||||
|
||||
def detection(image, conf_threshold=0.3):
|
||||
image = cv2.resize(image, (model.input_width, model.input_height))
|
||||
new_image = model.detect_objects(image, conf_threshold)
|
||||
|
||||
@@ -8,7 +8,6 @@ from utils import draw_detections
|
||||
|
||||
class YOLOv10:
|
||||
def __init__(self, path):
|
||||
|
||||
# Initialize model
|
||||
self.initialize_model(path)
|
||||
|
||||
@@ -53,7 +52,11 @@ class YOLOv10:
|
||||
)
|
||||
|
||||
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
|
||||
boxes, scores, class_ids, = self.process_output(outputs, conf_threshold)
|
||||
(
|
||||
boxes,
|
||||
scores,
|
||||
class_ids,
|
||||
) = self.process_output(outputs, conf_threshold)
|
||||
return self.draw_detections(image, boxes, scores, class_ids)
|
||||
|
||||
def process_output(self, output, conf_threshold=0.3):
|
||||
@@ -83,7 +86,7 @@ class YOLOv10:
|
||||
boxes = self.rescale_boxes(boxes)
|
||||
|
||||
# Convert boxes to xyxy format
|
||||
#boxes = xywh2xyxy(boxes)
|
||||
# boxes = xywh2xyxy(boxes)
|
||||
|
||||
return boxes
|
||||
|
||||
@@ -98,10 +101,10 @@ class YOLOv10:
|
||||
)
|
||||
return boxes
|
||||
|
||||
def draw_detections(self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4):
|
||||
return draw_detections(
|
||||
image, boxes, scores, class_ids, mask_alpha
|
||||
)
|
||||
def draw_detections(
|
||||
self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
|
||||
):
|
||||
return draw_detections(image, boxes, scores, class_ids, mask_alpha)
|
||||
|
||||
def get_input_details(self):
|
||||
model_inputs = self.session.get_inputs()
|
||||
@@ -139,7 +142,6 @@ if __name__ == "__main__":
|
||||
# # Detect Objects
|
||||
combined_image = yolov8_detector.detect_objects(img)
|
||||
|
||||
|
||||
# Draw detections
|
||||
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
|
||||
cv2.imshow("Output", combined_image)
|
||||
|
||||
@@ -164,7 +164,7 @@ def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
|
||||
font_size = min([img_height, img_width]) * 0.0006
|
||||
text_thickness = int(min([img_height, img_width]) * 0.001)
|
||||
|
||||
#det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
|
||||
# det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
|
||||
|
||||
# Draw bounding boxes and labels of detections
|
||||
for class_id, box, score in zip(class_ids, boxes, scores):
|
||||
|
||||
Reference in New Issue
Block a user