mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
@@ -9,6 +9,7 @@ from collections import defaultdict
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
@@ -16,11 +17,14 @@ from typing import (
|
||||
)
|
||||
|
||||
from aiortc import (
|
||||
RTCConfiguration,
|
||||
RTCIceCandidate,
|
||||
RTCIceServer,
|
||||
RTCPeerConnection,
|
||||
RTCSessionDescription,
|
||||
)
|
||||
from aiortc.contrib.media import MediaRelay # type: ignore
|
||||
from anyio.to_thread import run_sync
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastrtc.tracks import (
|
||||
@@ -37,6 +41,7 @@ from fastrtc.tracks import (
|
||||
from fastrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
Context,
|
||||
RTCConfigurationCallable,
|
||||
create_message,
|
||||
webrtc_error_handler,
|
||||
)
|
||||
@@ -80,12 +85,32 @@ class WebRTCConnectionMixin:
|
||||
self.modality: Literal["video", "audio", "audio-video"]
|
||||
self.mode: Literal["send", "receive", "send-receive"]
|
||||
self.allow_extra_tracks: bool
|
||||
self.rtc_configuration: dict[str, Any] | None | RTCConfigurationCallable | None
|
||||
self.server_rtc_configuration: RTCConfiguration | None
|
||||
|
||||
@staticmethod
|
||||
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
||||
await asyncio.sleep(time_limit)
|
||||
await pc.close()
|
||||
|
||||
@staticmethod
|
||||
def convert_to_aiortc_format(
|
||||
rtc_configuration: dict[str, Any] | None,
|
||||
) -> RTCConfiguration | None:
|
||||
rtc_config = rtc_configuration
|
||||
if rtc_config is not None:
|
||||
rtc_config = RTCConfiguration(
|
||||
iceServers=[
|
||||
RTCIceServer(
|
||||
urls=server["urls"],
|
||||
username=server.get("username"),
|
||||
credential=server.get("credential"),
|
||||
)
|
||||
for server in rtc_config.get("iceServers", [])
|
||||
]
|
||||
)
|
||||
return rtc_config
|
||||
|
||||
async def connection_timeout(
|
||||
self,
|
||||
pc: RTCPeerConnection,
|
||||
@@ -148,6 +173,15 @@ class WebRTCConnectionMixin:
|
||||
|
||||
return set_outputs
|
||||
|
||||
async def resolve_rtc_configuration(self) -> dict[str, Any] | None:
|
||||
if inspect.isfunction(self.rtc_configuration):
|
||||
if inspect.iscoroutinefunction(self.rtc_configuration):
|
||||
return await self.rtc_configuration()
|
||||
else:
|
||||
return await run_sync(self.rtc_configuration)
|
||||
else:
|
||||
return cast(dict[str, Any], self.rtc_configuration) or {}
|
||||
|
||||
async def handle_offer(self, body, set_outputs):
|
||||
logger.debug("Starting to handle offer")
|
||||
logger.debug("Offer body %s", body)
|
||||
@@ -169,13 +203,9 @@ class WebRTCConnectionMixin:
|
||||
pc = self.pcs[webrtc_id]
|
||||
if pc.connectionState != "closed":
|
||||
try:
|
||||
# Parse the candidate string from the browser
|
||||
candidate_str = body["candidate"].get("candidate", "")
|
||||
|
||||
# Example format: "candidate:2393089663 1 udp 2122260223 192.168.86.60 63692 typ host generation 0 ufrag LkZb network-id 1 network-cost 10"
|
||||
# We need to parse this string to extract the required fields
|
||||
|
||||
# Parse the candidate string
|
||||
parts = candidate_str.split()
|
||||
if len(parts) >= 10 and parts[0].startswith("candidate:"):
|
||||
foundation = parts[0].split(":", 1)[1]
|
||||
@@ -253,7 +283,7 @@ class WebRTCConnectionMixin:
|
||||
|
||||
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
|
||||
|
||||
pc = RTCPeerConnection()
|
||||
pc = RTCPeerConnection(configuration=self.server_rtc_configuration)
|
||||
self.pcs[body["webrtc_id"]] = pc
|
||||
|
||||
if isinstance(self.event_handler, StreamHandlerImpl):
|
||||
|
||||
Reference in New Issue
Block a user