diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py index 42e21b2..87b84bb 100644 --- a/backend/fastrtc/stream.py +++ b/backend/fastrtc/stream.py @@ -102,6 +102,7 @@ class Stream(WebRTCConnectionMixin): allow_extra_tracks: bool = False, rtp_params: dict[str, Any] | None = None, rtc_configuration: RTCConfigurationCallable | None = None, + server_rtc_configuration: dict[str, Any] | None = None, track_constraints: dict[str, Any] | None = None, additional_inputs: list[Component] | None = None, additional_outputs: list[Component] | None = None, @@ -121,6 +122,8 @@ class Stream(WebRTCConnectionMixin): rtp_params: Optional dictionary of RTP encoding parameters. rtc_configuration: Optional Callable or dictionary for RTCPeerConnection configuration (e.g., ICE servers). Required when deploying on Colab or Spaces. + server_rtc_configuration: Optional dictionary for RTCPeerConnection configuration on the server side. Note + that setting iceServers to be an empty list will mean no ICE servers will be used in the server. track_constraints: Optional dictionary of constraints for media tracks (e.g., resolution, frame rate). additional_inputs: Optional list of extra Gradio input components. additional_outputs: Optional list of extra Gradio output components. Requires `additional_outputs_handler`. @@ -149,6 +152,9 @@ class Stream(WebRTCConnectionMixin): self.track_constraints = track_constraints self.webrtc_component: WebRTC self.rtc_configuration = rtc_configuration + self.server_rtc_configuration = self.convert_to_aiortc_format( + server_rtc_configuration + ) self._ui = self._generate_default_ui(ui_args) self._ui.launch = self._wrap_gradio_launch(self._ui.launch) diff --git a/backend/fastrtc/webrtc.py b/backend/fastrtc/webrtc.py index fa3977e..7f4a1db 100644 --- a/backend/fastrtc/webrtc.py +++ b/backend/fastrtc/webrtc.py @@ -2,7 +2,6 @@ from __future__ import annotations -import inspect import logging from collections.abc import Callable, Iterable, Sequence from typing import ( @@ -15,8 +14,6 @@ from typing import ( cast, ) -import anyio -import anyio.to_thread from gradio import wasm_utils from gradio.components.base import Component, server from gradio_client import handle_file @@ -82,6 +79,7 @@ class WebRTC(Component, WebRTCConnectionMixin): key: int | str | None = None, mirror_webcam: bool = True, rtc_configuration: dict[str, Any] | None | RTCConfigurationCallable = None, + server_rtc_configuration: dict[str, Any] | None = None, track_constraints: dict[str, Any] | None = None, time_limit: float | None = None, allow_extra_tracks: bool = False, @@ -115,6 +113,8 @@ class WebRTC(Component, WebRTCConnectionMixin): key: if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved. mirror_webcam: if True webcam will be mirrored. Default is True. rtc_configuration: WebRTC configuration options. See https://developer.mozilla.org/en-US/docs/Web/API/RTCPeerConnection/RTCPeerConnection . If running the demo on a remote server, you will need to specify a rtc_configuration. See https://freddyaboulton.github.io/gradio-webrtc/deployment/ + server_rtc_configuration: Optional dictionary for RTCPeerConnection configuration on the server side. Note + that setting iceServers to be an empty list will mean no ICE servers will be used in the server. track_constraints: Media track constraints for WebRTC. For example, to set video height, width use {"width": {"exact": 800}, "height": {"exact": 600}, "aspectRatio": {"exact": 1.33333}} time_limit: Maximum duration in seconds for recording. allow_extra_tracks: Allow tracks not supported by the modality. For example, a peer connection with an audio track would be allowed even if modality is 'video', which normally throws a ``ValueError`` exception. @@ -134,6 +134,9 @@ class WebRTC(Component, WebRTCConnectionMixin): self.mirror_webcam = mirror_webcam self.concurrency_limit = 1 self.rtc_configuration = rtc_configuration + self.server_rtc_configuration = self.convert_to_aiortc_format( + server_rtc_configuration + ) self.allow_extra_tracks = allow_extra_tracks self.mode = mode self.modality = modality @@ -238,7 +241,6 @@ class WebRTC(Component, WebRTCConnectionMixin): inputs = list(inputs) async def handler(webrtc_id: str, *args): - print("webrtc_id", webrtc_id) async for next_outputs in self.output_stream(webrtc_id): yield fn(*args, *next_outputs.args) # type: ignore @@ -366,13 +368,7 @@ class WebRTC(Component, WebRTCConnectionMixin): @server async def turn(self, _): try: - if inspect.isfunction(self.rtc_configuration): - if inspect.iscoroutinefunction(self.rtc_configuration): - return await self.rtc_configuration() - else: - return await anyio.to_thread.run_sync(self.rtc_configuration) - else: - return self.rtc_configuration or {} + return await self.resolve_rtc_configuration() except Exception as e: return {"error": str(e)} diff --git a/backend/fastrtc/webrtc_connection_mixin.py b/backend/fastrtc/webrtc_connection_mixin.py index c2fe240..782092c 100644 --- a/backend/fastrtc/webrtc_connection_mixin.py +++ b/backend/fastrtc/webrtc_connection_mixin.py @@ -9,6 +9,7 @@ from collections import defaultdict from collections.abc import AsyncGenerator, Callable from dataclasses import dataclass, field from typing import ( + Any, Literal, ParamSpec, TypeVar, @@ -16,11 +17,14 @@ from typing import ( ) from aiortc import ( + RTCConfiguration, RTCIceCandidate, + RTCIceServer, RTCPeerConnection, RTCSessionDescription, ) from aiortc.contrib.media import MediaRelay # type: ignore +from anyio.to_thread import run_sync from fastapi.responses import JSONResponse from fastrtc.tracks import ( @@ -37,6 +41,7 @@ from fastrtc.tracks import ( from fastrtc.utils import ( AdditionalOutputs, Context, + RTCConfigurationCallable, create_message, webrtc_error_handler, ) @@ -80,12 +85,32 @@ class WebRTCConnectionMixin: self.modality: Literal["video", "audio", "audio-video"] self.mode: Literal["send", "receive", "send-receive"] self.allow_extra_tracks: bool + self.rtc_configuration: dict[str, Any] | None | RTCConfigurationCallable | None + self.server_rtc_configuration: RTCConfiguration | None @staticmethod async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float): await asyncio.sleep(time_limit) await pc.close() + @staticmethod + def convert_to_aiortc_format( + rtc_configuration: dict[str, Any] | None, + ) -> RTCConfiguration | None: + rtc_config = rtc_configuration + if rtc_config is not None: + rtc_config = RTCConfiguration( + iceServers=[ + RTCIceServer( + urls=server["urls"], + username=server.get("username"), + credential=server.get("credential"), + ) + for server in rtc_config.get("iceServers", []) + ] + ) + return rtc_config + async def connection_timeout( self, pc: RTCPeerConnection, @@ -148,6 +173,15 @@ class WebRTCConnectionMixin: return set_outputs + async def resolve_rtc_configuration(self) -> dict[str, Any] | None: + if inspect.isfunction(self.rtc_configuration): + if inspect.iscoroutinefunction(self.rtc_configuration): + return await self.rtc_configuration() + else: + return await run_sync(self.rtc_configuration) + else: + return cast(dict[str, Any], self.rtc_configuration) or {} + async def handle_offer(self, body, set_outputs): logger.debug("Starting to handle offer") logger.debug("Offer body %s", body) @@ -169,13 +203,9 @@ class WebRTCConnectionMixin: pc = self.pcs[webrtc_id] if pc.connectionState != "closed": try: - # Parse the candidate string from the browser candidate_str = body["candidate"].get("candidate", "") # Example format: "candidate:2393089663 1 udp 2122260223 192.168.86.60 63692 typ host generation 0 ufrag LkZb network-id 1 network-cost 10" - # We need to parse this string to extract the required fields - - # Parse the candidate string parts = candidate_str.split() if len(parts) >= 10 and parts[0].startswith("candidate:"): foundation = parts[0].split(":", 1)[1] @@ -253,7 +283,7 @@ class WebRTCConnectionMixin: offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"]) - pc = RTCPeerConnection() + pc = RTCPeerConnection(configuration=self.server_rtc_configuration) self.pcs[body["webrtc_id"]] = pc if isinstance(self.event_handler, StreamHandlerImpl): diff --git a/demo/talk_to_sambanova/app.py b/demo/talk_to_sambanova/app.py index 8866195..c74a68a 100644 --- a/demo/talk_to_sambanova/app.py +++ b/demo/talk_to_sambanova/app.py @@ -13,6 +13,7 @@ from fastrtc import ( AdditionalOutputs, ReplyOnPause, Stream, + get_cloudflare_turn_credentials, get_cloudflare_turn_credentials_async, get_stt_model, ) @@ -76,6 +77,7 @@ stream = Stream( additional_outputs_handler=lambda *a: (a[2], a[3]), concurrency_limit=20 if get_space() else None, rtc_configuration=get_cloudflare_turn_credentials_async, + server_rtc_configuration=get_cloudflare_turn_credentials(ttl=36_000), ) app = FastAPI() diff --git a/docs/deployment.md b/docs/deployment.md index 852a1f4..41bfc4a 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -13,7 +13,7 @@ Cloudflare also offers a managed TURN server with [Cloudflare Calls](https://www Cloudflare and Hugging Face have partnered to allow you to stream 10gb of WebRTC traffic per month for free with a Hugging Face account! ```python -from fastrtc import Stream, get_cloudflare_turn_credentials_async +from fastrtc import Stream, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials # Make sure the HF_TOKEN environment variable is set # Or pass in a callable with all arguments set @@ -26,11 +26,14 @@ async def get_credentials(): stream = Stream( handler=..., rtc_configuration=get_credentials, + server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000) modality="audio", mode="send-receive", ) ``` +!!! tip + Setting an rtc configuration in the server is recommended but not required. It's a good practice to set short lived credentials in the client (default `ttl` value of 10 minutes when calling `get_cloudflare_turn_credentials*`) but you can share the same credentials between server and client. ### With a Cloudflare API Token diff --git a/docs/reference/stream.md b/docs/reference/stream.md index 3e6c796..df5c634 100644 --- a/docs/reference/stream.md +++ b/docs/reference/stream.md @@ -39,7 +39,8 @@ This class encapsulates the logic for handling real-time communication (WebRTC) | `additional_outputs_handler` | `Callable \| None` | Handler for additional outputs. | | `track_constraints` | `dict[str, Any] \| None` | Constraints for media tracks (e.g., resolution). | | `webrtc_component` | `WebRTC` | The underlying Gradio WebRTC component instance. | -| `rtc_configuration` | `dict[str, Any] \| None` | Configuration for the RTCPeerConnection (e.g., ICE servers). | +| `rtc_configuration` | `dict[str, Any] \| None \| Callable` | Configuration for the RTCPeerConnection (e.g., ICE servers). | +| `server_rtc_configuration` | `dict[str, Any] \| None` | Configuration for the RTCPeerConnection (e.g., ICE servers) to be used in the server | | `_ui` | `Blocks` | The Gradio Blocks UI instance. | ## Methods diff --git a/pyproject.toml b/pyproject.toml index 35011f6..9d6aaac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" [project] name = "fastrtc" -version = "0.0.21" +version = "0.0.22.rc2" description = "The realtime communication library for Python" readme = "README.md" license = "MIT" @@ -28,17 +28,17 @@ keywords = [ "video processing", "gradio-custom-component", ] -# Add dependencies here dependencies = [ "gradio>=4.0,<6.0", "aiortc", + "aioice>=0.10.1", "audioop-lts;python_version>='3.13'", "librosa", "numpy>=2.0.2", # because of librosa "numba>=0.60.0", "standard-aifc;python_version>='3.13'", "standard-sunau;python_version>='3.13'", -] +] # Add dependencies here classifiers = [ 'Development Status :: 3 - Alpha', 'Operating System :: OS Independent',