This commit is contained in:
freddyaboulton
2024-09-26 12:41:18 -04:00
parent d5f5db5f9b
commit 9e0d3f5bbf
5 changed files with 59 additions and 73 deletions

View File

@@ -1,4 +1,3 @@
from .webrtc import WebRTC from .webrtc import WebRTC
__all__ = ['WebRTC'] __all__ = ["WebRTC"]

View File

@@ -12,7 +12,7 @@ from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay from aiortc.contrib.media import MediaRelay
from aiortc import VideoStreamTrack from aiortc import VideoStreamTrack
from aiortc.mediastreams import MediaStreamError from aiortc.mediastreams import MediaStreamError
from aiortc.contrib.media import VideoFrame # type: ignore from aiortc.contrib.media import VideoFrame # type: ignore
from gradio_client import handle_file from gradio_client import handle_file
import numpy as np import numpy as np
@@ -29,7 +29,6 @@ if wasm_utils.IS_WASM:
raise ValueError("Not supported in gradio-lite!") raise ValueError("Not supported in gradio-lite!")
class VideoCallback(VideoStreamTrack): class VideoCallback(VideoStreamTrack):
""" """
This works for streaming input and output This works for streaming input and output
@@ -47,7 +46,9 @@ class VideoCallback(VideoStreamTrack):
self.event_handler = event_handler self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set" self.latest_args: str | list[Any] = "not_set"
def add_frame_to_payload(self, args: list[Any], frame: np.ndarray | None) -> list[Any]: def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None
) -> list[Any]:
new_args = [] new_args = []
for val in args: for val in args:
if isinstance(val, str) and val == "__webrtc_value__": if isinstance(val, str) and val == "__webrtc_value__":
@@ -55,14 +56,12 @@ class VideoCallback(VideoStreamTrack):
else: else:
new_args.append(val) new_args.append(val)
return new_args return new_args
def array_to_frame(self, array: np.ndarray) -> VideoFrame: def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24") return VideoFrame.from_ndarray(array, format="bgr24")
async def recv(self): async def recv(self):
try: try:
try: try:
frame = await self.track.recv() frame = await self.track.recv()
except MediaStreamError: except MediaStreamError:
@@ -70,12 +69,10 @@ class VideoCallback(VideoStreamTrack):
frame_array = frame.to_ndarray(format="bgr24") frame_array = frame.to_ndarray(format="bgr24")
if self.latest_args == "not_set": if self.latest_args == "not_set":
print("args not set")
return frame return frame
args = self.add_frame_to_payload(self.latest_args, frame_array) args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array = self.event_handler(*args) array = self.event_handler(*args)
new_frame = self.array_to_frame(array) new_frame = self.array_to_frame(array)
@@ -110,18 +107,11 @@ class WebRTC(Component):
relay = MediaRelay() relay = MediaRelay()
connections: dict[str, VideoCallback] = {} connections: dict[str, VideoCallback] = {}
EVENTS = [ EVENTS = ["tick"]
"tick"
]
def __init__( def __init__(
self, self,
value: str value: None = None,
| Path
| tuple[str | Path, str | Path | None]
| Callable
| None = None,
*,
height: int | str | None = None, height: int | str | None = None,
width: int | str | None = None, width: int | str | None = None,
label: str | None = None, label: str | None = None,
@@ -138,10 +128,6 @@ class WebRTC(Component):
render: bool = True, render: bool = True,
key: int | str | None = None, key: int | str | None = None,
mirror_webcam: bool = True, mirror_webcam: bool = True,
show_share_button: bool | None = None,
show_download_button: bool | None = None,
min_length: int | None = None,
max_length: int | None = None,
rtc_configuration: dict[str, Any] | None = None, rtc_configuration: dict[str, Any] | None = None,
time_limit: float | None = None, time_limit: float | None = None,
): ):
@@ -180,14 +166,6 @@ class WebRTC(Component):
self.width = width self.width = width
self.mirror_webcam = mirror_webcam self.mirror_webcam = mirror_webcam
self.concurrency_limit = 1 self.concurrency_limit = 1
self.show_share_button = (
(utils.get_space() is not None)
if show_share_button is None
else show_share_button
)
self.show_download_button = show_download_button
self.min_length = min_length
self.max_length = max_length
self.rtc_configuration = rtc_configuration self.rtc_configuration = rtc_configuration
self.event_handler: Callable | None = None self.event_handler: Callable | None = None
super().__init__( super().__init__(
@@ -216,7 +194,6 @@ class WebRTC(Component):
""" """
return payload return payload
def postprocess(self, value: Any) -> str: def postprocess(self, value: Any) -> str:
""" """
Parameters: Parameters:
@@ -229,7 +206,7 @@ class WebRTC(Component):
def set_output(self, webrtc_id: str, *args): def set_output(self, webrtc_id: str, *args):
if webrtc_id in self.connections: if webrtc_id in self.connections:
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args) self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
def stream( def stream(
self, self,
fn: Callable[..., Any] | None = None, fn: Callable[..., Any] | None = None,
@@ -238,8 +215,8 @@ class WebRTC(Component):
js: str | None = None, js: str | None = None,
concurrency_limit: int | None | Literal["default"] = "default", concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None, concurrency_id: str | None = None,
time_limit: float | None = None): time_limit: float | None = None,
):
from gradio.blocks import Block from gradio.blocks import Block
if isinstance(inputs, Block): if isinstance(inputs, Block):
@@ -248,23 +225,33 @@ class WebRTC(Component):
outputs = [outputs] outputs = [outputs]
if cast(list[Block], inputs)[0] != self: if cast(list[Block], inputs)[0] != self:
raise ValueError("In the webrtc stream event, the first input component must be the WebRTC component.") raise ValueError(
"In the webrtc stream event, the first input component must be the WebRTC component."
)
if len(cast(list[Block], outputs)) != 1 and cast(list[Block], outputs)[0] != self: if (
raise ValueError("In the webrtc stream event, the only output component must be the WebRTC component.") len(cast(list[Block], outputs)) != 1
and cast(list[Block], outputs)[0] != self
):
raise ValueError(
"In the webrtc stream event, the only output component must be the WebRTC component."
)
self.concurrency_limit = 1 if concurrency_limit in ["default", None] else concurrency_limit self.concurrency_limit = (
1 if concurrency_limit in ["default", None] else concurrency_limit
)
self.event_handler = fn self.event_handler = fn
self.time_limit = time_limit self.time_limit = time_limit
return self.tick(self.set_output, return self.tick( # type: ignore
inputs=inputs, self.set_output,
outputs=None, inputs=inputs,
concurrency_id=concurrency_id, outputs=None,
concurrency_limit=None, concurrency_id=concurrency_id,
stream_every=0.5, concurrency_limit=None,
time_limit=None, stream_every=0.5,
js=js time_limit=None,
) js=js,
)
@staticmethod @staticmethod
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float): async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
@@ -273,11 +260,10 @@ class WebRTC(Component):
@server @server
async def offer(self, body): async def offer(self, body):
if len(self.connections) >= cast(int, self.concurrency_limit): if len(self.connections) >= cast(int, self.concurrency_limit):
return {"status": "failed"} return {"status": "failed"}
offer = RTCSessionDescription(sdp=body['sdp'], type=body['type']) offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
pc = RTCPeerConnection() pc = RTCPeerConnection()
self.pcs.add(pc) self.pcs.add(pc)
@@ -287,14 +273,14 @@ class WebRTC(Component):
print(pc.iceConnectionState) print(pc.iceConnectionState)
if pc.iceConnectionState == "failed": if pc.iceConnectionState == "failed":
await pc.close() await pc.close()
self.connections.pop(body['webrtc_id'], None) self.connections.pop(body["webrtc_id"], None)
self.pcs.discard(pc) self.pcs.discard(pc)
@pc.on("connectionstatechange") @pc.on("connectionstatechange")
async def on_connectionstatechange(): async def on_connectionstatechange():
if pc.connectionState in ["failed", "closed"]: if pc.connectionState in ["failed", "closed"]:
await pc.close() await pc.close()
self.connections.pop(body['webrtc_id'], None) self.connections.pop(body["webrtc_id"], None)
self.pcs.discard(pc) self.pcs.discard(pc)
if pc.connectionState == "connected": if pc.connectionState == "connected":
if self.time_limit is not None: if self.time_limit is not None:
@@ -303,12 +289,12 @@ class WebRTC(Component):
@pc.on("track") @pc.on("track")
def on_track(track): def on_track(track):
cb = VideoCallback( cb = VideoCallback(
self.relay.subscribe(track), self.relay.subscribe(track),
event_handler=cast(Callable, self.event_handler) event_handler=cast(Callable, self.event_handler),
) )
self.connections[body['webrtc_id']] = cb self.connections[body["webrtc_id"]] = cb
pc.addTrack(cb) pc.addTrack(cb)
# handle offer # handle offer
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)
@@ -317,9 +303,9 @@ class WebRTC(Component):
await pc.setLocalDescription(answer) # type: ignore await pc.setLocalDescription(answer) # type: ignore
return { return {
"sdp": pc.localDescription.sdp, "sdp": pc.localDescription.sdp,
"type": pc.localDescription.type, "type": pc.localDescription.type,
} }
def example_payload(self) -> Any: def example_payload(self) -> Any:
return { return {
@@ -331,7 +317,5 @@ class WebRTC(Component):
def example_value(self) -> Any: def example_value(self) -> Any:
return "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4" return "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
def api_info(self) -> Any: def api_info(self) -> Any:
return {"type": "number"} return {"type": "number"}

View File

@@ -30,6 +30,7 @@ else:
rtc_configuration = None rtc_configuration = None
def detection(image, conf_threshold=0.3): def detection(image, conf_threshold=0.3):
image = cv2.resize(image, (model.input_width, model.input_height)) image = cv2.resize(image, (model.input_width, model.input_height))
new_image = model.detect_objects(image, conf_threshold) new_image = model.detect_objects(image, conf_threshold)

View File

@@ -8,7 +8,6 @@ from utils import draw_detections
class YOLOv10: class YOLOv10:
def __init__(self, path): def __init__(self, path):
# Initialize model # Initialize model
self.initialize_model(path) self.initialize_model(path)
@@ -53,7 +52,11 @@ class YOLOv10:
) )
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms") print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
boxes, scores, class_ids, = self.process_output(outputs, conf_threshold) (
boxes,
scores,
class_ids,
) = self.process_output(outputs, conf_threshold)
return self.draw_detections(image, boxes, scores, class_ids) return self.draw_detections(image, boxes, scores, class_ids)
def process_output(self, output, conf_threshold=0.3): def process_output(self, output, conf_threshold=0.3):
@@ -83,7 +86,7 @@ class YOLOv10:
boxes = self.rescale_boxes(boxes) boxes = self.rescale_boxes(boxes)
# Convert boxes to xyxy format # Convert boxes to xyxy format
#boxes = xywh2xyxy(boxes) # boxes = xywh2xyxy(boxes)
return boxes return boxes
@@ -98,10 +101,10 @@ class YOLOv10:
) )
return boxes return boxes
def draw_detections(self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4): def draw_detections(
return draw_detections( self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
image, boxes, scores, class_ids, mask_alpha ):
) return draw_detections(image, boxes, scores, class_ids, mask_alpha)
def get_input_details(self): def get_input_details(self):
model_inputs = self.session.get_inputs() model_inputs = self.session.get_inputs()
@@ -139,7 +142,6 @@ if __name__ == "__main__":
# # Detect Objects # # Detect Objects
combined_image = yolov8_detector.detect_objects(img) combined_image = yolov8_detector.detect_objects(img)
# Draw detections # Draw detections
cv2.namedWindow("Output", cv2.WINDOW_NORMAL) cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
cv2.imshow("Output", combined_image) cv2.imshow("Output", combined_image)

View File

@@ -164,7 +164,7 @@ def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
font_size = min([img_height, img_width]) * 0.0006 font_size = min([img_height, img_width]) * 0.0006
text_thickness = int(min([img_height, img_width]) * 0.001) text_thickness = int(min([img_height, img_width]) * 0.001)
#det_img = draw_masks(det_img, boxes, class_ids, mask_alpha) # det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
# Draw bounding boxes and labels of detections # Draw bounding boxes and labels of detections
for class_id, box, score in zip(class_ids, boxes, scores): for class_id, box, score in zip(class_ids, boxes, scores):