Some Video Fixes (#200)

* FPS control:

* add code

* Add code
This commit is contained in:
Freddy Boulton
2025-03-20 20:45:46 -04:00
committed by GitHub
parent bce7cb95a6
commit 3fed4cb2ad
6 changed files with 208 additions and 20 deletions

View File

@@ -21,6 +21,7 @@ from .tracks import (
AudioVideoStreamHandler,
StreamHandler,
VideoEmitType,
VideoStreamHandler,
)
from .utils import (
AdditionalOutputs,
@@ -73,4 +74,5 @@ __all__ = [
"PauseDetectionModel",
"get_silero_model",
"SileroVadOptions",
"VideoStreamHandler",
]

View File

@@ -46,6 +46,11 @@ class UIArgs(TypedDict):
"""Color of the pulse animation. Default is var(--color-accent) of the demo theme."""
icon_radius: NotRequired[int]
"""Border radius of the icon button expressed as a percentage of the button size. Default is 50%."""
send_input_on: NotRequired[Literal["submit", "change"]]
"""When to send the input to the handler. Default is "change".
If "submit", the input will be sent when the submit event is triggered by the user.
If "change", the input will be sent whenever the user changes the input value.
"""
class Stream(WebRTCConnectionMixin):
@@ -229,6 +234,7 @@ class Stream(WebRTCConnectionMixin):
trigger=button.click,
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -275,6 +281,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[output_video],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -325,6 +332,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -377,6 +385,7 @@ class Stream(WebRTCConnectionMixin):
trigger=button.click,
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -428,6 +437,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -480,6 +490,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
@@ -489,7 +500,9 @@ class Stream(WebRTCConnectionMixin):
outputs=additional_output_components,
)
elif self.modality == "audio-video" and self.mode == "send-receive":
with gr.Blocks() as demo:
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks(css=css) as demo:
gr.HTML(
f"""
<h1 style='text-align: center'>
@@ -506,8 +519,8 @@ class Stream(WebRTCConnectionMixin):
"""
)
with gr.Row():
with gr.Column():
with gr.Group():
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
image = WebRTC(
label="Stream",
rtc_configuration=self.rtc_configuration,
@@ -532,6 +545,7 @@ class Stream(WebRTCConnectionMixin):
outputs=[image],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import fractions
import functools
import inspect
import logging
@@ -11,10 +12,12 @@ import time
import traceback
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import (
Any,
Generator,
Literal,
Tuple,
TypeAlias,
Union,
cast,
@@ -29,7 +32,7 @@ from aiortc import (
VideoStreamTrack,
)
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
from aiortc.mediastreams import MediaStreamError
from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE, MediaStreamError
from numpy import typing as npt
from fastrtc.utils import (
@@ -56,6 +59,13 @@ VideoEmitType = (
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
@dataclass
class VideoStreamHandler:
callable: VideoEventHandler
fps: int = 30
skip_frames: bool = False
class VideoCallback(VideoStreamTrack):
"""
This works for streaming input and output
@@ -70,8 +80,10 @@ class VideoCallback(VideoStreamTrack):
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive",
fps: int = 30,
skip_frames: bool = False,
) -> None:
super().__init__() # don't forget this!
super().__init__()
self.track = track
self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set"
@@ -81,6 +93,11 @@ class VideoCallback(VideoStreamTrack):
self.mode = mode
self.channel_set = asyncio.Event()
self.has_started = False
self.fps = fps
self.frame_ptime = 1.0 / fps
self.skip_frames = skip_frames
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
self.latest_frame = None
def set_channel(self, channel: DataChannel):
self.channel = channel
@@ -127,21 +144,33 @@ class VideoCallback(VideoStreamTrack):
if current_channel.get() != self.channel:
current_channel.set(self.channel)
async def recv(self): # type: ignore
try:
async def accept_input(self):
self.has_started = True
while not self.thread_quit.is_set():
try:
frame = cast(VideoFrame, await self.track.recv())
self.latest_frame = frame
self.frame_queue.put_nowait(frame)
except MediaStreamError:
self.stop()
return
def accept_input_in_background(self):
if not self.has_started:
asyncio.create_task(self.accept_input())
async def recv(self): # type: ignore
self.accept_input_in_background()
try:
frame = await self.frame_queue.get()
if self.skip_frames:
frame = self.latest_frame
await self.wait_for_channel()
frame_array = frame.to_ndarray(format="bgr24")
frame_array = frame.to_ndarray(format="bgr24") # type: ignore
if self.latest_args == "not_set":
return frame
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array, outputs = split_output(self.event_handler(*args))
if (
isinstance(outputs, AdditionalOutputs)
@@ -161,7 +190,7 @@ class VideoCallback(VideoStreamTrack):
pts, time_base = await self.next_timestamp()
new_frame.pts = pts
new_frame.time_base = time_base
self.function_running = False
return new_frame
except Exception as e:
logger.debug("exception %s", e)
@@ -172,6 +201,21 @@ class VideoCallback(VideoStreamTrack):
else:
raise WebRTCError(str(e)) from e
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError
if hasattr(self, "_timestamp"):
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
if wait > 0:
await asyncio.sleep(wait)
else:
self._start = time.time()
self._timestamp = 0
return self._timestamp, VIDEO_TIME_BASE
class StreamHandlerBase(ABC):
def __init__(
@@ -180,11 +224,13 @@ class StreamHandlerBase(ABC):
output_sample_rate: int = 24000,
output_frame_size: int = 960,
input_sample_rate: int = 48000,
fps: int = 30,
) -> None:
self.expected_layout = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.input_sample_rate = input_sample_rate
self.fps = fps
self.latest_args: list[Any] = []
self._resampler = None
self._channel: DataChannel | None = None
@@ -353,10 +399,16 @@ VideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler
HandlerType = StreamHandlerImpl | VideoStreamHandlerImpl | VideoEventHandler | Callable
HandlerType = (
StreamHandlerImpl
| VideoStreamHandlerImpl
| VideoEventHandler
| Callable
| VideoStreamHandler
)
class VideoStreamHandler(VideoCallback):
class VideoStreamHandler_(VideoCallback):
async def process_frames(self):
while not self.thread_quit.is_set():
try:
@@ -576,6 +628,7 @@ class ServerToClientVideo(VideoStreamTrack):
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
fps: int = 30,
) -> None:
super().__init__() # don't forget this!
self.event_handler = event_handler
@@ -584,6 +637,8 @@ class ServerToClientVideo(VideoStreamTrack):
self.generator: Generator[Any, None, Any] | None = None
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.fps = fps
self.frame_ptime = 1.0 / fps
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
@@ -595,6 +650,21 @@ class ServerToClientVideo(VideoStreamTrack):
self.latest_args = list(args)
self.args_set.set()
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError
if hasattr(self, "_timestamp"):
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
if wait > 0:
await asyncio.sleep(wait)
else:
self._start = time.time()
self._timestamp = 0
return self._timestamp, VIDEO_TIME_BASE
async def recv(self): # type: ignore
try:
pts, time_base = await self.next_timestamp()

View File

@@ -26,6 +26,7 @@ from .tracks import (
StreamHandlerBase,
StreamHandlerImpl,
VideoEventHandler,
VideoStreamHandler,
)
from .webrtc_connection_mixin import WebRTCConnectionMixin
@@ -254,6 +255,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
| StreamHandlerImpl
| AudioVideoStreamHandlerImpl
| VideoEventHandler
| VideoStreamHandler
| None
) = None,
inputs: Block | Sequence[Block] | set[Block] | None = None,
@@ -263,6 +265,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
concurrency_id: str | None = None,
time_limit: float | None = None,
trigger: Callable | None = None,
send_input_on: Literal["submit", "change"] = "change",
):
from gradio.blocks import Block
@@ -304,7 +307,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
"In the webrtc stream event, the only output component must be the WebRTC component."
)
for input_component in inputs[1:]: # type: ignore
if hasattr(input_component, "change"):
if hasattr(input_component, "change") and send_input_on == "change":
input_component.change( # type: ignore
self.set_input,
inputs=inputs,
@@ -314,6 +317,13 @@ class WebRTC(Component, WebRTCConnectionMixin):
time_limit=None,
js=js,
)
if hasattr(input_component, "submit") and send_input_on == "submit":
input_component.submit( # type: ignore
self.set_input,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
)
return self.tick( # type: ignore
self.set_input,
inputs=inputs,

View File

@@ -31,7 +31,9 @@ from fastrtc.tracks import (
StreamHandlerBase,
StreamHandlerImpl,
VideoCallback,
VideoEventHandler,
VideoStreamHandler,
VideoStreamHandler_,
)
from fastrtc.utils import (
AdditionalOutputs,
@@ -41,7 +43,7 @@ from fastrtc.utils import (
Track = (
VideoCallback
| VideoStreamHandler
| VideoStreamHandler_
| AudioCallback
| ServerToClientAudio
| ServerToClientVideo
@@ -174,6 +176,11 @@ class WebRTCConnectionMixin:
handler.video_receive = webrtc_error_handler(handler.video_receive) # type: ignore
if hasattr(handler, "video_emit"):
handler.video_emit = webrtc_error_handler(handler.video_emit) # type: ignore
elif isinstance(self.event_handler, VideoStreamHandler):
self.event_handler.callable = cast(
VideoEventHandler, webrtc_error_handler(self.event_handler.callable)
)
handler = self.event_handler
else:
handler = webrtc_error_handler(cast(Callable, self.event_handler))
@@ -208,17 +215,25 @@ class WebRTCConnectionMixin:
handler = self.handlers[body["webrtc_id"]]
if self.modality == "video" and track.kind == "video":
args = {}
handler_ = handler
if isinstance(handler, VideoStreamHandler):
handler_ = handler.callable
args["fps"] = handler.fps
args["skip_frames"] = handler.skip_frames
cb = VideoCallback(
relay.subscribe(track),
event_handler=cast(Callable, handler),
event_handler=cast(Callable, handler_),
set_additional_outputs=set_outputs,
mode=cast(Literal["send", "send-receive"], self.mode),
**args,
)
elif self.modality == "audio-video" and track.kind == "video":
cb = VideoStreamHandler(
cb = VideoStreamHandler_(
relay.subscribe(track),
event_handler=handler, # type: ignore
set_additional_outputs=set_outputs,
fps=cast(StreamHandlerImpl, handler).fps,
)
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
eh = cast(StreamHandlerImpl, handler)
@@ -245,10 +260,17 @@ class WebRTCConnectionMixin:
if self.mode == "receive":
if self.modality == "video":
cb = ServerToClientVideo(
cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
)
if isinstance(self.event_handler, VideoStreamHandler):
cb = ServerToClientVideo(
cast(Callable, self.event_handler.callable),
set_additional_outputs=set_outputs,
fps=self.event_handler.fps,
)
else:
cb = ServerToClientVideo(
cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
)
elif self.modality == "audio":
cb = ServerToClientAudio(
cast(Callable, self.event_handler),

View File

@@ -55,3 +55,73 @@ and set the `mode="receive"` in the `WebRTC` component.
mode="receive"
)
```
## Skipping Frames
If your event handler is not quite real-time yet, then the output feed will look very laggy.
To fix this, you can set the `skip_frames` parameter to `True`. This will skip the frames that are received while the event handler is still running.
``` py title="Skipping Frames"
import time
import numpy as np
from fastrtc import Stream, VideoStreamHandler
def process_image(image):
time.sleep(
0.2
) # Simulating 200ms processing time per frame; input arrives faster (30 FPS).
return np.flip(image, axis=0)
stream = Stream(
handler=VideoStreamHandler(process_image, skip_frames=True),
modality="video",
mode="send-receive",
)
stream.ui.launch()
```
## Setting the Output Frame Rate
You can set the output frame rate by setting the `fps` parameter in the `VideoStreamHandler`.
``` py title="Setting the Output Frame Rate"
def generation():
url = "https://github.com/user-attachments/assets/9636dc97-4fee-46bb-abb8-b92e69c08c71"
cap = cv2.VideoCapture(url)
iterating = True
# FPS calculation variables
frame_count = 0
start_time = time.time()
fps = 0
while iterating:
iterating, frame = cap.read()
# Calculate and print FPS
frame_count += 1
elapsed_time = time.time() - start_time
if elapsed_time >= 1.0: # Update FPS every second
fps = frame_count / elapsed_time
yield frame, AdditionalOutputs(fps)
frame_count = 0
start_time = time.time()
else:
yield frame
stream = Stream(
handler=VideoStreamHandler(generation, fps=60),
modality="video",
mode="receive",
additional_outputs=[gr.Number(label="FPS")],
additional_outputs_handler=lambda prev, cur: cur,
)
stream.ui.launch()
```