mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 09:29:23 +08:00
@@ -21,6 +21,7 @@ from .tracks import (
|
||||
AudioVideoStreamHandler,
|
||||
StreamHandler,
|
||||
VideoEmitType,
|
||||
VideoStreamHandler,
|
||||
)
|
||||
from .utils import (
|
||||
AdditionalOutputs,
|
||||
@@ -73,4 +74,5 @@ __all__ = [
|
||||
"PauseDetectionModel",
|
||||
"get_silero_model",
|
||||
"SileroVadOptions",
|
||||
"VideoStreamHandler",
|
||||
]
|
||||
|
||||
@@ -46,6 +46,11 @@ class UIArgs(TypedDict):
|
||||
"""Color of the pulse animation. Default is var(--color-accent) of the demo theme."""
|
||||
icon_radius: NotRequired[int]
|
||||
"""Border radius of the icon button expressed as a percentage of the button size. Default is 50%."""
|
||||
send_input_on: NotRequired[Literal["submit", "change"]]
|
||||
"""When to send the input to the handler. Default is "change".
|
||||
If "submit", the input will be sent when the submit event is triggered by the user.
|
||||
If "change", the input will be sent whenever the user changes the input value.
|
||||
"""
|
||||
|
||||
|
||||
class Stream(WebRTCConnectionMixin):
|
||||
@@ -229,6 +234,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
trigger=button.click,
|
||||
time_limit=self.time_limit,
|
||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||
send_input_on=ui_args.get("send_input_on", "change"),
|
||||
)
|
||||
if additional_output_components:
|
||||
assert self.additional_outputs_handler
|
||||
@@ -275,6 +281,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
outputs=[output_video],
|
||||
time_limit=self.time_limit,
|
||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||
send_input_on=ui_args.get("send_input_on", "change"),
|
||||
)
|
||||
if additional_output_components:
|
||||
assert self.additional_outputs_handler
|
||||
@@ -325,6 +332,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
outputs=[image],
|
||||
time_limit=self.time_limit,
|
||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||
send_input_on=ui_args.get("send_input_on", "change"),
|
||||
)
|
||||
if additional_output_components:
|
||||
assert self.additional_outputs_handler
|
||||
@@ -377,6 +385,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
trigger=button.click,
|
||||
time_limit=self.time_limit,
|
||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||
send_input_on=ui_args.get("send_input_on", "change"),
|
||||
)
|
||||
if additional_output_components:
|
||||
assert self.additional_outputs_handler
|
||||
@@ -428,6 +437,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
outputs=[image],
|
||||
time_limit=self.time_limit,
|
||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||
send_input_on=ui_args.get("send_input_on", "change"),
|
||||
)
|
||||
if additional_output_components:
|
||||
assert self.additional_outputs_handler
|
||||
@@ -480,6 +490,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
outputs=[image],
|
||||
time_limit=self.time_limit,
|
||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||
send_input_on=ui_args.get("send_input_on", "change"),
|
||||
)
|
||||
if additional_output_components:
|
||||
assert self.additional_outputs_handler
|
||||
@@ -489,7 +500,9 @@ class Stream(WebRTCConnectionMixin):
|
||||
outputs=additional_output_components,
|
||||
)
|
||||
elif self.modality == "audio-video" and self.mode == "send-receive":
|
||||
with gr.Blocks() as demo:
|
||||
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
|
||||
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
|
||||
with gr.Blocks(css=css) as demo:
|
||||
gr.HTML(
|
||||
f"""
|
||||
<h1 style='text-align: center'>
|
||||
@@ -506,8 +519,8 @@ class Stream(WebRTCConnectionMixin):
|
||||
"""
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
with gr.Column(elem_classes=["my-column"]):
|
||||
with gr.Group(elem_classes=["my-group"]):
|
||||
image = WebRTC(
|
||||
label="Stream",
|
||||
rtc_configuration=self.rtc_configuration,
|
||||
@@ -532,6 +545,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
outputs=[image],
|
||||
time_limit=self.time_limit,
|
||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||
send_input_on=ui_args.get("send_input_on", "change"),
|
||||
)
|
||||
if additional_output_components:
|
||||
assert self.additional_outputs_handler
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import fractions
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
@@ -11,10 +12,12 @@ import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Generator,
|
||||
Literal,
|
||||
Tuple,
|
||||
TypeAlias,
|
||||
Union,
|
||||
cast,
|
||||
@@ -29,7 +32,7 @@ from aiortc import (
|
||||
VideoStreamTrack,
|
||||
)
|
||||
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
|
||||
from aiortc.mediastreams import MediaStreamError
|
||||
from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE, MediaStreamError
|
||||
from numpy import typing as npt
|
||||
|
||||
from fastrtc.utils import (
|
||||
@@ -56,6 +59,13 @@ VideoEmitType = (
|
||||
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoStreamHandler:
|
||||
callable: VideoEventHandler
|
||||
fps: int = 30
|
||||
skip_frames: bool = False
|
||||
|
||||
|
||||
class VideoCallback(VideoStreamTrack):
|
||||
"""
|
||||
This works for streaming input and output
|
||||
@@ -70,8 +80,10 @@ class VideoCallback(VideoStreamTrack):
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
mode: Literal["send-receive", "send"] = "send-receive",
|
||||
fps: int = 30,
|
||||
skip_frames: bool = False,
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
@@ -81,6 +93,11 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.mode = mode
|
||||
self.channel_set = asyncio.Event()
|
||||
self.has_started = False
|
||||
self.fps = fps
|
||||
self.frame_ptime = 1.0 / fps
|
||||
self.skip_frames = skip_frames
|
||||
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
|
||||
self.latest_frame = None
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
@@ -127,21 +144,33 @@ class VideoCallback(VideoStreamTrack):
|
||||
if current_channel.get() != self.channel:
|
||||
current_channel.set(self.channel)
|
||||
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
async def accept_input(self):
|
||||
self.has_started = True
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
frame = cast(VideoFrame, await self.track.recv())
|
||||
self.latest_frame = frame
|
||||
self.frame_queue.put_nowait(frame)
|
||||
except MediaStreamError:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
def accept_input_in_background(self):
|
||||
if not self.has_started:
|
||||
asyncio.create_task(self.accept_input())
|
||||
|
||||
async def recv(self): # type: ignore
|
||||
self.accept_input_in_background()
|
||||
try:
|
||||
frame = await self.frame_queue.get()
|
||||
if self.skip_frames:
|
||||
frame = self.latest_frame
|
||||
await self.wait_for_channel()
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
frame_array = frame.to_ndarray(format="bgr24") # type: ignore
|
||||
if self.latest_args == "not_set":
|
||||
return frame
|
||||
|
||||
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
||||
|
||||
array, outputs = split_output(self.event_handler(*args))
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
@@ -161,7 +190,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
pts, time_base = await self.next_timestamp()
|
||||
new_frame.pts = pts
|
||||
new_frame.time_base = time_base
|
||||
|
||||
self.function_running = False
|
||||
return new_frame
|
||||
except Exception as e:
|
||||
logger.debug("exception %s", e)
|
||||
@@ -172,6 +201,21 @@ class VideoCallback(VideoStreamTrack):
|
||||
else:
|
||||
raise WebRTCError(str(e)) from e
|
||||
|
||||
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
|
||||
"""Override to control frame rate"""
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
if hasattr(self, "_timestamp"):
|
||||
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
|
||||
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
self._start = time.time()
|
||||
self._timestamp = 0
|
||||
return self._timestamp, VIDEO_TIME_BASE
|
||||
|
||||
|
||||
class StreamHandlerBase(ABC):
|
||||
def __init__(
|
||||
@@ -180,11 +224,13 @@ class StreamHandlerBase(ABC):
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 960,
|
||||
input_sample_rate: int = 48000,
|
||||
fps: int = 30,
|
||||
) -> None:
|
||||
self.expected_layout = expected_layout
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_frame_size = output_frame_size
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.fps = fps
|
||||
self.latest_args: list[Any] = []
|
||||
self._resampler = None
|
||||
self._channel: DataChannel | None = None
|
||||
@@ -353,10 +399,16 @@ VideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
|
||||
AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
|
||||
AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler
|
||||
|
||||
HandlerType = StreamHandlerImpl | VideoStreamHandlerImpl | VideoEventHandler | Callable
|
||||
HandlerType = (
|
||||
StreamHandlerImpl
|
||||
| VideoStreamHandlerImpl
|
||||
| VideoEventHandler
|
||||
| Callable
|
||||
| VideoStreamHandler
|
||||
)
|
||||
|
||||
|
||||
class VideoStreamHandler(VideoCallback):
|
||||
class VideoStreamHandler_(VideoCallback):
|
||||
async def process_frames(self):
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
@@ -576,6 +628,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
event_handler: Callable,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
fps: int = 30,
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
self.event_handler = event_handler
|
||||
@@ -584,6 +637,8 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.fps = fps
|
||||
self.frame_ptime = 1.0 / fps
|
||||
|
||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||
@@ -595,6 +650,21 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.latest_args = list(args)
|
||||
self.args_set.set()
|
||||
|
||||
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
|
||||
"""Override to control frame rate"""
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
if hasattr(self, "_timestamp"):
|
||||
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
|
||||
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
self._start = time.time()
|
||||
self._timestamp = 0
|
||||
return self._timestamp, VIDEO_TIME_BASE
|
||||
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
pts, time_base = await self.next_timestamp()
|
||||
|
||||
@@ -26,6 +26,7 @@ from .tracks import (
|
||||
StreamHandlerBase,
|
||||
StreamHandlerImpl,
|
||||
VideoEventHandler,
|
||||
VideoStreamHandler,
|
||||
)
|
||||
from .webrtc_connection_mixin import WebRTCConnectionMixin
|
||||
|
||||
@@ -254,6 +255,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
| StreamHandlerImpl
|
||||
| AudioVideoStreamHandlerImpl
|
||||
| VideoEventHandler
|
||||
| VideoStreamHandler
|
||||
| None
|
||||
) = None,
|
||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
@@ -263,6 +265,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
concurrency_id: str | None = None,
|
||||
time_limit: float | None = None,
|
||||
trigger: Callable | None = None,
|
||||
send_input_on: Literal["submit", "change"] = "change",
|
||||
):
|
||||
from gradio.blocks import Block
|
||||
|
||||
@@ -304,7 +307,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
"In the webrtc stream event, the only output component must be the WebRTC component."
|
||||
)
|
||||
for input_component in inputs[1:]: # type: ignore
|
||||
if hasattr(input_component, "change"):
|
||||
if hasattr(input_component, "change") and send_input_on == "change":
|
||||
input_component.change( # type: ignore
|
||||
self.set_input,
|
||||
inputs=inputs,
|
||||
@@ -314,6 +317,13 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
time_limit=None,
|
||||
js=js,
|
||||
)
|
||||
if hasattr(input_component, "submit") and send_input_on == "submit":
|
||||
input_component.submit( # type: ignore
|
||||
self.set_input,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
)
|
||||
return self.tick( # type: ignore
|
||||
self.set_input,
|
||||
inputs=inputs,
|
||||
|
||||
@@ -31,7 +31,9 @@ from fastrtc.tracks import (
|
||||
StreamHandlerBase,
|
||||
StreamHandlerImpl,
|
||||
VideoCallback,
|
||||
VideoEventHandler,
|
||||
VideoStreamHandler,
|
||||
VideoStreamHandler_,
|
||||
)
|
||||
from fastrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
@@ -41,7 +43,7 @@ from fastrtc.utils import (
|
||||
|
||||
Track = (
|
||||
VideoCallback
|
||||
| VideoStreamHandler
|
||||
| VideoStreamHandler_
|
||||
| AudioCallback
|
||||
| ServerToClientAudio
|
||||
| ServerToClientVideo
|
||||
@@ -174,6 +176,11 @@ class WebRTCConnectionMixin:
|
||||
handler.video_receive = webrtc_error_handler(handler.video_receive) # type: ignore
|
||||
if hasattr(handler, "video_emit"):
|
||||
handler.video_emit = webrtc_error_handler(handler.video_emit) # type: ignore
|
||||
elif isinstance(self.event_handler, VideoStreamHandler):
|
||||
self.event_handler.callable = cast(
|
||||
VideoEventHandler, webrtc_error_handler(self.event_handler.callable)
|
||||
)
|
||||
handler = self.event_handler
|
||||
else:
|
||||
handler = webrtc_error_handler(cast(Callable, self.event_handler))
|
||||
|
||||
@@ -208,17 +215,25 @@ class WebRTCConnectionMixin:
|
||||
handler = self.handlers[body["webrtc_id"]]
|
||||
|
||||
if self.modality == "video" and track.kind == "video":
|
||||
args = {}
|
||||
handler_ = handler
|
||||
if isinstance(handler, VideoStreamHandler):
|
||||
handler_ = handler.callable
|
||||
args["fps"] = handler.fps
|
||||
args["skip_frames"] = handler.skip_frames
|
||||
cb = VideoCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(Callable, handler),
|
||||
event_handler=cast(Callable, handler_),
|
||||
set_additional_outputs=set_outputs,
|
||||
mode=cast(Literal["send", "send-receive"], self.mode),
|
||||
**args,
|
||||
)
|
||||
elif self.modality == "audio-video" and track.kind == "video":
|
||||
cb = VideoStreamHandler(
|
||||
cb = VideoStreamHandler_(
|
||||
relay.subscribe(track),
|
||||
event_handler=handler, # type: ignore
|
||||
set_additional_outputs=set_outputs,
|
||||
fps=cast(StreamHandlerImpl, handler).fps,
|
||||
)
|
||||
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
|
||||
eh = cast(StreamHandlerImpl, handler)
|
||||
@@ -245,10 +260,17 @@ class WebRTCConnectionMixin:
|
||||
|
||||
if self.mode == "receive":
|
||||
if self.modality == "video":
|
||||
cb = ServerToClientVideo(
|
||||
cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
if isinstance(self.event_handler, VideoStreamHandler):
|
||||
cb = ServerToClientVideo(
|
||||
cast(Callable, self.event_handler.callable),
|
||||
set_additional_outputs=set_outputs,
|
||||
fps=self.event_handler.fps,
|
||||
)
|
||||
else:
|
||||
cb = ServerToClientVideo(
|
||||
cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
elif self.modality == "audio":
|
||||
cb = ServerToClientAudio(
|
||||
cast(Callable, self.event_handler),
|
||||
|
||||
@@ -55,3 +55,73 @@ and set the `mode="receive"` in the `WebRTC` component.
|
||||
mode="receive"
|
||||
)
|
||||
```
|
||||
|
||||
## Skipping Frames
|
||||
|
||||
If your event handler is not quite real-time yet, then the output feed will look very laggy.
|
||||
|
||||
To fix this, you can set the `skip_frames` parameter to `True`. This will skip the frames that are received while the event handler is still running.
|
||||
|
||||
``` py title="Skipping Frames"
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from fastrtc import Stream, VideoStreamHandler
|
||||
|
||||
|
||||
def process_image(image):
|
||||
time.sleep(
|
||||
0.2
|
||||
) # Simulating 200ms processing time per frame; input arrives faster (30 FPS).
|
||||
return np.flip(image, axis=0)
|
||||
|
||||
|
||||
stream = Stream(
|
||||
handler=VideoStreamHandler(process_image, skip_frames=True),
|
||||
modality="video",
|
||||
mode="send-receive",
|
||||
)
|
||||
|
||||
stream.ui.launch()
|
||||
```
|
||||
|
||||
## Setting the Output Frame Rate
|
||||
|
||||
You can set the output frame rate by setting the `fps` parameter in the `VideoStreamHandler`.
|
||||
|
||||
``` py title="Setting the Output Frame Rate"
|
||||
def generation():
|
||||
url = "https://github.com/user-attachments/assets/9636dc97-4fee-46bb-abb8-b92e69c08c71"
|
||||
cap = cv2.VideoCapture(url)
|
||||
iterating = True
|
||||
|
||||
# FPS calculation variables
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
fps = 0
|
||||
|
||||
while iterating:
|
||||
iterating, frame = cap.read()
|
||||
|
||||
# Calculate and print FPS
|
||||
frame_count += 1
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= 1.0: # Update FPS every second
|
||||
fps = frame_count / elapsed_time
|
||||
yield frame, AdditionalOutputs(fps)
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
else:
|
||||
yield frame
|
||||
|
||||
|
||||
stream = Stream(
|
||||
handler=VideoStreamHandler(generation, fps=60),
|
||||
modality="video",
|
||||
mode="receive",
|
||||
additional_outputs=[gr.Number(label="FPS")],
|
||||
additional_outputs_handler=lambda prev, cur: cur,
|
||||
)
|
||||
|
||||
stream.ui.launch()
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user