Some Video Fixes (#200)

* FPS control:

* add code

* Add code
This commit is contained in:
Freddy Boulton
2025-03-20 20:45:46 -04:00
committed by GitHub
parent bce7cb95a6
commit 3fed4cb2ad
6 changed files with 208 additions and 20 deletions

View File

@@ -21,6 +21,7 @@ from .tracks import (
AudioVideoStreamHandler, AudioVideoStreamHandler,
StreamHandler, StreamHandler,
VideoEmitType, VideoEmitType,
VideoStreamHandler,
) )
from .utils import ( from .utils import (
AdditionalOutputs, AdditionalOutputs,
@@ -73,4 +74,5 @@ __all__ = [
"PauseDetectionModel", "PauseDetectionModel",
"get_silero_model", "get_silero_model",
"SileroVadOptions", "SileroVadOptions",
"VideoStreamHandler",
] ]

View File

@@ -46,6 +46,11 @@ class UIArgs(TypedDict):
"""Color of the pulse animation. Default is var(--color-accent) of the demo theme.""" """Color of the pulse animation. Default is var(--color-accent) of the demo theme."""
icon_radius: NotRequired[int] icon_radius: NotRequired[int]
"""Border radius of the icon button expressed as a percentage of the button size. Default is 50%.""" """Border radius of the icon button expressed as a percentage of the button size. Default is 50%."""
send_input_on: NotRequired[Literal["submit", "change"]]
"""When to send the input to the handler. Default is "change".
If "submit", the input will be sent when the submit event is triggered by the user.
If "change", the input will be sent whenever the user changes the input value.
"""
class Stream(WebRTCConnectionMixin): class Stream(WebRTCConnectionMixin):
@@ -229,6 +234,7 @@ class Stream(WebRTCConnectionMixin):
trigger=button.click, trigger=button.click,
time_limit=self.time_limit, time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
) )
if additional_output_components: if additional_output_components:
assert self.additional_outputs_handler assert self.additional_outputs_handler
@@ -275,6 +281,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[output_video], outputs=[output_video],
time_limit=self.time_limit, time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
) )
if additional_output_components: if additional_output_components:
assert self.additional_outputs_handler assert self.additional_outputs_handler
@@ -325,6 +332,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image], outputs=[image],
time_limit=self.time_limit, time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
) )
if additional_output_components: if additional_output_components:
assert self.additional_outputs_handler assert self.additional_outputs_handler
@@ -377,6 +385,7 @@ class Stream(WebRTCConnectionMixin):
trigger=button.click, trigger=button.click,
time_limit=self.time_limit, time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
) )
if additional_output_components: if additional_output_components:
assert self.additional_outputs_handler assert self.additional_outputs_handler
@@ -428,6 +437,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image], outputs=[image],
time_limit=self.time_limit, time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
) )
if additional_output_components: if additional_output_components:
assert self.additional_outputs_handler assert self.additional_outputs_handler
@@ -480,6 +490,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image], outputs=[image],
time_limit=self.time_limit, time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
) )
if additional_output_components: if additional_output_components:
assert self.additional_outputs_handler assert self.additional_outputs_handler
@@ -489,7 +500,9 @@ class Stream(WebRTCConnectionMixin):
outputs=additional_output_components, outputs=additional_output_components,
) )
elif self.modality == "audio-video" and self.mode == "send-receive": elif self.modality == "audio-video" and self.mode == "send-receive":
with gr.Blocks() as demo: css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks(css=css) as demo:
gr.HTML( gr.HTML(
f""" f"""
<h1 style='text-align: center'> <h1 style='text-align: center'>
@@ -506,8 +519,8 @@ class Stream(WebRTCConnectionMixin):
""" """
) )
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column(elem_classes=["my-column"]):
with gr.Group(): with gr.Group(elem_classes=["my-group"]):
image = WebRTC( image = WebRTC(
label="Stream", label="Stream",
rtc_configuration=self.rtc_configuration, rtc_configuration=self.rtc_configuration,
@@ -532,6 +545,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image], outputs=[image],
time_limit=self.time_limit, time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
) )
if additional_output_components: if additional_output_components:
assert self.additional_outputs_handler assert self.additional_outputs_handler

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import fractions
import functools import functools
import inspect import inspect
import logging import logging
@@ -11,10 +12,12 @@ import time
import traceback import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass
from typing import ( from typing import (
Any, Any,
Generator, Generator,
Literal, Literal,
Tuple,
TypeAlias, TypeAlias,
Union, Union,
cast, cast,
@@ -29,7 +32,7 @@ from aiortc import (
VideoStreamTrack, VideoStreamTrack,
) )
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
from aiortc.mediastreams import MediaStreamError from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE, MediaStreamError
from numpy import typing as npt from numpy import typing as npt
from fastrtc.utils import ( from fastrtc.utils import (
@@ -56,6 +59,13 @@ VideoEmitType = (
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType] VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
@dataclass
class VideoStreamHandler:
callable: VideoEventHandler
fps: int = 30
skip_frames: bool = False
class VideoCallback(VideoStreamTrack): class VideoCallback(VideoStreamTrack):
""" """
This works for streaming input and output This works for streaming input and output
@@ -70,8 +80,10 @@ class VideoCallback(VideoStreamTrack):
channel: DataChannel | None = None, channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None, set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive", mode: Literal["send-receive", "send"] = "send-receive",
fps: int = 30,
skip_frames: bool = False,
) -> None: ) -> None:
super().__init__() # don't forget this! super().__init__()
self.track = track self.track = track
self.event_handler = event_handler self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set" self.latest_args: str | list[Any] = "not_set"
@@ -81,6 +93,11 @@ class VideoCallback(VideoStreamTrack):
self.mode = mode self.mode = mode
self.channel_set = asyncio.Event() self.channel_set = asyncio.Event()
self.has_started = False self.has_started = False
self.fps = fps
self.frame_ptime = 1.0 / fps
self.skip_frames = skip_frames
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
self.latest_frame = None
def set_channel(self, channel: DataChannel): def set_channel(self, channel: DataChannel):
self.channel = channel self.channel = channel
@@ -127,21 +144,33 @@ class VideoCallback(VideoStreamTrack):
if current_channel.get() != self.channel: if current_channel.get() != self.channel:
current_channel.set(self.channel) current_channel.set(self.channel)
async def recv(self): # type: ignore async def accept_input(self):
try: self.has_started = True
while not self.thread_quit.is_set():
try: try:
frame = cast(VideoFrame, await self.track.recv()) frame = cast(VideoFrame, await self.track.recv())
self.latest_frame = frame
self.frame_queue.put_nowait(frame)
except MediaStreamError: except MediaStreamError:
self.stop() self.stop()
return return
def accept_input_in_background(self):
if not self.has_started:
asyncio.create_task(self.accept_input())
async def recv(self): # type: ignore
self.accept_input_in_background()
try:
frame = await self.frame_queue.get()
if self.skip_frames:
frame = self.latest_frame
await self.wait_for_channel() await self.wait_for_channel()
frame_array = frame.to_ndarray(format="bgr24") frame_array = frame.to_ndarray(format="bgr24") # type: ignore
if self.latest_args == "not_set": if self.latest_args == "not_set":
return frame return frame
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array) args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array, outputs = split_output(self.event_handler(*args)) array, outputs = split_output(self.event_handler(*args))
if ( if (
isinstance(outputs, AdditionalOutputs) isinstance(outputs, AdditionalOutputs)
@@ -161,7 +190,7 @@ class VideoCallback(VideoStreamTrack):
pts, time_base = await self.next_timestamp() pts, time_base = await self.next_timestamp()
new_frame.pts = pts new_frame.pts = pts
new_frame.time_base = time_base new_frame.time_base = time_base
self.function_running = False
return new_frame return new_frame
except Exception as e: except Exception as e:
logger.debug("exception %s", e) logger.debug("exception %s", e)
@@ -172,6 +201,21 @@ class VideoCallback(VideoStreamTrack):
else: else:
raise WebRTCError(str(e)) from e raise WebRTCError(str(e)) from e
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError
if hasattr(self, "_timestamp"):
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
if wait > 0:
await asyncio.sleep(wait)
else:
self._start = time.time()
self._timestamp = 0
return self._timestamp, VIDEO_TIME_BASE
class StreamHandlerBase(ABC): class StreamHandlerBase(ABC):
def __init__( def __init__(
@@ -180,11 +224,13 @@ class StreamHandlerBase(ABC):
output_sample_rate: int = 24000, output_sample_rate: int = 24000,
output_frame_size: int = 960, output_frame_size: int = 960,
input_sample_rate: int = 48000, input_sample_rate: int = 48000,
fps: int = 30,
) -> None: ) -> None:
self.expected_layout = expected_layout self.expected_layout = expected_layout
self.output_sample_rate = output_sample_rate self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size self.output_frame_size = output_frame_size
self.input_sample_rate = input_sample_rate self.input_sample_rate = input_sample_rate
self.fps = fps
self.latest_args: list[Any] = [] self.latest_args: list[Any] = []
self._resampler = None self._resampler = None
self._channel: DataChannel | None = None self._channel: DataChannel | None = None
@@ -353,10 +399,16 @@ VideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler
HandlerType = StreamHandlerImpl | VideoStreamHandlerImpl | VideoEventHandler | Callable HandlerType = (
StreamHandlerImpl
| VideoStreamHandlerImpl
| VideoEventHandler
| Callable
| VideoStreamHandler
)
class VideoStreamHandler(VideoCallback): class VideoStreamHandler_(VideoCallback):
async def process_frames(self): async def process_frames(self):
while not self.thread_quit.is_set(): while not self.thread_quit.is_set():
try: try:
@@ -576,6 +628,7 @@ class ServerToClientVideo(VideoStreamTrack):
event_handler: Callable, event_handler: Callable,
channel: DataChannel | None = None, channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None, set_additional_outputs: Callable | None = None,
fps: int = 30,
) -> None: ) -> None:
super().__init__() # don't forget this! super().__init__() # don't forget this!
self.event_handler = event_handler self.event_handler = event_handler
@@ -584,6 +637,8 @@ class ServerToClientVideo(VideoStreamTrack):
self.generator: Generator[Any, None, Any] | None = None self.generator: Generator[Any, None, Any] | None = None
self.channel = channel self.channel = channel
self.set_additional_outputs = set_additional_outputs self.set_additional_outputs = set_additional_outputs
self.fps = fps
self.frame_ptime = 1.0 / fps
def array_to_frame(self, array: np.ndarray) -> VideoFrame: def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24") return VideoFrame.from_ndarray(array, format="bgr24")
@@ -595,6 +650,21 @@ class ServerToClientVideo(VideoStreamTrack):
self.latest_args = list(args) self.latest_args = list(args)
self.args_set.set() self.args_set.set()
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError
if hasattr(self, "_timestamp"):
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
if wait > 0:
await asyncio.sleep(wait)
else:
self._start = time.time()
self._timestamp = 0
return self._timestamp, VIDEO_TIME_BASE
async def recv(self): # type: ignore async def recv(self): # type: ignore
try: try:
pts, time_base = await self.next_timestamp() pts, time_base = await self.next_timestamp()

View File

@@ -26,6 +26,7 @@ from .tracks import (
StreamHandlerBase, StreamHandlerBase,
StreamHandlerImpl, StreamHandlerImpl,
VideoEventHandler, VideoEventHandler,
VideoStreamHandler,
) )
from .webrtc_connection_mixin import WebRTCConnectionMixin from .webrtc_connection_mixin import WebRTCConnectionMixin
@@ -254,6 +255,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
| StreamHandlerImpl | StreamHandlerImpl
| AudioVideoStreamHandlerImpl | AudioVideoStreamHandlerImpl
| VideoEventHandler | VideoEventHandler
| VideoStreamHandler
| None | None
) = None, ) = None,
inputs: Block | Sequence[Block] | set[Block] | None = None, inputs: Block | Sequence[Block] | set[Block] | None = None,
@@ -263,6 +265,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
concurrency_id: str | None = None, concurrency_id: str | None = None,
time_limit: float | None = None, time_limit: float | None = None,
trigger: Callable | None = None, trigger: Callable | None = None,
send_input_on: Literal["submit", "change"] = "change",
): ):
from gradio.blocks import Block from gradio.blocks import Block
@@ -304,7 +307,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
"In the webrtc stream event, the only output component must be the WebRTC component." "In the webrtc stream event, the only output component must be the WebRTC component."
) )
for input_component in inputs[1:]: # type: ignore for input_component in inputs[1:]: # type: ignore
if hasattr(input_component, "change"): if hasattr(input_component, "change") and send_input_on == "change":
input_component.change( # type: ignore input_component.change( # type: ignore
self.set_input, self.set_input,
inputs=inputs, inputs=inputs,
@@ -314,6 +317,13 @@ class WebRTC(Component, WebRTCConnectionMixin):
time_limit=None, time_limit=None,
js=js, js=js,
) )
if hasattr(input_component, "submit") and send_input_on == "submit":
input_component.submit( # type: ignore
self.set_input,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
)
return self.tick( # type: ignore return self.tick( # type: ignore
self.set_input, self.set_input,
inputs=inputs, inputs=inputs,

View File

@@ -31,7 +31,9 @@ from fastrtc.tracks import (
StreamHandlerBase, StreamHandlerBase,
StreamHandlerImpl, StreamHandlerImpl,
VideoCallback, VideoCallback,
VideoEventHandler,
VideoStreamHandler, VideoStreamHandler,
VideoStreamHandler_,
) )
from fastrtc.utils import ( from fastrtc.utils import (
AdditionalOutputs, AdditionalOutputs,
@@ -41,7 +43,7 @@ from fastrtc.utils import (
Track = ( Track = (
VideoCallback VideoCallback
| VideoStreamHandler | VideoStreamHandler_
| AudioCallback | AudioCallback
| ServerToClientAudio | ServerToClientAudio
| ServerToClientVideo | ServerToClientVideo
@@ -174,6 +176,11 @@ class WebRTCConnectionMixin:
handler.video_receive = webrtc_error_handler(handler.video_receive) # type: ignore handler.video_receive = webrtc_error_handler(handler.video_receive) # type: ignore
if hasattr(handler, "video_emit"): if hasattr(handler, "video_emit"):
handler.video_emit = webrtc_error_handler(handler.video_emit) # type: ignore handler.video_emit = webrtc_error_handler(handler.video_emit) # type: ignore
elif isinstance(self.event_handler, VideoStreamHandler):
self.event_handler.callable = cast(
VideoEventHandler, webrtc_error_handler(self.event_handler.callable)
)
handler = self.event_handler
else: else:
handler = webrtc_error_handler(cast(Callable, self.event_handler)) handler = webrtc_error_handler(cast(Callable, self.event_handler))
@@ -208,17 +215,25 @@ class WebRTCConnectionMixin:
handler = self.handlers[body["webrtc_id"]] handler = self.handlers[body["webrtc_id"]]
if self.modality == "video" and track.kind == "video": if self.modality == "video" and track.kind == "video":
args = {}
handler_ = handler
if isinstance(handler, VideoStreamHandler):
handler_ = handler.callable
args["fps"] = handler.fps
args["skip_frames"] = handler.skip_frames
cb = VideoCallback( cb = VideoCallback(
relay.subscribe(track), relay.subscribe(track),
event_handler=cast(Callable, handler), event_handler=cast(Callable, handler_),
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
mode=cast(Literal["send", "send-receive"], self.mode), mode=cast(Literal["send", "send-receive"], self.mode),
**args,
) )
elif self.modality == "audio-video" and track.kind == "video": elif self.modality == "audio-video" and track.kind == "video":
cb = VideoStreamHandler( cb = VideoStreamHandler_(
relay.subscribe(track), relay.subscribe(track),
event_handler=handler, # type: ignore event_handler=handler, # type: ignore
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
fps=cast(StreamHandlerImpl, handler).fps,
) )
elif self.modality in ["audio", "audio-video"] and track.kind == "audio": elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
eh = cast(StreamHandlerImpl, handler) eh = cast(StreamHandlerImpl, handler)
@@ -245,10 +260,17 @@ class WebRTCConnectionMixin:
if self.mode == "receive": if self.mode == "receive":
if self.modality == "video": if self.modality == "video":
cb = ServerToClientVideo( if isinstance(self.event_handler, VideoStreamHandler):
cast(Callable, self.event_handler), cb = ServerToClientVideo(
set_additional_outputs=set_outputs, cast(Callable, self.event_handler.callable),
) set_additional_outputs=set_outputs,
fps=self.event_handler.fps,
)
else:
cb = ServerToClientVideo(
cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
)
elif self.modality == "audio": elif self.modality == "audio":
cb = ServerToClientAudio( cb = ServerToClientAudio(
cast(Callable, self.event_handler), cast(Callable, self.event_handler),

View File

@@ -55,3 +55,73 @@ and set the `mode="receive"` in the `WebRTC` component.
mode="receive" mode="receive"
) )
``` ```
## Skipping Frames
If your event handler is not quite real-time yet, then the output feed will look very laggy.
To fix this, you can set the `skip_frames` parameter to `True`. This will skip the frames that are received while the event handler is still running.
``` py title="Skipping Frames"
import time
import numpy as np
from fastrtc import Stream, VideoStreamHandler
def process_image(image):
time.sleep(
0.2
) # Simulating 200ms processing time per frame; input arrives faster (30 FPS).
return np.flip(image, axis=0)
stream = Stream(
handler=VideoStreamHandler(process_image, skip_frames=True),
modality="video",
mode="send-receive",
)
stream.ui.launch()
```
## Setting the Output Frame Rate
You can set the output frame rate by setting the `fps` parameter in the `VideoStreamHandler`.
``` py title="Setting the Output Frame Rate"
def generation():
url = "https://github.com/user-attachments/assets/9636dc97-4fee-46bb-abb8-b92e69c08c71"
cap = cv2.VideoCapture(url)
iterating = True
# FPS calculation variables
frame_count = 0
start_time = time.time()
fps = 0
while iterating:
iterating, frame = cap.read()
# Calculate and print FPS
frame_count += 1
elapsed_time = time.time() - start_time
if elapsed_time >= 1.0: # Update FPS every second
fps = frame_count / elapsed_time
yield frame, AdditionalOutputs(fps)
frame_count = 0
start_time = time.time()
else:
yield frame
stream = Stream(
handler=VideoStreamHandler(generation, fps=60),
modality="video",
mode="receive",
additional_outputs=[gr.Number(label="FPS")],
additional_outputs_handler=lambda prev, cur: cur,
)
stream.ui.launch()
```