mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 09:59:22 +08:00
[feat] update some feature
sync code of fastrtc, add text support through datachannel, fix safari connect problem support chat without camera or mic
This commit is contained in:
15
demo/object_detection/README.md
Normal file
15
demo/object_detection/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
---
|
||||
title: Object Detection
|
||||
emoji: 📸
|
||||
colorFrom: purple
|
||||
colorTo: red
|
||||
sdk: gradio
|
||||
sdk_version: 5.16.0
|
||||
app_file: app.py
|
||||
pinned: false
|
||||
license: mit
|
||||
short_description: Use YOLOv10 to detect objects in real-time
|
||||
tags: [webrtc, websocket, gradio, secret|TWILIO_ACCOUNT_SID, secret|TWILIO_AUTH_TOKEN]
|
||||
---
|
||||
|
||||
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
||||
77
demo/object_detection/app.py
Normal file
77
demo/object_detection/app.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import gradio as gr
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastrtc import Stream, get_twilio_turn_credentials
|
||||
from gradio.utils import get_space
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
from demo.object_detection.inference import YOLOv10
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
from inference import YOLOv10
|
||||
|
||||
|
||||
cur_dir = Path(__file__).parent
|
||||
|
||||
model_file = hf_hub_download(
|
||||
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
|
||||
)
|
||||
|
||||
model = YOLOv10(model_file)
|
||||
|
||||
|
||||
def detection(image, conf_threshold=0.3):
|
||||
image = cv2.resize(image, (model.input_width, model.input_height))
|
||||
print("conf_threshold", conf_threshold)
|
||||
new_image = model.detect_objects(image, conf_threshold)
|
||||
return cv2.resize(new_image, (500, 500))
|
||||
|
||||
|
||||
stream = Stream(
|
||||
handler=detection,
|
||||
modality="video",
|
||||
mode="send-receive",
|
||||
additional_inputs=[gr.Slider(minimum=0, maximum=1, step=0.01, value=0.3)],
|
||||
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
||||
concurrency_limit=2 if get_space() else None,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
stream.mount(app)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def _():
|
||||
rtc_config = get_twilio_turn_credentials() if get_space() else None
|
||||
html_content = open(cur_dir / "index.html").read()
|
||||
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
|
||||
class InputData(BaseModel):
|
||||
webrtc_id: str
|
||||
conf_threshold: float = Field(ge=0, le=1)
|
||||
|
||||
|
||||
@app.post("/input_hook")
|
||||
async def _(data: InputData):
|
||||
stream.set_input(data.webrtc_id, data.conf_threshold)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
if (mode := os.getenv("MODE")) == "UI":
|
||||
stream.ui.launch(server_port=7860)
|
||||
elif mode == "PHONE":
|
||||
stream.fastphone(host="0.0.0.0", port=7860)
|
||||
else:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7860)
|
||||
340
demo/object_detection/index.html
Normal file
340
demo/object_detection/index.html
Normal file
@@ -0,0 +1,340 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Object Detection</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: system-ui, -apple-system, sans-serif;
|
||||
background: linear-gradient(135deg, #2d2b52 0%, #191731 100%);
|
||||
color: white;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
height: 100vh;
|
||||
box-sizing: border-box;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.container {
|
||||
width: 100%;
|
||||
max-width: 800px;
|
||||
text-align: center;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.video-container {
|
||||
width: 100%;
|
||||
max-width: 500px;
|
||||
aspect-ratio: 1/1;
|
||||
background: rgba(255, 255, 255, 0.1);
|
||||
border-radius: 12px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
#video-output {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
button {
|
||||
background: white;
|
||||
color: #2d2b52;
|
||||
border: none;
|
||||
padding: 12px 32px;
|
||||
border-radius: 24px;
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
button:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 6px 16px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 2.5em;
|
||||
margin-bottom: 0.3em;
|
||||
}
|
||||
|
||||
p {
|
||||
color: rgba(255, 255, 255, 0.8);
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
|
||||
.controls {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
align-items: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.slider-container {
|
||||
width: 100%;
|
||||
max-width: 300px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.slider-container label {
|
||||
color: rgba(255, 255, 255, 0.8);
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
input[type="range"] {
|
||||
width: 100%;
|
||||
height: 6px;
|
||||
-webkit-appearance: none;
|
||||
background: rgba(255, 255, 255, 0.1);
|
||||
border-radius: 3px;
|
||||
outline: none;
|
||||
}
|
||||
|
||||
input[type="range"]::-webkit-slider-thumb {
|
||||
-webkit-appearance: none;
|
||||
width: 18px;
|
||||
height: 18px;
|
||||
background: white;
|
||||
border-radius: 50%;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
/* Add styles for toast notifications */
|
||||
.toast {
|
||||
position: fixed;
|
||||
top: 20px;
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
padding: 16px 24px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
z-index: 1000;
|
||||
display: none;
|
||||
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.toast.error {
|
||||
background-color: #f44336;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.toast.warning {
|
||||
background-color: #ffd700;
|
||||
color: black;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<!-- Add toast element after body opening tag -->
|
||||
<div id="error-toast" class="toast"></div>
|
||||
<div class="container">
|
||||
<h1>Real-time Object Detection</h1>
|
||||
<p>Using YOLOv10 to detect objects in your webcam feed</p>
|
||||
<div class="video-container">
|
||||
<video id="video-output" autoplay playsinline></video>
|
||||
</div>
|
||||
<div class="controls">
|
||||
<div class="slider-container">
|
||||
<label>Confidence Threshold: <span id="conf-value">0.3</span></label>
|
||||
<input type="range" id="conf-threshold" min="0" max="1" step="0.01" value="0.3">
|
||||
</div>
|
||||
<button id="start-button">Start</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let peerConnection;
|
||||
let webrtc_id;
|
||||
const startButton = document.getElementById('start-button');
|
||||
const videoOutput = document.getElementById('video-output');
|
||||
const confThreshold = document.getElementById('conf-threshold');
|
||||
const confValue = document.getElementById('conf-value');
|
||||
|
||||
// Update confidence value display
|
||||
confThreshold.addEventListener('input', (e) => {
|
||||
confValue.textContent = e.target.value;
|
||||
if (peerConnection) {
|
||||
updateConfThreshold(e.target.value);
|
||||
}
|
||||
});
|
||||
|
||||
function updateConfThreshold(value) {
|
||||
fetch('/input_hook', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
webrtc_id: webrtc_id,
|
||||
conf_threshold: parseFloat(value)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
function showError(message) {
|
||||
const toast = document.getElementById('error-toast');
|
||||
toast.textContent = message;
|
||||
toast.className = 'toast error';
|
||||
toast.style.display = 'block';
|
||||
|
||||
// Hide toast after 5 seconds
|
||||
setTimeout(() => {
|
||||
toast.style.display = 'none';
|
||||
}, 5000);
|
||||
}
|
||||
|
||||
async function setupWebRTC() {
|
||||
const config = __RTC_CONFIGURATION__;
|
||||
peerConnection = new RTCPeerConnection(config);
|
||||
|
||||
const timeoutId = setTimeout(() => {
|
||||
const toast = document.getElementById('error-toast');
|
||||
toast.textContent = "Connection is taking longer than usual. Are you on a VPN?";
|
||||
toast.className = 'toast warning';
|
||||
toast.style.display = 'block';
|
||||
|
||||
// Hide warning after 5 seconds
|
||||
setTimeout(() => {
|
||||
toast.style.display = 'none';
|
||||
}, 5000);
|
||||
}, 5000);
|
||||
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({
|
||||
video: true
|
||||
});
|
||||
|
||||
stream.getTracks().forEach(track => {
|
||||
peerConnection.addTrack(track, stream);
|
||||
});
|
||||
|
||||
peerConnection.addEventListener('track', (evt) => {
|
||||
if (videoOutput && videoOutput.srcObject !== evt.streams[0]) {
|
||||
videoOutput.srcObject = evt.streams[0];
|
||||
}
|
||||
});
|
||||
|
||||
const dataChannel = peerConnection.createDataChannel('text');
|
||||
dataChannel.onmessage = (event) => {
|
||||
const eventJson = JSON.parse(event.data);
|
||||
if (eventJson.type === "error") {
|
||||
showError(eventJson.message);
|
||||
} else if (eventJson.type === "send_input") {
|
||||
updateConfThreshold(confThreshold.value);
|
||||
}
|
||||
};
|
||||
|
||||
const offer = await peerConnection.createOffer();
|
||||
await peerConnection.setLocalDescription(offer);
|
||||
|
||||
await new Promise((resolve) => {
|
||||
if (peerConnection.iceGatheringState === "complete") {
|
||||
resolve();
|
||||
} else {
|
||||
const checkState = () => {
|
||||
if (peerConnection.iceGatheringState === "complete") {
|
||||
peerConnection.removeEventListener("icegatheringstatechange", checkState);
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
peerConnection.addEventListener("icegatheringstatechange", checkState);
|
||||
}
|
||||
});
|
||||
|
||||
webrtc_id = Math.random().toString(36).substring(7);
|
||||
|
||||
const response = await fetch('/webrtc/offer', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
sdp: peerConnection.localDescription.sdp,
|
||||
type: peerConnection.localDescription.type,
|
||||
webrtc_id: webrtc_id
|
||||
})
|
||||
});
|
||||
|
||||
const serverResponse = await response.json();
|
||||
|
||||
if (serverResponse.status === 'failed') {
|
||||
showError(serverResponse.meta.error === 'concurrency_limit_reached'
|
||||
? `Too many connections. Maximum limit is ${serverResponse.meta.limit}`
|
||||
: serverResponse.meta.error);
|
||||
stop();
|
||||
startButton.textContent = 'Start';
|
||||
return;
|
||||
}
|
||||
|
||||
await peerConnection.setRemoteDescription(serverResponse);
|
||||
|
||||
// Send initial confidence threshold
|
||||
updateConfThreshold(confThreshold.value);
|
||||
|
||||
peerConnection.addEventListener('connectionstatechange', () => {
|
||||
if (peerConnection.connectionState === 'connected') {
|
||||
clearTimeout(timeoutId);
|
||||
const toast = document.getElementById('error-toast');
|
||||
toast.style.display = 'none';
|
||||
}
|
||||
});
|
||||
|
||||
} catch (err) {
|
||||
clearTimeout(timeoutId);
|
||||
console.error('Error setting up WebRTC:', err);
|
||||
showError('Failed to establish connection. Please try again.');
|
||||
stop();
|
||||
startButton.textContent = 'Start';
|
||||
}
|
||||
}
|
||||
|
||||
function stop() {
|
||||
if (peerConnection) {
|
||||
if (peerConnection.getTransceivers) {
|
||||
peerConnection.getTransceivers().forEach(transceiver => {
|
||||
if (transceiver.stop) {
|
||||
transceiver.stop();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (peerConnection.getSenders) {
|
||||
peerConnection.getSenders().forEach(sender => {
|
||||
if (sender.track && sender.track.stop) sender.track.stop();
|
||||
});
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
peerConnection.close();
|
||||
}, 500);
|
||||
}
|
||||
|
||||
videoOutput.srcObject = null;
|
||||
}
|
||||
|
||||
startButton.addEventListener('click', () => {
|
||||
if (startButton.textContent === 'Start') {
|
||||
setupWebRTC();
|
||||
startButton.textContent = 'Stop';
|
||||
} else {
|
||||
stop();
|
||||
startButton.textContent = 'Start';
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
153
demo/object_detection/inference.py
Normal file
153
demo/object_detection/inference.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
|
||||
try:
|
||||
from demo.object_detection.utils import draw_detections
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
from utils import draw_detections
|
||||
|
||||
|
||||
class YOLOv10:
|
||||
def __init__(self, path):
|
||||
# Initialize model
|
||||
self.initialize_model(path)
|
||||
|
||||
def __call__(self, image):
|
||||
return self.detect_objects(image)
|
||||
|
||||
def initialize_model(self, path):
|
||||
self.session = onnxruntime.InferenceSession(
|
||||
path, providers=onnxruntime.get_available_providers()
|
||||
)
|
||||
# Get model info
|
||||
self.get_input_details()
|
||||
self.get_output_details()
|
||||
|
||||
def detect_objects(self, image, conf_threshold=0.3):
|
||||
input_tensor = self.prepare_input(image)
|
||||
|
||||
# Perform inference on the image
|
||||
new_image = self.inference(image, input_tensor, conf_threshold)
|
||||
|
||||
return new_image
|
||||
|
||||
def prepare_input(self, image):
|
||||
self.img_height, self.img_width = image.shape[:2]
|
||||
|
||||
input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize input image
|
||||
input_img = cv2.resize(input_img, (self.input_width, self.input_height))
|
||||
|
||||
# Scale input pixel values to 0 to 1
|
||||
input_img = input_img / 255.0
|
||||
input_img = input_img.transpose(2, 0, 1)
|
||||
input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
|
||||
|
||||
return input_tensor
|
||||
|
||||
def inference(self, image, input_tensor, conf_threshold=0.3):
|
||||
start = time.perf_counter()
|
||||
outputs = self.session.run(
|
||||
self.output_names, {self.input_names[0]: input_tensor}
|
||||
)
|
||||
|
||||
print(f"Inference time: {(time.perf_counter() - start) * 1000:.2f} ms")
|
||||
(
|
||||
boxes,
|
||||
scores,
|
||||
class_ids,
|
||||
) = self.process_output(outputs, conf_threshold)
|
||||
return self.draw_detections(image, boxes, scores, class_ids)
|
||||
|
||||
def process_output(self, output, conf_threshold=0.3):
|
||||
predictions = np.squeeze(output[0])
|
||||
|
||||
# Filter out object confidence scores below threshold
|
||||
scores = predictions[:, 4]
|
||||
predictions = predictions[scores > conf_threshold, :]
|
||||
scores = scores[scores > conf_threshold]
|
||||
|
||||
if len(scores) == 0:
|
||||
return [], [], []
|
||||
|
||||
# Get the class with the highest confidence
|
||||
class_ids = predictions[:, 5].astype(int)
|
||||
|
||||
# Get bounding boxes for each object
|
||||
boxes = self.extract_boxes(predictions)
|
||||
|
||||
return boxes, scores, class_ids
|
||||
|
||||
def extract_boxes(self, predictions):
|
||||
# Extract boxes from predictions
|
||||
boxes = predictions[:, :4]
|
||||
|
||||
# Scale boxes to original image dimensions
|
||||
boxes = self.rescale_boxes(boxes)
|
||||
|
||||
# Convert boxes to xyxy format
|
||||
# boxes = xywh2xyxy(boxes)
|
||||
|
||||
return boxes
|
||||
|
||||
def rescale_boxes(self, boxes):
|
||||
# Rescale boxes to original image dimensions
|
||||
input_shape = np.array(
|
||||
[self.input_width, self.input_height, self.input_width, self.input_height]
|
||||
)
|
||||
boxes = np.divide(boxes, input_shape, dtype=np.float32)
|
||||
boxes *= np.array(
|
||||
[self.img_width, self.img_height, self.img_width, self.img_height]
|
||||
)
|
||||
return boxes
|
||||
|
||||
def draw_detections(
|
||||
self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
|
||||
):
|
||||
return draw_detections(image, boxes, scores, class_ids, mask_alpha)
|
||||
|
||||
def get_input_details(self):
|
||||
model_inputs = self.session.get_inputs()
|
||||
self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
||||
|
||||
self.input_shape = model_inputs[0].shape
|
||||
self.input_height = self.input_shape[2]
|
||||
self.input_width = self.input_shape[3]
|
||||
|
||||
def get_output_details(self):
|
||||
model_outputs = self.session.get_outputs()
|
||||
self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import tempfile
|
||||
|
||||
import requests
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
model_file = hf_hub_download(
|
||||
repo_id="onnx-community/yolov10s", filename="onnx/model.onnx"
|
||||
)
|
||||
|
||||
yolov8_detector = YOLOv10(model_file)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(
|
||||
requests.get(
|
||||
"https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
|
||||
).content
|
||||
)
|
||||
f.seek(0)
|
||||
img = cv2.imread(f.name)
|
||||
|
||||
# # Detect Objects
|
||||
combined_image = yolov8_detector.detect_objects(img)
|
||||
|
||||
# Draw detections
|
||||
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
|
||||
cv2.imshow("Output", combined_image)
|
||||
cv2.waitKey(0)
|
||||
4
demo/object_detection/requirements.txt
Normal file
4
demo/object_detection/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
fastrtc
|
||||
opencv-python
|
||||
twilio
|
||||
onnxruntime-gpu
|
||||
237
demo/object_detection/utils.py
Normal file
237
demo/object_detection/utils.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
class_names = [
|
||||
"person",
|
||||
"bicycle",
|
||||
"car",
|
||||
"motorcycle",
|
||||
"airplane",
|
||||
"bus",
|
||||
"train",
|
||||
"truck",
|
||||
"boat",
|
||||
"traffic light",
|
||||
"fire hydrant",
|
||||
"stop sign",
|
||||
"parking meter",
|
||||
"bench",
|
||||
"bird",
|
||||
"cat",
|
||||
"dog",
|
||||
"horse",
|
||||
"sheep",
|
||||
"cow",
|
||||
"elephant",
|
||||
"bear",
|
||||
"zebra",
|
||||
"giraffe",
|
||||
"backpack",
|
||||
"umbrella",
|
||||
"handbag",
|
||||
"tie",
|
||||
"suitcase",
|
||||
"frisbee",
|
||||
"skis",
|
||||
"snowboard",
|
||||
"sports ball",
|
||||
"kite",
|
||||
"baseball bat",
|
||||
"baseball glove",
|
||||
"skateboard",
|
||||
"surfboard",
|
||||
"tennis racket",
|
||||
"bottle",
|
||||
"wine glass",
|
||||
"cup",
|
||||
"fork",
|
||||
"knife",
|
||||
"spoon",
|
||||
"bowl",
|
||||
"banana",
|
||||
"apple",
|
||||
"sandwich",
|
||||
"orange",
|
||||
"broccoli",
|
||||
"carrot",
|
||||
"hot dog",
|
||||
"pizza",
|
||||
"donut",
|
||||
"cake",
|
||||
"chair",
|
||||
"couch",
|
||||
"potted plant",
|
||||
"bed",
|
||||
"dining table",
|
||||
"toilet",
|
||||
"tv",
|
||||
"laptop",
|
||||
"mouse",
|
||||
"remote",
|
||||
"keyboard",
|
||||
"cell phone",
|
||||
"microwave",
|
||||
"oven",
|
||||
"toaster",
|
||||
"sink",
|
||||
"refrigerator",
|
||||
"book",
|
||||
"clock",
|
||||
"vase",
|
||||
"scissors",
|
||||
"teddy bear",
|
||||
"hair drier",
|
||||
"toothbrush",
|
||||
]
|
||||
|
||||
# Create a list of colors for each class where each color is a tuple of 3 integer values
|
||||
rng = np.random.default_rng(3)
|
||||
colors = rng.uniform(0, 255, size=(len(class_names), 3))
|
||||
|
||||
|
||||
def nms(boxes, scores, iou_threshold):
|
||||
# Sort by score
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
|
||||
keep_boxes = []
|
||||
while sorted_indices.size > 0:
|
||||
# Pick the last box
|
||||
box_id = sorted_indices[0]
|
||||
keep_boxes.append(box_id)
|
||||
|
||||
# Compute IoU of the picked box with the rest
|
||||
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
|
||||
|
||||
# Remove boxes with IoU over the threshold
|
||||
keep_indices = np.where(ious < iou_threshold)[0]
|
||||
|
||||
# print(keep_indices.shape, sorted_indices.shape)
|
||||
sorted_indices = sorted_indices[keep_indices + 1]
|
||||
|
||||
return keep_boxes
|
||||
|
||||
|
||||
def multiclass_nms(boxes, scores, class_ids, iou_threshold):
|
||||
unique_class_ids = np.unique(class_ids)
|
||||
|
||||
keep_boxes = []
|
||||
for class_id in unique_class_ids:
|
||||
class_indices = np.where(class_ids == class_id)[0]
|
||||
class_boxes = boxes[class_indices, :]
|
||||
class_scores = scores[class_indices]
|
||||
|
||||
class_keep_boxes = nms(class_boxes, class_scores, iou_threshold)
|
||||
keep_boxes.extend(class_indices[class_keep_boxes])
|
||||
|
||||
return keep_boxes
|
||||
|
||||
|
||||
def compute_iou(box, boxes):
|
||||
# Compute xmin, ymin, xmax, ymax for both boxes
|
||||
xmin = np.maximum(box[0], boxes[:, 0])
|
||||
ymin = np.maximum(box[1], boxes[:, 1])
|
||||
xmax = np.minimum(box[2], boxes[:, 2])
|
||||
ymax = np.minimum(box[3], boxes[:, 3])
|
||||
|
||||
# Compute intersection area
|
||||
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
|
||||
|
||||
# Compute union area
|
||||
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
||||
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
union_area = box_area + boxes_area - intersection_area
|
||||
|
||||
# Compute IoU
|
||||
iou = intersection_area / union_area
|
||||
|
||||
return iou
|
||||
|
||||
|
||||
def xywh2xyxy(x):
|
||||
# Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
|
||||
y = np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2
|
||||
return y
|
||||
|
||||
|
||||
def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
|
||||
det_img = image.copy()
|
||||
|
||||
img_height, img_width = image.shape[:2]
|
||||
font_size = min([img_height, img_width]) * 0.0006
|
||||
text_thickness = int(min([img_height, img_width]) * 0.001)
|
||||
|
||||
# det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
|
||||
|
||||
# Draw bounding boxes and labels of detections
|
||||
for class_id, box, score in zip(class_ids, boxes, scores):
|
||||
color = colors[class_id]
|
||||
|
||||
draw_box(det_img, box, color) # type: ignore
|
||||
|
||||
label = class_names[class_id]
|
||||
caption = f"{label} {int(score * 100)}%"
|
||||
draw_text(det_img, caption, box, color, font_size, text_thickness) # type: ignore
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
def draw_box(
|
||||
image: np.ndarray,
|
||||
box: np.ndarray,
|
||||
color: tuple[int, int, int] = (0, 0, 255),
|
||||
thickness: int = 2,
|
||||
) -> np.ndarray:
|
||||
x1, y1, x2, y2 = box.astype(int)
|
||||
return cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
|
||||
|
||||
|
||||
def draw_text(
|
||||
image: np.ndarray,
|
||||
text: str,
|
||||
box: np.ndarray,
|
||||
color: tuple[int, int, int] = (0, 0, 255),
|
||||
font_size: float = 0.001,
|
||||
text_thickness: int = 2,
|
||||
) -> np.ndarray:
|
||||
x1, y1, x2, y2 = box.astype(int)
|
||||
(tw, th), _ = cv2.getTextSize(
|
||||
text=text,
|
||||
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
||||
fontScale=font_size,
|
||||
thickness=text_thickness,
|
||||
)
|
||||
th = int(th * 1.2)
|
||||
|
||||
cv2.rectangle(image, (x1, y1), (x1 + tw, y1 - th), color, -1)
|
||||
|
||||
return cv2.putText(
|
||||
image,
|
||||
text,
|
||||
(x1, y1),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
font_size,
|
||||
(255, 255, 255),
|
||||
text_thickness,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
|
||||
def draw_masks(
|
||||
image: np.ndarray, boxes: np.ndarray, classes: np.ndarray, mask_alpha: float = 0.3
|
||||
) -> np.ndarray:
|
||||
mask_img = image.copy()
|
||||
|
||||
# Draw bounding boxes and labels of detections
|
||||
for box, class_id in zip(boxes, classes):
|
||||
color = colors[class_id]
|
||||
|
||||
x1, y1, x2, y2 = box.astype(int)
|
||||
|
||||
# Draw fill rectangle in mask image
|
||||
cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1) # type: ignore
|
||||
|
||||
return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0)
|
||||
Reference in New Issue
Block a user