Cloudflare turn integration (#264)

* Turn integration

* Add code:

* type hint

* Fix typehint

* add code

* format

* WIP

* trickle ice

* bump version

* Better docs

* Modify

* code

* Mute icon for whisper

* Add code

* llama 4 demo

* code

* OpenAI interruptions

* fix docs
This commit is contained in:
Freddy Boulton
2025-04-09 09:36:51 -04:00
committed by GitHub
parent f70b27bd41
commit 837330dcd8
37 changed files with 2914 additions and 780 deletions

View File

@@ -1,6 +1,10 @@
from .credentials import (
get_cloudflare_turn_credentials,
get_cloudflare_turn_credentials_async,
get_hf_turn_credentials,
get_hf_turn_credentials_async,
get_turn_credentials,
get_turn_credentials_async,
get_twilio_turn_credentials,
)
from .pause_detection import (
@@ -70,6 +74,10 @@ __all__ = [
"Warning",
"get_tts_model",
"KokoroTTSOptions",
"get_cloudflare_turn_credentials_async",
"get_hf_turn_credentials_async",
"get_turn_credentials_async",
"get_cloudflare_turn_credentials",
"wait_for_item",
"UIArgs",
"ModelOptions",

View File

@@ -1,29 +1,268 @@
import os
import warnings
from typing import Literal
import requests
import httpx
CLOUDFLARE_FASTRTC_TURN_URL = "https://turn.fastrtc.org/credentials"
async_httpx_client = httpx.AsyncClient()
def get_hf_turn_credentials(token=None):
def _format_response(response):
if response.is_success:
return response.json()
else:
raise Exception(
f"Failed to get TURN credentials: {response.status_code} {response.text}"
)
def get_hf_turn_credentials(token=None, ttl=600):
"""Retrieves TURN credentials from Hugging Face (deprecated).
This function fetches TURN server credentials using a Hugging Face token.
It is deprecated and `get_cloudflare_turn_credentials` should be used instead.
Args:
token (str, optional): Hugging Face API token. Defaults to None, in which
case the HF_TOKEN environment variable is used.
ttl (int, optional): Time-to-live for the credentials in seconds.
Defaults to 600.
Returns:
dict: A dictionary containing the TURN credentials.
Raises:
ValueError: If no token is provided and the HF_TOKEN environment variable
is not set.
Exception: If the request to the TURN server fails.
"""
warnings.warn(
"get_hf_turn_credentials is deprecated. Use get_cloudflare_turn_credentials instead.",
UserWarning,
)
if token is None:
token = os.getenv("HF_TOKEN")
credentials = requests.get(
"https://fastrtc-turn-server-login.hf.space/credentials",
headers={"X-HF-Access-Token": token},
if token is None:
raise ValueError(
"HF_TOKEN environment variable must be set or token must be provided to use get_hf_turn_credentials"
)
response = httpx.get(
CLOUDFLARE_FASTRTC_TURN_URL,
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
params={"ttl": ttl},
)
if not credentials.status_code == 200:
raise ValueError("Failed to get credentials from HF turn server")
return {
"iceServers": [
{
"urls": "turn:gradio-turn.com:80",
**credentials.json(),
return _format_response(response)
async def get_hf_turn_credentials_async(
token=None, ttl=600, client: httpx.AsyncClient | None = None
):
"""Asynchronously retrieves TURN credentials from Hugging Face (deprecated).
This function asynchronously fetches TURN server credentials using a Hugging Face
token. It is deprecated and `get_cloudflare_turn_credentials_async` should be
used instead.
Args:
token (str, optional): Hugging Face API token. Defaults to None, in which
case the HF_TOKEN environment variable is used.
ttl (int, optional): Time-to-live for the credentials in seconds.
Defaults to 600.
client (httpx.AsyncClient | None, optional): An existing httpx async client
to use for the request. If None, a default client is used. Defaults to None.
Returns:
dict: A dictionary containing the TURN credentials.
Raises:
ValueError: If no token is provided and the HF_TOKEN environment variable
is not set.
Exception: If the request to the TURN server fails.
"""
warnings.warn(
"get_hf_turn_credentials_async is deprecated. Use get_cloudflare_turn_credentials_async instead.",
UserWarning,
)
if client is None:
client = async_httpx_client
if token is None:
token = os.getenv("HF_TOKEN")
if token is None:
raise ValueError(
"HF_TOKEN environment variable must be set or token must be provided to use get_hf_turn_credentials"
)
async with client:
response = await client.get(
"https://turn.fastrtc.org/credentials",
headers={"Authorization": f"Bearer {token}"},
params={"ttl": ttl},
)
return _format_response(response)
def get_cloudflare_turn_credentials(
turn_key_id=None, turn_key_api_token=None, hf_token=None, ttl=600
):
"""Retrieves TURN credentials from Cloudflare or Hugging Face.
Fetches TURN server credentials either directly from Cloudflare using API keys
or via the Hugging Face TURN endpoint using an HF token. The HF token method
takes precedence if provided.
Args:
turn_key_id (str, optional): Cloudflare TURN key ID. Defaults to None,
in which case the CLOUDFLARE_TURN_KEY_ID environment variable is used.
turn_key_api_token (str, optional): Cloudflare TURN key API token.
Defaults to None, in which case the CLOUDFLARE_TURN_KEY_API_TOKEN
environment variable is used.
hf_token (str, optional): Hugging Face API token. If provided, this method
is used instead of Cloudflare keys. Defaults to None, in which case
the HF_TOKEN environment variable is used.
ttl (int, optional): Time-to-live for the credentials in seconds.
Defaults to 600.
Returns:
dict: A dictionary containing the TURN credentials (ICE servers).
Raises:
ValueError: If neither HF token nor Cloudflare keys (either as arguments
or environment variables) are provided.
Exception: If the request to the credential server fails.
"""
if hf_token is None:
hf_token = os.getenv("HF_TOKEN")
if hf_token is not None:
return httpx.get(
CLOUDFLARE_FASTRTC_TURN_URL,
headers={"Authorization": f"Bearer {hf_token}"},
params={"ttl": ttl},
).json()
else:
if turn_key_id is None or turn_key_api_token is None:
turn_key_id = os.getenv("CLOUDFLARE_TURN_KEY_ID")
turn_key_api_token = os.getenv("CLOUDFLARE_TURN_KEY_API_TOKEN")
if turn_key_id is None or turn_key_api_token is None:
raise ValueError(
"HF_TOKEN or CLOUDFLARE_TURN_KEY_ID and CLOUDFLARE_TURN_KEY_API_TOKEN must be set to use get_cloudflare_turn_credentials_sync"
)
response = httpx.post(
f"https://rtc.live.cloudflare.com/v1/turn/keys/{turn_key_id}/credentials/generate-ice-servers",
headers={
"Authorization": f"Bearer {turn_key_api_token}",
"Content-Type": "application/json",
},
]
}
json={"ttl": ttl},
)
if response.is_success:
return response.json()
else:
raise Exception(
f"Failed to get TURN credentials: {response.status_code} {response.text}"
)
async def get_cloudflare_turn_credentials_async(
turn_key_id=None,
turn_key_api_token=None,
hf_token=None,
ttl=600,
client: httpx.AsyncClient | None = None,
):
"""Asynchronously retrieves TURN credentials from Cloudflare or Hugging Face.
Asynchronously fetches TURN server credentials either directly from Cloudflare
using API keys or via the Hugging Face TURN endpoint using an HF token. The HF
token method takes precedence if provided.
Args:
turn_key_id (str, optional): Cloudflare TURN key ID. Defaults to None,
in which case the CLOUDFLARE_TURN_KEY_ID environment variable is used.
turn_key_api_token (str, optional): Cloudflare TURN key API token.
Defaults to None, in which case the CLOUDFLARE_TURN_KEY_API_TOKEN
environment variable is used.
hf_token (str, optional): Hugging Face API token. If provided, this method
is used instead of Cloudflare keys. Defaults to None, in which case
the HF_TOKEN environment variable is used.
ttl (int, optional): Time-to-live for the credentials in seconds.
Defaults to 600.
client (httpx.AsyncClient | None, optional): An existing httpx async client
to use for the request. If None, a new client is created per request.
Defaults to None.
Returns:
dict: A dictionary containing the TURN credentials (ICE servers).
Raises:
ValueError: If neither HF token nor Cloudflare keys (either as arguments
or environment variables) are provided.
Exception: If the request to the credential server fails.
"""
if client is None:
client = async_httpx_client
if hf_token is None:
hf_token = os.getenv("HF_TOKEN", "").strip()
if hf_token is not None:
async with httpx.AsyncClient() as client:
response = await client.get(
CLOUDFLARE_FASTRTC_TURN_URL,
headers={"Authorization": f"Bearer {hf_token}"},
params={"ttl": ttl},
)
return _format_response(response)
else:
if turn_key_id is None or turn_key_api_token is None:
turn_key_id = os.getenv("CLOUDFLARE_TURN_KEY_ID")
turn_key_api_token = os.getenv("CLOUDFLARE_TURN_KEY_API_TOKEN")
if turn_key_id is None or turn_key_api_token is None:
raise ValueError(
"HF_TOKEN or CLOUDFLARE_TURN_KEY_ID and CLOUDFLARE_TURN_KEY_API_TOKEN must be set to use get_cloudflare_turn_credentials"
)
async with httpx.AsyncClient() as client:
response = await client.post(
f"https://rtc.live.cloudflare.com/v1/turn/keys/{turn_key_id}/credentials/generate-ice-servers",
headers={
"Authorization": f"Bearer {turn_key_api_token}",
"Content-Type": "application/json",
},
json={"ttl": ttl},
)
if response.is_success:
return response.json()
else:
raise Exception(
f"Failed to get TURN credentials: {response.status_code} {response.text}"
)
def get_twilio_turn_credentials(twilio_sid=None, twilio_token=None):
"""Retrieves TURN credentials from Twilio.
Uses the Twilio REST API to generate temporary TURN credentials. Requires
the `twilio` package to be installed.
Args:
twilio_sid (str, optional): Twilio Account SID. Defaults to None, in which
case the TWILIO_ACCOUNT_SID environment variable is used.
twilio_token (str, optional): Twilio Auth Token. Defaults to None, in which
case the TWILIO_AUTH_TOKEN environment variable is used.
Returns:
dict: A dictionary containing the TURN credentials formatted for WebRTC,
including 'iceServers' and 'iceTransportPolicy'.
Raises:
ImportError: If the `twilio` package is not installed.
ValueError: If Twilio credentials (SID and token) are not provided either
as arguments or environment variables.
TwilioRestException: If the Twilio API request fails.
"""
try:
from twilio.rest import Client
except ImportError:
@@ -43,10 +282,105 @@ def get_twilio_turn_credentials(twilio_sid=None, twilio_token=None):
}
def get_turn_credentials(method: Literal["hf", "twilio"] = "hf", **kwargs):
def get_turn_credentials(
method: Literal["hf", "twilio", "cloudflare"] = "cloudflare", **kwargs
):
"""Retrieves TURN credentials from the specified provider.
Acts as a dispatcher function to call the appropriate credential retrieval
function based on the method specified.
Args:
method (Literal["hf", "twilio", "cloudflare"], optional): The provider
to use. 'hf' uses the deprecated Hugging Face endpoint. 'cloudflare'
uses either Cloudflare keys or the HF endpoint. 'twilio' uses the
Twilio API. Defaults to "cloudflare".
**kwargs: Additional keyword arguments passed directly to the underlying
provider-specific function (e.g., `token`, `ttl` for 'hf';
`twilio_sid`, `twilio_token` for 'twilio'; `turn_key_id`,
`turn_key_api_token`, `hf_token`, `ttl` for 'cloudflare').
Returns:
dict: A dictionary containing the TURN credentials from the chosen provider.
Raises:
ValueError: If an invalid method is specified.
Also raises exceptions from the underlying provider functions (see their
docstrings).
"""
if method == "hf":
return get_hf_turn_credentials(**kwargs)
warnings.warn(
"Method 'hf' is deprecated. Use 'cloudflare' instead.", UserWarning
)
# Ensure only relevant kwargs are passed
hf_kwargs = {k: v for k, v in kwargs.items() if k in ["token", "ttl"]}
return get_hf_turn_credentials(**hf_kwargs)
elif method == "cloudflare":
# Ensure only relevant kwargs are passed
cf_kwargs = {
k: v
for k, v in kwargs.items()
if k in ["turn_key_id", "turn_key_api_token", "hf_token", "ttl"]
}
return get_cloudflare_turn_credentials(**cf_kwargs)
elif method == "twilio":
return get_twilio_turn_credentials(**kwargs)
# Ensure only relevant kwargs are passed
twilio_kwargs = {
k: v for k, v in kwargs.items() if k in ["twilio_sid", "twilio_token"]
}
return get_twilio_turn_credentials(**twilio_kwargs)
else:
raise ValueError("Invalid method. Must be 'hf' or 'twilio'")
raise ValueError("Invalid method. Must be 'hf', 'twilio', or 'cloudflare'")
async def get_turn_credentials_async(
method: Literal["hf", "twilio", "cloudflare"] = "cloudflare", **kwargs
):
"""Asynchronously retrieves TURN credentials from the specified provider.
Acts as an async dispatcher function to call the appropriate async credential
retrieval function based on the method specified.
Args:
method (Literal["hf", "twilio", "cloudflare"], optional): The provider
to use. 'hf' uses the deprecated Hugging Face endpoint. 'cloudflare'
uses either Cloudflare keys or the HF endpoint. 'twilio' is not
supported asynchronously by this function yet. Defaults to "cloudflare".
**kwargs: Additional keyword arguments passed directly to the underlying
provider-specific async function (e.g., `token`, `ttl`, `client` for 'hf';
`turn_key_id`, `turn_key_api_token`, `hf_token`, `ttl`, `client` for
'cloudflare').
Returns:
dict: A dictionary containing the TURN credentials from the chosen provider.
Raises:
ValueError: If an invalid or unsupported method is specified (currently
'twilio' is not supported asynchronously here).
NotImplementedError: If method 'twilio' is requested.
Also raises exceptions from the underlying provider functions (see their
docstrings).
"""
if method == "hf":
warnings.warn(
"Method 'hf' is deprecated. Use 'cloudflare' instead.", UserWarning
)
# Ensure only relevant kwargs are passed
hf_kwargs = {k: v for k, v in kwargs.items() if k in ["token", "ttl", "client"]}
return await get_hf_turn_credentials_async(**hf_kwargs)
elif method == "cloudflare":
# Ensure only relevant kwargs are passed
cf_kwargs = {
k: v
for k, v in kwargs.items()
if k in ["turn_key_id", "turn_key_api_token", "hf_token", "ttl", "client"]
}
return await get_cloudflare_turn_credentials_async(**cf_kwargs)
elif method == "twilio":
# Twilio client library doesn't have a standard async interface for this.
# You might need to run the sync version in an executor or use a different library.
raise NotImplementedError(
"Async retrieval for Twilio credentials is not implemented."
)
else:
raise ValueError("Invalid method. Must be 'hf', 'twilio', or 'cloudflare'")

View File

@@ -1,3 +1,4 @@
import inspect
import logging
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
@@ -9,6 +10,7 @@ from typing import (
cast,
)
import anyio
import gradio as gr
from fastapi import FastAPI, Request, WebSocket
from fastapi.responses import HTMLResponse
@@ -18,6 +20,7 @@ from pydantic import BaseModel
from typing_extensions import NotRequired
from .tracks import HandlerType, StreamHandlerImpl
from .utils import RTCConfigurationCallable
from .webrtc import WebRTC
from .webrtc_connection_mixin import WebRTCConnectionMixin
from .websocket import WebSocketHandler
@@ -98,7 +101,7 @@ class Stream(WebRTCConnectionMixin):
time_limit: float | None = None,
allow_extra_tracks: bool = False,
rtp_params: dict[str, Any] | None = None,
rtc_configuration: dict[str, Any] | None = None,
rtc_configuration: RTCConfigurationCallable | None = None,
track_constraints: dict[str, Any] | None = None,
additional_inputs: list[Component] | None = None,
additional_outputs: list[Component] | None = None,
@@ -116,7 +119,7 @@ class Stream(WebRTCConnectionMixin):
time_limit: Maximum execution time for the handler function in seconds.
allow_extra_tracks: If True, allows connections with tracks not matching the modality.
rtp_params: Optional dictionary of RTP encoding parameters.
rtc_configuration: Optional dictionary for RTCPeerConnection configuration (e.g., ICE servers).
rtc_configuration: Optional Callable or dictionary for RTCPeerConnection configuration (e.g., ICE servers).
Required when deploying on Colab or Spaces.
track_constraints: Optional dictionary of constraints for media tracks (e.g., resolution, frame rate).
additional_inputs: Optional list of extra Gradio input components.
@@ -749,6 +752,15 @@ class Stream(WebRTCConnectionMixin):
body.model_dump(), set_outputs=self.set_additional_outputs(body.webrtc_id)
)
async def get_rtc_configuration(self):
if inspect.isfunction(self.rtc_configuration):
if inspect.iscoroutinefunction(self.rtc_configuration):
return await self.rtc_configuration()
else:
return anyio.to_thread.run_sync(self.rtc_configuration) # type: ignore
else:
return self.rtc_configuration
async def handle_incoming_call(self, request: Request):
"""
Handle incoming telephone calls (e.g., via Twilio).

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ import json
import logging
import tempfile
import traceback
from collections.abc import Callable
from collections.abc import Callable, Coroutine
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Any, Literal, Protocol, TypedDict, cast
@@ -486,3 +486,15 @@ async def wait_for_item(queue: asyncio.Queue, timeout: float = 0.1) -> Any:
return await asyncio.wait_for(queue.get(), timeout=timeout)
except (TimeoutError, asyncio.TimeoutError):
return None
RTCConfigurationCallable = (
Callable[[], dict[str, Any]]
| Callable[[], Coroutine[dict[str, Any], Any, dict[str, Any]]]
| Callable[[str | None, str | None, str | None], dict[str, Any]]
| Callable[
[str | None, str | None, str | None],
Coroutine[dict[str, Any], Any, dict[str, Any]],
]
| dict[str, Any]
)

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import inspect
import logging
from collections.abc import Callable, Iterable, Sequence
from typing import (
@@ -14,6 +15,8 @@ from typing import (
cast,
)
import anyio
import anyio.to_thread
from gradio import wasm_utils
from gradio.components.base import Component, server
from gradio_client import handle_file
@@ -26,6 +29,7 @@ from .tracks import (
VideoEventHandler,
VideoStreamHandler,
)
from .utils import RTCConfigurationCallable
from .webrtc_connection_mixin import WebRTCConnectionMixin
if TYPE_CHECKING:
@@ -77,7 +81,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
render: bool = True,
key: int | str | None = None,
mirror_webcam: bool = True,
rtc_configuration: dict[str, Any] | None = None,
rtc_configuration: dict[str, Any] | None | RTCConfigurationCallable = None,
track_constraints: dict[str, Any] | None = None,
time_limit: float | None = None,
allow_extra_tracks: bool = False,
@@ -359,6 +363,19 @@ class WebRTC(Component, WebRTCConnectionMixin):
concurrency_id=concurrency_id,
)
@server
async def turn(self, _):
try:
if inspect.isfunction(self.rtc_configuration):
if inspect.iscoroutinefunction(self.rtc_configuration):
return await self.rtc_configuration()
else:
return await anyio.to_thread.run_sync(self.rtc_configuration)
else:
return self.rtc_configuration or {}
except Exception as e:
return {"error": str(e)}
@server
async def offer(self, body):
return await self.handle_offer(