mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add support for trickle ice (#193)
* cherry-pick trickle-ice * Add code * Add code * format
This commit is contained in:
@@ -17,6 +17,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from aiortc import (
|
||||
RTCIceCandidate,
|
||||
RTCPeerConnection,
|
||||
RTCSessionDescription,
|
||||
)
|
||||
@@ -66,7 +67,7 @@ class OutputQueue:
|
||||
|
||||
class WebRTCConnectionMixin:
|
||||
def __init__(self):
|
||||
self.pcs = set([])
|
||||
self.pcs: dict[str, RTCPeerConnection] = {}
|
||||
self.relay = MediaRelay()
|
||||
self.connections = defaultdict(list)
|
||||
self.data_channels = {}
|
||||
@@ -149,6 +150,83 @@ class WebRTCConnectionMixin:
|
||||
async def handle_offer(self, body, set_outputs):
|
||||
logger.debug("Starting to handle offer")
|
||||
logger.debug("Offer body %s", body)
|
||||
|
||||
if body.get("type") == "ice-candidate" and "candidate" in body:
|
||||
webrtc_id = body.get("webrtc_id")
|
||||
if webrtc_id not in self.pcs:
|
||||
logger.warning(
|
||||
f"Received ICE candidate for unknown connection: {webrtc_id}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": "failed",
|
||||
"meta": {"error": "unknown_connection"},
|
||||
},
|
||||
)
|
||||
|
||||
pc = self.pcs[webrtc_id]
|
||||
if pc.connectionState != "closed":
|
||||
try:
|
||||
# Parse the candidate string from the browser
|
||||
candidate_str = body["candidate"].get("candidate", "")
|
||||
|
||||
# Example format: "candidate:2393089663 1 udp 2122260223 192.168.86.60 63692 typ host generation 0 ufrag LkZb network-id 1 network-cost 10"
|
||||
# We need to parse this string to extract the required fields
|
||||
|
||||
# Parse the candidate string
|
||||
parts = candidate_str.split()
|
||||
if len(parts) >= 10 and parts[0].startswith("candidate:"):
|
||||
foundation = parts[0].split(":", 1)[1]
|
||||
component = int(parts[1])
|
||||
protocol = parts[2]
|
||||
priority = int(parts[3])
|
||||
ip = parts[4]
|
||||
port = int(parts[5])
|
||||
# Find the candidate type
|
||||
typ_index = parts.index("typ")
|
||||
candidate_type = parts[typ_index + 1]
|
||||
|
||||
# Create the RTCIceCandidate object
|
||||
ice_candidate = RTCIceCandidate(
|
||||
component=component,
|
||||
foundation=foundation,
|
||||
ip=ip,
|
||||
port=port,
|
||||
priority=priority,
|
||||
protocol=protocol,
|
||||
type=candidate_type,
|
||||
sdpMid=body["candidate"].get("sdpMid"),
|
||||
sdpMLineIndex=body["candidate"].get("sdpMLineIndex"),
|
||||
)
|
||||
|
||||
# Add the candidate to the peer connection
|
||||
await pc.addIceCandidate(ice_candidate)
|
||||
logger.debug(f"Added ICE candidate for {webrtc_id}")
|
||||
return JSONResponse(
|
||||
status_code=200, content={"status": "success"}
|
||||
)
|
||||
else:
|
||||
logger.error(f"Invalid candidate format: {candidate_str}")
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": "failed",
|
||||
"meta": {"error": "invalid_candidate_format"},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding ICE candidate: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"status": "failed", "meta": {"error": str(e)}},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"status": "failed", "meta": {"error": "connection_closed"}},
|
||||
)
|
||||
|
||||
if len(self.connections) >= cast(int, self.concurrency_limit):
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
@@ -164,7 +242,7 @@ class WebRTCConnectionMixin:
|
||||
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
|
||||
|
||||
pc = RTCPeerConnection()
|
||||
self.pcs.add(pc)
|
||||
self.pcs[body["webrtc_id"]] = pc
|
||||
|
||||
if isinstance(self.event_handler, StreamHandlerBase):
|
||||
handler = self.event_handler.copy()
|
||||
@@ -192,7 +270,7 @@ class WebRTCConnectionMixin:
|
||||
if pc.iceConnectionState == "failed":
|
||||
await pc.close()
|
||||
self.connections.pop(body["webrtc_id"], None)
|
||||
self.pcs.discard(pc)
|
||||
self.pcs.pop(body["webrtc_id"], None)
|
||||
|
||||
@pc.on("connectionstatechange")
|
||||
async def _():
|
||||
@@ -203,7 +281,7 @@ class WebRTCConnectionMixin:
|
||||
if connection:
|
||||
for conn in connection:
|
||||
conn.stop()
|
||||
self.pcs.discard(pc)
|
||||
self.pcs.pop(body["webrtc_id"], None)
|
||||
if pc.connectionState == "connected":
|
||||
self.connection_timeouts[body["webrtc_id"]].set()
|
||||
if self.time_limit is not None:
|
||||
|
||||
Reference in New Issue
Block a user