Merge pull request #1 from freddyaboulton/yolov8-inference

Use Yolov10 for inference. Support time_limit, change event name to `stream` instead of `webrtc_stream`
This commit is contained in:
Freddy Boulton
2024-09-26 12:52:53 -04:00
committed by GitHub
12 changed files with 582 additions and 427 deletions

182
README.md
View File

@@ -25,89 +25,41 @@ pip install gradio_webrtc
```python
import gradio as gr
import cv2
import numpy as np
from huggingface_hub import hf_hub_download
from gradio_webrtc import WebRTC
from pathlib import Path
from twilio.rest import Client
import os
from inference import YOLOv10
model_file = hf_hub_download(
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
)
model = YOLOv10(model_file)
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
client = Client(account_sid, auth_token)
token = client.tokens.create()
if account_sid and auth_token:
client = Client(account_sid, auth_token)
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
token = client.tokens.create()
CLASSES = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
directory = Path(__file__).parent
MODEL = str((directory / "MobileNetSSD_deploy.caffemodel").resolve())
PROTOTXT = str((directory / "MobileNetSSD_deploy.prototxt.txt").resolve())
net = cv2.dnn.readNetFromCaffe(PROTOTXT, MODEL)
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
def detection(image, conf_threshold=0.3):
blob = cv2.dnn.blobFromImage(
cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
)
net.setInput(blob)
detections = net.forward()
image = cv2.resize(image, (500, 500))
(h, w) = image.shape[:2]
labels = []
for i in np.arange(0, detections.shape[2]):
confidence = detections[0, 0, i, 2]
if confidence > conf_threshold:
# extract the index of the class label from the `detections`,
# then compute the (x, y)-coordinates of the bounding box for
# the object
idx = int(detections[0, 0, i, 1])
box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
(startX, startY, endX, endY) = box.astype("int")
# display the prediction
label = f"{CLASSES[idx]}: {round(confidence * 100, 2)}%"
labels.append(label)
cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
y = startY - 15 if startY - 15 > 15 else startY + 15
cv2.putText(
image, label, (startX, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[idx], 2
)
return image
image = cv2.resize(image, (model.input_width, model.input_height))
new_image = model.detect_objects(image, conf_threshold)
return cv2.resize(new_image, (500, 500))
css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
@@ -115,18 +67,20 @@ with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
YOLOv10 Webcam Stream
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️)
</h1>
""")
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
</h3>
""")
"""
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
image = WebRTC(label="Strean", rtc_configuration=rtc_configuration)
image = WebRTC(label="Stream", rtc_configuration=rtc_configuration)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
@@ -134,15 +88,12 @@ with gr.Blocks(css=css) as demo:
step=0.05,
value=0.30,
)
image.webrtc_stream(
fn=detection,
inputs=[image],
stream_every=0.05,
time_limit=30
image.stream(
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10
)
if __name__ == '__main__':
if __name__ == "__main__":
demo.launch()
```
@@ -166,11 +117,7 @@ if __name__ == '__main__':
<td align="left" style="width: 25%;">
```python
str
| Path
| tuple[str | Path, str | Path | None]
| Callable
| None
None
```
</td>
@@ -386,58 +333,6 @@ bool
<td align="left">if True webcam will be mirrored. Default is True.</td>
</tr>
<tr>
<td align="left"><code>show_share_button</code></td>
<td align="left" style="width: 25%;">
```python
bool | None
```
</td>
<td align="left"><code>None</code></td>
<td align="left">if True, will show a share icon in the corner of the component that allows user to share outputs to Hugging Face Spaces Discussions. If False, icon does not appear. If set to None (default behavior), then the icon appears if this Gradio app is launched on Spaces, but not otherwise.</td>
</tr>
<tr>
<td align="left"><code>show_download_button</code></td>
<td align="left" style="width: 25%;">
```python
bool | None
```
</td>
<td align="left"><code>None</code></td>
<td align="left">if True, will show a download icon in the corner of the component that allows user to download the output. If False, icon does not appear. By default, it will be True for output components and False for input components.</td>
</tr>
<tr>
<td align="left"><code>min_length</code></td>
<td align="left" style="width: 25%;">
```python
int | None
```
</td>
<td align="left"><code>None</code></td>
<td align="left">the minimum length of video (in seconds) that the user can pass into the prediction function. If None, there is no minimum length.</td>
</tr>
<tr>
<td align="left"><code>max_length</code></td>
<td align="left" style="width: 25%;">
```python
int | None
```
</td>
<td align="left"><code>None</code></td>
<td align="left">the maximum length of video (in seconds) that the user can pass into the prediction function. If None, there is no maximum length.</td>
</tr>
<tr>
<td align="left"><code>rtc_configuration</code></td>
<td align="left" style="width: 25%;">
@@ -446,6 +341,19 @@ int | None
dict[str, Any] | None
```
</td>
<td align="left"><code>None</code></td>
<td align="left">None</td>
</tr>
<tr>
<td align="left"><code>time_limit</code></td>
<td align="left" style="width: 25%;">
```python
float | None
```
</td>
<td align="left"><code>None</code></td>
<td align="left">None</td>

View File

@@ -1,4 +1,3 @@
from .webrtc import WebRTC
__all__ = ['WebRTC']
__all__ = ["WebRTC"]

View File

@@ -2,21 +2,21 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay
from aiortc import VideoStreamTrack
from aiortc.mediastreams import MediaStreamError
from aiortc.contrib.media import VideoFrame # type: ignore
from aiortc.contrib.media import VideoFrame # type: ignore
from gradio_client import handle_file
import numpy as np
from gradio import utils, wasm_utils
from gradio import wasm_utils
from gradio.components.base import Component, server
if TYPE_CHECKING:
@@ -28,7 +28,6 @@ if wasm_utils.IS_WASM:
raise ValueError("Not supported in gradio-lite!")
class VideoCallback(VideoStreamTrack):
"""
This works for streaming input and output
@@ -46,7 +45,9 @@ class VideoCallback(VideoStreamTrack):
self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set"
def add_frame_to_payload(self, args: list[Any], frame: np.ndarray | None) -> list[Any]:
def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None
) -> list[Any]:
new_args = []
for val in args:
if isinstance(val, str) and val == "__webrtc_value__":
@@ -54,14 +55,12 @@ class VideoCallback(VideoStreamTrack):
else:
new_args.append(val)
return new_args
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
async def recv(self):
try:
try:
frame = await self.track.recv()
except MediaStreamError:
@@ -69,12 +68,10 @@ class VideoCallback(VideoStreamTrack):
frame_array = frame.to_ndarray(format="bgr24")
if self.latest_args == "not_set":
print("args not set")
return frame
args = self.add_frame_to_payload(self.latest_args, frame_array)
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array = self.event_handler(*args)
new_frame = self.array_to_frame(array)
@@ -109,18 +106,11 @@ class WebRTC(Component):
relay = MediaRelay()
connections: dict[str, VideoCallback] = {}
EVENTS = [
"tick"
]
EVENTS = ["tick"]
def __init__(
self,
value: str
| Path
| tuple[str | Path, str | Path | None]
| Callable
| None = None,
*,
value: None = None,
height: int | str | None = None,
width: int | str | None = None,
label: str | None = None,
@@ -137,11 +127,8 @@ class WebRTC(Component):
render: bool = True,
key: int | str | None = None,
mirror_webcam: bool = True,
show_share_button: bool | None = None,
show_download_button: bool | None = None,
min_length: int | None = None,
max_length: int | None = None,
rtc_configuration: dict[str, Any] | None = None,
time_limit: float | None = None,
):
"""
Parameters:
@@ -173,18 +160,11 @@ class WebRTC(Component):
streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding.
watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png.
"""
self.time_limit = time_limit
self.height = height
self.width = width
self.mirror_webcam = mirror_webcam
self.concurrency_limit = 1
self.show_share_button = (
(utils.get_space() is not None)
if show_share_button is None
else show_share_button
)
self.show_download_button = show_download_button
self.min_length = min_length
self.max_length = max_length
self.rtc_configuration = rtc_configuration
self.event_handler: Callable | None = None
super().__init__(
@@ -213,7 +193,6 @@ class WebRTC(Component):
"""
return payload
def postprocess(self, value: Any) -> str:
"""
Parameters:
@@ -226,40 +205,64 @@ class WebRTC(Component):
def set_output(self, webrtc_id: str, *args):
if webrtc_id in self.connections:
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
def webrtc_stream(
def stream(
self,
fn: Callable[..., Any] | None = None,
inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | set[Block] | None = None,
js: str | None = None,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
stream_every: float = 0.5,
time_limit: float | None = None):
time_limit: float | None = None,
):
from gradio.blocks import Block
if inputs[0] != self:
raise ValueError("In the webrtc_stream event, the first input component must be the WebRTC component.")
self.concurrency_limit = 1 if concurrency_limit in ["default", None] else concurrency_limit
if isinstance(inputs, Block):
inputs = [inputs]
if isinstance(outputs, Block):
outputs = [outputs]
if cast(list[Block], inputs)[0] != self:
raise ValueError(
"In the webrtc stream event, the first input component must be the WebRTC component."
)
if (
len(cast(list[Block], outputs)) != 1
and cast(list[Block], outputs)[0] != self
):
raise ValueError(
"In the webrtc stream event, the only output component must be the WebRTC component."
)
self.concurrency_limit = (
1 if concurrency_limit in ["default", None] else concurrency_limit
)
self.event_handler = fn
return self.tick(self.set_output,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
concurrency_limit=None,
stream_every=stream_every,
time_limit=None,
js=js
)
self.time_limit = time_limit
return self.tick( # type: ignore
self.set_output,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
concurrency_limit=None,
stream_every=0.5,
time_limit=None,
js=js,
)
@staticmethod
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
await asyncio.sleep(time_limit)
await pc.close()
@server
async def offer(self, body):
if len(self.connections) >= self.concurrency_limit:
if len(self.connections) >= cast(int, self.concurrency_limit):
return {"status": "failed"}
offer = RTCSessionDescription(sdp=body['sdp'], type=body['type'])
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
pc = RTCPeerConnection()
self.pcs.add(pc)
@@ -269,25 +272,28 @@ class WebRTC(Component):
print(pc.iceConnectionState)
if pc.iceConnectionState == "failed":
await pc.close()
self.connections.pop(body['webrtc_id'], None)
self.connections.pop(body["webrtc_id"], None)
self.pcs.discard(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
if pc.connectionState in ["failed", "closed"]:
await pc.close()
self.connections.pop(body['webrtc_id'], None)
self.connections.pop(body["webrtc_id"], None)
self.pcs.discard(pc)
if pc.connectionState == "connected":
if self.time_limit is not None:
asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit))
@pc.on("track")
def on_track(track):
cb = VideoCallback(
self.relay.subscribe(track),
event_handler=cast(Callable, self.event_handler)
)
self.connections[body['webrtc_id']] = cb
self.relay.subscribe(track),
event_handler=cast(Callable, self.event_handler),
)
self.connections[body["webrtc_id"]] = cb
pc.addTrack(cb)
# handle offer
await pc.setRemoteDescription(offer)
@@ -296,9 +302,9 @@ class WebRTC(Component):
await pc.setLocalDescription(answer) # type: ignore
return {
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type,
}
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type,
}
def example_payload(self) -> Any:
return {
@@ -310,7 +316,5 @@ class WebRTC(Component):
def example_value(self) -> Any:
return "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
def api_info(self) -> Any:
return {"type": "number"}

View File

@@ -1,88 +1,40 @@
import gradio as gr
import cv2
import numpy as np
from huggingface_hub import hf_hub_download
from gradio_webrtc import WebRTC
from pathlib import Path
from twilio.rest import Client
import os
from inference import YOLOv10
model_file = hf_hub_download(
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
)
model = YOLOv10(model_file)
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
client = Client(account_sid, auth_token)
token = client.tokens.create()
if account_sid and auth_token:
client = Client(account_sid, auth_token)
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
token = client.tokens.create()
CLASSES = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
directory = Path(__file__).parent
MODEL = str((directory / "MobileNetSSD_deploy.caffemodel").resolve())
PROTOTXT = str((directory / "MobileNetSSD_deploy.prototxt.txt").resolve())
net = cv2.dnn.readNetFromCaffe(PROTOTXT, MODEL)
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
def detection(image, conf_threshold=0.3):
blob = cv2.dnn.blobFromImage(
cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
)
net.setInput(blob)
detections = net.forward()
image = cv2.resize(image, (500, 500))
(h, w) = image.shape[:2]
labels = []
for i in np.arange(0, detections.shape[2]):
confidence = detections[0, 0, i, 2]
if confidence > conf_threshold:
# extract the index of the class label from the `detections`,
# then compute the (x, y)-coordinates of the bounding box for
# the object
idx = int(detections[0, 0, i, 1])
box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
(startX, startY, endX, endY) = box.astype("int")
# display the prediction
label = f"{CLASSES[idx]}: {round(confidence * 100, 2)}%"
labels.append(label)
cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
y = startY - 15 if startY - 15 > 15 else startY + 15
cv2.putText(
image, label, (startX, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[idx], 2
)
return image
image = cv2.resize(image, (model.input_width, model.input_height))
new_image = model.detect_objects(image, conf_threshold)
return cv2.resize(new_image, (500, 500))
css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
@@ -90,18 +42,20 @@ with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
YOLOv10 Webcam Stream
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️)
</h1>
""")
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
</h3>
""")
"""
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
image = WebRTC(label="Strean", rtc_configuration=rtc_configuration)
image = WebRTC(label="Stream", rtc_configuration=rtc_configuration)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
@@ -109,13 +63,10 @@ with gr.Blocks(css=css) as demo:
step=0.05,
value=0.30,
)
image.webrtc_stream(
fn=detection,
inputs=[image],
stream_every=0.05,
time_limit=30
image.stream(
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10
)
if __name__ == '__main__':
if __name__ == "__main__":
demo.launch()

View File

@@ -1,45 +0,0 @@
from PIL import ImageDraw, ImageFont # type: ignore
import colorsys
def get_color(label):
# Simple hash function to generate consistent colors for each label
hash_value = hash(label)
hue = (hash_value % 100) / 100.0
saturation = 0.7
value = 0.9
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
return tuple(int(x * 255) for x in rgb)
def draw_bounding_boxes(image, results: dict, model, threshold=0.3):
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
for score, label_id, box in zip(
results["scores"], results["labels"], results["boxes"]
):
if score > threshold:
label = model.config.id2label[label_id.item()]
box = [round(i, 2) for i in box.tolist()]
color = get_color(label)
# Draw bounding box
draw.rectangle(box, outline=color, width=3) # type: ignore
# Prepare text
text = f"{label}: {score:.2f}"
text_bbox = draw.textbbox((0, 0), text, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
# Draw text background
draw.rectangle(
[box[0], box[1] - text_height - 4, box[0] + text_width, box[1]], # type: ignore
fill=color, # type: ignore
)
# Draw text
draw.text((box[0], box[1] - text_height - 4), text, fill="white", font=font)
return image

148
demo/inference.py Normal file
View File

@@ -0,0 +1,148 @@
import time
import cv2
import numpy as np
import onnxruntime
from utils import draw_detections
class YOLOv10:
def __init__(self, path):
# Initialize model
self.initialize_model(path)
def __call__(self, image):
return self.detect_objects(image)
def initialize_model(self, path):
self.session = onnxruntime.InferenceSession(
path, providers=onnxruntime.get_available_providers()
)
# Get model info
self.get_input_details()
self.get_output_details()
def detect_objects(self, image, conf_threshold=0.3):
input_tensor = self.prepare_input(image)
# Perform inference on the image
new_image = self.inference(image, input_tensor, conf_threshold)
return new_image
def prepare_input(self, image):
self.img_height, self.img_width = image.shape[:2]
input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Resize input image
input_img = cv2.resize(input_img, (self.input_width, self.input_height))
# Scale input pixel values to 0 to 1
input_img = input_img / 255.0
input_img = input_img.transpose(2, 0, 1)
input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
return input_tensor
def inference(self, image, input_tensor, conf_threshold=0.3):
start = time.perf_counter()
outputs = self.session.run(
self.output_names, {self.input_names[0]: input_tensor}
)
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
(
boxes,
scores,
class_ids,
) = self.process_output(outputs, conf_threshold)
return self.draw_detections(image, boxes, scores, class_ids)
def process_output(self, output, conf_threshold=0.3):
predictions = np.squeeze(output[0])
# Filter out object confidence scores below threshold
scores = predictions[:, 4]
predictions = predictions[scores > conf_threshold, :]
scores = scores[scores > conf_threshold]
if len(scores) == 0:
return [], [], []
# Get the class with the highest confidence
class_ids = np.argmax(predictions[:, 4:], axis=1)
# Get bounding boxes for each object
boxes = self.extract_boxes(predictions)
return boxes, scores, class_ids
def extract_boxes(self, predictions):
# Extract boxes from predictions
boxes = predictions[:, :4]
# Scale boxes to original image dimensions
boxes = self.rescale_boxes(boxes)
# Convert boxes to xyxy format
# boxes = xywh2xyxy(boxes)
return boxes
def rescale_boxes(self, boxes):
# Rescale boxes to original image dimensions
input_shape = np.array(
[self.input_width, self.input_height, self.input_width, self.input_height]
)
boxes = np.divide(boxes, input_shape, dtype=np.float32)
boxes *= np.array(
[self.img_width, self.img_height, self.img_width, self.img_height]
)
return boxes
def draw_detections(
self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
):
return draw_detections(image, boxes, scores, class_ids, mask_alpha)
def get_input_details(self):
model_inputs = self.session.get_inputs()
self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
self.input_shape = model_inputs[0].shape
self.input_height = self.input_shape[2]
self.input_width = self.input_shape[3]
def get_output_details(self):
model_outputs = self.session.get_outputs()
self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
if __name__ == "__main__":
import requests
import tempfile
from huggingface_hub import hf_hub_download
model_file = hf_hub_download(
repo_id="onnx-community/yolov10s", filename="onnx/model.onnx"
)
yolov8_detector = YOLOv10(model_file)
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
f.write(
requests.get(
"https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
).content
)
f.seek(0)
img = cv2.imread(f.name)
# # Detect Objects
combined_image = yolov8_detector.detect_objects(img)
# Draw detections
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
cv2.imshow("Output", combined_image)
cv2.waitKey(0)

View File

@@ -2,4 +2,5 @@ safetensors==0.4.3
opencv-python
twilio
https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/gradio-5.0.0b3-py3-none-any.whl
https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/gradio_webrtc-0.0.1-py3-none-any.whl
https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/gradio_webrtc-0.0.1-py3-none-any.whl
onnxruntime-gpu

File diff suppressed because one or more lines are too long

237
demo/utils.py Normal file
View File

@@ -0,0 +1,237 @@
import numpy as np
import cv2
class_names = [
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
# Create a list of colors for each class where each color is a tuple of 3 integer values
rng = np.random.default_rng(3)
colors = rng.uniform(0, 255, size=(len(class_names), 3))
def nms(boxes, scores, iou_threshold):
# Sort by score
sorted_indices = np.argsort(scores)[::-1]
keep_boxes = []
while sorted_indices.size > 0:
# Pick the last box
box_id = sorted_indices[0]
keep_boxes.append(box_id)
# Compute IoU of the picked box with the rest
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
# Remove boxes with IoU over the threshold
keep_indices = np.where(ious < iou_threshold)[0]
# print(keep_indices.shape, sorted_indices.shape)
sorted_indices = sorted_indices[keep_indices + 1]
return keep_boxes
def multiclass_nms(boxes, scores, class_ids, iou_threshold):
unique_class_ids = np.unique(class_ids)
keep_boxes = []
for class_id in unique_class_ids:
class_indices = np.where(class_ids == class_id)[0]
class_boxes = boxes[class_indices, :]
class_scores = scores[class_indices]
class_keep_boxes = nms(class_boxes, class_scores, iou_threshold)
keep_boxes.extend(class_indices[class_keep_boxes])
return keep_boxes
def compute_iou(box, boxes):
# Compute xmin, ymin, xmax, ymax for both boxes
xmin = np.maximum(box[0], boxes[:, 0])
ymin = np.maximum(box[1], boxes[:, 1])
xmax = np.minimum(box[2], boxes[:, 2])
ymax = np.minimum(box[3], boxes[:, 3])
# Compute intersection area
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
# Compute union area
box_area = (box[2] - box[0]) * (box[3] - box[1])
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
union_area = box_area + boxes_area - intersection_area
# Compute IoU
iou = intersection_area / union_area
return iou
def xywh2xyxy(x):
# Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
y = np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2
y[..., 1] = x[..., 1] - x[..., 3] / 2
y[..., 2] = x[..., 0] + x[..., 2] / 2
y[..., 3] = x[..., 1] + x[..., 3] / 2
return y
def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
det_img = image.copy()
img_height, img_width = image.shape[:2]
font_size = min([img_height, img_width]) * 0.0006
text_thickness = int(min([img_height, img_width]) * 0.001)
# det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
# Draw bounding boxes and labels of detections
for class_id, box, score in zip(class_ids, boxes, scores):
color = colors[class_id]
draw_box(det_img, box, color)
label = class_names[class_id]
caption = f"{label} {int(score * 100)}%"
draw_text(det_img, caption, box, color, font_size, text_thickness)
return det_img
def draw_box(
image: np.ndarray,
box: np.ndarray,
color: tuple[int, int, int] = (0, 0, 255),
thickness: int = 2,
) -> np.ndarray:
x1, y1, x2, y2 = box.astype(int)
return cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
def draw_text(
image: np.ndarray,
text: str,
box: np.ndarray,
color: tuple[int, int, int] = (0, 0, 255),
font_size: float = 0.001,
text_thickness: int = 2,
) -> np.ndarray:
x1, y1, x2, y2 = box.astype(int)
(tw, th), _ = cv2.getTextSize(
text=text,
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=font_size,
thickness=text_thickness,
)
th = int(th * 1.2)
cv2.rectangle(image, (x1, y1), (x1 + tw, y1 - th), color, -1)
return cv2.putText(
image,
text,
(x1, y1),
cv2.FONT_HERSHEY_SIMPLEX,
font_size,
(255, 255, 255),
text_thickness,
cv2.LINE_AA,
)
def draw_masks(
image: np.ndarray, boxes: np.ndarray, classes: np.ndarray, mask_alpha: float = 0.3
) -> np.ndarray:
mask_img = image.copy()
# Draw bounding boxes and labels of detections
for box, class_id in zip(boxes, classes):
color = colors[class_id]
x1, y1, x2, y2 = box.astype(int)
# Draw fill rectangle in mask image
cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1)
return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0)

View File

@@ -1,7 +1,6 @@
<svelte:options accessors={true} />
<script lang="ts">
import { Block, UploadText } from "@gradio/atoms";
import Video from "./shared/InteractiveVideo.svelte";
import { StatusTracker } from "@gradio/statustracker";
@@ -27,6 +26,7 @@
export let min_width: number | undefined = undefined;
export let gradio;
export let rtc_configuration: Object;
export let time_limit: number | null = null;
// export let gradio: Gradio<{
// change: never;
// clear: never;
@@ -80,6 +80,7 @@
{root}
{server}
{rtc_configuration}
{time_limit}
on:clear={() => gradio.dispatch("clear")}
on:play={() => gradio.dispatch("play")}
on:pause={() => gradio.dispatch("pause")}

View File

@@ -11,11 +11,11 @@
export let label: string | undefined = undefined;
export let show_label = true;
export let include_audio: boolean;
export let root: string;
export let i18n: I18nFormatter;
export let active_source: "webcam" | "upload" = "webcam";
export let handle_reset_value: () => void = () => {};
export let stream_handler: Client["stream"];
export let time_limit: number | null = null;
export let server: {
offer: (body: any) => Promise<any>;
};
@@ -44,9 +44,9 @@
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />
<div data-testid="video" class="video-container">
<Webcam
{root}
{rtc_configuration}
{include_audio}
{time_limit}
on:error
on:start_recording
on:stop_recording

View File

@@ -21,7 +21,8 @@
let video_source: HTMLVideoElement;
let available_video_devices: MediaDeviceInfo[] = [];
let selected_device: MediaDeviceInfo | null = null;
let time_limit: number | null = null;
let _time_limit: number | null = null;
export let time_limit: number | null = null;
let stream_state: "open" | "waiting" | "closed" = "closed";
export const webrtc_id = Math.random().toString(36).substring(2);
@@ -29,7 +30,7 @@
state: "open" | "closed" | "waiting"
) => {
if (state === "closed") {
time_limit = null;
_time_limit = null;
stream_state = "closed";
} else if (state === "waiting") {
stream_state = "waiting";
@@ -38,14 +39,8 @@
}
};
export const set_time_limit = (time: number): void => {
if (recording) time_limit = time;
};
let canvas: HTMLCanvasElement;
export let rtc_configuration: Object;
export let pending = false;
export let root = "";
export let stream_every = 1;
export let server: {
offer: (body: any) => Promise<any>;
@@ -133,10 +128,15 @@
console.log("config", configuration);
pc = new RTCPeerConnection(configuration);
pc.addEventListener("connectionstatechange",
(event) => {
async (event) => {
switch(pc.connectionState) {
case "connected":
stream_state = "open"
stream_state = "open";
_time_limit = time_limit;
break;
case "disconnected":
stream_state = "closed";
await access_webcam();
break;
default:
break;
@@ -196,7 +196,7 @@
</script>
<div class="wrap">
<StreamingBar {time_limit} />
<StreamingBar time_limit={_time_limit} />
<!-- svelte-ignore a11y-media-has-caption -->
<!-- need to suppress for video streaming https://github.com/sveltejs/svelte/issues/5967 -->
<video