mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
@@ -102,6 +102,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
allow_extra_tracks: bool = False,
|
allow_extra_tracks: bool = False,
|
||||||
rtp_params: dict[str, Any] | None = None,
|
rtp_params: dict[str, Any] | None = None,
|
||||||
rtc_configuration: RTCConfigurationCallable | None = None,
|
rtc_configuration: RTCConfigurationCallable | None = None,
|
||||||
|
server_rtc_configuration: dict[str, Any] | None = None,
|
||||||
track_constraints: dict[str, Any] | None = None,
|
track_constraints: dict[str, Any] | None = None,
|
||||||
additional_inputs: list[Component] | None = None,
|
additional_inputs: list[Component] | None = None,
|
||||||
additional_outputs: list[Component] | None = None,
|
additional_outputs: list[Component] | None = None,
|
||||||
@@ -121,6 +122,8 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
rtp_params: Optional dictionary of RTP encoding parameters.
|
rtp_params: Optional dictionary of RTP encoding parameters.
|
||||||
rtc_configuration: Optional Callable or dictionary for RTCPeerConnection configuration (e.g., ICE servers).
|
rtc_configuration: Optional Callable or dictionary for RTCPeerConnection configuration (e.g., ICE servers).
|
||||||
Required when deploying on Colab or Spaces.
|
Required when deploying on Colab or Spaces.
|
||||||
|
server_rtc_configuration: Optional dictionary for RTCPeerConnection configuration on the server side. Note
|
||||||
|
that setting iceServers to be an empty list will mean no ICE servers will be used in the server.
|
||||||
track_constraints: Optional dictionary of constraints for media tracks (e.g., resolution, frame rate).
|
track_constraints: Optional dictionary of constraints for media tracks (e.g., resolution, frame rate).
|
||||||
additional_inputs: Optional list of extra Gradio input components.
|
additional_inputs: Optional list of extra Gradio input components.
|
||||||
additional_outputs: Optional list of extra Gradio output components. Requires `additional_outputs_handler`.
|
additional_outputs: Optional list of extra Gradio output components. Requires `additional_outputs_handler`.
|
||||||
@@ -149,6 +152,9 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
self.track_constraints = track_constraints
|
self.track_constraints = track_constraints
|
||||||
self.webrtc_component: WebRTC
|
self.webrtc_component: WebRTC
|
||||||
self.rtc_configuration = rtc_configuration
|
self.rtc_configuration = rtc_configuration
|
||||||
|
self.server_rtc_configuration = self.convert_to_aiortc_format(
|
||||||
|
server_rtc_configuration
|
||||||
|
)
|
||||||
self._ui = self._generate_default_ui(ui_args)
|
self._ui = self._generate_default_ui(ui_args)
|
||||||
self._ui.launch = self._wrap_gradio_launch(self._ui.launch)
|
self._ui.launch = self._wrap_gradio_launch(self._ui.launch)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -15,8 +14,6 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import anyio
|
|
||||||
import anyio.to_thread
|
|
||||||
from gradio import wasm_utils
|
from gradio import wasm_utils
|
||||||
from gradio.components.base import Component, server
|
from gradio.components.base import Component, server
|
||||||
from gradio_client import handle_file
|
from gradio_client import handle_file
|
||||||
@@ -82,6 +79,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
|||||||
key: int | str | None = None,
|
key: int | str | None = None,
|
||||||
mirror_webcam: bool = True,
|
mirror_webcam: bool = True,
|
||||||
rtc_configuration: dict[str, Any] | None | RTCConfigurationCallable = None,
|
rtc_configuration: dict[str, Any] | None | RTCConfigurationCallable = None,
|
||||||
|
server_rtc_configuration: dict[str, Any] | None = None,
|
||||||
track_constraints: dict[str, Any] | None = None,
|
track_constraints: dict[str, Any] | None = None,
|
||||||
time_limit: float | None = None,
|
time_limit: float | None = None,
|
||||||
allow_extra_tracks: bool = False,
|
allow_extra_tracks: bool = False,
|
||||||
@@ -115,6 +113,8 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
|||||||
key: if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.
|
key: if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.
|
||||||
mirror_webcam: if True webcam will be mirrored. Default is True.
|
mirror_webcam: if True webcam will be mirrored. Default is True.
|
||||||
rtc_configuration: WebRTC configuration options. See https://developer.mozilla.org/en-US/docs/Web/API/RTCPeerConnection/RTCPeerConnection . If running the demo on a remote server, you will need to specify a rtc_configuration. See https://freddyaboulton.github.io/gradio-webrtc/deployment/
|
rtc_configuration: WebRTC configuration options. See https://developer.mozilla.org/en-US/docs/Web/API/RTCPeerConnection/RTCPeerConnection . If running the demo on a remote server, you will need to specify a rtc_configuration. See https://freddyaboulton.github.io/gradio-webrtc/deployment/
|
||||||
|
server_rtc_configuration: Optional dictionary for RTCPeerConnection configuration on the server side. Note
|
||||||
|
that setting iceServers to be an empty list will mean no ICE servers will be used in the server.
|
||||||
track_constraints: Media track constraints for WebRTC. For example, to set video height, width use {"width": {"exact": 800}, "height": {"exact": 600}, "aspectRatio": {"exact": 1.33333}}
|
track_constraints: Media track constraints for WebRTC. For example, to set video height, width use {"width": {"exact": 800}, "height": {"exact": 600}, "aspectRatio": {"exact": 1.33333}}
|
||||||
time_limit: Maximum duration in seconds for recording.
|
time_limit: Maximum duration in seconds for recording.
|
||||||
allow_extra_tracks: Allow tracks not supported by the modality. For example, a peer connection with an audio track would be allowed even if modality is 'video', which normally throws a ``ValueError`` exception.
|
allow_extra_tracks: Allow tracks not supported by the modality. For example, a peer connection with an audio track would be allowed even if modality is 'video', which normally throws a ``ValueError`` exception.
|
||||||
@@ -134,6 +134,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
|||||||
self.mirror_webcam = mirror_webcam
|
self.mirror_webcam = mirror_webcam
|
||||||
self.concurrency_limit = 1
|
self.concurrency_limit = 1
|
||||||
self.rtc_configuration = rtc_configuration
|
self.rtc_configuration = rtc_configuration
|
||||||
|
self.server_rtc_configuration = self.convert_to_aiortc_format(
|
||||||
|
server_rtc_configuration
|
||||||
|
)
|
||||||
self.allow_extra_tracks = allow_extra_tracks
|
self.allow_extra_tracks = allow_extra_tracks
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.modality = modality
|
self.modality = modality
|
||||||
@@ -238,7 +241,6 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
|||||||
inputs = list(inputs)
|
inputs = list(inputs)
|
||||||
|
|
||||||
async def handler(webrtc_id: str, *args):
|
async def handler(webrtc_id: str, *args):
|
||||||
print("webrtc_id", webrtc_id)
|
|
||||||
async for next_outputs in self.output_stream(webrtc_id):
|
async for next_outputs in self.output_stream(webrtc_id):
|
||||||
yield fn(*args, *next_outputs.args) # type: ignore
|
yield fn(*args, *next_outputs.args) # type: ignore
|
||||||
|
|
||||||
@@ -366,13 +368,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
|||||||
@server
|
@server
|
||||||
async def turn(self, _):
|
async def turn(self, _):
|
||||||
try:
|
try:
|
||||||
if inspect.isfunction(self.rtc_configuration):
|
return await self.resolve_rtc_configuration()
|
||||||
if inspect.iscoroutinefunction(self.rtc_configuration):
|
|
||||||
return await self.rtc_configuration()
|
|
||||||
else:
|
|
||||||
return await anyio.to_thread.run_sync(self.rtc_configuration)
|
|
||||||
else:
|
|
||||||
return self.rtc_configuration or {}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from collections import defaultdict
|
|||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
Literal,
|
Literal,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@@ -16,11 +17,14 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from aiortc import (
|
from aiortc import (
|
||||||
|
RTCConfiguration,
|
||||||
RTCIceCandidate,
|
RTCIceCandidate,
|
||||||
|
RTCIceServer,
|
||||||
RTCPeerConnection,
|
RTCPeerConnection,
|
||||||
RTCSessionDescription,
|
RTCSessionDescription,
|
||||||
)
|
)
|
||||||
from aiortc.contrib.media import MediaRelay # type: ignore
|
from aiortc.contrib.media import MediaRelay # type: ignore
|
||||||
|
from anyio.to_thread import run_sync
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from fastrtc.tracks import (
|
from fastrtc.tracks import (
|
||||||
@@ -37,6 +41,7 @@ from fastrtc.tracks import (
|
|||||||
from fastrtc.utils import (
|
from fastrtc.utils import (
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
Context,
|
Context,
|
||||||
|
RTCConfigurationCallable,
|
||||||
create_message,
|
create_message,
|
||||||
webrtc_error_handler,
|
webrtc_error_handler,
|
||||||
)
|
)
|
||||||
@@ -80,12 +85,32 @@ class WebRTCConnectionMixin:
|
|||||||
self.modality: Literal["video", "audio", "audio-video"]
|
self.modality: Literal["video", "audio", "audio-video"]
|
||||||
self.mode: Literal["send", "receive", "send-receive"]
|
self.mode: Literal["send", "receive", "send-receive"]
|
||||||
self.allow_extra_tracks: bool
|
self.allow_extra_tracks: bool
|
||||||
|
self.rtc_configuration: dict[str, Any] | None | RTCConfigurationCallable | None
|
||||||
|
self.server_rtc_configuration: RTCConfiguration | None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
||||||
await asyncio.sleep(time_limit)
|
await asyncio.sleep(time_limit)
|
||||||
await pc.close()
|
await pc.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_to_aiortc_format(
|
||||||
|
rtc_configuration: dict[str, Any] | None,
|
||||||
|
) -> RTCConfiguration | None:
|
||||||
|
rtc_config = rtc_configuration
|
||||||
|
if rtc_config is not None:
|
||||||
|
rtc_config = RTCConfiguration(
|
||||||
|
iceServers=[
|
||||||
|
RTCIceServer(
|
||||||
|
urls=server["urls"],
|
||||||
|
username=server.get("username"),
|
||||||
|
credential=server.get("credential"),
|
||||||
|
)
|
||||||
|
for server in rtc_config.get("iceServers", [])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return rtc_config
|
||||||
|
|
||||||
async def connection_timeout(
|
async def connection_timeout(
|
||||||
self,
|
self,
|
||||||
pc: RTCPeerConnection,
|
pc: RTCPeerConnection,
|
||||||
@@ -148,6 +173,15 @@ class WebRTCConnectionMixin:
|
|||||||
|
|
||||||
return set_outputs
|
return set_outputs
|
||||||
|
|
||||||
|
async def resolve_rtc_configuration(self) -> dict[str, Any] | None:
|
||||||
|
if inspect.isfunction(self.rtc_configuration):
|
||||||
|
if inspect.iscoroutinefunction(self.rtc_configuration):
|
||||||
|
return await self.rtc_configuration()
|
||||||
|
else:
|
||||||
|
return await run_sync(self.rtc_configuration)
|
||||||
|
else:
|
||||||
|
return cast(dict[str, Any], self.rtc_configuration) or {}
|
||||||
|
|
||||||
async def handle_offer(self, body, set_outputs):
|
async def handle_offer(self, body, set_outputs):
|
||||||
logger.debug("Starting to handle offer")
|
logger.debug("Starting to handle offer")
|
||||||
logger.debug("Offer body %s", body)
|
logger.debug("Offer body %s", body)
|
||||||
@@ -169,13 +203,9 @@ class WebRTCConnectionMixin:
|
|||||||
pc = self.pcs[webrtc_id]
|
pc = self.pcs[webrtc_id]
|
||||||
if pc.connectionState != "closed":
|
if pc.connectionState != "closed":
|
||||||
try:
|
try:
|
||||||
# Parse the candidate string from the browser
|
|
||||||
candidate_str = body["candidate"].get("candidate", "")
|
candidate_str = body["candidate"].get("candidate", "")
|
||||||
|
|
||||||
# Example format: "candidate:2393089663 1 udp 2122260223 192.168.86.60 63692 typ host generation 0 ufrag LkZb network-id 1 network-cost 10"
|
# Example format: "candidate:2393089663 1 udp 2122260223 192.168.86.60 63692 typ host generation 0 ufrag LkZb network-id 1 network-cost 10"
|
||||||
# We need to parse this string to extract the required fields
|
|
||||||
|
|
||||||
# Parse the candidate string
|
|
||||||
parts = candidate_str.split()
|
parts = candidate_str.split()
|
||||||
if len(parts) >= 10 and parts[0].startswith("candidate:"):
|
if len(parts) >= 10 and parts[0].startswith("candidate:"):
|
||||||
foundation = parts[0].split(":", 1)[1]
|
foundation = parts[0].split(":", 1)[1]
|
||||||
@@ -253,7 +283,7 @@ class WebRTCConnectionMixin:
|
|||||||
|
|
||||||
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
|
offer = RTCSessionDescription(sdp=body["sdp"], type=body["type"])
|
||||||
|
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection(configuration=self.server_rtc_configuration)
|
||||||
self.pcs[body["webrtc_id"]] = pc
|
self.pcs[body["webrtc_id"]] = pc
|
||||||
|
|
||||||
if isinstance(self.event_handler, StreamHandlerImpl):
|
if isinstance(self.event_handler, StreamHandlerImpl):
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from fastrtc import (
|
|||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
ReplyOnPause,
|
ReplyOnPause,
|
||||||
Stream,
|
Stream,
|
||||||
|
get_cloudflare_turn_credentials,
|
||||||
get_cloudflare_turn_credentials_async,
|
get_cloudflare_turn_credentials_async,
|
||||||
get_stt_model,
|
get_stt_model,
|
||||||
)
|
)
|
||||||
@@ -76,6 +77,7 @@ stream = Stream(
|
|||||||
additional_outputs_handler=lambda *a: (a[2], a[3]),
|
additional_outputs_handler=lambda *a: (a[2], a[3]),
|
||||||
concurrency_limit=20 if get_space() else None,
|
concurrency_limit=20 if get_space() else None,
|
||||||
rtc_configuration=get_cloudflare_turn_credentials_async,
|
rtc_configuration=get_cloudflare_turn_credentials_async,
|
||||||
|
server_rtc_configuration=get_cloudflare_turn_credentials(ttl=36_000),
|
||||||
)
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ Cloudflare also offers a managed TURN server with [Cloudflare Calls](https://www
|
|||||||
Cloudflare and Hugging Face have partnered to allow you to stream 10gb of WebRTC traffic per month for free with a Hugging Face account!
|
Cloudflare and Hugging Face have partnered to allow you to stream 10gb of WebRTC traffic per month for free with a Hugging Face account!
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from fastrtc import Stream, get_cloudflare_turn_credentials_async
|
from fastrtc import Stream, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials
|
||||||
|
|
||||||
# Make sure the HF_TOKEN environment variable is set
|
# Make sure the HF_TOKEN environment variable is set
|
||||||
# Or pass in a callable with all arguments set
|
# Or pass in a callable with all arguments set
|
||||||
@@ -26,11 +26,14 @@ async def get_credentials():
|
|||||||
stream = Stream(
|
stream = Stream(
|
||||||
handler=...,
|
handler=...,
|
||||||
rtc_configuration=get_credentials,
|
rtc_configuration=get_credentials,
|
||||||
|
server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000)
|
||||||
modality="audio",
|
modality="audio",
|
||||||
mode="send-receive",
|
mode="send-receive",
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
!!! tip
|
||||||
|
Setting an rtc configuration in the server is recommended but not required. It's a good practice to set short lived credentials in the client (default `ttl` value of 10 minutes when calling `get_cloudflare_turn_credentials*`) but you can share the same credentials between server and client.
|
||||||
|
|
||||||
### With a Cloudflare API Token
|
### With a Cloudflare API Token
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,8 @@ This class encapsulates the logic for handling real-time communication (WebRTC)
|
|||||||
| `additional_outputs_handler` | `Callable \| None` | Handler for additional outputs. |
|
| `additional_outputs_handler` | `Callable \| None` | Handler for additional outputs. |
|
||||||
| `track_constraints` | `dict[str, Any] \| None` | Constraints for media tracks (e.g., resolution). |
|
| `track_constraints` | `dict[str, Any] \| None` | Constraints for media tracks (e.g., resolution). |
|
||||||
| `webrtc_component` | `WebRTC` | The underlying Gradio WebRTC component instance. |
|
| `webrtc_component` | `WebRTC` | The underlying Gradio WebRTC component instance. |
|
||||||
| `rtc_configuration` | `dict[str, Any] \| None` | Configuration for the RTCPeerConnection (e.g., ICE servers). |
|
| `rtc_configuration` | `dict[str, Any] \| None \| Callable` | Configuration for the RTCPeerConnection (e.g., ICE servers). |
|
||||||
|
| `server_rtc_configuration` | `dict[str, Any] \| None` | Configuration for the RTCPeerConnection (e.g., ICE servers) to be used in the server |
|
||||||
| `_ui` | `Blocks` | The Gradio Blocks UI instance. |
|
| `_ui` | `Blocks` | The Gradio Blocks UI instance. |
|
||||||
|
|
||||||
## Methods
|
## Methods
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "fastrtc"
|
name = "fastrtc"
|
||||||
version = "0.0.21"
|
version = "0.0.22.rc2"
|
||||||
description = "The realtime communication library for Python"
|
description = "The realtime communication library for Python"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@@ -28,17 +28,17 @@ keywords = [
|
|||||||
"video processing",
|
"video processing",
|
||||||
"gradio-custom-component",
|
"gradio-custom-component",
|
||||||
]
|
]
|
||||||
# Add dependencies here
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"gradio>=4.0,<6.0",
|
"gradio>=4.0,<6.0",
|
||||||
"aiortc",
|
"aiortc",
|
||||||
|
"aioice>=0.10.1",
|
||||||
"audioop-lts;python_version>='3.13'",
|
"audioop-lts;python_version>='3.13'",
|
||||||
"librosa",
|
"librosa",
|
||||||
"numpy>=2.0.2", # because of librosa
|
"numpy>=2.0.2", # because of librosa
|
||||||
"numba>=0.60.0",
|
"numba>=0.60.0",
|
||||||
"standard-aifc;python_version>='3.13'",
|
"standard-aifc;python_version>='3.13'",
|
||||||
"standard-sunau;python_version>='3.13'",
|
"standard-sunau;python_version>='3.13'",
|
||||||
]
|
] # Add dependencies here
|
||||||
classifiers = [
|
classifiers = [
|
||||||
'Development Status :: 3 - Alpha',
|
'Development Status :: 3 - Alpha',
|
||||||
'Operating System :: OS Independent',
|
'Operating System :: OS Independent',
|
||||||
|
|||||||
Reference in New Issue
Block a user