mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
Add code
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||||
@@ -142,6 +143,7 @@ class WebRTC(Component):
|
|||||||
min_length: int | None = None,
|
min_length: int | None = None,
|
||||||
max_length: int | None = None,
|
max_length: int | None = None,
|
||||||
rtc_configuration: dict[str, Any] | None = None,
|
rtc_configuration: dict[str, Any] | None = None,
|
||||||
|
time_limit: float | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
@@ -173,6 +175,7 @@ class WebRTC(Component):
|
|||||||
streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding.
|
streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding.
|
||||||
watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png.
|
watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png.
|
||||||
"""
|
"""
|
||||||
|
self.time_limit = time_limit
|
||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
self.mirror_webcam = mirror_webcam
|
self.mirror_webcam = mirror_webcam
|
||||||
@@ -227,37 +230,52 @@ class WebRTC(Component):
|
|||||||
if webrtc_id in self.connections:
|
if webrtc_id in self.connections:
|
||||||
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
|
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
|
||||||
|
|
||||||
def webrtc_stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
fn: Callable[..., Any] | None = None,
|
fn: Callable[..., Any] | None = None,
|
||||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||||
|
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||||
js: str | None = None,
|
js: str | None = None,
|
||||||
concurrency_limit: int | None | Literal["default"] = "default",
|
concurrency_limit: int | None | Literal["default"] = "default",
|
||||||
concurrency_id: str | None = None,
|
concurrency_id: str | None = None,
|
||||||
stream_every: float = 0.5,
|
|
||||||
time_limit: float | None = None):
|
time_limit: float | None = None):
|
||||||
|
|
||||||
if inputs[0] != self:
|
from gradio.blocks import Block
|
||||||
raise ValueError("In the webrtc_stream event, the first input component must be the WebRTC component.")
|
|
||||||
|
if isinstance(inputs, Block):
|
||||||
|
inputs = [inputs]
|
||||||
|
if isinstance(outputs, Block):
|
||||||
|
outputs = [outputs]
|
||||||
|
|
||||||
|
if cast(list[Block], inputs)[0] != self:
|
||||||
|
raise ValueError("In the webrtc stream event, the first input component must be the WebRTC component.")
|
||||||
|
|
||||||
|
if len(cast(list[Block], outputs)) != 1 and cast(list[Block], outputs)[0] != self:
|
||||||
|
raise ValueError("In the webrtc stream event, the only output component must be the WebRTC component.")
|
||||||
|
|
||||||
self.concurrency_limit = 1 if concurrency_limit in ["default", None] else concurrency_limit
|
self.concurrency_limit = 1 if concurrency_limit in ["default", None] else concurrency_limit
|
||||||
self.event_handler = fn
|
self.event_handler = fn
|
||||||
|
self.time_limit = time_limit
|
||||||
return self.tick(self.set_output,
|
return self.tick(self.set_output,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs=None,
|
outputs=None,
|
||||||
concurrency_id=concurrency_id,
|
concurrency_id=concurrency_id,
|
||||||
concurrency_limit=None,
|
concurrency_limit=None,
|
||||||
stream_every=stream_every,
|
stream_every=0.5,
|
||||||
time_limit=None,
|
time_limit=None,
|
||||||
js=js
|
js=js
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
||||||
|
await asyncio.sleep(time_limit)
|
||||||
|
await pc.close()
|
||||||
|
|
||||||
@server
|
@server
|
||||||
async def offer(self, body):
|
async def offer(self, body):
|
||||||
|
|
||||||
if len(self.connections) >= self.concurrency_limit:
|
if len(self.connections) >= cast(int, self.concurrency_limit):
|
||||||
return {"status": "failed"}
|
return {"status": "failed"}
|
||||||
|
|
||||||
|
|
||||||
offer = RTCSessionDescription(sdp=body['sdp'], type=body['type'])
|
offer = RTCSessionDescription(sdp=body['sdp'], type=body['type'])
|
||||||
|
|
||||||
@@ -278,6 +296,9 @@ class WebRTC(Component):
|
|||||||
await pc.close()
|
await pc.close()
|
||||||
self.connections.pop(body['webrtc_id'], None)
|
self.connections.pop(body['webrtc_id'], None)
|
||||||
self.pcs.discard(pc)
|
self.pcs.discard(pc)
|
||||||
|
if pc.connectionState == "connected":
|
||||||
|
if self.time_limit is not None:
|
||||||
|
asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit))
|
||||||
|
|
||||||
@pc.on("track")
|
@pc.on("track")
|
||||||
def on_track(track):
|
def on_track(track):
|
||||||
|
|||||||
111
demo/app.py
111
demo/app.py
@@ -1,88 +1,42 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
from huggingface_hub import hf_hub_download
|
||||||
from gradio_webrtc import WebRTC
|
from gradio_webrtc import WebRTC
|
||||||
from pathlib import Path
|
|
||||||
from twilio.rest import Client
|
from twilio.rest import Client
|
||||||
import os
|
import os
|
||||||
|
from inference import YOLOv10
|
||||||
|
|
||||||
|
model_file = hf_hub_download(
|
||||||
|
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = YOLOv10(model_file)
|
||||||
|
|
||||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||||
client = Client(account_sid, auth_token)
|
|
||||||
|
|
||||||
token = client.tokens.create()
|
if account_sid and auth_token:
|
||||||
|
client = Client(account_sid, auth_token)
|
||||||
|
|
||||||
rtc_configuration = {
|
token = client.tokens.create()
|
||||||
"iceServers": token.ice_servers,
|
|
||||||
"iceTransportPolicy": "relay",
|
|
||||||
}
|
|
||||||
|
|
||||||
CLASSES = [
|
rtc_configuration = {
|
||||||
"background",
|
"iceServers": token.ice_servers,
|
||||||
"aeroplane",
|
"iceTransportPolicy": "relay",
|
||||||
"bicycle",
|
}
|
||||||
"bird",
|
else:
|
||||||
"boat",
|
rtc_configuration = None
|
||||||
"bottle",
|
|
||||||
"bus",
|
|
||||||
"car",
|
|
||||||
"cat",
|
|
||||||
"chair",
|
|
||||||
"cow",
|
|
||||||
"diningtable",
|
|
||||||
"dog",
|
|
||||||
"horse",
|
|
||||||
"motorbike",
|
|
||||||
"person",
|
|
||||||
"pottedplant",
|
|
||||||
"sheep",
|
|
||||||
"sofa",
|
|
||||||
"train",
|
|
||||||
"tvmonitor",
|
|
||||||
]
|
|
||||||
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
|
|
||||||
|
|
||||||
directory = Path(__file__).parent
|
|
||||||
|
|
||||||
MODEL = str((directory / "MobileNetSSD_deploy.caffemodel").resolve())
|
|
||||||
PROTOTXT = str((directory / "MobileNetSSD_deploy.prototxt.txt").resolve())
|
|
||||||
net = cv2.dnn.readNetFromCaffe(PROTOTXT, MODEL)
|
|
||||||
|
|
||||||
|
rtc_configuration = None
|
||||||
|
|
||||||
def detection(image, conf_threshold=0.3):
|
def detection(image, conf_threshold=0.3):
|
||||||
|
image = cv2.resize(image, (model.input_width, model.input_height))
|
||||||
blob = cv2.dnn.blobFromImage(
|
new_image = model.detect_objects(image, conf_threshold)
|
||||||
cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
|
return cv2.resize(new_image, (500, 500))
|
||||||
)
|
|
||||||
net.setInput(blob)
|
|
||||||
|
|
||||||
detections = net.forward()
|
|
||||||
image = cv2.resize(image, (500, 500))
|
|
||||||
(h, w) = image.shape[:2]
|
|
||||||
labels = []
|
|
||||||
for i in np.arange(0, detections.shape[2]):
|
|
||||||
confidence = detections[0, 0, i, 2]
|
|
||||||
|
|
||||||
if confidence > conf_threshold:
|
|
||||||
# extract the index of the class label from the `detections`,
|
|
||||||
# then compute the (x, y)-coordinates of the bounding box for
|
|
||||||
# the object
|
|
||||||
idx = int(detections[0, 0, i, 1])
|
|
||||||
box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
|
|
||||||
(startX, startY, endX, endY) = box.astype("int")
|
|
||||||
|
|
||||||
# display the prediction
|
|
||||||
label = f"{CLASSES[idx]}: {round(confidence * 100, 2)}%"
|
|
||||||
labels.append(label)
|
|
||||||
cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
|
|
||||||
y = startY - 15 if startY - 15 > 15 else startY + 15
|
|
||||||
cv2.putText(
|
|
||||||
image, label, (startX, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[idx], 2
|
|
||||||
)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
|
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
|
||||||
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
|
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
|
||||||
|
|
||||||
|
|
||||||
@@ -90,18 +44,20 @@ with gr.Blocks(css=css) as demo:
|
|||||||
gr.HTML(
|
gr.HTML(
|
||||||
"""
|
"""
|
||||||
<h1 style='text-align: center'>
|
<h1 style='text-align: center'>
|
||||||
YOLOv10 Webcam Stream
|
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️)
|
||||||
</h1>
|
</h1>
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
gr.HTML(
|
gr.HTML(
|
||||||
"""
|
"""
|
||||||
<h3 style='text-align: center'>
|
<h3 style='text-align: center'>
|
||||||
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
|
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
|
||||||
</h3>
|
</h3>
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
with gr.Column(elem_classes=["my-column"]):
|
with gr.Column(elem_classes=["my-column"]):
|
||||||
with gr.Group(elem_classes=["my-group"]):
|
with gr.Group(elem_classes=["my-group"]):
|
||||||
image = WebRTC(label="Strean", rtc_configuration=rtc_configuration)
|
image = WebRTC(label="Stream", rtc_configuration=rtc_configuration)
|
||||||
conf_threshold = gr.Slider(
|
conf_threshold = gr.Slider(
|
||||||
label="Confidence Threshold",
|
label="Confidence Threshold",
|
||||||
minimum=0.0,
|
minimum=0.0,
|
||||||
@@ -109,13 +65,10 @@ with gr.Blocks(css=css) as demo:
|
|||||||
step=0.05,
|
step=0.05,
|
||||||
value=0.30,
|
value=0.30,
|
||||||
)
|
)
|
||||||
|
|
||||||
image.webrtc_stream(
|
image.stream(
|
||||||
fn=detection,
|
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10
|
||||||
inputs=[image],
|
|
||||||
stream_every=0.05,
|
|
||||||
time_limit=30
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
demo.launch()
|
demo.launch()
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ def draw_bounding_boxes(image, results: dict, model, threshold=0.3):
|
|||||||
color = get_color(label)
|
color = get_color(label)
|
||||||
|
|
||||||
# Draw bounding box
|
# Draw bounding box
|
||||||
draw.rectangle(box, outline=color, width=3) # type: ignore
|
draw.rectangle(box, outline=color, width=3) # type: ignore
|
||||||
|
|
||||||
# Prepare text
|
# Prepare text
|
||||||
text = f"{label}: {score:.2f}"
|
text = f"{label}: {score:.2f}"
|
||||||
@@ -35,8 +35,8 @@ def draw_bounding_boxes(image, results: dict, model, threshold=0.3):
|
|||||||
|
|
||||||
# Draw text background
|
# Draw text background
|
||||||
draw.rectangle(
|
draw.rectangle(
|
||||||
[box[0], box[1] - text_height - 4, box[0] + text_width, box[1]], # type: ignore
|
[box[0], box[1] - text_height - 4, box[0] + text_width, box[1]], # type: ignore
|
||||||
fill=color, # type: ignore
|
fill=color, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
# Draw text
|
# Draw text
|
||||||
|
|||||||
146
demo/inference.py
Normal file
146
demo/inference.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import time
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime
|
||||||
|
|
||||||
|
from utils import draw_detections
|
||||||
|
|
||||||
|
|
||||||
|
class YOLOv10:
|
||||||
|
def __init__(self, path):
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
self.initialize_model(path)
|
||||||
|
|
||||||
|
def __call__(self, image):
|
||||||
|
return self.detect_objects(image)
|
||||||
|
|
||||||
|
def initialize_model(self, path):
|
||||||
|
self.session = onnxruntime.InferenceSession(
|
||||||
|
path, providers=onnxruntime.get_available_providers()
|
||||||
|
)
|
||||||
|
# Get model info
|
||||||
|
self.get_input_details()
|
||||||
|
self.get_output_details()
|
||||||
|
|
||||||
|
def detect_objects(self, image, conf_threshold=0.3):
|
||||||
|
input_tensor = self.prepare_input(image)
|
||||||
|
|
||||||
|
# Perform inference on the image
|
||||||
|
new_image = self.inference(image, input_tensor, conf_threshold)
|
||||||
|
|
||||||
|
return new_image
|
||||||
|
|
||||||
|
def prepare_input(self, image):
|
||||||
|
self.img_height, self.img_width = image.shape[:2]
|
||||||
|
|
||||||
|
input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# Resize input image
|
||||||
|
input_img = cv2.resize(input_img, (self.input_width, self.input_height))
|
||||||
|
|
||||||
|
# Scale input pixel values to 0 to 1
|
||||||
|
input_img = input_img / 255.0
|
||||||
|
input_img = input_img.transpose(2, 0, 1)
|
||||||
|
input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
|
||||||
|
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
|
def inference(self, image, input_tensor, conf_threshold=0.3):
|
||||||
|
start = time.perf_counter()
|
||||||
|
outputs = self.session.run(
|
||||||
|
self.output_names, {self.input_names[0]: input_tensor}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
|
||||||
|
boxes, scores, class_ids, = self.process_output(outputs, conf_threshold)
|
||||||
|
return self.draw_detections(image, boxes, scores, class_ids)
|
||||||
|
|
||||||
|
def process_output(self, output, conf_threshold=0.3):
|
||||||
|
predictions = np.squeeze(output[0])
|
||||||
|
|
||||||
|
# Filter out object confidence scores below threshold
|
||||||
|
scores = predictions[:, 4]
|
||||||
|
predictions = predictions[scores > conf_threshold, :]
|
||||||
|
scores = scores[scores > conf_threshold]
|
||||||
|
|
||||||
|
if len(scores) == 0:
|
||||||
|
return [], [], []
|
||||||
|
|
||||||
|
# Get the class with the highest confidence
|
||||||
|
class_ids = np.argmax(predictions[:, 4:], axis=1)
|
||||||
|
|
||||||
|
# Get bounding boxes for each object
|
||||||
|
boxes = self.extract_boxes(predictions)
|
||||||
|
|
||||||
|
return boxes, scores, class_ids
|
||||||
|
|
||||||
|
def extract_boxes(self, predictions):
|
||||||
|
# Extract boxes from predictions
|
||||||
|
boxes = predictions[:, :4]
|
||||||
|
|
||||||
|
# Scale boxes to original image dimensions
|
||||||
|
boxes = self.rescale_boxes(boxes)
|
||||||
|
|
||||||
|
# Convert boxes to xyxy format
|
||||||
|
#boxes = xywh2xyxy(boxes)
|
||||||
|
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
def rescale_boxes(self, boxes):
|
||||||
|
# Rescale boxes to original image dimensions
|
||||||
|
input_shape = np.array(
|
||||||
|
[self.input_width, self.input_height, self.input_width, self.input_height]
|
||||||
|
)
|
||||||
|
boxes = np.divide(boxes, input_shape, dtype=np.float32)
|
||||||
|
boxes *= np.array(
|
||||||
|
[self.img_width, self.img_height, self.img_width, self.img_height]
|
||||||
|
)
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
def draw_detections(self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4):
|
||||||
|
return draw_detections(
|
||||||
|
image, boxes, scores, class_ids, mask_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_input_details(self):
|
||||||
|
model_inputs = self.session.get_inputs()
|
||||||
|
self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
||||||
|
|
||||||
|
self.input_shape = model_inputs[0].shape
|
||||||
|
self.input_height = self.input_shape[2]
|
||||||
|
self.input_width = self.input_shape[3]
|
||||||
|
|
||||||
|
def get_output_details(self):
|
||||||
|
model_outputs = self.session.get_outputs()
|
||||||
|
self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import requests
|
||||||
|
import tempfile
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
model_file = hf_hub_download(
|
||||||
|
repo_id="onnx-community/yolov10s", filename="onnx/model.onnx"
|
||||||
|
)
|
||||||
|
|
||||||
|
yolov8_detector = YOLOv10(model_file)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||||
|
f.write(
|
||||||
|
requests.get(
|
||||||
|
"https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
f.seek(0)
|
||||||
|
img = cv2.imread(f.name)
|
||||||
|
|
||||||
|
# # Detect Objects
|
||||||
|
combined_image = yolov8_detector.detect_objects(img)
|
||||||
|
|
||||||
|
|
||||||
|
# Draw detections
|
||||||
|
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
|
||||||
|
cv2.imshow("Output", combined_image)
|
||||||
|
cv2.waitKey(0)
|
||||||
@@ -2,4 +2,5 @@ safetensors==0.4.3
|
|||||||
opencv-python
|
opencv-python
|
||||||
twilio
|
twilio
|
||||||
https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/gradio-5.0.0b3-py3-none-any.whl
|
https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/gradio-5.0.0b3-py3-none-any.whl
|
||||||
https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/gradio_webrtc-0.0.1-py3-none-any.whl
|
https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/gradio_webrtc-0.0.1-py3-none-any.whl
|
||||||
|
onnxruntime-gpu
|
||||||
183
demo/space.py
183
demo/space.py
File diff suppressed because one or more lines are too long
237
demo/utils.py
Normal file
237
demo/utils.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
class_names = [
|
||||||
|
"person",
|
||||||
|
"bicycle",
|
||||||
|
"car",
|
||||||
|
"motorcycle",
|
||||||
|
"airplane",
|
||||||
|
"bus",
|
||||||
|
"train",
|
||||||
|
"truck",
|
||||||
|
"boat",
|
||||||
|
"traffic light",
|
||||||
|
"fire hydrant",
|
||||||
|
"stop sign",
|
||||||
|
"parking meter",
|
||||||
|
"bench",
|
||||||
|
"bird",
|
||||||
|
"cat",
|
||||||
|
"dog",
|
||||||
|
"horse",
|
||||||
|
"sheep",
|
||||||
|
"cow",
|
||||||
|
"elephant",
|
||||||
|
"bear",
|
||||||
|
"zebra",
|
||||||
|
"giraffe",
|
||||||
|
"backpack",
|
||||||
|
"umbrella",
|
||||||
|
"handbag",
|
||||||
|
"tie",
|
||||||
|
"suitcase",
|
||||||
|
"frisbee",
|
||||||
|
"skis",
|
||||||
|
"snowboard",
|
||||||
|
"sports ball",
|
||||||
|
"kite",
|
||||||
|
"baseball bat",
|
||||||
|
"baseball glove",
|
||||||
|
"skateboard",
|
||||||
|
"surfboard",
|
||||||
|
"tennis racket",
|
||||||
|
"bottle",
|
||||||
|
"wine glass",
|
||||||
|
"cup",
|
||||||
|
"fork",
|
||||||
|
"knife",
|
||||||
|
"spoon",
|
||||||
|
"bowl",
|
||||||
|
"banana",
|
||||||
|
"apple",
|
||||||
|
"sandwich",
|
||||||
|
"orange",
|
||||||
|
"broccoli",
|
||||||
|
"carrot",
|
||||||
|
"hot dog",
|
||||||
|
"pizza",
|
||||||
|
"donut",
|
||||||
|
"cake",
|
||||||
|
"chair",
|
||||||
|
"couch",
|
||||||
|
"potted plant",
|
||||||
|
"bed",
|
||||||
|
"dining table",
|
||||||
|
"toilet",
|
||||||
|
"tv",
|
||||||
|
"laptop",
|
||||||
|
"mouse",
|
||||||
|
"remote",
|
||||||
|
"keyboard",
|
||||||
|
"cell phone",
|
||||||
|
"microwave",
|
||||||
|
"oven",
|
||||||
|
"toaster",
|
||||||
|
"sink",
|
||||||
|
"refrigerator",
|
||||||
|
"book",
|
||||||
|
"clock",
|
||||||
|
"vase",
|
||||||
|
"scissors",
|
||||||
|
"teddy bear",
|
||||||
|
"hair drier",
|
||||||
|
"toothbrush",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a list of colors for each class where each color is a tuple of 3 integer values
|
||||||
|
rng = np.random.default_rng(3)
|
||||||
|
colors = rng.uniform(0, 255, size=(len(class_names), 3))
|
||||||
|
|
||||||
|
|
||||||
|
def nms(boxes, scores, iou_threshold):
|
||||||
|
# Sort by score
|
||||||
|
sorted_indices = np.argsort(scores)[::-1]
|
||||||
|
|
||||||
|
keep_boxes = []
|
||||||
|
while sorted_indices.size > 0:
|
||||||
|
# Pick the last box
|
||||||
|
box_id = sorted_indices[0]
|
||||||
|
keep_boxes.append(box_id)
|
||||||
|
|
||||||
|
# Compute IoU of the picked box with the rest
|
||||||
|
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
|
||||||
|
|
||||||
|
# Remove boxes with IoU over the threshold
|
||||||
|
keep_indices = np.where(ious < iou_threshold)[0]
|
||||||
|
|
||||||
|
# print(keep_indices.shape, sorted_indices.shape)
|
||||||
|
sorted_indices = sorted_indices[keep_indices + 1]
|
||||||
|
|
||||||
|
return keep_boxes
|
||||||
|
|
||||||
|
|
||||||
|
def multiclass_nms(boxes, scores, class_ids, iou_threshold):
|
||||||
|
unique_class_ids = np.unique(class_ids)
|
||||||
|
|
||||||
|
keep_boxes = []
|
||||||
|
for class_id in unique_class_ids:
|
||||||
|
class_indices = np.where(class_ids == class_id)[0]
|
||||||
|
class_boxes = boxes[class_indices, :]
|
||||||
|
class_scores = scores[class_indices]
|
||||||
|
|
||||||
|
class_keep_boxes = nms(class_boxes, class_scores, iou_threshold)
|
||||||
|
keep_boxes.extend(class_indices[class_keep_boxes])
|
||||||
|
|
||||||
|
return keep_boxes
|
||||||
|
|
||||||
|
|
||||||
|
def compute_iou(box, boxes):
|
||||||
|
# Compute xmin, ymin, xmax, ymax for both boxes
|
||||||
|
xmin = np.maximum(box[0], boxes[:, 0])
|
||||||
|
ymin = np.maximum(box[1], boxes[:, 1])
|
||||||
|
xmax = np.minimum(box[2], boxes[:, 2])
|
||||||
|
ymax = np.minimum(box[3], boxes[:, 3])
|
||||||
|
|
||||||
|
# Compute intersection area
|
||||||
|
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
|
||||||
|
|
||||||
|
# Compute union area
|
||||||
|
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
||||||
|
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||||
|
union_area = box_area + boxes_area - intersection_area
|
||||||
|
|
||||||
|
# Compute IoU
|
||||||
|
iou = intersection_area / union_area
|
||||||
|
|
||||||
|
return iou
|
||||||
|
|
||||||
|
|
||||||
|
def xywh2xyxy(x):
|
||||||
|
# Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
|
||||||
|
y = np.copy(x)
|
||||||
|
y[..., 0] = x[..., 0] - x[..., 2] / 2
|
||||||
|
y[..., 1] = x[..., 1] - x[..., 3] / 2
|
||||||
|
y[..., 2] = x[..., 0] + x[..., 2] / 2
|
||||||
|
y[..., 3] = x[..., 1] + x[..., 3] / 2
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
|
||||||
|
det_img = image.copy()
|
||||||
|
|
||||||
|
img_height, img_width = image.shape[:2]
|
||||||
|
font_size = min([img_height, img_width]) * 0.0006
|
||||||
|
text_thickness = int(min([img_height, img_width]) * 0.001)
|
||||||
|
|
||||||
|
#det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
|
||||||
|
|
||||||
|
# Draw bounding boxes and labels of detections
|
||||||
|
for class_id, box, score in zip(class_ids, boxes, scores):
|
||||||
|
color = colors[class_id]
|
||||||
|
|
||||||
|
draw_box(det_img, box, color)
|
||||||
|
|
||||||
|
label = class_names[class_id]
|
||||||
|
caption = f"{label} {int(score * 100)}%"
|
||||||
|
draw_text(det_img, caption, box, color, font_size, text_thickness)
|
||||||
|
|
||||||
|
return det_img
|
||||||
|
|
||||||
|
|
||||||
|
def draw_box(
|
||||||
|
image: np.ndarray,
|
||||||
|
box: np.ndarray,
|
||||||
|
color: tuple[int, int, int] = (0, 0, 255),
|
||||||
|
thickness: int = 2,
|
||||||
|
) -> np.ndarray:
|
||||||
|
x1, y1, x2, y2 = box.astype(int)
|
||||||
|
return cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_text(
|
||||||
|
image: np.ndarray,
|
||||||
|
text: str,
|
||||||
|
box: np.ndarray,
|
||||||
|
color: tuple[int, int, int] = (0, 0, 255),
|
||||||
|
font_size: float = 0.001,
|
||||||
|
text_thickness: int = 2,
|
||||||
|
) -> np.ndarray:
|
||||||
|
x1, y1, x2, y2 = box.astype(int)
|
||||||
|
(tw, th), _ = cv2.getTextSize(
|
||||||
|
text=text,
|
||||||
|
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
fontScale=font_size,
|
||||||
|
thickness=text_thickness,
|
||||||
|
)
|
||||||
|
th = int(th * 1.2)
|
||||||
|
|
||||||
|
cv2.rectangle(image, (x1, y1), (x1 + tw, y1 - th), color, -1)
|
||||||
|
|
||||||
|
return cv2.putText(
|
||||||
|
image,
|
||||||
|
text,
|
||||||
|
(x1, y1),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
font_size,
|
||||||
|
(255, 255, 255),
|
||||||
|
text_thickness,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_masks(
|
||||||
|
image: np.ndarray, boxes: np.ndarray, classes: np.ndarray, mask_alpha: float = 0.3
|
||||||
|
) -> np.ndarray:
|
||||||
|
mask_img = image.copy()
|
||||||
|
|
||||||
|
# Draw bounding boxes and labels of detections
|
||||||
|
for box, class_id in zip(boxes, classes):
|
||||||
|
color = colors[class_id]
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = box.astype(int)
|
||||||
|
|
||||||
|
# Draw fill rectangle in mask image
|
||||||
|
cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1)
|
||||||
|
|
||||||
|
return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0)
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
<svelte:options accessors={true} />
|
<svelte:options accessors={true} />
|
||||||
|
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
|
|
||||||
import { Block, UploadText } from "@gradio/atoms";
|
import { Block, UploadText } from "@gradio/atoms";
|
||||||
import Video from "./shared/InteractiveVideo.svelte";
|
import Video from "./shared/InteractiveVideo.svelte";
|
||||||
import { StatusTracker } from "@gradio/statustracker";
|
import { StatusTracker } from "@gradio/statustracker";
|
||||||
@@ -27,6 +26,7 @@
|
|||||||
export let min_width: number | undefined = undefined;
|
export let min_width: number | undefined = undefined;
|
||||||
export let gradio;
|
export let gradio;
|
||||||
export let rtc_configuration: Object;
|
export let rtc_configuration: Object;
|
||||||
|
export let time_limit: number | null = null;
|
||||||
// export let gradio: Gradio<{
|
// export let gradio: Gradio<{
|
||||||
// change: never;
|
// change: never;
|
||||||
// clear: never;
|
// clear: never;
|
||||||
@@ -80,6 +80,7 @@
|
|||||||
{root}
|
{root}
|
||||||
{server}
|
{server}
|
||||||
{rtc_configuration}
|
{rtc_configuration}
|
||||||
|
{time_limit}
|
||||||
on:clear={() => gradio.dispatch("clear")}
|
on:clear={() => gradio.dispatch("clear")}
|
||||||
on:play={() => gradio.dispatch("play")}
|
on:play={() => gradio.dispatch("play")}
|
||||||
on:pause={() => gradio.dispatch("pause")}
|
on:pause={() => gradio.dispatch("pause")}
|
||||||
|
|||||||
@@ -11,11 +11,11 @@
|
|||||||
export let label: string | undefined = undefined;
|
export let label: string | undefined = undefined;
|
||||||
export let show_label = true;
|
export let show_label = true;
|
||||||
export let include_audio: boolean;
|
export let include_audio: boolean;
|
||||||
export let root: string;
|
|
||||||
export let i18n: I18nFormatter;
|
export let i18n: I18nFormatter;
|
||||||
export let active_source: "webcam" | "upload" = "webcam";
|
export let active_source: "webcam" | "upload" = "webcam";
|
||||||
export let handle_reset_value: () => void = () => {};
|
export let handle_reset_value: () => void = () => {};
|
||||||
export let stream_handler: Client["stream"];
|
export let stream_handler: Client["stream"];
|
||||||
|
export let time_limit: number | null = null;
|
||||||
export let server: {
|
export let server: {
|
||||||
offer: (body: any) => Promise<any>;
|
offer: (body: any) => Promise<any>;
|
||||||
};
|
};
|
||||||
@@ -44,9 +44,9 @@
|
|||||||
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />
|
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />
|
||||||
<div data-testid="video" class="video-container">
|
<div data-testid="video" class="video-container">
|
||||||
<Webcam
|
<Webcam
|
||||||
{root}
|
|
||||||
{rtc_configuration}
|
{rtc_configuration}
|
||||||
{include_audio}
|
{include_audio}
|
||||||
|
{time_limit}
|
||||||
on:error
|
on:error
|
||||||
on:start_recording
|
on:start_recording
|
||||||
on:stop_recording
|
on:stop_recording
|
||||||
|
|||||||
@@ -21,7 +21,8 @@
|
|||||||
let video_source: HTMLVideoElement;
|
let video_source: HTMLVideoElement;
|
||||||
let available_video_devices: MediaDeviceInfo[] = [];
|
let available_video_devices: MediaDeviceInfo[] = [];
|
||||||
let selected_device: MediaDeviceInfo | null = null;
|
let selected_device: MediaDeviceInfo | null = null;
|
||||||
let time_limit: number | null = null;
|
let _time_limit: number | null = null;
|
||||||
|
export let time_limit: number | null = null;
|
||||||
let stream_state: "open" | "waiting" | "closed" = "closed";
|
let stream_state: "open" | "waiting" | "closed" = "closed";
|
||||||
export const webrtc_id = Math.random().toString(36).substring(2);
|
export const webrtc_id = Math.random().toString(36).substring(2);
|
||||||
|
|
||||||
@@ -29,7 +30,7 @@
|
|||||||
state: "open" | "closed" | "waiting"
|
state: "open" | "closed" | "waiting"
|
||||||
) => {
|
) => {
|
||||||
if (state === "closed") {
|
if (state === "closed") {
|
||||||
time_limit = null;
|
_time_limit = null;
|
||||||
stream_state = "closed";
|
stream_state = "closed";
|
||||||
} else if (state === "waiting") {
|
} else if (state === "waiting") {
|
||||||
stream_state = "waiting";
|
stream_state = "waiting";
|
||||||
@@ -38,14 +39,8 @@
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
export const set_time_limit = (time: number): void => {
|
|
||||||
if (recording) time_limit = time;
|
|
||||||
};
|
|
||||||
|
|
||||||
let canvas: HTMLCanvasElement;
|
let canvas: HTMLCanvasElement;
|
||||||
export let rtc_configuration: Object;
|
export let rtc_configuration: Object;
|
||||||
export let pending = false;
|
|
||||||
export let root = "";
|
|
||||||
export let stream_every = 1;
|
export let stream_every = 1;
|
||||||
export let server: {
|
export let server: {
|
||||||
offer: (body: any) => Promise<any>;
|
offer: (body: any) => Promise<any>;
|
||||||
@@ -133,10 +128,15 @@
|
|||||||
console.log("config", configuration);
|
console.log("config", configuration);
|
||||||
pc = new RTCPeerConnection(configuration);
|
pc = new RTCPeerConnection(configuration);
|
||||||
pc.addEventListener("connectionstatechange",
|
pc.addEventListener("connectionstatechange",
|
||||||
(event) => {
|
async (event) => {
|
||||||
switch(pc.connectionState) {
|
switch(pc.connectionState) {
|
||||||
case "connected":
|
case "connected":
|
||||||
stream_state = "open"
|
stream_state = "open";
|
||||||
|
_time_limit = time_limit;
|
||||||
|
break;
|
||||||
|
case "disconnected":
|
||||||
|
stream_state = "closed";
|
||||||
|
await access_webcam();
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
@@ -196,7 +196,7 @@
|
|||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div class="wrap">
|
<div class="wrap">
|
||||||
<StreamingBar {time_limit} />
|
<StreamingBar time_limit={_time_limit} />
|
||||||
<!-- svelte-ignore a11y-media-has-caption -->
|
<!-- svelte-ignore a11y-media-has-caption -->
|
||||||
<!-- need to suppress for video streaming https://github.com/sveltejs/svelte/issues/5967 -->
|
<!-- need to suppress for video streaming https://github.com/sveltejs/svelte/issues/5967 -->
|
||||||
<video
|
<video
|
||||||
|
|||||||
Reference in New Issue
Block a user