mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
format
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from .webrtc import WebRTC
|
from .webrtc import WebRTC
|
||||||
|
|
||||||
__all__ = ['WebRTC']
|
__all__ = ["WebRTC"]
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from aiortc import RTCPeerConnection, RTCSessionDescription
|
|||||||
from aiortc.contrib.media import MediaRelay
|
from aiortc.contrib.media import MediaRelay
|
||||||
from aiortc import VideoStreamTrack
|
from aiortc import VideoStreamTrack
|
||||||
from aiortc.mediastreams import MediaStreamError
|
from aiortc.mediastreams import MediaStreamError
|
||||||
from aiortc.contrib.media import VideoFrame # type: ignore
|
from aiortc.contrib.media import VideoFrame # type: ignore
|
||||||
from gradio_client import handle_file
|
from gradio_client import handle_file
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -29,7 +29,6 @@ if wasm_utils.IS_WASM:
|
|||||||
raise ValueError("Not supported in gradio-lite!")
|
raise ValueError("Not supported in gradio-lite!")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VideoCallback(VideoStreamTrack):
|
class VideoCallback(VideoStreamTrack):
|
||||||
"""
|
"""
|
||||||
This works for streaming input and output
|
This works for streaming input and output
|
||||||
@@ -47,7 +46,9 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
self.event_handler = event_handler
|
self.event_handler = event_handler
|
||||||
self.latest_args: str | list[Any] = "not_set"
|
self.latest_args: str | list[Any] = "not_set"
|
||||||
|
|
||||||
def add_frame_to_payload(self, args: list[Any], frame: np.ndarray | None) -> list[Any]:
|
def add_frame_to_payload(
|
||||||
|
self, args: list[Any], frame: np.ndarray | None
|
||||||
|
) -> list[Any]:
|
||||||
new_args = []
|
new_args = []
|
||||||
for val in args:
|
for val in args:
|
||||||
if isinstance(val, str) and val == "__webrtc_value__":
|
if isinstance(val, str) and val == "__webrtc_value__":
|
||||||
@@ -55,14 +56,12 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
else:
|
else:
|
||||||
new_args.append(val)
|
new_args.append(val)
|
||||||
return new_args
|
return new_args
|
||||||
|
|
||||||
|
|
||||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
frame = await self.track.recv()
|
frame = await self.track.recv()
|
||||||
except MediaStreamError:
|
except MediaStreamError:
|
||||||
@@ -70,12 +69,10 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
frame_array = frame.to_ndarray(format="bgr24")
|
frame_array = frame.to_ndarray(format="bgr24")
|
||||||
|
|
||||||
if self.latest_args == "not_set":
|
if self.latest_args == "not_set":
|
||||||
print("args not set")
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
|
|
||||||
args = self.add_frame_to_payload(self.latest_args, frame_array)
|
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
||||||
|
|
||||||
array = self.event_handler(*args)
|
array = self.event_handler(*args)
|
||||||
|
|
||||||
new_frame = self.array_to_frame(array)
|
new_frame = self.array_to_frame(array)
|
||||||
@@ -110,18 +107,11 @@ class WebRTC(Component):
|
|||||||
relay = MediaRelay()
|
relay = MediaRelay()
|
||||||
connections: dict[str, VideoCallback] = {}
|
connections: dict[str, VideoCallback] = {}
|
||||||
|
|
||||||
EVENTS = [
|
EVENTS = ["tick"]
|
||||||
"tick"
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
value: str
|
value: None = None,
|
||||||
| Path
|
|
||||||
| tuple[str | Path, str | Path | None]
|
|
||||||
| Callable
|
|
||||||
| None = None,
|
|
||||||
*,
|
|
||||||
height: int | str | None = None,
|
height: int | str | None = None,
|
||||||
width: int | str | None = None,
|
width: int | str | None = None,
|
||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
@@ -138,10 +128,6 @@ class WebRTC(Component):
|
|||||||
render: bool = True,
|
render: bool = True,
|
||||||
key: int | str | None = None,
|
key: int | str | None = None,
|
||||||
mirror_webcam: bool = True,
|
mirror_webcam: bool = True,
|
||||||
show_share_button: bool | None = None,
|
|
||||||
show_download_button: bool | None = None,
|
|
||||||
min_length: int | None = None,
|
|
||||||
max_length: int | None = None,
|
|
||||||
rtc_configuration: dict[str, Any] | None = None,
|
rtc_configuration: dict[str, Any] | None = None,
|
||||||
time_limit: float | None = None,
|
time_limit: float | None = None,
|
||||||
):
|
):
|
||||||
@@ -180,14 +166,6 @@ class WebRTC(Component):
|
|||||||
self.width = width
|
self.width = width
|
||||||
self.mirror_webcam = mirror_webcam
|
self.mirror_webcam = mirror_webcam
|
||||||
self.concurrency_limit = 1
|
self.concurrency_limit = 1
|
||||||
self.show_share_button = (
|
|
||||||
(utils.get_space() is not None)
|
|
||||||
if show_share_button is None
|
|
||||||
else show_share_button
|
|
||||||
)
|
|
||||||
self.show_download_button = show_download_button
|
|
||||||
self.min_length = min_length
|
|
||||||
self.max_length = max_length
|
|
||||||
self.rtc_configuration = rtc_configuration
|
self.rtc_configuration = rtc_configuration
|
||||||
self.event_handler: Callable | None = None
|
self.event_handler: Callable | None = None
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -216,7 +194,6 @@ class WebRTC(Component):
|
|||||||
"""
|
"""
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
def postprocess(self, value: Any) -> str:
|
def postprocess(self, value: Any) -> str:
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
@@ -229,7 +206,7 @@ class WebRTC(Component):
|
|||||||
def set_output(self, webrtc_id: str, *args):
|
def set_output(self, webrtc_id: str, *args):
|
||||||
if webrtc_id in self.connections:
|
if webrtc_id in self.connections:
|
||||||
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
|
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
fn: Callable[..., Any] | None = None,
|
fn: Callable[..., Any] | None = None,
|
||||||
@@ -238,8 +215,8 @@ class WebRTC(Component):
|
|||||||
js: str | None = None,
|
js: str | None = None,
|
||||||
concurrency_limit: int | None | Literal["default"] = "default",
|
concurrency_limit: int | None | Literal["default"] = "default",
|
||||||
concurrency_id: str | None = None,
|
concurrency_id: str | None = None,
|
||||||
time_limit: float | None = None):
|
time_limit: float | None = None,
|
||||||
|
):
|
||||||
from gradio.blocks import Block
|
from gradio.blocks import Block
|
||||||
|
|
||||||
if isinstance(inputs, Block):
|
if isinstance(inputs, Block):
|
||||||
@@ -248,23 +225,33 @@ class WebRTC(Component):
|
|||||||
outputs = [outputs]
|
outputs = [outputs]
|
||||||
|
|
||||||
if cast(list[Block], inputs)[0] != self:
|
if cast(list[Block], inputs)[0] != self:
|
||||||
raise ValueError("In the webrtc stream event, the first input component must be the WebRTC component.")
|
raise ValueError(
|
||||||
|
"In the webrtc stream event, the first input component must be the WebRTC component."
|
||||||
|
)
|
||||||
|
|
||||||
if len(cast(list[Block], outputs)) != 1 and cast(list[Block], outputs)[0] != self:
|
if (
|
||||||
raise ValueError("In the webrtc stream event, the only output component must be the WebRTC component.")
|
len(cast(list[Block], outputs)) != 1
|
||||||
|
and cast(list[Block], outputs)[0] != self
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"In the webrtc stream event, the only output component must be the WebRTC component."
|
||||||
|
)
|
||||||
|
|
||||||
self.concurrency_limit = 1 if concurrency_limit in ["default", None] else concurrency_limit
|
self.concurrency_limit = (
|
||||||
|
1 if concurrency_limit in ["default", None] else concurrency_limit
|
||||||
|
)
|
||||||
self.event_handler = fn
|
self.event_handler = fn
|
||||||
self.time_limit = time_limit
|
self.time_limit = time_limit
|
||||||
return self.tick(self.set_output,
|
return self.tick( # type: ignore
|
||||||
inputs=inputs,
|
self.set_output,
|
||||||
outputs=None,
|
inputs=inputs,
|
||||||
concurrency_id=concurrency_id,
|
outputs=None,
|
||||||
concurrency_limit=None,
|
concurrency_id=concurrency_id,
|
||||||
stream_every=0.5,
|
concurrency_limit=None,
|
||||||
time_limit=None,
|
stream_every=0.5,
|
||||||
js=js
|
time_limit=None,
|
||||||
)
|
js=js,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
||||||
@@ -273,11 +260,10 @@ class WebRTC(Component):
|
|||||||
|
|
||||||
@server
|
@server
|
||||||
async def offer(self, body):
|
async def offer(self, body):
|
||||||
|
|
||||||
if len(self.connections) >= cast(int, self.concurrency_limit):
|
if len(self.connections) >= cast(int, self.concurrency_limit):
|
||||||
return {"status": "failed"}
|
return {"status": "failed"}
|
||||||
|
|
||||||
offer = RTCSessionDescription(sdp=body['sdp'], type=body['type'])
|
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
|
||||||
|
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
self.pcs.add(pc)
|
self.pcs.add(pc)
|
||||||
@@ -287,14 +273,14 @@ class WebRTC(Component):
|
|||||||
print(pc.iceConnectionState)
|
print(pc.iceConnectionState)
|
||||||
if pc.iceConnectionState == "failed":
|
if pc.iceConnectionState == "failed":
|
||||||
await pc.close()
|
await pc.close()
|
||||||
self.connections.pop(body['webrtc_id'], None)
|
self.connections.pop(body["webrtc_id"], None)
|
||||||
self.pcs.discard(pc)
|
self.pcs.discard(pc)
|
||||||
|
|
||||||
@pc.on("connectionstatechange")
|
@pc.on("connectionstatechange")
|
||||||
async def on_connectionstatechange():
|
async def on_connectionstatechange():
|
||||||
if pc.connectionState in ["failed", "closed"]:
|
if pc.connectionState in ["failed", "closed"]:
|
||||||
await pc.close()
|
await pc.close()
|
||||||
self.connections.pop(body['webrtc_id'], None)
|
self.connections.pop(body["webrtc_id"], None)
|
||||||
self.pcs.discard(pc)
|
self.pcs.discard(pc)
|
||||||
if pc.connectionState == "connected":
|
if pc.connectionState == "connected":
|
||||||
if self.time_limit is not None:
|
if self.time_limit is not None:
|
||||||
@@ -303,12 +289,12 @@ class WebRTC(Component):
|
|||||||
@pc.on("track")
|
@pc.on("track")
|
||||||
def on_track(track):
|
def on_track(track):
|
||||||
cb = VideoCallback(
|
cb = VideoCallback(
|
||||||
self.relay.subscribe(track),
|
self.relay.subscribe(track),
|
||||||
event_handler=cast(Callable, self.event_handler)
|
event_handler=cast(Callable, self.event_handler),
|
||||||
)
|
)
|
||||||
self.connections[body['webrtc_id']] = cb
|
self.connections[body["webrtc_id"]] = cb
|
||||||
pc.addTrack(cb)
|
pc.addTrack(cb)
|
||||||
|
|
||||||
# handle offer
|
# handle offer
|
||||||
await pc.setRemoteDescription(offer)
|
await pc.setRemoteDescription(offer)
|
||||||
|
|
||||||
@@ -317,9 +303,9 @@ class WebRTC(Component):
|
|||||||
await pc.setLocalDescription(answer) # type: ignore
|
await pc.setLocalDescription(answer) # type: ignore
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"sdp": pc.localDescription.sdp,
|
"sdp": pc.localDescription.sdp,
|
||||||
"type": pc.localDescription.type,
|
"type": pc.localDescription.type,
|
||||||
}
|
}
|
||||||
|
|
||||||
def example_payload(self) -> Any:
|
def example_payload(self) -> Any:
|
||||||
return {
|
return {
|
||||||
@@ -331,7 +317,5 @@ class WebRTC(Component):
|
|||||||
def example_value(self) -> Any:
|
def example_value(self) -> Any:
|
||||||
return "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
|
return "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
|
||||||
|
|
||||||
|
|
||||||
def api_info(self) -> Any:
|
def api_info(self) -> Any:
|
||||||
return {"type": "number"}
|
return {"type": "number"}
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ else:
|
|||||||
|
|
||||||
rtc_configuration = None
|
rtc_configuration = None
|
||||||
|
|
||||||
|
|
||||||
def detection(image, conf_threshold=0.3):
|
def detection(image, conf_threshold=0.3):
|
||||||
image = cv2.resize(image, (model.input_width, model.input_height))
|
image = cv2.resize(image, (model.input_width, model.input_height))
|
||||||
new_image = model.detect_objects(image, conf_threshold)
|
new_image = model.detect_objects(image, conf_threshold)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from utils import draw_detections
|
|||||||
|
|
||||||
class YOLOv10:
|
class YOLOv10:
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
self.initialize_model(path)
|
self.initialize_model(path)
|
||||||
|
|
||||||
@@ -53,7 +52,11 @@ class YOLOv10:
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
|
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
|
||||||
boxes, scores, class_ids, = self.process_output(outputs, conf_threshold)
|
(
|
||||||
|
boxes,
|
||||||
|
scores,
|
||||||
|
class_ids,
|
||||||
|
) = self.process_output(outputs, conf_threshold)
|
||||||
return self.draw_detections(image, boxes, scores, class_ids)
|
return self.draw_detections(image, boxes, scores, class_ids)
|
||||||
|
|
||||||
def process_output(self, output, conf_threshold=0.3):
|
def process_output(self, output, conf_threshold=0.3):
|
||||||
@@ -83,7 +86,7 @@ class YOLOv10:
|
|||||||
boxes = self.rescale_boxes(boxes)
|
boxes = self.rescale_boxes(boxes)
|
||||||
|
|
||||||
# Convert boxes to xyxy format
|
# Convert boxes to xyxy format
|
||||||
#boxes = xywh2xyxy(boxes)
|
# boxes = xywh2xyxy(boxes)
|
||||||
|
|
||||||
return boxes
|
return boxes
|
||||||
|
|
||||||
@@ -98,10 +101,10 @@ class YOLOv10:
|
|||||||
)
|
)
|
||||||
return boxes
|
return boxes
|
||||||
|
|
||||||
def draw_detections(self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4):
|
def draw_detections(
|
||||||
return draw_detections(
|
self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
|
||||||
image, boxes, scores, class_ids, mask_alpha
|
):
|
||||||
)
|
return draw_detections(image, boxes, scores, class_ids, mask_alpha)
|
||||||
|
|
||||||
def get_input_details(self):
|
def get_input_details(self):
|
||||||
model_inputs = self.session.get_inputs()
|
model_inputs = self.session.get_inputs()
|
||||||
@@ -139,7 +142,6 @@ if __name__ == "__main__":
|
|||||||
# # Detect Objects
|
# # Detect Objects
|
||||||
combined_image = yolov8_detector.detect_objects(img)
|
combined_image = yolov8_detector.detect_objects(img)
|
||||||
|
|
||||||
|
|
||||||
# Draw detections
|
# Draw detections
|
||||||
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
|
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
|
||||||
cv2.imshow("Output", combined_image)
|
cv2.imshow("Output", combined_image)
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
|
|||||||
font_size = min([img_height, img_width]) * 0.0006
|
font_size = min([img_height, img_width]) * 0.0006
|
||||||
text_thickness = int(min([img_height, img_width]) * 0.001)
|
text_thickness = int(min([img_height, img_width]) * 0.001)
|
||||||
|
|
||||||
#det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
|
# det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
|
||||||
|
|
||||||
# Draw bounding boxes and labels of detections
|
# Draw bounding boxes and labels of detections
|
||||||
for class_id, box, score in zip(class_ids, boxes, scores):
|
for class_id, box, score in zip(class_ids, boxes, scores):
|
||||||
|
|||||||
Reference in New Issue
Block a user