mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
413 lines
16 KiB
Python
413 lines
16 KiB
Python
"""Mixin for handling WebRTC connections."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import inspect
|
|
import logging
|
|
from collections import defaultdict
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from typing import (
|
|
AsyncGenerator,
|
|
Literal,
|
|
ParamSpec,
|
|
TypeVar,
|
|
cast,
|
|
)
|
|
|
|
from aiortc import (
|
|
RTCIceCandidate,
|
|
RTCPeerConnection,
|
|
RTCSessionDescription,
|
|
)
|
|
from aiortc.contrib.media import MediaRelay # type: ignore
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from fastrtc.tracks import (
|
|
AudioCallback,
|
|
HandlerType,
|
|
ServerToClientAudio,
|
|
ServerToClientVideo,
|
|
StreamHandlerBase,
|
|
StreamHandlerImpl,
|
|
VideoCallback,
|
|
VideoEventHandler,
|
|
VideoStreamHandler,
|
|
VideoStreamHandler_,
|
|
)
|
|
from fastrtc.utils import (
|
|
AdditionalOutputs,
|
|
Context,
|
|
create_message,
|
|
webrtc_error_handler,
|
|
)
|
|
|
|
Track = (
|
|
VideoCallback
|
|
| VideoStreamHandler_
|
|
| AudioCallback
|
|
| ServerToClientAudio
|
|
| ServerToClientVideo
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# For the return type
|
|
R = TypeVar("R")
|
|
# For the parameter specification
|
|
P = ParamSpec("P")
|
|
|
|
|
|
@dataclass
|
|
class OutputQueue:
|
|
queue: asyncio.Queue[AdditionalOutputs] = field(default_factory=asyncio.Queue)
|
|
quit: asyncio.Event = field(default_factory=asyncio.Event)
|
|
|
|
|
|
class WebRTCConnectionMixin:
|
|
def __init__(self):
|
|
self.pcs: dict[str, RTCPeerConnection] = {}
|
|
self.relay = MediaRelay()
|
|
self.connections = defaultdict(list)
|
|
self.data_channels = {}
|
|
self.additional_outputs = defaultdict(OutputQueue)
|
|
self.handlers = {}
|
|
self.connection_timeouts = defaultdict(asyncio.Event)
|
|
# These attributes should be set by subclasses:
|
|
self.concurrency_limit: int | None
|
|
self.event_handler: HandlerType | None
|
|
self.time_limit: float | None
|
|
self.modality: Literal["video", "audio", "audio-video"]
|
|
self.mode: Literal["send", "receive", "send-receive"]
|
|
|
|
@staticmethod
|
|
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
|
await asyncio.sleep(time_limit)
|
|
await pc.close()
|
|
|
|
async def connection_timeout(
|
|
self,
|
|
pc: RTCPeerConnection,
|
|
webrtc_id: str,
|
|
time_limit: float,
|
|
):
|
|
try:
|
|
await asyncio.wait_for(
|
|
self.connection_timeouts[webrtc_id].wait(), time_limit
|
|
)
|
|
except (asyncio.TimeoutError, TimeoutError):
|
|
await pc.close()
|
|
self.connection_timeouts[webrtc_id].clear()
|
|
self.clean_up(webrtc_id)
|
|
|
|
def clean_up(self, webrtc_id: str):
|
|
self.handlers.pop(webrtc_id, None)
|
|
self.connection_timeouts.pop(webrtc_id, None)
|
|
connection = self.connections.pop(webrtc_id, [])
|
|
for conn in connection:
|
|
if isinstance(conn, AudioCallback):
|
|
if inspect.iscoroutinefunction(conn.event_handler.shutdown):
|
|
asyncio.create_task(conn.event_handler.shutdown())
|
|
conn.event_handler.reset()
|
|
else:
|
|
conn.event_handler.shutdown()
|
|
conn.event_handler.reset()
|
|
output = self.additional_outputs.pop(webrtc_id, None)
|
|
if output:
|
|
logger.debug("setting quit for webrtc id %s", webrtc_id)
|
|
output.quit.set()
|
|
self.data_channels.pop(webrtc_id, None)
|
|
return connection
|
|
|
|
def set_input(self, webrtc_id: str, *args):
|
|
if webrtc_id in self.connections:
|
|
for conn in self.connections[webrtc_id]:
|
|
conn.set_args(list(args))
|
|
|
|
async def output_stream(
|
|
self, webrtc_id: str
|
|
) -> AsyncGenerator[AdditionalOutputs, None]:
|
|
outputs = self.additional_outputs[webrtc_id]
|
|
while not outputs.quit.is_set():
|
|
try:
|
|
yield await asyncio.wait_for(outputs.queue.get(), 10)
|
|
except (asyncio.TimeoutError, TimeoutError):
|
|
logger.debug("Timeout waiting for output")
|
|
|
|
async def fetch_latest_output(self, webrtc_id: str) -> AdditionalOutputs:
|
|
outputs = self.additional_outputs[webrtc_id]
|
|
return await asyncio.wait_for(outputs.queue.get(), 10)
|
|
|
|
def set_additional_outputs(
|
|
self, webrtc_id: str
|
|
) -> Callable[[AdditionalOutputs], None]:
|
|
def set_outputs(outputs: AdditionalOutputs):
|
|
self.additional_outputs[webrtc_id].queue.put_nowait(outputs)
|
|
|
|
return set_outputs
|
|
|
|
async def handle_offer(self, body, set_outputs):
|
|
logger.debug("Starting to handle offer")
|
|
logger.debug("Offer body %s", body)
|
|
|
|
if body.get("type") == "ice-candidate" and "candidate" in body:
|
|
webrtc_id = body.get("webrtc_id")
|
|
if webrtc_id not in self.pcs:
|
|
logger.warning(
|
|
f"Received ICE candidate for unknown connection: {webrtc_id}"
|
|
)
|
|
return JSONResponse(
|
|
status_code=200,
|
|
content={
|
|
"status": "failed",
|
|
"meta": {"error": "unknown_connection"},
|
|
},
|
|
)
|
|
|
|
pc = self.pcs[webrtc_id]
|
|
if pc.connectionState != "closed":
|
|
try:
|
|
# Parse the candidate string from the browser
|
|
candidate_str = body["candidate"].get("candidate", "")
|
|
|
|
# Example format: "candidate:2393089663 1 udp 2122260223 192.168.86.60 63692 typ host generation 0 ufrag LkZb network-id 1 network-cost 10"
|
|
# We need to parse this string to extract the required fields
|
|
|
|
# Parse the candidate string
|
|
parts = candidate_str.split()
|
|
if len(parts) >= 10 and parts[0].startswith("candidate:"):
|
|
foundation = parts[0].split(":", 1)[1]
|
|
component = int(parts[1])
|
|
protocol = parts[2]
|
|
priority = int(parts[3])
|
|
ip = parts[4]
|
|
port = int(parts[5])
|
|
# Find the candidate type
|
|
typ_index = parts.index("typ")
|
|
candidate_type = parts[typ_index + 1]
|
|
|
|
# Create the RTCIceCandidate object
|
|
ice_candidate = RTCIceCandidate(
|
|
component=component,
|
|
foundation=foundation,
|
|
ip=ip,
|
|
port=port,
|
|
priority=priority,
|
|
protocol=protocol,
|
|
type=candidate_type,
|
|
sdpMid=body["candidate"].get("sdpMid"),
|
|
sdpMLineIndex=body["candidate"].get("sdpMLineIndex"),
|
|
)
|
|
|
|
# Add the candidate to the peer connection
|
|
await pc.addIceCandidate(ice_candidate)
|
|
logger.debug(f"Added ICE candidate for {webrtc_id}")
|
|
return JSONResponse(
|
|
status_code=200, content={"status": "success"}
|
|
)
|
|
else:
|
|
logger.error(f"Invalid candidate format: {candidate_str}")
|
|
return JSONResponse(
|
|
status_code=200,
|
|
content={
|
|
"status": "failed",
|
|
"meta": {"error": "invalid_candidate_format"},
|
|
},
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error adding ICE candidate: {e}", exc_info=True)
|
|
return JSONResponse(
|
|
status_code=200,
|
|
content={"status": "failed", "meta": {"error": str(e)}},
|
|
)
|
|
|
|
return JSONResponse(
|
|
status_code=200,
|
|
content={"status": "failed", "meta": {"error": "connection_closed"}},
|
|
)
|
|
|
|
if len(self.connections) >= cast(int, self.concurrency_limit):
|
|
return JSONResponse(
|
|
status_code=200,
|
|
content={
|
|
"status": "failed",
|
|
"meta": {
|
|
"error": "concurrency_limit_reached",
|
|
"limit": self.concurrency_limit,
|
|
},
|
|
},
|
|
)
|
|
|
|
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
|
|
|
|
pc = RTCPeerConnection()
|
|
self.pcs[body["webrtc_id"]] = pc
|
|
|
|
if isinstance(self.event_handler, StreamHandlerBase):
|
|
handler = self.event_handler.copy()
|
|
handler.emit = webrtc_error_handler(handler.emit) # type: ignore
|
|
handler.receive = webrtc_error_handler(handler.receive) # type: ignore
|
|
handler.start_up = webrtc_error_handler(handler.start_up) # type: ignore
|
|
handler.shutdown = webrtc_error_handler(handler.shutdown) # type: ignore
|
|
if hasattr(handler, "video_receive"):
|
|
handler.video_receive = webrtc_error_handler(handler.video_receive) # type: ignore
|
|
if hasattr(handler, "video_emit"):
|
|
handler.video_emit = webrtc_error_handler(handler.video_emit) # type: ignore
|
|
elif isinstance(self.event_handler, VideoStreamHandler):
|
|
self.event_handler.callable = cast(
|
|
VideoEventHandler, webrtc_error_handler(self.event_handler.callable)
|
|
)
|
|
handler = self.event_handler
|
|
else:
|
|
handler = webrtc_error_handler(cast(Callable, self.event_handler))
|
|
|
|
self.handlers[body["webrtc_id"]] = handler
|
|
|
|
@pc.on("iceconnectionstatechange")
|
|
async def on_iceconnectionstatechange():
|
|
logger.debug("ICE connection state change %s", pc.iceConnectionState)
|
|
if pc.iceConnectionState == "failed":
|
|
await pc.close()
|
|
self.connections.pop(body["webrtc_id"], None)
|
|
self.pcs.pop(body["webrtc_id"], None)
|
|
|
|
@pc.on("connectionstatechange")
|
|
async def _():
|
|
logger.debug("pc.connectionState %s", pc.connectionState)
|
|
if pc.connectionState in ["failed", "closed"]:
|
|
await pc.close()
|
|
connection = self.clean_up(body["webrtc_id"])
|
|
if connection:
|
|
for conn in connection:
|
|
conn.stop()
|
|
self.pcs.pop(body["webrtc_id"], None)
|
|
if pc.connectionState == "connected":
|
|
self.connection_timeouts[body["webrtc_id"]].set()
|
|
if self.time_limit is not None:
|
|
asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit))
|
|
|
|
@pc.on("track")
|
|
def _(track):
|
|
relay = MediaRelay()
|
|
handler = self.handlers[body["webrtc_id"]]
|
|
context = Context(webrtc_id=body["webrtc_id"])
|
|
if self.modality == "video" and track.kind == "video":
|
|
args = {}
|
|
handler_ = handler
|
|
if isinstance(handler, VideoStreamHandler):
|
|
handler_ = handler.callable
|
|
args["fps"] = handler.fps
|
|
args["skip_frames"] = handler.skip_frames
|
|
cb = VideoCallback(
|
|
relay.subscribe(track),
|
|
event_handler=cast(Callable, handler_),
|
|
set_additional_outputs=set_outputs,
|
|
mode=cast(Literal["send", "send-receive"], self.mode),
|
|
context=context,
|
|
**args,
|
|
)
|
|
elif self.modality == "audio-video" and track.kind == "video":
|
|
cb = VideoStreamHandler_(
|
|
relay.subscribe(track),
|
|
event_handler=handler, # type: ignore
|
|
set_additional_outputs=set_outputs,
|
|
fps=cast(StreamHandlerImpl, handler).fps,
|
|
context=context,
|
|
)
|
|
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
|
|
eh = cast(StreamHandlerImpl, handler)
|
|
eh._loop = asyncio.get_running_loop()
|
|
cb = AudioCallback(
|
|
relay.subscribe(track),
|
|
event_handler=eh,
|
|
set_additional_outputs=set_outputs,
|
|
context=context,
|
|
)
|
|
else:
|
|
if self.modality not in ["video", "audio", "audio-video"]:
|
|
msg = "Modality must be either video, audio, or audio-video"
|
|
else:
|
|
msg = f"Unsupported track kind '{track.kind}' for modality '{self.modality}'"
|
|
raise ValueError(msg)
|
|
if body["webrtc_id"] not in self.connections:
|
|
self.connections[body["webrtc_id"]] = []
|
|
|
|
self.connections[body["webrtc_id"]].append(cb)
|
|
if body["webrtc_id"] in self.data_channels:
|
|
for conn in self.connections[body["webrtc_id"]]:
|
|
conn.set_channel(self.data_channels[body["webrtc_id"]])
|
|
if self.mode == "send-receive":
|
|
logger.debug("Adding track to peer connection %s", cb)
|
|
pc.addTrack(cb)
|
|
elif self.mode == "send":
|
|
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
|
|
|
|
context = Context(webrtc_id=body["webrtc_id"])
|
|
if self.mode == "receive":
|
|
if self.modality == "video":
|
|
if isinstance(self.event_handler, VideoStreamHandler):
|
|
cb = ServerToClientVideo(
|
|
cast(Callable, self.event_handler.callable),
|
|
set_additional_outputs=set_outputs,
|
|
fps=self.event_handler.fps,
|
|
context=context,
|
|
)
|
|
else:
|
|
cb = ServerToClientVideo(
|
|
cast(Callable, self.event_handler),
|
|
set_additional_outputs=set_outputs,
|
|
context=context,
|
|
)
|
|
elif self.modality == "audio":
|
|
cb = ServerToClientAudio(
|
|
cast(Callable, self.event_handler),
|
|
set_additional_outputs=set_outputs,
|
|
context=context,
|
|
)
|
|
else:
|
|
raise ValueError("Modality must be either video or audio")
|
|
|
|
logger.debug("Adding track to peer connection %s", cb)
|
|
pc.addTrack(cb)
|
|
self.connections[body["webrtc_id"]].append(cb)
|
|
cb.on("ended", lambda: self.clean_up(body["webrtc_id"]))
|
|
|
|
@pc.on("datachannel")
|
|
def _(channel):
|
|
logger.debug(f"Data channel established: {channel.label}")
|
|
|
|
self.data_channels[body["webrtc_id"]] = channel
|
|
|
|
async def set_channel(webrtc_id: str):
|
|
while not self.connections.get(webrtc_id):
|
|
await asyncio.sleep(0.05)
|
|
logger.debug("setting channel for webrtc id %s", webrtc_id)
|
|
for conn in self.connections[webrtc_id]:
|
|
conn.set_channel(channel)
|
|
|
|
asyncio.create_task(set_channel(body["webrtc_id"]))
|
|
|
|
@channel.on("message")
|
|
def _(message):
|
|
logger.debug(f"Received message: {message}")
|
|
if channel.readyState == "open":
|
|
channel.send(
|
|
create_message("log", data=f"Server received: {message}")
|
|
)
|
|
|
|
# handle offer
|
|
await pc.setRemoteDescription(offer)
|
|
asyncio.create_task(self.connection_timeout(pc, body["webrtc_id"], 30))
|
|
# send answer
|
|
answer = await pc.createAnswer()
|
|
await pc.setLocalDescription(answer) # type: ignore
|
|
logger.debug("done handling offer about to return")
|
|
await asyncio.sleep(0.1)
|
|
|
|
return {
|
|
"sdp": pc.localDescription.sdp,
|
|
"type": pc.localDescription.type,
|
|
}
|