diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py index 0663ad9..1742fa2 100644 --- a/backend/fastrtc/stream.py +++ b/backend/fastrtc/stream.py @@ -5,6 +5,7 @@ from typing import ( AsyncContextManager, Callable, Literal, + Optional, TypedDict, cast, ) @@ -28,7 +29,8 @@ curr_dir = Path(__file__).parent class Body(BaseModel): - sdp: str + sdp: Optional[str] = None + candidate: Optional[dict[str, Any]] = None type: str webrtc_id: str diff --git a/backend/fastrtc/templates/component/index.js b/backend/fastrtc/templates/component/index.js index 295998a..7623847 100644 --- a/backend/fastrtc/templates/component/index.js +++ b/backend/fastrtc/templates/component/index.js @@ -17747,16 +17747,13 @@ function R6(n, e, t = () => { } async function L6(n, e, t, r = () => { }) { - return n.createOffer().then((a) => n.setLocalDescription(a)).then(() => new Promise((a) => { - if (console.debug("ice gathering state", n.iceGatheringState), n.iceGatheringState === "complete") - a(); - else { - const i = () => { - n.iceGatheringState === "complete" && (console.debug("ice complete"), n.removeEventListener("icegatheringstatechange", i), a()); - }; - n.addEventListener("icegatheringstatechange", i); - } - })).then(() => { + return n.onicecandidate = ({ candidate: a }) => { + a && (console.debug("Sending ICE candidate", a), e({ + candidate: a.toJSON(), + webrtc_id: t, + type: "ice-candidate" + }).catch((i) => console.error("Error sending ICE candidate:", i))); + }, n.createOffer().then((a) => n.setLocalDescription(a)).then(() => { var a = n.localDescription; return R6( e, diff --git a/backend/fastrtc/webrtc_connection_mixin.py b/backend/fastrtc/webrtc_connection_mixin.py index fae614c..70762ad 100644 --- a/backend/fastrtc/webrtc_connection_mixin.py +++ b/backend/fastrtc/webrtc_connection_mixin.py @@ -17,6 +17,7 @@ from typing import ( ) from aiortc import ( + RTCIceCandidate, RTCPeerConnection, RTCSessionDescription, ) @@ -66,7 +67,7 @@ class OutputQueue: class WebRTCConnectionMixin: def __init__(self): - self.pcs = set([]) + self.pcs: dict[str, RTCPeerConnection] = {} self.relay = MediaRelay() self.connections = defaultdict(list) self.data_channels = {} @@ -149,6 +150,83 @@ class WebRTCConnectionMixin: async def handle_offer(self, body, set_outputs): logger.debug("Starting to handle offer") logger.debug("Offer body %s", body) + + if body.get("type") == "ice-candidate" and "candidate" in body: + webrtc_id = body.get("webrtc_id") + if webrtc_id not in self.pcs: + logger.warning( + f"Received ICE candidate for unknown connection: {webrtc_id}" + ) + return JSONResponse( + status_code=200, + content={ + "status": "failed", + "meta": {"error": "unknown_connection"}, + }, + ) + + pc = self.pcs[webrtc_id] + if pc.connectionState != "closed": + try: + # Parse the candidate string from the browser + candidate_str = body["candidate"].get("candidate", "") + + # Example format: "candidate:2393089663 1 udp 2122260223 192.168.86.60 63692 typ host generation 0 ufrag LkZb network-id 1 network-cost 10" + # We need to parse this string to extract the required fields + + # Parse the candidate string + parts = candidate_str.split() + if len(parts) >= 10 and parts[0].startswith("candidate:"): + foundation = parts[0].split(":", 1)[1] + component = int(parts[1]) + protocol = parts[2] + priority = int(parts[3]) + ip = parts[4] + port = int(parts[5]) + # Find the candidate type + typ_index = parts.index("typ") + candidate_type = parts[typ_index + 1] + + # Create the RTCIceCandidate object + ice_candidate = RTCIceCandidate( + component=component, + foundation=foundation, + ip=ip, + port=port, + priority=priority, + protocol=protocol, + type=candidate_type, + sdpMid=body["candidate"].get("sdpMid"), + sdpMLineIndex=body["candidate"].get("sdpMLineIndex"), + ) + + # Add the candidate to the peer connection + await pc.addIceCandidate(ice_candidate) + logger.debug(f"Added ICE candidate for {webrtc_id}") + return JSONResponse( + status_code=200, content={"status": "success"} + ) + else: + logger.error(f"Invalid candidate format: {candidate_str}") + return JSONResponse( + status_code=200, + content={ + "status": "failed", + "meta": {"error": "invalid_candidate_format"}, + }, + ) + except Exception as e: + logger.error(f"Error adding ICE candidate: {e}", exc_info=True) + return JSONResponse( + status_code=200, + content={"status": "failed", "meta": {"error": str(e)}}, + ) + + return JSONResponse( + status_code=200, + content={"status": "failed", "meta": {"error": "connection_closed"}}, + ) + if len(self.connections) >= cast(int, self.concurrency_limit): return JSONResponse( status_code=200, @@ -164,7 +242,7 @@ class WebRTCConnectionMixin: offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"]) pc = RTCPeerConnection() - self.pcs.add(pc) + self.pcs[body["webrtc_id"]] = pc if isinstance(self.event_handler, StreamHandlerBase): handler = self.event_handler.copy() @@ -192,7 +270,7 @@ class WebRTCConnectionMixin: if pc.iceConnectionState == "failed": await pc.close() self.connections.pop(body["webrtc_id"], None) - self.pcs.discard(pc) + self.pcs.pop(body["webrtc_id"], None) @pc.on("connectionstatechange") async def _(): @@ -203,7 +281,7 @@ class WebRTCConnectionMixin: if connection: for conn in connection: conn.stop() - self.pcs.discard(pc) + self.pcs.pop(body["webrtc_id"], None) if pc.connectionState == "connected": self.connection_timeouts[body["webrtc_id"]].set() if self.time_limit is not None: diff --git a/demo/talk_to_sambanova/index.html b/demo/talk_to_sambanova/index.html index d4e206a..594fe37 100644 --- a/demo/talk_to_sambanova/index.html +++ b/demo/talk_to_sambanova/index.html @@ -357,19 +357,20 @@ const offer = await peerConnection.createOffer(); await peerConnection.setLocalDescription(offer); - await new Promise((resolve) => { - if (peerConnection.iceGatheringState === "complete") { - resolve(); - } else { - const checkState = () => { - if (peerConnection.iceGatheringState === "complete") { - peerConnection.removeEventListener("icegatheringstatechange", checkState); - resolve(); - } - }; - peerConnection.addEventListener("icegatheringstatechange", checkState); + peerConnection.onicecandidate = ({ candidate }) => { + if (candidate) { + console.debug("Sending ICE candidate", candidate); + fetch('/webrtc/offer', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + candidate: candidate.toJSON(), + webrtc_id: webrtc_id, + type: "ice-candidate", + }) + }) } - }); + }; peerConnection.addEventListener('connectionstatechange', () => { console.log('connectionstatechange', peerConnection.connectionState); diff --git a/docs/userguide/api.md b/docs/userguide/api.md index 4ff8cdb..aa344da 100644 --- a/docs/userguide/api.md +++ b/docs/userguide/api.md @@ -358,6 +358,25 @@ async function setupWebRTC(peerConnection) { const offer = await peerConnection.createOffer(); await peerConnection.setLocalDescription(offer); + let webrtc_id = Math.random().toString(36).substring(7) + + // Send ICE candidates to server + // (especially needed when server is behind firewall) + peerConnection.onicecandidate = ({ candidate }) => { + if (candidate) { + console.debug("Sending ICE candidate", candidate); + fetch('/webrtc/offer', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + candidate: candidate.toJSON(), + webrtc_id: webrtc_id, + type: "ice-candidate", + }) + }) + } + }; + // Send offer to server const response = await fetch('/webrtc/offer', { method: 'POST', @@ -365,7 +384,7 @@ async function setupWebRTC(peerConnection) { body: JSON.stringify({ sdp: offer.sdp, type: offer.type, - webrtc_id: Math.random().toString(36).substring(7) + webrtc_id: webrtc_id }) }); diff --git a/frontend/shared/webrtc_utils.ts b/frontend/shared/webrtc_utils.ts index b400c46..7159cec 100644 --- a/frontend/shared/webrtc_utils.ts +++ b/frontend/shared/webrtc_utils.ts @@ -129,29 +129,22 @@ async function negotiate( webrtc_id: string, reject_cb: (msg: object) => void = () => {}, ): Promise { + pc.onicecandidate = ({ candidate }) => { + if (candidate) { + console.debug("Sending ICE candidate", candidate); + server_fn({ + candidate: candidate.toJSON(), + webrtc_id: webrtc_id, + type: "ice-candidate", + }).catch((err) => console.error("Error sending ICE candidate:", err)); + } + }; + return pc .createOffer() .then((offer) => { return pc.setLocalDescription(offer); }) - .then(() => { - // wait for ICE gathering to complete - return new Promise((resolve) => { - console.debug("ice gathering state", pc.iceGatheringState); - if (pc.iceGatheringState === "complete") { - resolve(); - } else { - const checkState = () => { - if (pc.iceGatheringState === "complete") { - console.debug("ice complete"); - pc.removeEventListener("icegatheringstatechange", checkState); - resolve(); - } - }; - pc.addEventListener("icegatheringstatechange", checkState); - } - }); - }) .then(() => { var offer = pc.localDescription; return make_offer(