From f742c932356081e0fcf38b600be9efc98d894002 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Fri, 28 Mar 2025 21:12:58 -0400 Subject: [PATCH] add code (#223) --- backend/fastrtc/__init__.py | 2 ++ backend/fastrtc/tracks.py | 16 ++++++++++++++++ backend/fastrtc/utils.py | 17 +++++++++++++++++ backend/fastrtc/webrtc_connection_mixin.py | 10 +++++++++- 4 files changed, 44 insertions(+), 1 deletion(-) diff --git a/backend/fastrtc/__init__.py b/backend/fastrtc/__init__.py index 75e982a..97e1bd0 100644 --- a/backend/fastrtc/__init__.py +++ b/backend/fastrtc/__init__.py @@ -34,6 +34,7 @@ from .utils import ( audio_to_file, audio_to_float32, audio_to_int16, + get_current_context, wait_for_item, ) from .webrtc import ( @@ -77,4 +78,5 @@ __all__ = [ "SileroVadOptions", "VideoStreamHandler", "CloseStream", + "get_current_context", ] diff --git a/backend/fastrtc/tracks.py b/backend/fastrtc/tracks.py index 8371e01..36fcea1 100644 --- a/backend/fastrtc/tracks.py +++ b/backend/fastrtc/tracks.py @@ -38,10 +38,12 @@ from numpy import typing as npt from fastrtc.utils import ( AdditionalOutputs, CloseStream, + Context, DataChannel, WebRTCError, create_message, current_channel, + current_context, player_worker_decode, split_output, ) @@ -83,6 +85,7 @@ class VideoCallback(VideoStreamTrack): self, track: MediaStreamTrack, event_handler: VideoEventHandler, + context: Context, channel: DataChannel | None = None, set_additional_outputs: Callable | None = None, mode: Literal["send-receive", "send"] = "send-receive", @@ -104,10 +107,12 @@ class VideoCallback(VideoStreamTrack): self.skip_frames = skip_frames self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue() self.latest_frame = None + self.context = context def set_channel(self, channel: DataChannel): self.channel = channel current_channel.set(channel) + current_context.set(self.context) self.channel_set.set() def set_args(self, args: list[Any]): @@ -145,6 +150,7 @@ class VideoCallback(VideoStreamTrack): self.thread_quit.set() async def wait_for_channel(self): + current_context.set(self.context) if not self.channel_set.is_set(): await self.channel_set.wait() if current_channel.get() != self.channel: @@ -486,6 +492,7 @@ class AudioCallback(AudioStreamTrack): self, track: MediaStreamTrack, event_handler: StreamHandlerBase, + context: Context, channel: DataChannel | None = None, set_additional_outputs: Callable | None = None, ) -> None: @@ -502,6 +509,7 @@ class AudioCallback(AudioStreamTrack): self.last_timestamp = 0 self.channel = channel self.set_additional_outputs = set_additional_outputs + self.context = context def clear_queue(self): logger.debug("clearing queue") @@ -514,6 +522,7 @@ class AudioCallback(AudioStreamTrack): self._start = None async def wait_for_channel(self): + current_context.set(self.context) if not self.event_handler.channel_set.is_set(): await self.event_handler.channel_set.wait() if current_channel.get() != self.event_handler.channel: @@ -532,6 +541,7 @@ class AudioCallback(AudioStreamTrack): def event_handler_emit(self) -> EmitType: current_channel.set(self.event_handler.channel) + current_context.set(self.context) return cast(Callable, self.event_handler.emit)() async def process_input_frames(self) -> None: @@ -649,6 +659,7 @@ class ServerToClientVideo(VideoStreamTrack): def __init__( self, event_handler: Callable, + context: Context, channel: DataChannel | None = None, set_additional_outputs: Callable | None = None, fps: int = 30, @@ -662,6 +673,7 @@ class ServerToClientVideo(VideoStreamTrack): self.set_additional_outputs = set_additional_outputs self.fps = fps self.frame_ptime = 1.0 / fps + self.context = context def array_to_frame(self, array: np.ndarray) -> VideoFrame: return VideoFrame.from_ndarray(array, format="bgr24") @@ -693,6 +705,7 @@ class ServerToClientVideo(VideoStreamTrack): pts, time_base = await self.next_timestamp() await self.args_set.wait() current_channel.set(self.channel) + current_context.set(self.context) if self.generator is None: self.generator = cast( Generator[Any, None, Any], self.event_handler(*self.latest_args) @@ -736,6 +749,7 @@ class ServerToClientAudio(AudioStreamTrack): def __init__( self, event_handler: Callable, + context: Context, channel: DataChannel | None = None, set_additional_outputs: Callable | None = None, ) -> None: @@ -751,6 +765,7 @@ class ServerToClientAudio(AudioStreamTrack): self.set_additional_outputs = set_additional_outputs self.has_started = False self._start: float | None = None + self.context = context super().__init__() def clear_queue(self): @@ -766,6 +781,7 @@ class ServerToClientAudio(AudioStreamTrack): self.args_set.set() def next(self) -> tuple[int, np.ndarray] | None: + current_context.set(self.context) self.args_set.wait() current_channel.set(self.channel) if self.generator is None: diff --git a/backend/fastrtc/utils.py b/backend/fastrtc/utils.py index 0fb1c06..5f2802d 100644 --- a/backend/fastrtc/utils.py +++ b/backend/fastrtc/utils.py @@ -8,6 +8,7 @@ import logging import tempfile import traceback from contextvars import ContextVar +from dataclasses import dataclass from typing import Any, Callable, Literal, Protocol, TypedDict, cast import av @@ -61,6 +62,22 @@ current_channel: ContextVar[DataChannel | None] = ContextVar( ) +@dataclass +class Context: + webrtc_id: str + + +current_context: ContextVar[Context | None] = ContextVar( + "current_context", default=None +) + + +def get_current_context() -> Context: + if not (ctx := current_context.get()): + raise RuntimeError("No context found") + return ctx + + def _send_log(message: str, type: str) -> None: async def _send(channel: DataChannel) -> None: channel.send( diff --git a/backend/fastrtc/webrtc_connection_mixin.py b/backend/fastrtc/webrtc_connection_mixin.py index 3b744dc..81458f4 100644 --- a/backend/fastrtc/webrtc_connection_mixin.py +++ b/backend/fastrtc/webrtc_connection_mixin.py @@ -38,6 +38,7 @@ from fastrtc.tracks import ( ) from fastrtc.utils import ( AdditionalOutputs, + Context, create_message, webrtc_error_handler, ) @@ -291,7 +292,7 @@ class WebRTCConnectionMixin: def _(track): relay = MediaRelay() handler = self.handlers[body["webrtc_id"]] - + context = Context(webrtc_id=body["webrtc_id"]) if self.modality == "video" and track.kind == "video": args = {} handler_ = handler @@ -304,6 +305,7 @@ class WebRTCConnectionMixin: event_handler=cast(Callable, handler_), set_additional_outputs=set_outputs, mode=cast(Literal["send", "send-receive"], self.mode), + context=context, **args, ) elif self.modality == "audio-video" and track.kind == "video": @@ -312,6 +314,7 @@ class WebRTCConnectionMixin: event_handler=handler, # type: ignore set_additional_outputs=set_outputs, fps=cast(StreamHandlerImpl, handler).fps, + context=context, ) elif self.modality in ["audio", "audio-video"] and track.kind == "audio": eh = cast(StreamHandlerImpl, handler) @@ -320,6 +323,7 @@ class WebRTCConnectionMixin: relay.subscribe(track), event_handler=eh, set_additional_outputs=set_outputs, + context=context, ) else: raise ValueError("Modality must be either video, audio, or audio-video") @@ -336,6 +340,7 @@ class WebRTCConnectionMixin: elif self.mode == "send": asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start()) + context = Context(webrtc_id=body["webrtc_id"]) if self.mode == "receive": if self.modality == "video": if isinstance(self.event_handler, VideoStreamHandler): @@ -343,16 +348,19 @@ class WebRTCConnectionMixin: cast(Callable, self.event_handler.callable), set_additional_outputs=set_outputs, fps=self.event_handler.fps, + context=context, ) else: cb = ServerToClientVideo( cast(Callable, self.event_handler), set_additional_outputs=set_outputs, + context=context, ) elif self.modality == "audio": cb = ServerToClientAudio( cast(Callable, self.event_handler), set_additional_outputs=set_outputs, + context=context, ) else: raise ValueError("Modality must be either video or audio")