import logging from pathlib import Path from typing import ( Any, AsyncContextManager, Callable, Literal, TypedDict, cast, ) import gradio as gr from fastapi import FastAPI, Request, WebSocket from fastapi.responses import HTMLResponse from gradio import Blocks from gradio.components.base import Component from pydantic import BaseModel from typing_extensions import NotRequired from .tracks import HandlerType, StreamHandlerImpl from .webrtc import WebRTC from .webrtc_connection_mixin import WebRTCConnectionMixin from .websocket import WebSocketHandler logger = logging.getLogger(__name__) curr_dir = Path(__file__).parent class Body(BaseModel): sdp: str type: str webrtc_id: str class UIArgs(TypedDict): title: NotRequired[str] """Title of the demo""" icon: NotRequired[str] """Icon to display on the button instead of the wave animation. The icon should be a path/url to a .svg/.png/.jpeg file.""" icon_button_color: NotRequired[str] """Color of the icon button. Default is var(--color-accent) of the demo theme.""" pulse_color: NotRequired[str] """Color of the pulse animation. Default is var(--color-accent) of the demo theme.""" class Stream(WebRTCConnectionMixin): def __init__( self, handler: HandlerType, *, additional_outputs_handler: Callable | None = None, mode: Literal["send-receive", "receive", "send"] = "send-receive", modality: Literal["video", "audio", "audio-video"] = "video", concurrency_limit: int | None | Literal["default"] = "default", time_limit: float | None = None, rtp_params: dict[str, Any] | None = None, rtc_configuration: dict[str, Any] | None = None, additional_inputs: list[Component] | None = None, additional_outputs: list[Component] | None = None, ui_args: UIArgs | None = None, ): self.mode = mode self.modality = modality self.rtp_params = rtp_params self.event_handler = handler self.concurrency_limit = cast( (int | float), 1 if concurrency_limit in ["default", None] else concurrency_limit, ) self.time_limit = time_limit self.additional_output_components = additional_outputs self.additional_input_components = additional_inputs self.additional_outputs_handler = additional_outputs_handler self.rtc_configuration = rtc_configuration self._ui = self._generate_default_ui(ui_args) self._ui.launch = self._wrap_gradio_launch(self._ui.launch) def mount(self, app: FastAPI): app.router.post("/webrtc/offer")(self.offer) app.router.websocket("/telephone/handler")(self.telephone_handler) app.router.post("/telephone/incoming")(self.handle_incoming_call) app.router.websocket("/websocket/offer")(self.websocket_offer) lifespan = self._inject_startup_message(app.router.lifespan_context) app.router.lifespan_context = lifespan @staticmethod def print_error(env: Literal["colab", "spaces"]): import click print( click.style("ERROR", fg="red") + f":\t Running in {env} is not possible without providing a valid rtc_configuration. " + "See " + click.style("https://fastrtc.org/deployment/", fg="cyan") + " for more information." ) raise RuntimeError( f"Running in {env} is not possible without providing a valid rtc_configuration. " + "See https://fastrtc.org/deployment/ for more information." ) def _check_colab_or_spaces(self): from gradio.utils import colab_check, get_space if colab_check() and not self.rtc_configuration: self.print_error("colab") if get_space() and not self.rtc_configuration: self.print_error("spaces") def _wrap_gradio_launch(self, callable): import contextlib def wrapper(*args, **kwargs): lifespan = kwargs.get("app_kwargs", {}).get("lifespan", None) @contextlib.asynccontextmanager async def new_lifespan(app: FastAPI): if lifespan is None: self._check_colab_or_spaces() yield else: async with lifespan(app): self._check_colab_or_spaces() yield if "app_kwargs" not in kwargs: kwargs["app_kwargs"] = {} kwargs["app_kwargs"]["lifespan"] = new_lifespan return callable(*args, **kwargs) return wrapper def _inject_startup_message( self, lifespan: Callable[[FastAPI], AsyncContextManager] | None = None ): import contextlib import click def print_startup_message(): self._check_colab_or_spaces() print( click.style("INFO", fg="green") + ":\t Visit " + click.style("https://fastrtc.org/userguide/api/", fg="cyan") + " for WebRTC or Websocket API docs." ) @contextlib.asynccontextmanager async def new_lifespan(app: FastAPI): if lifespan is None: print_startup_message() yield else: async with lifespan(app): print_startup_message() yield return new_lifespan def _generate_default_ui( self, ui_args: UIArgs | None = None, ): ui_args = ui_args or {} same_components = [] additional_input_components = self.additional_input_components or [] additional_output_components = self.additional_output_components or [] if additional_output_components and not self.additional_outputs_handler: raise ValueError( "additional_outputs_handler must be provided if there are additional output components." ) if additional_input_components and additional_output_components: same_components = [ component for component in additional_input_components if component in additional_output_components ] for component in additional_output_components: if component not in same_components: same_components.append(component) if self.modality == "video" and self.mode == "receive": with gr.Blocks() as demo: gr.HTML( f"""

{ui_args.get("title", "Video Streaming (Powered by FastRTC ⚡️)")}

""" ) with gr.Row(): if additional_input_components: with gr.Column(): for component in additional_input_components: component.render() button = gr.Button("Start Stream", variant="primary") with gr.Column(): output_video = WebRTC( label="Video Stream", rtc_configuration=self.rtc_configuration, mode="receive", modality="video", ) for component in additional_output_components: if component not in same_components: component.render() output_video.stream( fn=self.event_handler, inputs=self.additional_input_components, outputs=[output_video], trigger=button.click, time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore ) if additional_output_components: assert self.additional_outputs_handler output_video.on_additional_outputs( self.additional_outputs_handler, outputs=additional_output_components, ) elif self.modality == "video" and self.mode == "send": with gr.Blocks() as demo: gr.HTML( f"""

{ui_args.get("title", "Video Streaming (Powered by FastRTC ⚡️)")}

""" ) with gr.Row(): if additional_input_components: with gr.Column(): for component in additional_input_components: component.render() with gr.Column(): output_video = WebRTC( label="Video Stream", rtc_configuration=self.rtc_configuration, mode="send", modality="video", ) for component in additional_output_components: if component not in same_components: component.render() output_video.stream( fn=self.event_handler, inputs=[output_video] + additional_input_components, outputs=[output_video], time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore ) if additional_output_components: assert self.additional_outputs_handler output_video.on_additional_outputs( self.additional_outputs_handler, outputs=additional_output_components, ) elif self.modality == "video" and self.mode == "send-receive": css = """.my-group {max-width: 600px !important; max-height: 600 !important;} .my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" with gr.Blocks(css=css) as demo: gr.HTML( f"""

{ui_args.get("title", "Video Streaming (Powered by FastRTC ⚡️)")}

""" ) with gr.Column(elem_classes=["my-column"]): with gr.Group(elem_classes=["my-group"]): image = WebRTC( label="Stream", rtc_configuration=self.rtc_configuration, mode="send-receive", modality="video", ) for component in additional_input_components: component.render() if additional_output_components: with gr.Column(): for component in additional_output_components: if component not in same_components: component.render() image.stream( fn=self.event_handler, inputs=[image] + additional_input_components, outputs=[image], time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore ) if additional_output_components: assert self.additional_outputs_handler image.on_additional_outputs( self.additional_outputs_handler, inputs=additional_output_components, outputs=additional_output_components, ) elif self.modality == "audio" and self.mode == "receive": with gr.Blocks() as demo: gr.HTML( f"""

{ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")}

""" ) with gr.Row(): with gr.Column(): for component in additional_input_components: component.render() button = gr.Button("Start Stream", variant="primary") if additional_output_components: with gr.Column(): output_video = WebRTC( label="Audio Stream", rtc_configuration=self.rtc_configuration, mode="receive", modality="audio", icon=ui_args.get("icon"), icon_button_color=ui_args.get("icon_button_color"), pulse_color=ui_args.get("pulse_color"), ) for component in additional_output_components: if component not in same_components: component.render() output_video.stream( fn=self.event_handler, inputs=self.additional_input_components, outputs=[output_video], trigger=button.click, time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore ) if additional_output_components: assert self.additional_outputs_handler output_video.on_additional_outputs( self.additional_outputs_handler, inputs=additional_output_components, outputs=additional_output_components, ) elif self.modality == "audio" and self.mode == "send": with gr.Blocks() as demo: gr.HTML( f"""

{ui_args.get("title", "Audio Streaming (Powered by WebRTC ⚡️)")}

""" ) with gr.Row(): with gr.Column(): with gr.Group(): image = WebRTC( label="Stream", rtc_configuration=self.rtc_configuration, mode="send-receive", modality="audio", ) for component in additional_input_components: if component not in same_components: component.render() if additional_output_components: with gr.Column(): for component in additional_output_components: component.render() image.stream( fn=self.event_handler, inputs=[image] + additional_input_components, outputs=[image], time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore ) if additional_output_components: assert self.additional_outputs_handler image.on_additional_outputs( self.additional_outputs_handler, inputs=additional_output_components, outputs=additional_output_components, ) elif self.modality == "audio" and self.mode == "send-receive": with gr.Blocks() as demo: gr.HTML( f"""

{ui_args.get("title", "Audio Streaming (Powered by WebRTC ⚡️)")}

""" ) with gr.Row(): with gr.Column(): with gr.Group(): image = WebRTC( label="Stream", rtc_configuration=self.rtc_configuration, mode="send-receive", modality="audio", icon=ui_args.get("icon"), icon_button_color=ui_args.get("icon_button_color"), pulse_color=ui_args.get("pulse_color"), ) for component in additional_input_components: if component not in same_components: component.render() if additional_output_components: with gr.Column(): for component in additional_output_components: component.render() image.stream( fn=self.event_handler, inputs=[image] + additional_input_components, outputs=[image], time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore ) if additional_output_components: assert self.additional_outputs_handler image.on_additional_outputs( self.additional_outputs_handler, inputs=additional_output_components, outputs=additional_output_components, ) elif self.modality == "audio-video" and self.mode == "send-receive": with gr.Blocks() as demo: gr.HTML( f"""

{ui_args.get("title", "Audio Streaming (Powered by WebRTC ⚡️)")}

""" ) with gr.Row(): with gr.Column(): with gr.Group(): image = WebRTC( label="Stream", rtc_configuration=self.rtc_configuration, mode="send-receive", modality="audio-video", icon=ui_args.get("icon"), icon_button_color=ui_args.get("icon_button_color"), pulse_color=ui_args.get("pulse_color"), ) for component in additional_input_components: if component not in same_components: component.render() if additional_output_components: with gr.Column(): for component in additional_output_components: component.render() image.stream( fn=self.event_handler, inputs=[image] + additional_input_components, outputs=[image], time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore ) if additional_output_components: assert self.additional_outputs_handler image.on_additional_outputs( self.additional_outputs_handler, inputs=additional_output_components, outputs=additional_output_components, ) else: raise ValueError(f"Invalid modality: {self.modality} and mode: {self.mode}") return demo @property def ui(self) -> Blocks: return self._ui @ui.setter def ui(self, blocks: Blocks): self._ui = blocks async def offer(self, body: Body): return await self.handle_offer( body.model_dump(), set_outputs=self.set_additional_outputs(body.webrtc_id) ) async def handle_incoming_call(self, request: Request): from twilio.twiml.voice_response import Connect, VoiceResponse response = VoiceResponse() response.say("Connecting to the AI assistant.") connect = Connect() connect.stream(url=f"wss://{request.url.hostname}/telephone/handler") response.append(connect) response.say("The call has been disconnected.") return HTMLResponse(content=str(response), media_type="application/xml") async def telephone_handler(self, websocket: WebSocket): handler = cast(StreamHandlerImpl, self.event_handler.copy()) handler.phone_mode = True async def set_handler(s: str, a: WebSocketHandler): if len(self.connections) >= self.concurrency_limit: await cast(WebSocket, a.websocket).send_json( { "status": "failed", "meta": { "error": "concurrency_limit_reached", "limit": self.concurrency_limit, }, } ) await websocket.close() return ws = WebSocketHandler( handler, set_handler, lambda s: None, lambda s: lambda a: None ) await ws.handle_websocket(websocket) async def websocket_offer(self, websocket: WebSocket): handler = cast(StreamHandlerImpl, self.event_handler.copy()) handler.phone_mode = False async def set_handler(s: str, a: WebSocketHandler): if len(self.connections) >= self.concurrency_limit: await cast(WebSocket, a.websocket).send_json( { "status": "failed", "meta": { "error": "concurrency_limit_reached", "limit": self.concurrency_limit, }, } ) await websocket.close() return self.connections[s] = [a] # type: ignore def clean_up(s): self.clean_up(s) ws = WebSocketHandler( handler, set_handler, clean_up, lambda s: self.set_additional_outputs(s) ) await ws.handle_websocket(websocket) def fastphone( self, token: str | None = None, host: str = "127.0.0.1", port: int = 8000, **kwargs, ): import secrets import threading import time import urllib.parse import click import httpx import uvicorn from gradio.networking import setup_tunnel from gradio.tunneling import CURRENT_TUNNELS from huggingface_hub import get_token app = FastAPI() self.mount(app) t = threading.Thread( target=uvicorn.run, args=(app,), kwargs={"host": host, "port": port, **kwargs}, ) t.start() url = setup_tunnel( host, port, share_token=secrets.token_urlsafe(32), share_server_address=None ) host = urllib.parse.urlparse(url).netloc URL = "https://api.fastrtc.org" r = httpx.post( URL + "/register", json={"url": host}, headers={"Authorization": token or get_token() or ""}, ) r.raise_for_status() data = r.json() code = f"{data['code']}" phone_number = data["phone"] reset_date = data["reset_date"] print( click.style("INFO", fg="green") + ":\t Your FastPhone is now live! Call " + click.style(phone_number, fg="cyan") + " and use code " + click.style(code, fg="cyan") + " to connect to your stream." ) minutes = str(int(data["time_remaining"] // 60)).zfill(2) seconds = str(int(data["time_remaining"] % 60)).zfill(2) print( click.style("INFO", fg="green") + ":\t You have " + click.style(f"{minutes}:{seconds}", fg="cyan") + " minutes remaining in your quota (Resetting on " + click.style(f"{reset_date}", fg="cyan") + ")" ) print( click.style("INFO", fg="green") + ":\t Visit " + click.style( "https://fastrtc.org/userguide/audio/#telephone-integration", fg="cyan", ) + " for information on making your handler compatible with phone usage." ) try: while True: time.sleep(0.1) except (KeyboardInterrupt, OSError): print( click.style("INFO", fg="green") + ":\t Keyboard interruption in main thread... closing server." ) r = httpx.post( URL + "/unregister", json={"url": host, "code": code}, headers={"Authorization": token or get_token() or ""}, ) t.join(timeout=5) for tunnel in CURRENT_TUNNELS: tunnel.kill()