mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 09:29:23 +08:00
Add support for trickle ice (#193)
* cherry-pick trickle-ice * Add code * Add code * format
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import (
|
||||
AsyncContextManager,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
cast,
|
||||
)
|
||||
@@ -28,7 +29,8 @@ curr_dir = Path(__file__).parent
|
||||
|
||||
|
||||
class Body(BaseModel):
|
||||
sdp: str
|
||||
sdp: Optional[str] = None
|
||||
candidate: Optional[dict[str, Any]] = None
|
||||
type: str
|
||||
webrtc_id: str
|
||||
|
||||
|
||||
@@ -17747,16 +17747,13 @@ function R6(n, e, t = () => {
|
||||
}
|
||||
async function L6(n, e, t, r = () => {
|
||||
}) {
|
||||
return n.createOffer().then((a) => n.setLocalDescription(a)).then(() => new Promise((a) => {
|
||||
if (console.debug("ice gathering state", n.iceGatheringState), n.iceGatheringState === "complete")
|
||||
a();
|
||||
else {
|
||||
const i = () => {
|
||||
n.iceGatheringState === "complete" && (console.debug("ice complete"), n.removeEventListener("icegatheringstatechange", i), a());
|
||||
};
|
||||
n.addEventListener("icegatheringstatechange", i);
|
||||
}
|
||||
})).then(() => {
|
||||
return n.onicecandidate = ({ candidate: a }) => {
|
||||
a && (console.debug("Sending ICE candidate", a), e({
|
||||
candidate: a.toJSON(),
|
||||
webrtc_id: t,
|
||||
type: "ice-candidate"
|
||||
}).catch((i) => console.error("Error sending ICE candidate:", i)));
|
||||
}, n.createOffer().then((a) => n.setLocalDescription(a)).then(() => {
|
||||
var a = n.localDescription;
|
||||
return R6(
|
||||
e,
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from aiortc import (
|
||||
RTCIceCandidate,
|
||||
RTCPeerConnection,
|
||||
RTCSessionDescription,
|
||||
)
|
||||
@@ -66,7 +67,7 @@ class OutputQueue:
|
||||
|
||||
class WebRTCConnectionMixin:
|
||||
def __init__(self):
|
||||
self.pcs = set([])
|
||||
self.pcs: dict[str, RTCPeerConnection] = {}
|
||||
self.relay = MediaRelay()
|
||||
self.connections = defaultdict(list)
|
||||
self.data_channels = {}
|
||||
@@ -149,6 +150,83 @@ class WebRTCConnectionMixin:
|
||||
async def handle_offer(self, body, set_outputs):
|
||||
logger.debug("Starting to handle offer")
|
||||
logger.debug("Offer body %s", body)
|
||||
|
||||
if body.get("type") == "ice-candidate" and "candidate" in body:
|
||||
webrtc_id = body.get("webrtc_id")
|
||||
if webrtc_id not in self.pcs:
|
||||
logger.warning(
|
||||
f"Received ICE candidate for unknown connection: {webrtc_id}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": "failed",
|
||||
"meta": {"error": "unknown_connection"},
|
||||
},
|
||||
)
|
||||
|
||||
pc = self.pcs[webrtc_id]
|
||||
if pc.connectionState != "closed":
|
||||
try:
|
||||
# Parse the candidate string from the browser
|
||||
candidate_str = body["candidate"].get("candidate", "")
|
||||
|
||||
# Example format: "candidate:2393089663 1 udp 2122260223 192.168.86.60 63692 typ host generation 0 ufrag LkZb network-id 1 network-cost 10"
|
||||
# We need to parse this string to extract the required fields
|
||||
|
||||
# Parse the candidate string
|
||||
parts = candidate_str.split()
|
||||
if len(parts) >= 10 and parts[0].startswith("candidate:"):
|
||||
foundation = parts[0].split(":", 1)[1]
|
||||
component = int(parts[1])
|
||||
protocol = parts[2]
|
||||
priority = int(parts[3])
|
||||
ip = parts[4]
|
||||
port = int(parts[5])
|
||||
# Find the candidate type
|
||||
typ_index = parts.index("typ")
|
||||
candidate_type = parts[typ_index + 1]
|
||||
|
||||
# Create the RTCIceCandidate object
|
||||
ice_candidate = RTCIceCandidate(
|
||||
component=component,
|
||||
foundation=foundation,
|
||||
ip=ip,
|
||||
port=port,
|
||||
priority=priority,
|
||||
protocol=protocol,
|
||||
type=candidate_type,
|
||||
sdpMid=body["candidate"].get("sdpMid"),
|
||||
sdpMLineIndex=body["candidate"].get("sdpMLineIndex"),
|
||||
)
|
||||
|
||||
# Add the candidate to the peer connection
|
||||
await pc.addIceCandidate(ice_candidate)
|
||||
logger.debug(f"Added ICE candidate for {webrtc_id}")
|
||||
return JSONResponse(
|
||||
status_code=200, content={"status": "success"}
|
||||
)
|
||||
else:
|
||||
logger.error(f"Invalid candidate format: {candidate_str}")
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": "failed",
|
||||
"meta": {"error": "invalid_candidate_format"},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding ICE candidate: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"status": "failed", "meta": {"error": str(e)}},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"status": "failed", "meta": {"error": "connection_closed"}},
|
||||
)
|
||||
|
||||
if len(self.connections) >= cast(int, self.concurrency_limit):
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
@@ -164,7 +242,7 @@ class WebRTCConnectionMixin:
|
||||
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
|
||||
|
||||
pc = RTCPeerConnection()
|
||||
self.pcs.add(pc)
|
||||
self.pcs[body["webrtc_id"]] = pc
|
||||
|
||||
if isinstance(self.event_handler, StreamHandlerBase):
|
||||
handler = self.event_handler.copy()
|
||||
@@ -192,7 +270,7 @@ class WebRTCConnectionMixin:
|
||||
if pc.iceConnectionState == "failed":
|
||||
await pc.close()
|
||||
self.connections.pop(body["webrtc_id"], None)
|
||||
self.pcs.discard(pc)
|
||||
self.pcs.pop(body["webrtc_id"], None)
|
||||
|
||||
@pc.on("connectionstatechange")
|
||||
async def _():
|
||||
@@ -203,7 +281,7 @@ class WebRTCConnectionMixin:
|
||||
if connection:
|
||||
for conn in connection:
|
||||
conn.stop()
|
||||
self.pcs.discard(pc)
|
||||
self.pcs.pop(body["webrtc_id"], None)
|
||||
if pc.connectionState == "connected":
|
||||
self.connection_timeouts[body["webrtc_id"]].set()
|
||||
if self.time_limit is not None:
|
||||
|
||||
@@ -357,19 +357,20 @@
|
||||
const offer = await peerConnection.createOffer();
|
||||
await peerConnection.setLocalDescription(offer);
|
||||
|
||||
await new Promise((resolve) => {
|
||||
if (peerConnection.iceGatheringState === "complete") {
|
||||
resolve();
|
||||
} else {
|
||||
const checkState = () => {
|
||||
if (peerConnection.iceGatheringState === "complete") {
|
||||
peerConnection.removeEventListener("icegatheringstatechange", checkState);
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
peerConnection.addEventListener("icegatheringstatechange", checkState);
|
||||
peerConnection.onicecandidate = ({ candidate }) => {
|
||||
if (candidate) {
|
||||
console.debug("Sending ICE candidate", candidate);
|
||||
fetch('/webrtc/offer', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
candidate: candidate.toJSON(),
|
||||
webrtc_id: webrtc_id,
|
||||
type: "ice-candidate",
|
||||
})
|
||||
})
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
peerConnection.addEventListener('connectionstatechange', () => {
|
||||
console.log('connectionstatechange', peerConnection.connectionState);
|
||||
|
||||
@@ -358,6 +358,25 @@ async function setupWebRTC(peerConnection) {
|
||||
const offer = await peerConnection.createOffer();
|
||||
await peerConnection.setLocalDescription(offer);
|
||||
|
||||
let webrtc_id = Math.random().toString(36).substring(7)
|
||||
|
||||
// Send ICE candidates to server
|
||||
// (especially needed when server is behind firewall)
|
||||
peerConnection.onicecandidate = ({ candidate }) => {
|
||||
if (candidate) {
|
||||
console.debug("Sending ICE candidate", candidate);
|
||||
fetch('/webrtc/offer', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
candidate: candidate.toJSON(),
|
||||
webrtc_id: webrtc_id,
|
||||
type: "ice-candidate",
|
||||
})
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
// Send offer to server
|
||||
const response = await fetch('/webrtc/offer', {
|
||||
method: 'POST',
|
||||
@@ -365,7 +384,7 @@ async function setupWebRTC(peerConnection) {
|
||||
body: JSON.stringify({
|
||||
sdp: offer.sdp,
|
||||
type: offer.type,
|
||||
webrtc_id: Math.random().toString(36).substring(7)
|
||||
webrtc_id: webrtc_id
|
||||
})
|
||||
});
|
||||
|
||||
|
||||
@@ -129,29 +129,22 @@ async function negotiate(
|
||||
webrtc_id: string,
|
||||
reject_cb: (msg: object) => void = () => {},
|
||||
): Promise<void> {
|
||||
pc.onicecandidate = ({ candidate }) => {
|
||||
if (candidate) {
|
||||
console.debug("Sending ICE candidate", candidate);
|
||||
server_fn({
|
||||
candidate: candidate.toJSON(),
|
||||
webrtc_id: webrtc_id,
|
||||
type: "ice-candidate",
|
||||
}).catch((err) => console.error("Error sending ICE candidate:", err));
|
||||
}
|
||||
};
|
||||
|
||||
return pc
|
||||
.createOffer()
|
||||
.then((offer) => {
|
||||
return pc.setLocalDescription(offer);
|
||||
})
|
||||
.then(() => {
|
||||
// wait for ICE gathering to complete
|
||||
return new Promise<void>((resolve) => {
|
||||
console.debug("ice gathering state", pc.iceGatheringState);
|
||||
if (pc.iceGatheringState === "complete") {
|
||||
resolve();
|
||||
} else {
|
||||
const checkState = () => {
|
||||
if (pc.iceGatheringState === "complete") {
|
||||
console.debug("ice complete");
|
||||
pc.removeEventListener("icegatheringstatechange", checkState);
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
pc.addEventListener("icegatheringstatechange", checkState);
|
||||
}
|
||||
});
|
||||
})
|
||||
.then(() => {
|
||||
var offer = pc.localDescription;
|
||||
return make_offer(
|
||||
|
||||
Reference in New Issue
Block a user