mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import fractions
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
@@ -11,10 +12,12 @@ import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Generator,
|
||||
Literal,
|
||||
Tuple,
|
||||
TypeAlias,
|
||||
Union,
|
||||
cast,
|
||||
@@ -29,7 +32,7 @@ from aiortc import (
|
||||
VideoStreamTrack,
|
||||
)
|
||||
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
|
||||
from aiortc.mediastreams import MediaStreamError
|
||||
from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE, MediaStreamError
|
||||
from numpy import typing as npt
|
||||
|
||||
from fastrtc.utils import (
|
||||
@@ -56,6 +59,13 @@ VideoEmitType = (
|
||||
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoStreamHandler:
|
||||
callable: VideoEventHandler
|
||||
fps: int = 30
|
||||
skip_frames: bool = False
|
||||
|
||||
|
||||
class VideoCallback(VideoStreamTrack):
|
||||
"""
|
||||
This works for streaming input and output
|
||||
@@ -70,8 +80,10 @@ class VideoCallback(VideoStreamTrack):
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
mode: Literal["send-receive", "send"] = "send-receive",
|
||||
fps: int = 30,
|
||||
skip_frames: bool = False,
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
@@ -81,6 +93,11 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.mode = mode
|
||||
self.channel_set = asyncio.Event()
|
||||
self.has_started = False
|
||||
self.fps = fps
|
||||
self.frame_ptime = 1.0 / fps
|
||||
self.skip_frames = skip_frames
|
||||
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
|
||||
self.latest_frame = None
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
@@ -127,21 +144,33 @@ class VideoCallback(VideoStreamTrack):
|
||||
if current_channel.get() != self.channel:
|
||||
current_channel.set(self.channel)
|
||||
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
async def accept_input(self):
|
||||
self.has_started = True
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
frame = cast(VideoFrame, await self.track.recv())
|
||||
self.latest_frame = frame
|
||||
self.frame_queue.put_nowait(frame)
|
||||
except MediaStreamError:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
def accept_input_in_background(self):
|
||||
if not self.has_started:
|
||||
asyncio.create_task(self.accept_input())
|
||||
|
||||
async def recv(self): # type: ignore
|
||||
self.accept_input_in_background()
|
||||
try:
|
||||
frame = await self.frame_queue.get()
|
||||
if self.skip_frames:
|
||||
frame = self.latest_frame
|
||||
await self.wait_for_channel()
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
frame_array = frame.to_ndarray(format="bgr24") # type: ignore
|
||||
if self.latest_args == "not_set":
|
||||
return frame
|
||||
|
||||
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
||||
|
||||
array, outputs = split_output(self.event_handler(*args))
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
@@ -161,7 +190,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
pts, time_base = await self.next_timestamp()
|
||||
new_frame.pts = pts
|
||||
new_frame.time_base = time_base
|
||||
|
||||
self.function_running = False
|
||||
return new_frame
|
||||
except Exception as e:
|
||||
logger.debug("exception %s", e)
|
||||
@@ -172,6 +201,21 @@ class VideoCallback(VideoStreamTrack):
|
||||
else:
|
||||
raise WebRTCError(str(e)) from e
|
||||
|
||||
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
|
||||
"""Override to control frame rate"""
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
if hasattr(self, "_timestamp"):
|
||||
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
|
||||
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
self._start = time.time()
|
||||
self._timestamp = 0
|
||||
return self._timestamp, VIDEO_TIME_BASE
|
||||
|
||||
|
||||
class StreamHandlerBase(ABC):
|
||||
def __init__(
|
||||
@@ -180,11 +224,13 @@ class StreamHandlerBase(ABC):
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 960,
|
||||
input_sample_rate: int = 48000,
|
||||
fps: int = 30,
|
||||
) -> None:
|
||||
self.expected_layout = expected_layout
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_frame_size = output_frame_size
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.fps = fps
|
||||
self.latest_args: list[Any] = []
|
||||
self._resampler = None
|
||||
self._channel: DataChannel | None = None
|
||||
@@ -353,10 +399,16 @@ VideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
|
||||
AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
|
||||
AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler
|
||||
|
||||
HandlerType = StreamHandlerImpl | VideoStreamHandlerImpl | VideoEventHandler | Callable
|
||||
HandlerType = (
|
||||
StreamHandlerImpl
|
||||
| VideoStreamHandlerImpl
|
||||
| VideoEventHandler
|
||||
| Callable
|
||||
| VideoStreamHandler
|
||||
)
|
||||
|
||||
|
||||
class VideoStreamHandler(VideoCallback):
|
||||
class VideoStreamHandler_(VideoCallback):
|
||||
async def process_frames(self):
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
@@ -576,6 +628,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
event_handler: Callable,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
fps: int = 30,
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
self.event_handler = event_handler
|
||||
@@ -584,6 +637,8 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.fps = fps
|
||||
self.frame_ptime = 1.0 / fps
|
||||
|
||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||
@@ -595,6 +650,21 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.latest_args = list(args)
|
||||
self.args_set.set()
|
||||
|
||||
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
|
||||
"""Override to control frame rate"""
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
if hasattr(self, "_timestamp"):
|
||||
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
|
||||
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
self._start = time.time()
|
||||
self._timestamp = 0
|
||||
return self._timestamp, VIDEO_TIME_BASE
|
||||
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
pts, time_base = await self.next_timestamp()
|
||||
|
||||
Reference in New Issue
Block a user