This commit is contained in:
Freddy Boulton
2025-03-28 21:12:58 -04:00
committed by GitHub
parent 8ed27fba78
commit f742c93235
4 changed files with 44 additions and 1 deletions

View File

@@ -34,6 +34,7 @@ from .utils import (
audio_to_file, audio_to_file,
audio_to_float32, audio_to_float32,
audio_to_int16, audio_to_int16,
get_current_context,
wait_for_item, wait_for_item,
) )
from .webrtc import ( from .webrtc import (
@@ -77,4 +78,5 @@ __all__ = [
"SileroVadOptions", "SileroVadOptions",
"VideoStreamHandler", "VideoStreamHandler",
"CloseStream", "CloseStream",
"get_current_context",
] ]

View File

@@ -38,10 +38,12 @@ from numpy import typing as npt
from fastrtc.utils import ( from fastrtc.utils import (
AdditionalOutputs, AdditionalOutputs,
CloseStream, CloseStream,
Context,
DataChannel, DataChannel,
WebRTCError, WebRTCError,
create_message, create_message,
current_channel, current_channel,
current_context,
player_worker_decode, player_worker_decode,
split_output, split_output,
) )
@@ -83,6 +85,7 @@ class VideoCallback(VideoStreamTrack):
self, self,
track: MediaStreamTrack, track: MediaStreamTrack,
event_handler: VideoEventHandler, event_handler: VideoEventHandler,
context: Context,
channel: DataChannel | None = None, channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None, set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive", mode: Literal["send-receive", "send"] = "send-receive",
@@ -104,10 +107,12 @@ class VideoCallback(VideoStreamTrack):
self.skip_frames = skip_frames self.skip_frames = skip_frames
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue() self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
self.latest_frame = None self.latest_frame = None
self.context = context
def set_channel(self, channel: DataChannel): def set_channel(self, channel: DataChannel):
self.channel = channel self.channel = channel
current_channel.set(channel) current_channel.set(channel)
current_context.set(self.context)
self.channel_set.set() self.channel_set.set()
def set_args(self, args: list[Any]): def set_args(self, args: list[Any]):
@@ -145,6 +150,7 @@ class VideoCallback(VideoStreamTrack):
self.thread_quit.set() self.thread_quit.set()
async def wait_for_channel(self): async def wait_for_channel(self):
current_context.set(self.context)
if not self.channel_set.is_set(): if not self.channel_set.is_set():
await self.channel_set.wait() await self.channel_set.wait()
if current_channel.get() != self.channel: if current_channel.get() != self.channel:
@@ -486,6 +492,7 @@ class AudioCallback(AudioStreamTrack):
self, self,
track: MediaStreamTrack, track: MediaStreamTrack,
event_handler: StreamHandlerBase, event_handler: StreamHandlerBase,
context: Context,
channel: DataChannel | None = None, channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None, set_additional_outputs: Callable | None = None,
) -> None: ) -> None:
@@ -502,6 +509,7 @@ class AudioCallback(AudioStreamTrack):
self.last_timestamp = 0 self.last_timestamp = 0
self.channel = channel self.channel = channel
self.set_additional_outputs = set_additional_outputs self.set_additional_outputs = set_additional_outputs
self.context = context
def clear_queue(self): def clear_queue(self):
logger.debug("clearing queue") logger.debug("clearing queue")
@@ -514,6 +522,7 @@ class AudioCallback(AudioStreamTrack):
self._start = None self._start = None
async def wait_for_channel(self): async def wait_for_channel(self):
current_context.set(self.context)
if not self.event_handler.channel_set.is_set(): if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait() await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel: if current_channel.get() != self.event_handler.channel:
@@ -532,6 +541,7 @@ class AudioCallback(AudioStreamTrack):
def event_handler_emit(self) -> EmitType: def event_handler_emit(self) -> EmitType:
current_channel.set(self.event_handler.channel) current_channel.set(self.event_handler.channel)
current_context.set(self.context)
return cast(Callable, self.event_handler.emit)() return cast(Callable, self.event_handler.emit)()
async def process_input_frames(self) -> None: async def process_input_frames(self) -> None:
@@ -649,6 +659,7 @@ class ServerToClientVideo(VideoStreamTrack):
def __init__( def __init__(
self, self,
event_handler: Callable, event_handler: Callable,
context: Context,
channel: DataChannel | None = None, channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None, set_additional_outputs: Callable | None = None,
fps: int = 30, fps: int = 30,
@@ -662,6 +673,7 @@ class ServerToClientVideo(VideoStreamTrack):
self.set_additional_outputs = set_additional_outputs self.set_additional_outputs = set_additional_outputs
self.fps = fps self.fps = fps
self.frame_ptime = 1.0 / fps self.frame_ptime = 1.0 / fps
self.context = context
def array_to_frame(self, array: np.ndarray) -> VideoFrame: def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24") return VideoFrame.from_ndarray(array, format="bgr24")
@@ -693,6 +705,7 @@ class ServerToClientVideo(VideoStreamTrack):
pts, time_base = await self.next_timestamp() pts, time_base = await self.next_timestamp()
await self.args_set.wait() await self.args_set.wait()
current_channel.set(self.channel) current_channel.set(self.channel)
current_context.set(self.context)
if self.generator is None: if self.generator is None:
self.generator = cast( self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args) Generator[Any, None, Any], self.event_handler(*self.latest_args)
@@ -736,6 +749,7 @@ class ServerToClientAudio(AudioStreamTrack):
def __init__( def __init__(
self, self,
event_handler: Callable, event_handler: Callable,
context: Context,
channel: DataChannel | None = None, channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None, set_additional_outputs: Callable | None = None,
) -> None: ) -> None:
@@ -751,6 +765,7 @@ class ServerToClientAudio(AudioStreamTrack):
self.set_additional_outputs = set_additional_outputs self.set_additional_outputs = set_additional_outputs
self.has_started = False self.has_started = False
self._start: float | None = None self._start: float | None = None
self.context = context
super().__init__() super().__init__()
def clear_queue(self): def clear_queue(self):
@@ -766,6 +781,7 @@ class ServerToClientAudio(AudioStreamTrack):
self.args_set.set() self.args_set.set()
def next(self) -> tuple[int, np.ndarray] | None: def next(self) -> tuple[int, np.ndarray] | None:
current_context.set(self.context)
self.args_set.wait() self.args_set.wait()
current_channel.set(self.channel) current_channel.set(self.channel)
if self.generator is None: if self.generator is None:

View File

@@ -8,6 +8,7 @@ import logging
import tempfile import tempfile
import traceback import traceback
from contextvars import ContextVar from contextvars import ContextVar
from dataclasses import dataclass
from typing import Any, Callable, Literal, Protocol, TypedDict, cast from typing import Any, Callable, Literal, Protocol, TypedDict, cast
import av import av
@@ -61,6 +62,22 @@ current_channel: ContextVar[DataChannel | None] = ContextVar(
) )
@dataclass
class Context:
webrtc_id: str
current_context: ContextVar[Context | None] = ContextVar(
"current_context", default=None
)
def get_current_context() -> Context:
if not (ctx := current_context.get()):
raise RuntimeError("No context found")
return ctx
def _send_log(message: str, type: str) -> None: def _send_log(message: str, type: str) -> None:
async def _send(channel: DataChannel) -> None: async def _send(channel: DataChannel) -> None:
channel.send( channel.send(

View File

@@ -38,6 +38,7 @@ from fastrtc.tracks import (
) )
from fastrtc.utils import ( from fastrtc.utils import (
AdditionalOutputs, AdditionalOutputs,
Context,
create_message, create_message,
webrtc_error_handler, webrtc_error_handler,
) )
@@ -291,7 +292,7 @@ class WebRTCConnectionMixin:
def _(track): def _(track):
relay = MediaRelay() relay = MediaRelay()
handler = self.handlers[body["webrtc_id"]] handler = self.handlers[body["webrtc_id"]]
context = Context(webrtc_id=body["webrtc_id"])
if self.modality == "video" and track.kind == "video": if self.modality == "video" and track.kind == "video":
args = {} args = {}
handler_ = handler handler_ = handler
@@ -304,6 +305,7 @@ class WebRTCConnectionMixin:
event_handler=cast(Callable, handler_), event_handler=cast(Callable, handler_),
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
mode=cast(Literal["send", "send-receive"], self.mode), mode=cast(Literal["send", "send-receive"], self.mode),
context=context,
**args, **args,
) )
elif self.modality == "audio-video" and track.kind == "video": elif self.modality == "audio-video" and track.kind == "video":
@@ -312,6 +314,7 @@ class WebRTCConnectionMixin:
event_handler=handler, # type: ignore event_handler=handler, # type: ignore
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
fps=cast(StreamHandlerImpl, handler).fps, fps=cast(StreamHandlerImpl, handler).fps,
context=context,
) )
elif self.modality in ["audio", "audio-video"] and track.kind == "audio": elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
eh = cast(StreamHandlerImpl, handler) eh = cast(StreamHandlerImpl, handler)
@@ -320,6 +323,7 @@ class WebRTCConnectionMixin:
relay.subscribe(track), relay.subscribe(track),
event_handler=eh, event_handler=eh,
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
context=context,
) )
else: else:
raise ValueError("Modality must be either video, audio, or audio-video") raise ValueError("Modality must be either video, audio, or audio-video")
@@ -336,6 +340,7 @@ class WebRTCConnectionMixin:
elif self.mode == "send": elif self.mode == "send":
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start()) asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
context = Context(webrtc_id=body["webrtc_id"])
if self.mode == "receive": if self.mode == "receive":
if self.modality == "video": if self.modality == "video":
if isinstance(self.event_handler, VideoStreamHandler): if isinstance(self.event_handler, VideoStreamHandler):
@@ -343,16 +348,19 @@ class WebRTCConnectionMixin:
cast(Callable, self.event_handler.callable), cast(Callable, self.event_handler.callable),
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
fps=self.event_handler.fps, fps=self.event_handler.fps,
context=context,
) )
else: else:
cb = ServerToClientVideo( cb = ServerToClientVideo(
cast(Callable, self.event_handler), cast(Callable, self.event_handler),
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
context=context,
) )
elif self.modality == "audio": elif self.modality == "audio":
cb = ServerToClientAudio( cb = ServerToClientAudio(
cast(Callable, self.event_handler), cast(Callable, self.event_handler),
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
context=context,
) )
else: else:
raise ValueError("Modality must be either video or audio") raise ValueError("Modality must be either video or audio")