diff --git a/backend/fastrtc/__init__.py b/backend/fastrtc/__init__.py
index 4c646e5..797c1fb 100644
--- a/backend/fastrtc/__init__.py
+++ b/backend/fastrtc/__init__.py
@@ -21,6 +21,7 @@ from .tracks import (
AudioVideoStreamHandler,
StreamHandler,
VideoEmitType,
+ VideoStreamHandler,
)
from .utils import (
AdditionalOutputs,
@@ -73,4 +74,5 @@ __all__ = [
"PauseDetectionModel",
"get_silero_model",
"SileroVadOptions",
+ "VideoStreamHandler",
]
diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py
index b120f91..0663ad9 100644
--- a/backend/fastrtc/stream.py
+++ b/backend/fastrtc/stream.py
@@ -46,6 +46,11 @@ class UIArgs(TypedDict):
"""Color of the pulse animation. Default is var(--color-accent) of the demo theme."""
icon_radius: NotRequired[int]
"""Border radius of the icon button expressed as a percentage of the button size. Default is 50%."""
+ send_input_on: NotRequired[Literal["submit", "change"]]
+ """When to send the input to the handler. Default is "change".
+ If "submit", the input will be sent when the submit event is triggered by the user.
+ If "change", the input will be sent whenever the user changes the input value.
+ """
class Stream(WebRTCConnectionMixin):
@@ -229,6 +234,7 @@ class Stream(WebRTCConnectionMixin):
trigger=button.click,
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
+ send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -275,6 +281,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[output_video],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
+ send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -325,6 +332,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
+ send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -377,6 +385,7 @@ class Stream(WebRTCConnectionMixin):
trigger=button.click,
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
+ send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -428,6 +437,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
+ send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -480,6 +490,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
+ send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -489,7 +500,9 @@ class Stream(WebRTCConnectionMixin):
outputs=additional_output_components,
)
elif self.modality == "audio-video" and self.mode == "send-receive":
- with gr.Blocks() as demo:
+ css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
+ .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
+ with gr.Blocks(css=css) as demo:
gr.HTML(
f"""
@@ -506,8 +519,8 @@ class Stream(WebRTCConnectionMixin):
"""
)
with gr.Row():
- with gr.Column():
- with gr.Group():
+ with gr.Column(elem_classes=["my-column"]):
+ with gr.Group(elem_classes=["my-group"]):
image = WebRTC(
label="Stream",
rtc_configuration=self.rtc_configuration,
@@ -532,6 +545,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
+ send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
diff --git a/backend/fastrtc/tracks.py b/backend/fastrtc/tracks.py
index 0a9c5aa..9be63e2 100644
--- a/backend/fastrtc/tracks.py
+++ b/backend/fastrtc/tracks.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
+import fractions
import functools
import inspect
import logging
@@ -11,10 +12,12 @@ import time
import traceback
from abc import ABC, abstractmethod
from collections.abc import Callable
+from dataclasses import dataclass
from typing import (
Any,
Generator,
Literal,
+ Tuple,
TypeAlias,
Union,
cast,
@@ -29,7 +32,7 @@ from aiortc import (
VideoStreamTrack,
)
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
-from aiortc.mediastreams import MediaStreamError
+from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE, MediaStreamError
from numpy import typing as npt
from fastrtc.utils import (
@@ -56,6 +59,13 @@ VideoEmitType = (
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
+@dataclass
+class VideoStreamHandler:
+ callable: VideoEventHandler
+ fps: int = 30
+ skip_frames: bool = False
+
+
class VideoCallback(VideoStreamTrack):
"""
This works for streaming input and output
@@ -70,8 +80,10 @@ class VideoCallback(VideoStreamTrack):
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive",
+ fps: int = 30,
+ skip_frames: bool = False,
) -> None:
- super().__init__() # don't forget this!
+ super().__init__()
self.track = track
self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set"
@@ -81,6 +93,11 @@ class VideoCallback(VideoStreamTrack):
self.mode = mode
self.channel_set = asyncio.Event()
self.has_started = False
+ self.fps = fps
+ self.frame_ptime = 1.0 / fps
+ self.skip_frames = skip_frames
+ self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
+ self.latest_frame = None
def set_channel(self, channel: DataChannel):
self.channel = channel
@@ -127,21 +144,33 @@ class VideoCallback(VideoStreamTrack):
if current_channel.get() != self.channel:
current_channel.set(self.channel)
- async def recv(self): # type: ignore
- try:
+ async def accept_input(self):
+ self.has_started = True
+ while not self.thread_quit.is_set():
try:
frame = cast(VideoFrame, await self.track.recv())
+ self.latest_frame = frame
+ self.frame_queue.put_nowait(frame)
except MediaStreamError:
self.stop()
return
+ def accept_input_in_background(self):
+ if not self.has_started:
+ asyncio.create_task(self.accept_input())
+
+ async def recv(self): # type: ignore
+ self.accept_input_in_background()
+ try:
+ frame = await self.frame_queue.get()
+ if self.skip_frames:
+ frame = self.latest_frame
await self.wait_for_channel()
- frame_array = frame.to_ndarray(format="bgr24")
+ frame_array = frame.to_ndarray(format="bgr24") # type: ignore
if self.latest_args == "not_set":
return frame
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
-
array, outputs = split_output(self.event_handler(*args))
if (
isinstance(outputs, AdditionalOutputs)
@@ -161,7 +190,7 @@ class VideoCallback(VideoStreamTrack):
pts, time_base = await self.next_timestamp()
new_frame.pts = pts
new_frame.time_base = time_base
-
+ self.function_running = False
return new_frame
except Exception as e:
logger.debug("exception %s", e)
@@ -172,6 +201,21 @@ class VideoCallback(VideoStreamTrack):
else:
raise WebRTCError(str(e)) from e
+ async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
+ """Override to control frame rate"""
+ if self.readyState != "live":
+ raise MediaStreamError
+
+ if hasattr(self, "_timestamp"):
+ self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
+ wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
+ if wait > 0:
+ await asyncio.sleep(wait)
+ else:
+ self._start = time.time()
+ self._timestamp = 0
+ return self._timestamp, VIDEO_TIME_BASE
+
class StreamHandlerBase(ABC):
def __init__(
@@ -180,11 +224,13 @@ class StreamHandlerBase(ABC):
output_sample_rate: int = 24000,
output_frame_size: int = 960,
input_sample_rate: int = 48000,
+ fps: int = 30,
) -> None:
self.expected_layout = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.input_sample_rate = input_sample_rate
+ self.fps = fps
self.latest_args: list[Any] = []
self._resampler = None
self._channel: DataChannel | None = None
@@ -353,10 +399,16 @@ VideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler
-HandlerType = StreamHandlerImpl | VideoStreamHandlerImpl | VideoEventHandler | Callable
+HandlerType = (
+ StreamHandlerImpl
+ | VideoStreamHandlerImpl
+ | VideoEventHandler
+ | Callable
+ | VideoStreamHandler
+)
-class VideoStreamHandler(VideoCallback):
+class VideoStreamHandler_(VideoCallback):
async def process_frames(self):
while not self.thread_quit.is_set():
try:
@@ -576,6 +628,7 @@ class ServerToClientVideo(VideoStreamTrack):
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
+ fps: int = 30,
) -> None:
super().__init__() # don't forget this!
self.event_handler = event_handler
@@ -584,6 +637,8 @@ class ServerToClientVideo(VideoStreamTrack):
self.generator: Generator[Any, None, Any] | None = None
self.channel = channel
self.set_additional_outputs = set_additional_outputs
+ self.fps = fps
+ self.frame_ptime = 1.0 / fps
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
@@ -595,6 +650,21 @@ class ServerToClientVideo(VideoStreamTrack):
self.latest_args = list(args)
self.args_set.set()
+ async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
+ """Override to control frame rate"""
+ if self.readyState != "live":
+ raise MediaStreamError
+
+ if hasattr(self, "_timestamp"):
+ self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
+ wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
+ if wait > 0:
+ await asyncio.sleep(wait)
+ else:
+ self._start = time.time()
+ self._timestamp = 0
+ return self._timestamp, VIDEO_TIME_BASE
+
async def recv(self): # type: ignore
try:
pts, time_base = await self.next_timestamp()
diff --git a/backend/fastrtc/webrtc.py b/backend/fastrtc/webrtc.py
index 4374e90..6363fcd 100644
--- a/backend/fastrtc/webrtc.py
+++ b/backend/fastrtc/webrtc.py
@@ -26,6 +26,7 @@ from .tracks import (
StreamHandlerBase,
StreamHandlerImpl,
VideoEventHandler,
+ VideoStreamHandler,
)
from .webrtc_connection_mixin import WebRTCConnectionMixin
@@ -254,6 +255,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
| StreamHandlerImpl
| AudioVideoStreamHandlerImpl
| VideoEventHandler
+ | VideoStreamHandler
| None
) = None,
inputs: Block | Sequence[Block] | set[Block] | None = None,
@@ -263,6 +265,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
concurrency_id: str | None = None,
time_limit: float | None = None,
trigger: Callable | None = None,
+ send_input_on: Literal["submit", "change"] = "change",
):
from gradio.blocks import Block
@@ -304,7 +307,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
"In the webrtc stream event, the only output component must be the WebRTC component."
)
for input_component in inputs[1:]: # type: ignore
- if hasattr(input_component, "change"):
+ if hasattr(input_component, "change") and send_input_on == "change":
input_component.change( # type: ignore
self.set_input,
inputs=inputs,
@@ -314,6 +317,13 @@ class WebRTC(Component, WebRTCConnectionMixin):
time_limit=None,
js=js,
)
+ if hasattr(input_component, "submit") and send_input_on == "submit":
+ input_component.submit( # type: ignore
+ self.set_input,
+ inputs=inputs,
+ outputs=None,
+ concurrency_id=concurrency_id,
+ )
return self.tick( # type: ignore
self.set_input,
inputs=inputs,
diff --git a/backend/fastrtc/webrtc_connection_mixin.py b/backend/fastrtc/webrtc_connection_mixin.py
index b7e5733..fae614c 100644
--- a/backend/fastrtc/webrtc_connection_mixin.py
+++ b/backend/fastrtc/webrtc_connection_mixin.py
@@ -31,7 +31,9 @@ from fastrtc.tracks import (
StreamHandlerBase,
StreamHandlerImpl,
VideoCallback,
+ VideoEventHandler,
VideoStreamHandler,
+ VideoStreamHandler_,
)
from fastrtc.utils import (
AdditionalOutputs,
@@ -41,7 +43,7 @@ from fastrtc.utils import (
Track = (
VideoCallback
- | VideoStreamHandler
+ | VideoStreamHandler_
| AudioCallback
| ServerToClientAudio
| ServerToClientVideo
@@ -174,6 +176,11 @@ class WebRTCConnectionMixin:
handler.video_receive = webrtc_error_handler(handler.video_receive) # type: ignore
if hasattr(handler, "video_emit"):
handler.video_emit = webrtc_error_handler(handler.video_emit) # type: ignore
+ elif isinstance(self.event_handler, VideoStreamHandler):
+ self.event_handler.callable = cast(
+ VideoEventHandler, webrtc_error_handler(self.event_handler.callable)
+ )
+ handler = self.event_handler
else:
handler = webrtc_error_handler(cast(Callable, self.event_handler))
@@ -208,17 +215,25 @@ class WebRTCConnectionMixin:
handler = self.handlers[body["webrtc_id"]]
if self.modality == "video" and track.kind == "video":
+ args = {}
+ handler_ = handler
+ if isinstance(handler, VideoStreamHandler):
+ handler_ = handler.callable
+ args["fps"] = handler.fps
+ args["skip_frames"] = handler.skip_frames
cb = VideoCallback(
relay.subscribe(track),
- event_handler=cast(Callable, handler),
+ event_handler=cast(Callable, handler_),
set_additional_outputs=set_outputs,
mode=cast(Literal["send", "send-receive"], self.mode),
+ **args,
)
elif self.modality == "audio-video" and track.kind == "video":
- cb = VideoStreamHandler(
+ cb = VideoStreamHandler_(
relay.subscribe(track),
event_handler=handler, # type: ignore
set_additional_outputs=set_outputs,
+ fps=cast(StreamHandlerImpl, handler).fps,
)
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
eh = cast(StreamHandlerImpl, handler)
@@ -245,10 +260,17 @@ class WebRTCConnectionMixin:
if self.mode == "receive":
if self.modality == "video":
- cb = ServerToClientVideo(
- cast(Callable, self.event_handler),
- set_additional_outputs=set_outputs,
- )
+ if isinstance(self.event_handler, VideoStreamHandler):
+ cb = ServerToClientVideo(
+ cast(Callable, self.event_handler.callable),
+ set_additional_outputs=set_outputs,
+ fps=self.event_handler.fps,
+ )
+ else:
+ cb = ServerToClientVideo(
+ cast(Callable, self.event_handler),
+ set_additional_outputs=set_outputs,
+ )
elif self.modality == "audio":
cb = ServerToClientAudio(
cast(Callable, self.event_handler),
diff --git a/docs/userguide/video.md b/docs/userguide/video.md
index a05d10a..8d1b6ef 100644
--- a/docs/userguide/video.md
+++ b/docs/userguide/video.md
@@ -55,3 +55,73 @@ and set the `mode="receive"` in the `WebRTC` component.
mode="receive"
)
```
+
+## Skipping Frames
+
+If your event handler is not quite real-time yet, then the output feed will look very laggy.
+
+To fix this, you can set the `skip_frames` parameter to `True`. This will skip the frames that are received while the event handler is still running.
+
+``` py title="Skipping Frames"
+import time
+
+import numpy as np
+from fastrtc import Stream, VideoStreamHandler
+
+
+def process_image(image):
+ time.sleep(
+ 0.2
+ ) # Simulating 200ms processing time per frame; input arrives faster (30 FPS).
+ return np.flip(image, axis=0)
+
+
+stream = Stream(
+ handler=VideoStreamHandler(process_image, skip_frames=True),
+ modality="video",
+ mode="send-receive",
+)
+
+stream.ui.launch()
+```
+
+## Setting the Output Frame Rate
+
+You can set the output frame rate by setting the `fps` parameter in the `VideoStreamHandler`.
+
+``` py title="Setting the Output Frame Rate"
+def generation():
+ url = "https://github.com/user-attachments/assets/9636dc97-4fee-46bb-abb8-b92e69c08c71"
+ cap = cv2.VideoCapture(url)
+ iterating = True
+
+ # FPS calculation variables
+ frame_count = 0
+ start_time = time.time()
+ fps = 0
+
+ while iterating:
+ iterating, frame = cap.read()
+
+ # Calculate and print FPS
+ frame_count += 1
+ elapsed_time = time.time() - start_time
+ if elapsed_time >= 1.0: # Update FPS every second
+ fps = frame_count / elapsed_time
+ yield frame, AdditionalOutputs(fps)
+ frame_count = 0
+ start_time = time.time()
+ else:
+ yield frame
+
+
+stream = Stream(
+ handler=VideoStreamHandler(generation, fps=60),
+ modality="video",
+ mode="receive",
+ additional_outputs=[gr.Number(label="FPS")],
+ additional_outputs_handler=lambda prev, cur: cur,
+)
+
+stream.ui.launch()
+```