diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index 6783709..306a6e5 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -1,4 +1,3 @@ - from .webrtc import WebRTC -__all__ = ['WebRTC'] +__all__ = ["WebRTC"] diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index 033eb29..ae692fa 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -12,7 +12,7 @@ from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import MediaRelay from aiortc import VideoStreamTrack from aiortc.mediastreams import MediaStreamError -from aiortc.contrib.media import VideoFrame # type: ignore +from aiortc.contrib.media import VideoFrame # type: ignore from gradio_client import handle_file import numpy as np @@ -29,7 +29,6 @@ if wasm_utils.IS_WASM: raise ValueError("Not supported in gradio-lite!") - class VideoCallback(VideoStreamTrack): """ This works for streaming input and output @@ -47,7 +46,9 @@ class VideoCallback(VideoStreamTrack): self.event_handler = event_handler self.latest_args: str | list[Any] = "not_set" - def add_frame_to_payload(self, args: list[Any], frame: np.ndarray | None) -> list[Any]: + def add_frame_to_payload( + self, args: list[Any], frame: np.ndarray | None + ) -> list[Any]: new_args = [] for val in args: if isinstance(val, str) and val == "__webrtc_value__": @@ -55,14 +56,12 @@ class VideoCallback(VideoStreamTrack): else: new_args.append(val) return new_args - def array_to_frame(self, array: np.ndarray) -> VideoFrame: return VideoFrame.from_ndarray(array, format="bgr24") async def recv(self): try: - try: frame = await self.track.recv() except MediaStreamError: @@ -70,12 +69,10 @@ class VideoCallback(VideoStreamTrack): frame_array = frame.to_ndarray(format="bgr24") if self.latest_args == "not_set": - print("args not set") return frame - - args = self.add_frame_to_payload(self.latest_args, frame_array) - + args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array) + array = self.event_handler(*args) new_frame = self.array_to_frame(array) @@ -110,18 +107,11 @@ class WebRTC(Component): relay = MediaRelay() connections: dict[str, VideoCallback] = {} - EVENTS = [ - "tick" - ] + EVENTS = ["tick"] def __init__( self, - value: str - | Path - | tuple[str | Path, str | Path | None] - | Callable - | None = None, - *, + value: None = None, height: int | str | None = None, width: int | str | None = None, label: str | None = None, @@ -138,10 +128,6 @@ class WebRTC(Component): render: bool = True, key: int | str | None = None, mirror_webcam: bool = True, - show_share_button: bool | None = None, - show_download_button: bool | None = None, - min_length: int | None = None, - max_length: int | None = None, rtc_configuration: dict[str, Any] | None = None, time_limit: float | None = None, ): @@ -180,14 +166,6 @@ class WebRTC(Component): self.width = width self.mirror_webcam = mirror_webcam self.concurrency_limit = 1 - self.show_share_button = ( - (utils.get_space() is not None) - if show_share_button is None - else show_share_button - ) - self.show_download_button = show_download_button - self.min_length = min_length - self.max_length = max_length self.rtc_configuration = rtc_configuration self.event_handler: Callable | None = None super().__init__( @@ -216,7 +194,6 @@ class WebRTC(Component): """ return payload - def postprocess(self, value: Any) -> str: """ Parameters: @@ -229,7 +206,7 @@ class WebRTC(Component): def set_output(self, webrtc_id: str, *args): if webrtc_id in self.connections: self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args) - + def stream( self, fn: Callable[..., Any] | None = None, @@ -238,8 +215,8 @@ class WebRTC(Component): js: str | None = None, concurrency_limit: int | None | Literal["default"] = "default", concurrency_id: str | None = None, - time_limit: float | None = None): - + time_limit: float | None = None, + ): from gradio.blocks import Block if isinstance(inputs, Block): @@ -248,23 +225,33 @@ class WebRTC(Component): outputs = [outputs] if cast(list[Block], inputs)[0] != self: - raise ValueError("In the webrtc stream event, the first input component must be the WebRTC component.") + raise ValueError( + "In the webrtc stream event, the first input component must be the WebRTC component." + ) - if len(cast(list[Block], outputs)) != 1 and cast(list[Block], outputs)[0] != self: - raise ValueError("In the webrtc stream event, the only output component must be the WebRTC component.") + if ( + len(cast(list[Block], outputs)) != 1 + and cast(list[Block], outputs)[0] != self + ): + raise ValueError( + "In the webrtc stream event, the only output component must be the WebRTC component." + ) - self.concurrency_limit = 1 if concurrency_limit in ["default", None] else concurrency_limit + self.concurrency_limit = ( + 1 if concurrency_limit in ["default", None] else concurrency_limit + ) self.event_handler = fn self.time_limit = time_limit - return self.tick(self.set_output, - inputs=inputs, - outputs=None, - concurrency_id=concurrency_id, - concurrency_limit=None, - stream_every=0.5, - time_limit=None, - js=js - ) + return self.tick( # type: ignore + self.set_output, + inputs=inputs, + outputs=None, + concurrency_id=concurrency_id, + concurrency_limit=None, + stream_every=0.5, + time_limit=None, + js=js, + ) @staticmethod async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float): @@ -273,11 +260,10 @@ class WebRTC(Component): @server async def offer(self, body): - if len(self.connections) >= cast(int, self.concurrency_limit): return {"status": "failed"} - - offer = RTCSessionDescription(sdp=body['sdp'], type=body['type']) + + offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"]) pc = RTCPeerConnection() self.pcs.add(pc) @@ -287,14 +273,14 @@ class WebRTC(Component): print(pc.iceConnectionState) if pc.iceConnectionState == "failed": await pc.close() - self.connections.pop(body['webrtc_id'], None) + self.connections.pop(body["webrtc_id"], None) self.pcs.discard(pc) @pc.on("connectionstatechange") async def on_connectionstatechange(): if pc.connectionState in ["failed", "closed"]: await pc.close() - self.connections.pop(body['webrtc_id'], None) + self.connections.pop(body["webrtc_id"], None) self.pcs.discard(pc) if pc.connectionState == "connected": if self.time_limit is not None: @@ -303,12 +289,12 @@ class WebRTC(Component): @pc.on("track") def on_track(track): cb = VideoCallback( - self.relay.subscribe(track), - event_handler=cast(Callable, self.event_handler) - ) - self.connections[body['webrtc_id']] = cb + self.relay.subscribe(track), + event_handler=cast(Callable, self.event_handler), + ) + self.connections[body["webrtc_id"]] = cb pc.addTrack(cb) - + # handle offer await pc.setRemoteDescription(offer) @@ -317,9 +303,9 @@ class WebRTC(Component): await pc.setLocalDescription(answer) # type: ignore return { - "sdp": pc.localDescription.sdp, - "type": pc.localDescription.type, - } + "sdp": pc.localDescription.sdp, + "type": pc.localDescription.type, + } def example_payload(self) -> Any: return { @@ -331,7 +317,5 @@ class WebRTC(Component): def example_value(self) -> Any: return "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4" - def api_info(self) -> Any: return {"type": "number"} - diff --git a/demo/app.py b/demo/app.py index d2870bb..350258d 100644 --- a/demo/app.py +++ b/demo/app.py @@ -30,6 +30,7 @@ else: rtc_configuration = None + def detection(image, conf_threshold=0.3): image = cv2.resize(image, (model.input_width, model.input_height)) new_image = model.detect_objects(image, conf_threshold) diff --git a/demo/inference.py b/demo/inference.py index ff54736..1e4d9a8 100644 --- a/demo/inference.py +++ b/demo/inference.py @@ -8,7 +8,6 @@ from utils import draw_detections class YOLOv10: def __init__(self, path): - # Initialize model self.initialize_model(path) @@ -53,7 +52,11 @@ class YOLOv10: ) print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms") - boxes, scores, class_ids, = self.process_output(outputs, conf_threshold) + ( + boxes, + scores, + class_ids, + ) = self.process_output(outputs, conf_threshold) return self.draw_detections(image, boxes, scores, class_ids) def process_output(self, output, conf_threshold=0.3): @@ -83,7 +86,7 @@ class YOLOv10: boxes = self.rescale_boxes(boxes) # Convert boxes to xyxy format - #boxes = xywh2xyxy(boxes) + # boxes = xywh2xyxy(boxes) return boxes @@ -98,10 +101,10 @@ class YOLOv10: ) return boxes - def draw_detections(self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4): - return draw_detections( - image, boxes, scores, class_ids, mask_alpha - ) + def draw_detections( + self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4 + ): + return draw_detections(image, boxes, scores, class_ids, mask_alpha) def get_input_details(self): model_inputs = self.session.get_inputs() @@ -139,7 +142,6 @@ if __name__ == "__main__": # # Detect Objects combined_image = yolov8_detector.detect_objects(img) - # Draw detections cv2.namedWindow("Output", cv2.WINDOW_NORMAL) cv2.imshow("Output", combined_image) diff --git a/demo/utils.py b/demo/utils.py index cdd791c..8cdd1d2 100644 --- a/demo/utils.py +++ b/demo/utils.py @@ -164,7 +164,7 @@ def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3): font_size = min([img_height, img_width]) * 0.0006 text_thickness = int(min([img_height, img_width]) * 0.001) - #det_img = draw_masks(det_img, boxes, class_ids, mask_alpha) + # det_img = draw_masks(det_img, boxes, class_ids, mask_alpha) # Draw bounding boxes and labels of detections for class_id, box, score in zip(class_ids, boxes, scores):