mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add support for trickle ice (#193)
* cherry-pick trickle-ice * Add code * Add code * format
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import (
|
|||||||
AsyncContextManager,
|
AsyncContextManager,
|
||||||
Callable,
|
Callable,
|
||||||
Literal,
|
Literal,
|
||||||
|
Optional,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@@ -28,7 +29,8 @@ curr_dir = Path(__file__).parent
|
|||||||
|
|
||||||
|
|
||||||
class Body(BaseModel):
|
class Body(BaseModel):
|
||||||
sdp: str
|
sdp: Optional[str] = None
|
||||||
|
candidate: Optional[dict[str, Any]] = None
|
||||||
type: str
|
type: str
|
||||||
webrtc_id: str
|
webrtc_id: str
|
||||||
|
|
||||||
|
|||||||
@@ -17747,16 +17747,13 @@ function R6(n, e, t = () => {
|
|||||||
}
|
}
|
||||||
async function L6(n, e, t, r = () => {
|
async function L6(n, e, t, r = () => {
|
||||||
}) {
|
}) {
|
||||||
return n.createOffer().then((a) => n.setLocalDescription(a)).then(() => new Promise((a) => {
|
return n.onicecandidate = ({ candidate: a }) => {
|
||||||
if (console.debug("ice gathering state", n.iceGatheringState), n.iceGatheringState === "complete")
|
a && (console.debug("Sending ICE candidate", a), e({
|
||||||
a();
|
candidate: a.toJSON(),
|
||||||
else {
|
webrtc_id: t,
|
||||||
const i = () => {
|
type: "ice-candidate"
|
||||||
n.iceGatheringState === "complete" && (console.debug("ice complete"), n.removeEventListener("icegatheringstatechange", i), a());
|
}).catch((i) => console.error("Error sending ICE candidate:", i)));
|
||||||
};
|
}, n.createOffer().then((a) => n.setLocalDescription(a)).then(() => {
|
||||||
n.addEventListener("icegatheringstatechange", i);
|
|
||||||
}
|
|
||||||
})).then(() => {
|
|
||||||
var a = n.localDescription;
|
var a = n.localDescription;
|
||||||
return R6(
|
return R6(
|
||||||
e,
|
e,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from aiortc import (
|
from aiortc import (
|
||||||
|
RTCIceCandidate,
|
||||||
RTCPeerConnection,
|
RTCPeerConnection,
|
||||||
RTCSessionDescription,
|
RTCSessionDescription,
|
||||||
)
|
)
|
||||||
@@ -66,7 +67,7 @@ class OutputQueue:
|
|||||||
|
|
||||||
class WebRTCConnectionMixin:
|
class WebRTCConnectionMixin:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.pcs = set([])
|
self.pcs: dict[str, RTCPeerConnection] = {}
|
||||||
self.relay = MediaRelay()
|
self.relay = MediaRelay()
|
||||||
self.connections = defaultdict(list)
|
self.connections = defaultdict(list)
|
||||||
self.data_channels = {}
|
self.data_channels = {}
|
||||||
@@ -149,6 +150,83 @@ class WebRTCConnectionMixin:
|
|||||||
async def handle_offer(self, body, set_outputs):
|
async def handle_offer(self, body, set_outputs):
|
||||||
logger.debug("Starting to handle offer")
|
logger.debug("Starting to handle offer")
|
||||||
logger.debug("Offer body %s", body)
|
logger.debug("Offer body %s", body)
|
||||||
|
|
||||||
|
if body.get("type") == "ice-candidate" and "candidate" in body:
|
||||||
|
webrtc_id = body.get("webrtc_id")
|
||||||
|
if webrtc_id not in self.pcs:
|
||||||
|
logger.warning(
|
||||||
|
f"Received ICE candidate for unknown connection: {webrtc_id}"
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"status": "failed",
|
||||||
|
"meta": {"error": "unknown_connection"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
pc = self.pcs[webrtc_id]
|
||||||
|
if pc.connectionState != "closed":
|
||||||
|
try:
|
||||||
|
# Parse the candidate string from the browser
|
||||||
|
candidate_str = body["candidate"].get("candidate", "")
|
||||||
|
|
||||||
|
# Example format: "candidate:2393089663 1 udp 2122260223 192.168.86.60 63692 typ host generation 0 ufrag LkZb network-id 1 network-cost 10"
|
||||||
|
# We need to parse this string to extract the required fields
|
||||||
|
|
||||||
|
# Parse the candidate string
|
||||||
|
parts = candidate_str.split()
|
||||||
|
if len(parts) >= 10 and parts[0].startswith("candidate:"):
|
||||||
|
foundation = parts[0].split(":", 1)[1]
|
||||||
|
component = int(parts[1])
|
||||||
|
protocol = parts[2]
|
||||||
|
priority = int(parts[3])
|
||||||
|
ip = parts[4]
|
||||||
|
port = int(parts[5])
|
||||||
|
# Find the candidate type
|
||||||
|
typ_index = parts.index("typ")
|
||||||
|
candidate_type = parts[typ_index + 1]
|
||||||
|
|
||||||
|
# Create the RTCIceCandidate object
|
||||||
|
ice_candidate = RTCIceCandidate(
|
||||||
|
component=component,
|
||||||
|
foundation=foundation,
|
||||||
|
ip=ip,
|
||||||
|
port=port,
|
||||||
|
priority=priority,
|
||||||
|
protocol=protocol,
|
||||||
|
type=candidate_type,
|
||||||
|
sdpMid=body["candidate"].get("sdpMid"),
|
||||||
|
sdpMLineIndex=body["candidate"].get("sdpMLineIndex"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the candidate to the peer connection
|
||||||
|
await pc.addIceCandidate(ice_candidate)
|
||||||
|
logger.debug(f"Added ICE candidate for {webrtc_id}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200, content={"status": "success"}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Invalid candidate format: {candidate_str}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"status": "failed",
|
||||||
|
"meta": {"error": "invalid_candidate_format"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding ICE candidate: {e}", exc_info=True)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={"status": "failed", "meta": {"error": str(e)}},
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={"status": "failed", "meta": {"error": "connection_closed"}},
|
||||||
|
)
|
||||||
|
|
||||||
if len(self.connections) >= cast(int, self.concurrency_limit):
|
if len(self.connections) >= cast(int, self.concurrency_limit):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=200,
|
status_code=200,
|
||||||
@@ -164,7 +242,7 @@ class WebRTCConnectionMixin:
|
|||||||
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
|
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
|
||||||
|
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
self.pcs.add(pc)
|
self.pcs[body["webrtc_id"]] = pc
|
||||||
|
|
||||||
if isinstance(self.event_handler, StreamHandlerBase):
|
if isinstance(self.event_handler, StreamHandlerBase):
|
||||||
handler = self.event_handler.copy()
|
handler = self.event_handler.copy()
|
||||||
@@ -192,7 +270,7 @@ class WebRTCConnectionMixin:
|
|||||||
if pc.iceConnectionState == "failed":
|
if pc.iceConnectionState == "failed":
|
||||||
await pc.close()
|
await pc.close()
|
||||||
self.connections.pop(body["webrtc_id"], None)
|
self.connections.pop(body["webrtc_id"], None)
|
||||||
self.pcs.discard(pc)
|
self.pcs.pop(body["webrtc_id"], None)
|
||||||
|
|
||||||
@pc.on("connectionstatechange")
|
@pc.on("connectionstatechange")
|
||||||
async def _():
|
async def _():
|
||||||
@@ -203,7 +281,7 @@ class WebRTCConnectionMixin:
|
|||||||
if connection:
|
if connection:
|
||||||
for conn in connection:
|
for conn in connection:
|
||||||
conn.stop()
|
conn.stop()
|
||||||
self.pcs.discard(pc)
|
self.pcs.pop(body["webrtc_id"], None)
|
||||||
if pc.connectionState == "connected":
|
if pc.connectionState == "connected":
|
||||||
self.connection_timeouts[body["webrtc_id"]].set()
|
self.connection_timeouts[body["webrtc_id"]].set()
|
||||||
if self.time_limit is not None:
|
if self.time_limit is not None:
|
||||||
|
|||||||
@@ -357,19 +357,20 @@
|
|||||||
const offer = await peerConnection.createOffer();
|
const offer = await peerConnection.createOffer();
|
||||||
await peerConnection.setLocalDescription(offer);
|
await peerConnection.setLocalDescription(offer);
|
||||||
|
|
||||||
await new Promise((resolve) => {
|
peerConnection.onicecandidate = ({ candidate }) => {
|
||||||
if (peerConnection.iceGatheringState === "complete") {
|
if (candidate) {
|
||||||
resolve();
|
console.debug("Sending ICE candidate", candidate);
|
||||||
} else {
|
fetch('/webrtc/offer', {
|
||||||
const checkState = () => {
|
method: 'POST',
|
||||||
if (peerConnection.iceGatheringState === "complete") {
|
headers: { 'Content-Type': 'application/json' },
|
||||||
peerConnection.removeEventListener("icegatheringstatechange", checkState);
|
body: JSON.stringify({
|
||||||
resolve();
|
candidate: candidate.toJSON(),
|
||||||
|
webrtc_id: webrtc_id,
|
||||||
|
type: "ice-candidate",
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
peerConnection.addEventListener("icegatheringstatechange", checkState);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
peerConnection.addEventListener('connectionstatechange', () => {
|
peerConnection.addEventListener('connectionstatechange', () => {
|
||||||
console.log('connectionstatechange', peerConnection.connectionState);
|
console.log('connectionstatechange', peerConnection.connectionState);
|
||||||
|
|||||||
@@ -358,6 +358,25 @@ async function setupWebRTC(peerConnection) {
|
|||||||
const offer = await peerConnection.createOffer();
|
const offer = await peerConnection.createOffer();
|
||||||
await peerConnection.setLocalDescription(offer);
|
await peerConnection.setLocalDescription(offer);
|
||||||
|
|
||||||
|
let webrtc_id = Math.random().toString(36).substring(7)
|
||||||
|
|
||||||
|
// Send ICE candidates to server
|
||||||
|
// (especially needed when server is behind firewall)
|
||||||
|
peerConnection.onicecandidate = ({ candidate }) => {
|
||||||
|
if (candidate) {
|
||||||
|
console.debug("Sending ICE candidate", candidate);
|
||||||
|
fetch('/webrtc/offer', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
candidate: candidate.toJSON(),
|
||||||
|
webrtc_id: webrtc_id,
|
||||||
|
type: "ice-candidate",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Send offer to server
|
// Send offer to server
|
||||||
const response = await fetch('/webrtc/offer', {
|
const response = await fetch('/webrtc/offer', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
@@ -365,7 +384,7 @@ async function setupWebRTC(peerConnection) {
|
|||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
sdp: offer.sdp,
|
sdp: offer.sdp,
|
||||||
type: offer.type,
|
type: offer.type,
|
||||||
webrtc_id: Math.random().toString(36).substring(7)
|
webrtc_id: webrtc_id
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -129,29 +129,22 @@ async function negotiate(
|
|||||||
webrtc_id: string,
|
webrtc_id: string,
|
||||||
reject_cb: (msg: object) => void = () => {},
|
reject_cb: (msg: object) => void = () => {},
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
|
pc.onicecandidate = ({ candidate }) => {
|
||||||
|
if (candidate) {
|
||||||
|
console.debug("Sending ICE candidate", candidate);
|
||||||
|
server_fn({
|
||||||
|
candidate: candidate.toJSON(),
|
||||||
|
webrtc_id: webrtc_id,
|
||||||
|
type: "ice-candidate",
|
||||||
|
}).catch((err) => console.error("Error sending ICE candidate:", err));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return pc
|
return pc
|
||||||
.createOffer()
|
.createOffer()
|
||||||
.then((offer) => {
|
.then((offer) => {
|
||||||
return pc.setLocalDescription(offer);
|
return pc.setLocalDescription(offer);
|
||||||
})
|
})
|
||||||
.then(() => {
|
|
||||||
// wait for ICE gathering to complete
|
|
||||||
return new Promise<void>((resolve) => {
|
|
||||||
console.debug("ice gathering state", pc.iceGatheringState);
|
|
||||||
if (pc.iceGatheringState === "complete") {
|
|
||||||
resolve();
|
|
||||||
} else {
|
|
||||||
const checkState = () => {
|
|
||||||
if (pc.iceGatheringState === "complete") {
|
|
||||||
console.debug("ice complete");
|
|
||||||
pc.removeEventListener("icegatheringstatechange", checkState);
|
|
||||||
resolve();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
pc.addEventListener("icegatheringstatechange", checkState);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
})
|
|
||||||
.then(() => {
|
.then(() => {
|
||||||
var offer = pc.localDescription;
|
var offer = pc.localDescription;
|
||||||
return make_offer(
|
return make_offer(
|
||||||
|
|||||||
Reference in New Issue
Block a user