Some Video Fixes (#200)

* FPS control:

* add code

* Add code
This commit is contained in:
Freddy Boulton
2025-03-20 20:45:46 -04:00
committed by GitHub
parent bce7cb95a6
commit 3fed4cb2ad
6 changed files with 208 additions and 20 deletions

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import fractions
import functools
import inspect
import logging
@@ -11,10 +12,12 @@ import time
import traceback
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import (
Any,
Generator,
Literal,
Tuple,
TypeAlias,
Union,
cast,
@@ -29,7 +32,7 @@ from aiortc import (
VideoStreamTrack,
)
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
from aiortc.mediastreams import MediaStreamError
from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE, MediaStreamError
from numpy import typing as npt
from fastrtc.utils import (
@@ -56,6 +59,13 @@ VideoEmitType = (
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
@dataclass
class VideoStreamHandler:
callable: VideoEventHandler
fps: int = 30
skip_frames: bool = False
class VideoCallback(VideoStreamTrack):
"""
This works for streaming input and output
@@ -70,8 +80,10 @@ class VideoCallback(VideoStreamTrack):
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive",
fps: int = 30,
skip_frames: bool = False,
) -> None:
super().__init__() # don't forget this!
super().__init__()
self.track = track
self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set"
@@ -81,6 +93,11 @@ class VideoCallback(VideoStreamTrack):
self.mode = mode
self.channel_set = asyncio.Event()
self.has_started = False
self.fps = fps
self.frame_ptime = 1.0 / fps
self.skip_frames = skip_frames
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
self.latest_frame = None
def set_channel(self, channel: DataChannel):
self.channel = channel
@@ -127,21 +144,33 @@ class VideoCallback(VideoStreamTrack):
if current_channel.get() != self.channel:
current_channel.set(self.channel)
async def recv(self): # type: ignore
try:
async def accept_input(self):
self.has_started = True
while not self.thread_quit.is_set():
try:
frame = cast(VideoFrame, await self.track.recv())
self.latest_frame = frame
self.frame_queue.put_nowait(frame)
except MediaStreamError:
self.stop()
return
def accept_input_in_background(self):
if not self.has_started:
asyncio.create_task(self.accept_input())
async def recv(self): # type: ignore
self.accept_input_in_background()
try:
frame = await self.frame_queue.get()
if self.skip_frames:
frame = self.latest_frame
await self.wait_for_channel()
frame_array = frame.to_ndarray(format="bgr24")
frame_array = frame.to_ndarray(format="bgr24") # type: ignore
if self.latest_args == "not_set":
return frame
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array, outputs = split_output(self.event_handler(*args))
if (
isinstance(outputs, AdditionalOutputs)
@@ -161,7 +190,7 @@ class VideoCallback(VideoStreamTrack):
pts, time_base = await self.next_timestamp()
new_frame.pts = pts
new_frame.time_base = time_base
self.function_running = False
return new_frame
except Exception as e:
logger.debug("exception %s", e)
@@ -172,6 +201,21 @@ class VideoCallback(VideoStreamTrack):
else:
raise WebRTCError(str(e)) from e
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError
if hasattr(self, "_timestamp"):
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
if wait > 0:
await asyncio.sleep(wait)
else:
self._start = time.time()
self._timestamp = 0
return self._timestamp, VIDEO_TIME_BASE
class StreamHandlerBase(ABC):
def __init__(
@@ -180,11 +224,13 @@ class StreamHandlerBase(ABC):
output_sample_rate: int = 24000,
output_frame_size: int = 960,
input_sample_rate: int = 48000,
fps: int = 30,
) -> None:
self.expected_layout = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.input_sample_rate = input_sample_rate
self.fps = fps
self.latest_args: list[Any] = []
self._resampler = None
self._channel: DataChannel | None = None
@@ -353,10 +399,16 @@ VideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler
HandlerType = StreamHandlerImpl | VideoStreamHandlerImpl | VideoEventHandler | Callable
HandlerType = (
StreamHandlerImpl
| VideoStreamHandlerImpl
| VideoEventHandler
| Callable
| VideoStreamHandler
)
class VideoStreamHandler(VideoCallback):
class VideoStreamHandler_(VideoCallback):
async def process_frames(self):
while not self.thread_quit.is_set():
try:
@@ -576,6 +628,7 @@ class ServerToClientVideo(VideoStreamTrack):
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
fps: int = 30,
) -> None:
super().__init__() # don't forget this!
self.event_handler = event_handler
@@ -584,6 +637,8 @@ class ServerToClientVideo(VideoStreamTrack):
self.generator: Generator[Any, None, Any] | None = None
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.fps = fps
self.frame_ptime = 1.0 / fps
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
@@ -595,6 +650,21 @@ class ServerToClientVideo(VideoStreamTrack):
self.latest_args = list(args)
self.args_set.set()
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError
if hasattr(self, "_timestamp"):
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
if wait > 0:
await asyncio.sleep(wait)
else:
self._start = time.time()
self._timestamp = 0
return self._timestamp, VIDEO_TIME_BASE
async def recv(self): # type: ignore
try:
pts, time_base = await self.next_timestamp()