diff --git a/backend/fastrtc/__init__.py b/backend/fastrtc/__init__.py index 4c646e5..797c1fb 100644 --- a/backend/fastrtc/__init__.py +++ b/backend/fastrtc/__init__.py @@ -21,6 +21,7 @@ from .tracks import ( AudioVideoStreamHandler, StreamHandler, VideoEmitType, + VideoStreamHandler, ) from .utils import ( AdditionalOutputs, @@ -73,4 +74,5 @@ __all__ = [ "PauseDetectionModel", "get_silero_model", "SileroVadOptions", + "VideoStreamHandler", ] diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py index b120f91..0663ad9 100644 --- a/backend/fastrtc/stream.py +++ b/backend/fastrtc/stream.py @@ -46,6 +46,11 @@ class UIArgs(TypedDict): """Color of the pulse animation. Default is var(--color-accent) of the demo theme.""" icon_radius: NotRequired[int] """Border radius of the icon button expressed as a percentage of the button size. Default is 50%.""" + send_input_on: NotRequired[Literal["submit", "change"]] + """When to send the input to the handler. Default is "change". + If "submit", the input will be sent when the submit event is triggered by the user. + If "change", the input will be sent whenever the user changes the input value. + """ class Stream(WebRTCConnectionMixin): @@ -229,6 +234,7 @@ class Stream(WebRTCConnectionMixin): trigger=button.click, time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore + send_input_on=ui_args.get("send_input_on", "change"), ) if additional_output_components: assert self.additional_outputs_handler @@ -275,6 +281,7 @@ class Stream(WebRTCConnectionMixin): outputs=[output_video], time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore + send_input_on=ui_args.get("send_input_on", "change"), ) if additional_output_components: assert self.additional_outputs_handler @@ -325,6 +332,7 @@ class Stream(WebRTCConnectionMixin): outputs=[image], time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore + send_input_on=ui_args.get("send_input_on", "change"), ) if additional_output_components: assert self.additional_outputs_handler @@ -377,6 +385,7 @@ class Stream(WebRTCConnectionMixin): trigger=button.click, time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore + send_input_on=ui_args.get("send_input_on", "change"), ) if additional_output_components: assert self.additional_outputs_handler @@ -428,6 +437,7 @@ class Stream(WebRTCConnectionMixin): outputs=[image], time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore + send_input_on=ui_args.get("send_input_on", "change"), ) if additional_output_components: assert self.additional_outputs_handler @@ -480,6 +490,7 @@ class Stream(WebRTCConnectionMixin): outputs=[image], time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore + send_input_on=ui_args.get("send_input_on", "change"), ) if additional_output_components: assert self.additional_outputs_handler @@ -489,7 +500,9 @@ class Stream(WebRTCConnectionMixin): outputs=additional_output_components, ) elif self.modality == "audio-video" and self.mode == "send-receive": - with gr.Blocks() as demo: + css = """.my-group {max-width: 600px !important; max-height: 600 !important;} + .my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" + with gr.Blocks(css=css) as demo: gr.HTML( f"""

@@ -506,8 +519,8 @@ class Stream(WebRTCConnectionMixin): """ ) with gr.Row(): - with gr.Column(): - with gr.Group(): + with gr.Column(elem_classes=["my-column"]): + with gr.Group(elem_classes=["my-group"]): image = WebRTC( label="Stream", rtc_configuration=self.rtc_configuration, @@ -532,6 +545,7 @@ class Stream(WebRTCConnectionMixin): outputs=[image], time_limit=self.time_limit, concurrency_limit=self.concurrency_limit, # type: ignore + send_input_on=ui_args.get("send_input_on", "change"), ) if additional_output_components: assert self.additional_outputs_handler diff --git a/backend/fastrtc/tracks.py b/backend/fastrtc/tracks.py index 0a9c5aa..9be63e2 100644 --- a/backend/fastrtc/tracks.py +++ b/backend/fastrtc/tracks.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import fractions import functools import inspect import logging @@ -11,10 +12,12 @@ import time import traceback from abc import ABC, abstractmethod from collections.abc import Callable +from dataclasses import dataclass from typing import ( Any, Generator, Literal, + Tuple, TypeAlias, Union, cast, @@ -29,7 +32,7 @@ from aiortc import ( VideoStreamTrack, ) from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore -from aiortc.mediastreams import MediaStreamError +from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE, MediaStreamError from numpy import typing as npt from fastrtc.utils import ( @@ -56,6 +59,13 @@ VideoEmitType = ( VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType] +@dataclass +class VideoStreamHandler: + callable: VideoEventHandler + fps: int = 30 + skip_frames: bool = False + + class VideoCallback(VideoStreamTrack): """ This works for streaming input and output @@ -70,8 +80,10 @@ class VideoCallback(VideoStreamTrack): channel: DataChannel | None = None, set_additional_outputs: Callable | None = None, mode: Literal["send-receive", "send"] = "send-receive", + fps: int = 30, + skip_frames: bool = False, ) -> None: - super().__init__() # don't forget this! + super().__init__() self.track = track self.event_handler = event_handler self.latest_args: str | list[Any] = "not_set" @@ -81,6 +93,11 @@ class VideoCallback(VideoStreamTrack): self.mode = mode self.channel_set = asyncio.Event() self.has_started = False + self.fps = fps + self.frame_ptime = 1.0 / fps + self.skip_frames = skip_frames + self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue() + self.latest_frame = None def set_channel(self, channel: DataChannel): self.channel = channel @@ -127,21 +144,33 @@ class VideoCallback(VideoStreamTrack): if current_channel.get() != self.channel: current_channel.set(self.channel) - async def recv(self): # type: ignore - try: + async def accept_input(self): + self.has_started = True + while not self.thread_quit.is_set(): try: frame = cast(VideoFrame, await self.track.recv()) + self.latest_frame = frame + self.frame_queue.put_nowait(frame) except MediaStreamError: self.stop() return + def accept_input_in_background(self): + if not self.has_started: + asyncio.create_task(self.accept_input()) + + async def recv(self): # type: ignore + self.accept_input_in_background() + try: + frame = await self.frame_queue.get() + if self.skip_frames: + frame = self.latest_frame await self.wait_for_channel() - frame_array = frame.to_ndarray(format="bgr24") + frame_array = frame.to_ndarray(format="bgr24") # type: ignore if self.latest_args == "not_set": return frame args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array) - array, outputs = split_output(self.event_handler(*args)) if ( isinstance(outputs, AdditionalOutputs) @@ -161,7 +190,7 @@ class VideoCallback(VideoStreamTrack): pts, time_base = await self.next_timestamp() new_frame.pts = pts new_frame.time_base = time_base - + self.function_running = False return new_frame except Exception as e: logger.debug("exception %s", e) @@ -172,6 +201,21 @@ class VideoCallback(VideoStreamTrack): else: raise WebRTCError(str(e)) from e + async def next_timestamp(self) -> Tuple[int, fractions.Fraction]: + """Override to control frame rate""" + if self.readyState != "live": + raise MediaStreamError + + if hasattr(self, "_timestamp"): + self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE) + wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time() + if wait > 0: + await asyncio.sleep(wait) + else: + self._start = time.time() + self._timestamp = 0 + return self._timestamp, VIDEO_TIME_BASE + class StreamHandlerBase(ABC): def __init__( @@ -180,11 +224,13 @@ class StreamHandlerBase(ABC): output_sample_rate: int = 24000, output_frame_size: int = 960, input_sample_rate: int = 48000, + fps: int = 30, ) -> None: self.expected_layout = expected_layout self.output_sample_rate = output_sample_rate self.output_frame_size = output_frame_size self.input_sample_rate = input_sample_rate + self.fps = fps self.latest_args: list[Any] = [] self._resampler = None self._channel: DataChannel | None = None @@ -353,10 +399,16 @@ VideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler -HandlerType = StreamHandlerImpl | VideoStreamHandlerImpl | VideoEventHandler | Callable +HandlerType = ( + StreamHandlerImpl + | VideoStreamHandlerImpl + | VideoEventHandler + | Callable + | VideoStreamHandler +) -class VideoStreamHandler(VideoCallback): +class VideoStreamHandler_(VideoCallback): async def process_frames(self): while not self.thread_quit.is_set(): try: @@ -576,6 +628,7 @@ class ServerToClientVideo(VideoStreamTrack): event_handler: Callable, channel: DataChannel | None = None, set_additional_outputs: Callable | None = None, + fps: int = 30, ) -> None: super().__init__() # don't forget this! self.event_handler = event_handler @@ -584,6 +637,8 @@ class ServerToClientVideo(VideoStreamTrack): self.generator: Generator[Any, None, Any] | None = None self.channel = channel self.set_additional_outputs = set_additional_outputs + self.fps = fps + self.frame_ptime = 1.0 / fps def array_to_frame(self, array: np.ndarray) -> VideoFrame: return VideoFrame.from_ndarray(array, format="bgr24") @@ -595,6 +650,21 @@ class ServerToClientVideo(VideoStreamTrack): self.latest_args = list(args) self.args_set.set() + async def next_timestamp(self) -> Tuple[int, fractions.Fraction]: + """Override to control frame rate""" + if self.readyState != "live": + raise MediaStreamError + + if hasattr(self, "_timestamp"): + self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE) + wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time() + if wait > 0: + await asyncio.sleep(wait) + else: + self._start = time.time() + self._timestamp = 0 + return self._timestamp, VIDEO_TIME_BASE + async def recv(self): # type: ignore try: pts, time_base = await self.next_timestamp() diff --git a/backend/fastrtc/webrtc.py b/backend/fastrtc/webrtc.py index 4374e90..6363fcd 100644 --- a/backend/fastrtc/webrtc.py +++ b/backend/fastrtc/webrtc.py @@ -26,6 +26,7 @@ from .tracks import ( StreamHandlerBase, StreamHandlerImpl, VideoEventHandler, + VideoStreamHandler, ) from .webrtc_connection_mixin import WebRTCConnectionMixin @@ -254,6 +255,7 @@ class WebRTC(Component, WebRTCConnectionMixin): | StreamHandlerImpl | AudioVideoStreamHandlerImpl | VideoEventHandler + | VideoStreamHandler | None ) = None, inputs: Block | Sequence[Block] | set[Block] | None = None, @@ -263,6 +265,7 @@ class WebRTC(Component, WebRTCConnectionMixin): concurrency_id: str | None = None, time_limit: float | None = None, trigger: Callable | None = None, + send_input_on: Literal["submit", "change"] = "change", ): from gradio.blocks import Block @@ -304,7 +307,7 @@ class WebRTC(Component, WebRTCConnectionMixin): "In the webrtc stream event, the only output component must be the WebRTC component." ) for input_component in inputs[1:]: # type: ignore - if hasattr(input_component, "change"): + if hasattr(input_component, "change") and send_input_on == "change": input_component.change( # type: ignore self.set_input, inputs=inputs, @@ -314,6 +317,13 @@ class WebRTC(Component, WebRTCConnectionMixin): time_limit=None, js=js, ) + if hasattr(input_component, "submit") and send_input_on == "submit": + input_component.submit( # type: ignore + self.set_input, + inputs=inputs, + outputs=None, + concurrency_id=concurrency_id, + ) return self.tick( # type: ignore self.set_input, inputs=inputs, diff --git a/backend/fastrtc/webrtc_connection_mixin.py b/backend/fastrtc/webrtc_connection_mixin.py index b7e5733..fae614c 100644 --- a/backend/fastrtc/webrtc_connection_mixin.py +++ b/backend/fastrtc/webrtc_connection_mixin.py @@ -31,7 +31,9 @@ from fastrtc.tracks import ( StreamHandlerBase, StreamHandlerImpl, VideoCallback, + VideoEventHandler, VideoStreamHandler, + VideoStreamHandler_, ) from fastrtc.utils import ( AdditionalOutputs, @@ -41,7 +43,7 @@ from fastrtc.utils import ( Track = ( VideoCallback - | VideoStreamHandler + | VideoStreamHandler_ | AudioCallback | ServerToClientAudio | ServerToClientVideo @@ -174,6 +176,11 @@ class WebRTCConnectionMixin: handler.video_receive = webrtc_error_handler(handler.video_receive) # type: ignore if hasattr(handler, "video_emit"): handler.video_emit = webrtc_error_handler(handler.video_emit) # type: ignore + elif isinstance(self.event_handler, VideoStreamHandler): + self.event_handler.callable = cast( + VideoEventHandler, webrtc_error_handler(self.event_handler.callable) + ) + handler = self.event_handler else: handler = webrtc_error_handler(cast(Callable, self.event_handler)) @@ -208,17 +215,25 @@ class WebRTCConnectionMixin: handler = self.handlers[body["webrtc_id"]] if self.modality == "video" and track.kind == "video": + args = {} + handler_ = handler + if isinstance(handler, VideoStreamHandler): + handler_ = handler.callable + args["fps"] = handler.fps + args["skip_frames"] = handler.skip_frames cb = VideoCallback( relay.subscribe(track), - event_handler=cast(Callable, handler), + event_handler=cast(Callable, handler_), set_additional_outputs=set_outputs, mode=cast(Literal["send", "send-receive"], self.mode), + **args, ) elif self.modality == "audio-video" and track.kind == "video": - cb = VideoStreamHandler( + cb = VideoStreamHandler_( relay.subscribe(track), event_handler=handler, # type: ignore set_additional_outputs=set_outputs, + fps=cast(StreamHandlerImpl, handler).fps, ) elif self.modality in ["audio", "audio-video"] and track.kind == "audio": eh = cast(StreamHandlerImpl, handler) @@ -245,10 +260,17 @@ class WebRTCConnectionMixin: if self.mode == "receive": if self.modality == "video": - cb = ServerToClientVideo( - cast(Callable, self.event_handler), - set_additional_outputs=set_outputs, - ) + if isinstance(self.event_handler, VideoStreamHandler): + cb = ServerToClientVideo( + cast(Callable, self.event_handler.callable), + set_additional_outputs=set_outputs, + fps=self.event_handler.fps, + ) + else: + cb = ServerToClientVideo( + cast(Callable, self.event_handler), + set_additional_outputs=set_outputs, + ) elif self.modality == "audio": cb = ServerToClientAudio( cast(Callable, self.event_handler), diff --git a/docs/userguide/video.md b/docs/userguide/video.md index a05d10a..8d1b6ef 100644 --- a/docs/userguide/video.md +++ b/docs/userguide/video.md @@ -55,3 +55,73 @@ and set the `mode="receive"` in the `WebRTC` component. mode="receive" ) ``` + +## Skipping Frames + +If your event handler is not quite real-time yet, then the output feed will look very laggy. + +To fix this, you can set the `skip_frames` parameter to `True`. This will skip the frames that are received while the event handler is still running. + +``` py title="Skipping Frames" +import time + +import numpy as np +from fastrtc import Stream, VideoStreamHandler + + +def process_image(image): + time.sleep( + 0.2 + ) # Simulating 200ms processing time per frame; input arrives faster (30 FPS). + return np.flip(image, axis=0) + + +stream = Stream( + handler=VideoStreamHandler(process_image, skip_frames=True), + modality="video", + mode="send-receive", +) + +stream.ui.launch() +``` + +## Setting the Output Frame Rate + +You can set the output frame rate by setting the `fps` parameter in the `VideoStreamHandler`. + +``` py title="Setting the Output Frame Rate" +def generation(): + url = "https://github.com/user-attachments/assets/9636dc97-4fee-46bb-abb8-b92e69c08c71" + cap = cv2.VideoCapture(url) + iterating = True + + # FPS calculation variables + frame_count = 0 + start_time = time.time() + fps = 0 + + while iterating: + iterating, frame = cap.read() + + # Calculate and print FPS + frame_count += 1 + elapsed_time = time.time() - start_time + if elapsed_time >= 1.0: # Update FPS every second + fps = frame_count / elapsed_time + yield frame, AdditionalOutputs(fps) + frame_count = 0 + start_time = time.time() + else: + yield frame + + +stream = Stream( + handler=VideoStreamHandler(generation, fps=60), + modality="video", + mode="receive", + additional_outputs=[gr.Number(label="FPS")], + additional_outputs_handler=lambda prev, cur: cur, +) + +stream.ui.launch() +```