mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
@@ -21,6 +21,7 @@ from .tracks import (
|
|||||||
AudioVideoStreamHandler,
|
AudioVideoStreamHandler,
|
||||||
StreamHandler,
|
StreamHandler,
|
||||||
VideoEmitType,
|
VideoEmitType,
|
||||||
|
VideoStreamHandler,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
@@ -73,4 +74,5 @@ __all__ = [
|
|||||||
"PauseDetectionModel",
|
"PauseDetectionModel",
|
||||||
"get_silero_model",
|
"get_silero_model",
|
||||||
"SileroVadOptions",
|
"SileroVadOptions",
|
||||||
|
"VideoStreamHandler",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -46,6 +46,11 @@ class UIArgs(TypedDict):
|
|||||||
"""Color of the pulse animation. Default is var(--color-accent) of the demo theme."""
|
"""Color of the pulse animation. Default is var(--color-accent) of the demo theme."""
|
||||||
icon_radius: NotRequired[int]
|
icon_radius: NotRequired[int]
|
||||||
"""Border radius of the icon button expressed as a percentage of the button size. Default is 50%."""
|
"""Border radius of the icon button expressed as a percentage of the button size. Default is 50%."""
|
||||||
|
send_input_on: NotRequired[Literal["submit", "change"]]
|
||||||
|
"""When to send the input to the handler. Default is "change".
|
||||||
|
If "submit", the input will be sent when the submit event is triggered by the user.
|
||||||
|
If "change", the input will be sent whenever the user changes the input value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Stream(WebRTCConnectionMixin):
|
class Stream(WebRTCConnectionMixin):
|
||||||
@@ -229,6 +234,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
trigger=button.click,
|
trigger=button.click,
|
||||||
time_limit=self.time_limit,
|
time_limit=self.time_limit,
|
||||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||||
|
send_input_on=ui_args.get("send_input_on", "change"),
|
||||||
)
|
)
|
||||||
if additional_output_components:
|
if additional_output_components:
|
||||||
assert self.additional_outputs_handler
|
assert self.additional_outputs_handler
|
||||||
@@ -275,6 +281,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
outputs=[output_video],
|
outputs=[output_video],
|
||||||
time_limit=self.time_limit,
|
time_limit=self.time_limit,
|
||||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||||
|
send_input_on=ui_args.get("send_input_on", "change"),
|
||||||
)
|
)
|
||||||
if additional_output_components:
|
if additional_output_components:
|
||||||
assert self.additional_outputs_handler
|
assert self.additional_outputs_handler
|
||||||
@@ -325,6 +332,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
outputs=[image],
|
outputs=[image],
|
||||||
time_limit=self.time_limit,
|
time_limit=self.time_limit,
|
||||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||||
|
send_input_on=ui_args.get("send_input_on", "change"),
|
||||||
)
|
)
|
||||||
if additional_output_components:
|
if additional_output_components:
|
||||||
assert self.additional_outputs_handler
|
assert self.additional_outputs_handler
|
||||||
@@ -377,6 +385,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
trigger=button.click,
|
trigger=button.click,
|
||||||
time_limit=self.time_limit,
|
time_limit=self.time_limit,
|
||||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||||
|
send_input_on=ui_args.get("send_input_on", "change"),
|
||||||
)
|
)
|
||||||
if additional_output_components:
|
if additional_output_components:
|
||||||
assert self.additional_outputs_handler
|
assert self.additional_outputs_handler
|
||||||
@@ -428,6 +437,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
outputs=[image],
|
outputs=[image],
|
||||||
time_limit=self.time_limit,
|
time_limit=self.time_limit,
|
||||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||||
|
send_input_on=ui_args.get("send_input_on", "change"),
|
||||||
)
|
)
|
||||||
if additional_output_components:
|
if additional_output_components:
|
||||||
assert self.additional_outputs_handler
|
assert self.additional_outputs_handler
|
||||||
@@ -480,6 +490,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
outputs=[image],
|
outputs=[image],
|
||||||
time_limit=self.time_limit,
|
time_limit=self.time_limit,
|
||||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||||
|
send_input_on=ui_args.get("send_input_on", "change"),
|
||||||
)
|
)
|
||||||
if additional_output_components:
|
if additional_output_components:
|
||||||
assert self.additional_outputs_handler
|
assert self.additional_outputs_handler
|
||||||
@@ -489,7 +500,9 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
outputs=additional_output_components,
|
outputs=additional_output_components,
|
||||||
)
|
)
|
||||||
elif self.modality == "audio-video" and self.mode == "send-receive":
|
elif self.modality == "audio-video" and self.mode == "send-receive":
|
||||||
with gr.Blocks() as demo:
|
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
|
||||||
|
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
|
||||||
|
with gr.Blocks(css=css) as demo:
|
||||||
gr.HTML(
|
gr.HTML(
|
||||||
f"""
|
f"""
|
||||||
<h1 style='text-align: center'>
|
<h1 style='text-align: center'>
|
||||||
@@ -506,8 +519,8 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column(elem_classes=["my-column"]):
|
||||||
with gr.Group():
|
with gr.Group(elem_classes=["my-group"]):
|
||||||
image = WebRTC(
|
image = WebRTC(
|
||||||
label="Stream",
|
label="Stream",
|
||||||
rtc_configuration=self.rtc_configuration,
|
rtc_configuration=self.rtc_configuration,
|
||||||
@@ -532,6 +545,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
outputs=[image],
|
outputs=[image],
|
||||||
time_limit=self.time_limit,
|
time_limit=self.time_limit,
|
||||||
concurrency_limit=self.concurrency_limit, # type: ignore
|
concurrency_limit=self.concurrency_limit, # type: ignore
|
||||||
|
send_input_on=ui_args.get("send_input_on", "change"),
|
||||||
)
|
)
|
||||||
if additional_output_components:
|
if additional_output_components:
|
||||||
assert self.additional_outputs_handler
|
assert self.additional_outputs_handler
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import fractions
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
@@ -11,10 +12,12 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Generator,
|
Generator,
|
||||||
Literal,
|
Literal,
|
||||||
|
Tuple,
|
||||||
TypeAlias,
|
TypeAlias,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@@ -29,7 +32,7 @@ from aiortc import (
|
|||||||
VideoStreamTrack,
|
VideoStreamTrack,
|
||||||
)
|
)
|
||||||
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
|
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
|
||||||
from aiortc.mediastreams import MediaStreamError
|
from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE, MediaStreamError
|
||||||
from numpy import typing as npt
|
from numpy import typing as npt
|
||||||
|
|
||||||
from fastrtc.utils import (
|
from fastrtc.utils import (
|
||||||
@@ -56,6 +59,13 @@ VideoEmitType = (
|
|||||||
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
|
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoStreamHandler:
|
||||||
|
callable: VideoEventHandler
|
||||||
|
fps: int = 30
|
||||||
|
skip_frames: bool = False
|
||||||
|
|
||||||
|
|
||||||
class VideoCallback(VideoStreamTrack):
|
class VideoCallback(VideoStreamTrack):
|
||||||
"""
|
"""
|
||||||
This works for streaming input and output
|
This works for streaming input and output
|
||||||
@@ -70,8 +80,10 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
channel: DataChannel | None = None,
|
channel: DataChannel | None = None,
|
||||||
set_additional_outputs: Callable | None = None,
|
set_additional_outputs: Callable | None = None,
|
||||||
mode: Literal["send-receive", "send"] = "send-receive",
|
mode: Literal["send-receive", "send"] = "send-receive",
|
||||||
|
fps: int = 30,
|
||||||
|
skip_frames: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__() # don't forget this!
|
super().__init__()
|
||||||
self.track = track
|
self.track = track
|
||||||
self.event_handler = event_handler
|
self.event_handler = event_handler
|
||||||
self.latest_args: str | list[Any] = "not_set"
|
self.latest_args: str | list[Any] = "not_set"
|
||||||
@@ -81,6 +93,11 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.channel_set = asyncio.Event()
|
self.channel_set = asyncio.Event()
|
||||||
self.has_started = False
|
self.has_started = False
|
||||||
|
self.fps = fps
|
||||||
|
self.frame_ptime = 1.0 / fps
|
||||||
|
self.skip_frames = skip_frames
|
||||||
|
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
|
||||||
|
self.latest_frame = None
|
||||||
|
|
||||||
def set_channel(self, channel: DataChannel):
|
def set_channel(self, channel: DataChannel):
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
@@ -127,21 +144,33 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
if current_channel.get() != self.channel:
|
if current_channel.get() != self.channel:
|
||||||
current_channel.set(self.channel)
|
current_channel.set(self.channel)
|
||||||
|
|
||||||
async def recv(self): # type: ignore
|
async def accept_input(self):
|
||||||
try:
|
self.has_started = True
|
||||||
|
while not self.thread_quit.is_set():
|
||||||
try:
|
try:
|
||||||
frame = cast(VideoFrame, await self.track.recv())
|
frame = cast(VideoFrame, await self.track.recv())
|
||||||
|
self.latest_frame = frame
|
||||||
|
self.frame_queue.put_nowait(frame)
|
||||||
except MediaStreamError:
|
except MediaStreamError:
|
||||||
self.stop()
|
self.stop()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def accept_input_in_background(self):
|
||||||
|
if not self.has_started:
|
||||||
|
asyncio.create_task(self.accept_input())
|
||||||
|
|
||||||
|
async def recv(self): # type: ignore
|
||||||
|
self.accept_input_in_background()
|
||||||
|
try:
|
||||||
|
frame = await self.frame_queue.get()
|
||||||
|
if self.skip_frames:
|
||||||
|
frame = self.latest_frame
|
||||||
await self.wait_for_channel()
|
await self.wait_for_channel()
|
||||||
frame_array = frame.to_ndarray(format="bgr24")
|
frame_array = frame.to_ndarray(format="bgr24") # type: ignore
|
||||||
if self.latest_args == "not_set":
|
if self.latest_args == "not_set":
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
||||||
|
|
||||||
array, outputs = split_output(self.event_handler(*args))
|
array, outputs = split_output(self.event_handler(*args))
|
||||||
if (
|
if (
|
||||||
isinstance(outputs, AdditionalOutputs)
|
isinstance(outputs, AdditionalOutputs)
|
||||||
@@ -161,7 +190,7 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
pts, time_base = await self.next_timestamp()
|
pts, time_base = await self.next_timestamp()
|
||||||
new_frame.pts = pts
|
new_frame.pts = pts
|
||||||
new_frame.time_base = time_base
|
new_frame.time_base = time_base
|
||||||
|
self.function_running = False
|
||||||
return new_frame
|
return new_frame
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("exception %s", e)
|
logger.debug("exception %s", e)
|
||||||
@@ -172,6 +201,21 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
else:
|
else:
|
||||||
raise WebRTCError(str(e)) from e
|
raise WebRTCError(str(e)) from e
|
||||||
|
|
||||||
|
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
|
||||||
|
"""Override to control frame rate"""
|
||||||
|
if self.readyState != "live":
|
||||||
|
raise MediaStreamError
|
||||||
|
|
||||||
|
if hasattr(self, "_timestamp"):
|
||||||
|
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
|
||||||
|
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
|
||||||
|
if wait > 0:
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
else:
|
||||||
|
self._start = time.time()
|
||||||
|
self._timestamp = 0
|
||||||
|
return self._timestamp, VIDEO_TIME_BASE
|
||||||
|
|
||||||
|
|
||||||
class StreamHandlerBase(ABC):
|
class StreamHandlerBase(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -180,11 +224,13 @@ class StreamHandlerBase(ABC):
|
|||||||
output_sample_rate: int = 24000,
|
output_sample_rate: int = 24000,
|
||||||
output_frame_size: int = 960,
|
output_frame_size: int = 960,
|
||||||
input_sample_rate: int = 48000,
|
input_sample_rate: int = 48000,
|
||||||
|
fps: int = 30,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.expected_layout = expected_layout
|
self.expected_layout = expected_layout
|
||||||
self.output_sample_rate = output_sample_rate
|
self.output_sample_rate = output_sample_rate
|
||||||
self.output_frame_size = output_frame_size
|
self.output_frame_size = output_frame_size
|
||||||
self.input_sample_rate = input_sample_rate
|
self.input_sample_rate = input_sample_rate
|
||||||
|
self.fps = fps
|
||||||
self.latest_args: list[Any] = []
|
self.latest_args: list[Any] = []
|
||||||
self._resampler = None
|
self._resampler = None
|
||||||
self._channel: DataChannel | None = None
|
self._channel: DataChannel | None = None
|
||||||
@@ -353,10 +399,16 @@ VideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
|
|||||||
AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
|
AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
|
||||||
AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler
|
AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler
|
||||||
|
|
||||||
HandlerType = StreamHandlerImpl | VideoStreamHandlerImpl | VideoEventHandler | Callable
|
HandlerType = (
|
||||||
|
StreamHandlerImpl
|
||||||
|
| VideoStreamHandlerImpl
|
||||||
|
| VideoEventHandler
|
||||||
|
| Callable
|
||||||
|
| VideoStreamHandler
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VideoStreamHandler(VideoCallback):
|
class VideoStreamHandler_(VideoCallback):
|
||||||
async def process_frames(self):
|
async def process_frames(self):
|
||||||
while not self.thread_quit.is_set():
|
while not self.thread_quit.is_set():
|
||||||
try:
|
try:
|
||||||
@@ -576,6 +628,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
event_handler: Callable,
|
event_handler: Callable,
|
||||||
channel: DataChannel | None = None,
|
channel: DataChannel | None = None,
|
||||||
set_additional_outputs: Callable | None = None,
|
set_additional_outputs: Callable | None = None,
|
||||||
|
fps: int = 30,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__() # don't forget this!
|
super().__init__() # don't forget this!
|
||||||
self.event_handler = event_handler
|
self.event_handler = event_handler
|
||||||
@@ -584,6 +637,8 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
self.generator: Generator[Any, None, Any] | None = None
|
self.generator: Generator[Any, None, Any] | None = None
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.set_additional_outputs = set_additional_outputs
|
self.set_additional_outputs = set_additional_outputs
|
||||||
|
self.fps = fps
|
||||||
|
self.frame_ptime = 1.0 / fps
|
||||||
|
|
||||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||||
@@ -595,6 +650,21 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
self.latest_args = list(args)
|
self.latest_args = list(args)
|
||||||
self.args_set.set()
|
self.args_set.set()
|
||||||
|
|
||||||
|
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
|
||||||
|
"""Override to control frame rate"""
|
||||||
|
if self.readyState != "live":
|
||||||
|
raise MediaStreamError
|
||||||
|
|
||||||
|
if hasattr(self, "_timestamp"):
|
||||||
|
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
|
||||||
|
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
|
||||||
|
if wait > 0:
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
else:
|
||||||
|
self._start = time.time()
|
||||||
|
self._timestamp = 0
|
||||||
|
return self._timestamp, VIDEO_TIME_BASE
|
||||||
|
|
||||||
async def recv(self): # type: ignore
|
async def recv(self): # type: ignore
|
||||||
try:
|
try:
|
||||||
pts, time_base = await self.next_timestamp()
|
pts, time_base = await self.next_timestamp()
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from .tracks import (
|
|||||||
StreamHandlerBase,
|
StreamHandlerBase,
|
||||||
StreamHandlerImpl,
|
StreamHandlerImpl,
|
||||||
VideoEventHandler,
|
VideoEventHandler,
|
||||||
|
VideoStreamHandler,
|
||||||
)
|
)
|
||||||
from .webrtc_connection_mixin import WebRTCConnectionMixin
|
from .webrtc_connection_mixin import WebRTCConnectionMixin
|
||||||
|
|
||||||
@@ -254,6 +255,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
|||||||
| StreamHandlerImpl
|
| StreamHandlerImpl
|
||||||
| AudioVideoStreamHandlerImpl
|
| AudioVideoStreamHandlerImpl
|
||||||
| VideoEventHandler
|
| VideoEventHandler
|
||||||
|
| VideoStreamHandler
|
||||||
| None
|
| None
|
||||||
) = None,
|
) = None,
|
||||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||||
@@ -263,6 +265,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
|||||||
concurrency_id: str | None = None,
|
concurrency_id: str | None = None,
|
||||||
time_limit: float | None = None,
|
time_limit: float | None = None,
|
||||||
trigger: Callable | None = None,
|
trigger: Callable | None = None,
|
||||||
|
send_input_on: Literal["submit", "change"] = "change",
|
||||||
):
|
):
|
||||||
from gradio.blocks import Block
|
from gradio.blocks import Block
|
||||||
|
|
||||||
@@ -304,7 +307,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
|||||||
"In the webrtc stream event, the only output component must be the WebRTC component."
|
"In the webrtc stream event, the only output component must be the WebRTC component."
|
||||||
)
|
)
|
||||||
for input_component in inputs[1:]: # type: ignore
|
for input_component in inputs[1:]: # type: ignore
|
||||||
if hasattr(input_component, "change"):
|
if hasattr(input_component, "change") and send_input_on == "change":
|
||||||
input_component.change( # type: ignore
|
input_component.change( # type: ignore
|
||||||
self.set_input,
|
self.set_input,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
@@ -314,6 +317,13 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
|||||||
time_limit=None,
|
time_limit=None,
|
||||||
js=js,
|
js=js,
|
||||||
)
|
)
|
||||||
|
if hasattr(input_component, "submit") and send_input_on == "submit":
|
||||||
|
input_component.submit( # type: ignore
|
||||||
|
self.set_input,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=None,
|
||||||
|
concurrency_id=concurrency_id,
|
||||||
|
)
|
||||||
return self.tick( # type: ignore
|
return self.tick( # type: ignore
|
||||||
self.set_input,
|
self.set_input,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
|||||||
@@ -31,7 +31,9 @@ from fastrtc.tracks import (
|
|||||||
StreamHandlerBase,
|
StreamHandlerBase,
|
||||||
StreamHandlerImpl,
|
StreamHandlerImpl,
|
||||||
VideoCallback,
|
VideoCallback,
|
||||||
|
VideoEventHandler,
|
||||||
VideoStreamHandler,
|
VideoStreamHandler,
|
||||||
|
VideoStreamHandler_,
|
||||||
)
|
)
|
||||||
from fastrtc.utils import (
|
from fastrtc.utils import (
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
@@ -41,7 +43,7 @@ from fastrtc.utils import (
|
|||||||
|
|
||||||
Track = (
|
Track = (
|
||||||
VideoCallback
|
VideoCallback
|
||||||
| VideoStreamHandler
|
| VideoStreamHandler_
|
||||||
| AudioCallback
|
| AudioCallback
|
||||||
| ServerToClientAudio
|
| ServerToClientAudio
|
||||||
| ServerToClientVideo
|
| ServerToClientVideo
|
||||||
@@ -174,6 +176,11 @@ class WebRTCConnectionMixin:
|
|||||||
handler.video_receive = webrtc_error_handler(handler.video_receive) # type: ignore
|
handler.video_receive = webrtc_error_handler(handler.video_receive) # type: ignore
|
||||||
if hasattr(handler, "video_emit"):
|
if hasattr(handler, "video_emit"):
|
||||||
handler.video_emit = webrtc_error_handler(handler.video_emit) # type: ignore
|
handler.video_emit = webrtc_error_handler(handler.video_emit) # type: ignore
|
||||||
|
elif isinstance(self.event_handler, VideoStreamHandler):
|
||||||
|
self.event_handler.callable = cast(
|
||||||
|
VideoEventHandler, webrtc_error_handler(self.event_handler.callable)
|
||||||
|
)
|
||||||
|
handler = self.event_handler
|
||||||
else:
|
else:
|
||||||
handler = webrtc_error_handler(cast(Callable, self.event_handler))
|
handler = webrtc_error_handler(cast(Callable, self.event_handler))
|
||||||
|
|
||||||
@@ -208,17 +215,25 @@ class WebRTCConnectionMixin:
|
|||||||
handler = self.handlers[body["webrtc_id"]]
|
handler = self.handlers[body["webrtc_id"]]
|
||||||
|
|
||||||
if self.modality == "video" and track.kind == "video":
|
if self.modality == "video" and track.kind == "video":
|
||||||
|
args = {}
|
||||||
|
handler_ = handler
|
||||||
|
if isinstance(handler, VideoStreamHandler):
|
||||||
|
handler_ = handler.callable
|
||||||
|
args["fps"] = handler.fps
|
||||||
|
args["skip_frames"] = handler.skip_frames
|
||||||
cb = VideoCallback(
|
cb = VideoCallback(
|
||||||
relay.subscribe(track),
|
relay.subscribe(track),
|
||||||
event_handler=cast(Callable, handler),
|
event_handler=cast(Callable, handler_),
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
mode=cast(Literal["send", "send-receive"], self.mode),
|
mode=cast(Literal["send", "send-receive"], self.mode),
|
||||||
|
**args,
|
||||||
)
|
)
|
||||||
elif self.modality == "audio-video" and track.kind == "video":
|
elif self.modality == "audio-video" and track.kind == "video":
|
||||||
cb = VideoStreamHandler(
|
cb = VideoStreamHandler_(
|
||||||
relay.subscribe(track),
|
relay.subscribe(track),
|
||||||
event_handler=handler, # type: ignore
|
event_handler=handler, # type: ignore
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
|
fps=cast(StreamHandlerImpl, handler).fps,
|
||||||
)
|
)
|
||||||
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
|
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
|
||||||
eh = cast(StreamHandlerImpl, handler)
|
eh = cast(StreamHandlerImpl, handler)
|
||||||
@@ -245,10 +260,17 @@ class WebRTCConnectionMixin:
|
|||||||
|
|
||||||
if self.mode == "receive":
|
if self.mode == "receive":
|
||||||
if self.modality == "video":
|
if self.modality == "video":
|
||||||
cb = ServerToClientVideo(
|
if isinstance(self.event_handler, VideoStreamHandler):
|
||||||
cast(Callable, self.event_handler),
|
cb = ServerToClientVideo(
|
||||||
set_additional_outputs=set_outputs,
|
cast(Callable, self.event_handler.callable),
|
||||||
)
|
set_additional_outputs=set_outputs,
|
||||||
|
fps=self.event_handler.fps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cb = ServerToClientVideo(
|
||||||
|
cast(Callable, self.event_handler),
|
||||||
|
set_additional_outputs=set_outputs,
|
||||||
|
)
|
||||||
elif self.modality == "audio":
|
elif self.modality == "audio":
|
||||||
cb = ServerToClientAudio(
|
cb = ServerToClientAudio(
|
||||||
cast(Callable, self.event_handler),
|
cast(Callable, self.event_handler),
|
||||||
|
|||||||
@@ -55,3 +55,73 @@ and set the `mode="receive"` in the `WebRTC` component.
|
|||||||
mode="receive"
|
mode="receive"
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Skipping Frames
|
||||||
|
|
||||||
|
If your event handler is not quite real-time yet, then the output feed will look very laggy.
|
||||||
|
|
||||||
|
To fix this, you can set the `skip_frames` parameter to `True`. This will skip the frames that are received while the event handler is still running.
|
||||||
|
|
||||||
|
``` py title="Skipping Frames"
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from fastrtc import Stream, VideoStreamHandler
|
||||||
|
|
||||||
|
|
||||||
|
def process_image(image):
|
||||||
|
time.sleep(
|
||||||
|
0.2
|
||||||
|
) # Simulating 200ms processing time per frame; input arrives faster (30 FPS).
|
||||||
|
return np.flip(image, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
stream = Stream(
|
||||||
|
handler=VideoStreamHandler(process_image, skip_frames=True),
|
||||||
|
modality="video",
|
||||||
|
mode="send-receive",
|
||||||
|
)
|
||||||
|
|
||||||
|
stream.ui.launch()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Setting the Output Frame Rate
|
||||||
|
|
||||||
|
You can set the output frame rate by setting the `fps` parameter in the `VideoStreamHandler`.
|
||||||
|
|
||||||
|
``` py title="Setting the Output Frame Rate"
|
||||||
|
def generation():
|
||||||
|
url = "https://github.com/user-attachments/assets/9636dc97-4fee-46bb-abb8-b92e69c08c71"
|
||||||
|
cap = cv2.VideoCapture(url)
|
||||||
|
iterating = True
|
||||||
|
|
||||||
|
# FPS calculation variables
|
||||||
|
frame_count = 0
|
||||||
|
start_time = time.time()
|
||||||
|
fps = 0
|
||||||
|
|
||||||
|
while iterating:
|
||||||
|
iterating, frame = cap.read()
|
||||||
|
|
||||||
|
# Calculate and print FPS
|
||||||
|
frame_count += 1
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
if elapsed_time >= 1.0: # Update FPS every second
|
||||||
|
fps = frame_count / elapsed_time
|
||||||
|
yield frame, AdditionalOutputs(fps)
|
||||||
|
frame_count = 0
|
||||||
|
start_time = time.time()
|
||||||
|
else:
|
||||||
|
yield frame
|
||||||
|
|
||||||
|
|
||||||
|
stream = Stream(
|
||||||
|
handler=VideoStreamHandler(generation, fps=60),
|
||||||
|
modality="video",
|
||||||
|
mode="receive",
|
||||||
|
additional_outputs=[gr.Number(label="FPS")],
|
||||||
|
additional_outputs_handler=lambda prev, cur: cur,
|
||||||
|
)
|
||||||
|
|
||||||
|
stream.ui.launch()
|
||||||
|
```
|
||||||
|
|||||||
Reference in New Issue
Block a user