This commit is contained in:
freddyaboulton
2024-09-26 12:41:18 -04:00
parent d5f5db5f9b
commit 9e0d3f5bbf
5 changed files with 59 additions and 73 deletions

View File

@@ -1,4 +1,3 @@
from .webrtc import WebRTC
__all__ = ['WebRTC']
__all__ = ["WebRTC"]

View File

@@ -12,7 +12,7 @@ from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay
from aiortc import VideoStreamTrack
from aiortc.mediastreams import MediaStreamError
from aiortc.contrib.media import VideoFrame # type: ignore
from aiortc.contrib.media import VideoFrame # type: ignore
from gradio_client import handle_file
import numpy as np
@@ -29,7 +29,6 @@ if wasm_utils.IS_WASM:
raise ValueError("Not supported in gradio-lite!")
class VideoCallback(VideoStreamTrack):
"""
This works for streaming input and output
@@ -47,7 +46,9 @@ class VideoCallback(VideoStreamTrack):
self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set"
def add_frame_to_payload(self, args: list[Any], frame: np.ndarray | None) -> list[Any]:
def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None
) -> list[Any]:
new_args = []
for val in args:
if isinstance(val, str) and val == "__webrtc_value__":
@@ -55,14 +56,12 @@ class VideoCallback(VideoStreamTrack):
else:
new_args.append(val)
return new_args
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
async def recv(self):
try:
try:
frame = await self.track.recv()
except MediaStreamError:
@@ -70,12 +69,10 @@ class VideoCallback(VideoStreamTrack):
frame_array = frame.to_ndarray(format="bgr24")
if self.latest_args == "not_set":
print("args not set")
return frame
args = self.add_frame_to_payload(self.latest_args, frame_array)
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array = self.event_handler(*args)
new_frame = self.array_to_frame(array)
@@ -110,18 +107,11 @@ class WebRTC(Component):
relay = MediaRelay()
connections: dict[str, VideoCallback] = {}
EVENTS = [
"tick"
]
EVENTS = ["tick"]
def __init__(
self,
value: str
| Path
| tuple[str | Path, str | Path | None]
| Callable
| None = None,
*,
value: None = None,
height: int | str | None = None,
width: int | str | None = None,
label: str | None = None,
@@ -138,10 +128,6 @@ class WebRTC(Component):
render: bool = True,
key: int | str | None = None,
mirror_webcam: bool = True,
show_share_button: bool | None = None,
show_download_button: bool | None = None,
min_length: int | None = None,
max_length: int | None = None,
rtc_configuration: dict[str, Any] | None = None,
time_limit: float | None = None,
):
@@ -180,14 +166,6 @@ class WebRTC(Component):
self.width = width
self.mirror_webcam = mirror_webcam
self.concurrency_limit = 1
self.show_share_button = (
(utils.get_space() is not None)
if show_share_button is None
else show_share_button
)
self.show_download_button = show_download_button
self.min_length = min_length
self.max_length = max_length
self.rtc_configuration = rtc_configuration
self.event_handler: Callable | None = None
super().__init__(
@@ -216,7 +194,6 @@ class WebRTC(Component):
"""
return payload
def postprocess(self, value: Any) -> str:
"""
Parameters:
@@ -229,7 +206,7 @@ class WebRTC(Component):
def set_output(self, webrtc_id: str, *args):
if webrtc_id in self.connections:
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
def stream(
self,
fn: Callable[..., Any] | None = None,
@@ -238,8 +215,8 @@ class WebRTC(Component):
js: str | None = None,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
time_limit: float | None = None):
time_limit: float | None = None,
):
from gradio.blocks import Block
if isinstance(inputs, Block):
@@ -248,23 +225,33 @@ class WebRTC(Component):
outputs = [outputs]
if cast(list[Block], inputs)[0] != self:
raise ValueError("In the webrtc stream event, the first input component must be the WebRTC component.")
raise ValueError(
"In the webrtc stream event, the first input component must be the WebRTC component."
)
if len(cast(list[Block], outputs)) != 1 and cast(list[Block], outputs)[0] != self:
raise ValueError("In the webrtc stream event, the only output component must be the WebRTC component.")
if (
len(cast(list[Block], outputs)) != 1
and cast(list[Block], outputs)[0] != self
):
raise ValueError(
"In the webrtc stream event, the only output component must be the WebRTC component."
)
self.concurrency_limit = 1 if concurrency_limit in ["default", None] else concurrency_limit
self.concurrency_limit = (
1 if concurrency_limit in ["default", None] else concurrency_limit
)
self.event_handler = fn
self.time_limit = time_limit
return self.tick(self.set_output,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
concurrency_limit=None,
stream_every=0.5,
time_limit=None,
js=js
)
return self.tick( # type: ignore
self.set_output,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
concurrency_limit=None,
stream_every=0.5,
time_limit=None,
js=js,
)
@staticmethod
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
@@ -273,11 +260,10 @@ class WebRTC(Component):
@server
async def offer(self, body):
if len(self.connections) >= cast(int, self.concurrency_limit):
return {"status": "failed"}
offer = RTCSessionDescription(sdp=body['sdp'], type=body['type'])
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
pc = RTCPeerConnection()
self.pcs.add(pc)
@@ -287,14 +273,14 @@ class WebRTC(Component):
print(pc.iceConnectionState)
if pc.iceConnectionState == "failed":
await pc.close()
self.connections.pop(body['webrtc_id'], None)
self.connections.pop(body["webrtc_id"], None)
self.pcs.discard(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
if pc.connectionState in ["failed", "closed"]:
await pc.close()
self.connections.pop(body['webrtc_id'], None)
self.connections.pop(body["webrtc_id"], None)
self.pcs.discard(pc)
if pc.connectionState == "connected":
if self.time_limit is not None:
@@ -303,12 +289,12 @@ class WebRTC(Component):
@pc.on("track")
def on_track(track):
cb = VideoCallback(
self.relay.subscribe(track),
event_handler=cast(Callable, self.event_handler)
)
self.connections[body['webrtc_id']] = cb
self.relay.subscribe(track),
event_handler=cast(Callable, self.event_handler),
)
self.connections[body["webrtc_id"]] = cb
pc.addTrack(cb)
# handle offer
await pc.setRemoteDescription(offer)
@@ -317,9 +303,9 @@ class WebRTC(Component):
await pc.setLocalDescription(answer) # type: ignore
return {
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type,
}
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type,
}
def example_payload(self) -> Any:
return {
@@ -331,7 +317,5 @@ class WebRTC(Component):
def example_value(self) -> Any:
return "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
def api_info(self) -> Any:
return {"type": "number"}

View File

@@ -30,6 +30,7 @@ else:
rtc_configuration = None
def detection(image, conf_threshold=0.3):
image = cv2.resize(image, (model.input_width, model.input_height))
new_image = model.detect_objects(image, conf_threshold)

View File

@@ -8,7 +8,6 @@ from utils import draw_detections
class YOLOv10:
def __init__(self, path):
# Initialize model
self.initialize_model(path)
@@ -53,7 +52,11 @@ class YOLOv10:
)
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
boxes, scores, class_ids, = self.process_output(outputs, conf_threshold)
(
boxes,
scores,
class_ids,
) = self.process_output(outputs, conf_threshold)
return self.draw_detections(image, boxes, scores, class_ids)
def process_output(self, output, conf_threshold=0.3):
@@ -83,7 +86,7 @@ class YOLOv10:
boxes = self.rescale_boxes(boxes)
# Convert boxes to xyxy format
#boxes = xywh2xyxy(boxes)
# boxes = xywh2xyxy(boxes)
return boxes
@@ -98,10 +101,10 @@ class YOLOv10:
)
return boxes
def draw_detections(self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4):
return draw_detections(
image, boxes, scores, class_ids, mask_alpha
)
def draw_detections(
self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
):
return draw_detections(image, boxes, scores, class_ids, mask_alpha)
def get_input_details(self):
model_inputs = self.session.get_inputs()
@@ -139,7 +142,6 @@ if __name__ == "__main__":
# # Detect Objects
combined_image = yolov8_detector.detect_objects(img)
# Draw detections
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
cv2.imshow("Output", combined_image)

View File

@@ -164,7 +164,7 @@ def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
font_size = min([img_height, img_width]) * 0.0006
text_thickness = int(min([img_height, img_width]) * 0.001)
#det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
# det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
# Draw bounding boxes and labels of detections
for class_id, box, score in zip(class_ids, boxes, scores):