This commit is contained in:
Freddy Boulton
2025-03-28 21:12:58 -04:00
committed by GitHub
parent 8ed27fba78
commit f742c93235
4 changed files with 44 additions and 1 deletions

View File

@@ -34,6 +34,7 @@ from .utils import (
audio_to_file,
audio_to_float32,
audio_to_int16,
get_current_context,
wait_for_item,
)
from .webrtc import (
@@ -77,4 +78,5 @@ __all__ = [
"SileroVadOptions",
"VideoStreamHandler",
"CloseStream",
"get_current_context",
]

View File

@@ -38,10 +38,12 @@ from numpy import typing as npt
from fastrtc.utils import (
AdditionalOutputs,
CloseStream,
Context,
DataChannel,
WebRTCError,
create_message,
current_channel,
current_context,
player_worker_decode,
split_output,
)
@@ -83,6 +85,7 @@ class VideoCallback(VideoStreamTrack):
self,
track: MediaStreamTrack,
event_handler: VideoEventHandler,
context: Context,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive",
@@ -104,10 +107,12 @@ class VideoCallback(VideoStreamTrack):
self.skip_frames = skip_frames
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
self.latest_frame = None
self.context = context
def set_channel(self, channel: DataChannel):
self.channel = channel
current_channel.set(channel)
current_context.set(self.context)
self.channel_set.set()
def set_args(self, args: list[Any]):
@@ -145,6 +150,7 @@ class VideoCallback(VideoStreamTrack):
self.thread_quit.set()
async def wait_for_channel(self):
current_context.set(self.context)
if not self.channel_set.is_set():
await self.channel_set.wait()
if current_channel.get() != self.channel:
@@ -486,6 +492,7 @@ class AudioCallback(AudioStreamTrack):
self,
track: MediaStreamTrack,
event_handler: StreamHandlerBase,
context: Context,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
@@ -502,6 +509,7 @@ class AudioCallback(AudioStreamTrack):
self.last_timestamp = 0
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.context = context
def clear_queue(self):
logger.debug("clearing queue")
@@ -514,6 +522,7 @@ class AudioCallback(AudioStreamTrack):
self._start = None
async def wait_for_channel(self):
current_context.set(self.context)
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
@@ -532,6 +541,7 @@ class AudioCallback(AudioStreamTrack):
def event_handler_emit(self) -> EmitType:
current_channel.set(self.event_handler.channel)
current_context.set(self.context)
return cast(Callable, self.event_handler.emit)()
async def process_input_frames(self) -> None:
@@ -649,6 +659,7 @@ class ServerToClientVideo(VideoStreamTrack):
def __init__(
self,
event_handler: Callable,
context: Context,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
fps: int = 30,
@@ -662,6 +673,7 @@ class ServerToClientVideo(VideoStreamTrack):
self.set_additional_outputs = set_additional_outputs
self.fps = fps
self.frame_ptime = 1.0 / fps
self.context = context
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
@@ -693,6 +705,7 @@ class ServerToClientVideo(VideoStreamTrack):
pts, time_base = await self.next_timestamp()
await self.args_set.wait()
current_channel.set(self.channel)
current_context.set(self.context)
if self.generator is None:
self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args)
@@ -736,6 +749,7 @@ class ServerToClientAudio(AudioStreamTrack):
def __init__(
self,
event_handler: Callable,
context: Context,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
@@ -751,6 +765,7 @@ class ServerToClientAudio(AudioStreamTrack):
self.set_additional_outputs = set_additional_outputs
self.has_started = False
self._start: float | None = None
self.context = context
super().__init__()
def clear_queue(self):
@@ -766,6 +781,7 @@ class ServerToClientAudio(AudioStreamTrack):
self.args_set.set()
def next(self) -> tuple[int, np.ndarray] | None:
current_context.set(self.context)
self.args_set.wait()
current_channel.set(self.channel)
if self.generator is None:

View File

@@ -8,6 +8,7 @@ import logging
import tempfile
import traceback
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
import av
@@ -61,6 +62,22 @@ current_channel: ContextVar[DataChannel | None] = ContextVar(
)
@dataclass
class Context:
webrtc_id: str
current_context: ContextVar[Context | None] = ContextVar(
"current_context", default=None
)
def get_current_context() -> Context:
if not (ctx := current_context.get()):
raise RuntimeError("No context found")
return ctx
def _send_log(message: str, type: str) -> None:
async def _send(channel: DataChannel) -> None:
channel.send(

View File

@@ -38,6 +38,7 @@ from fastrtc.tracks import (
)
from fastrtc.utils import (
AdditionalOutputs,
Context,
create_message,
webrtc_error_handler,
)
@@ -291,7 +292,7 @@ class WebRTCConnectionMixin:
def _(track):
relay = MediaRelay()
handler = self.handlers[body["webrtc_id"]]
context = Context(webrtc_id=body["webrtc_id"])
if self.modality == "video" and track.kind == "video":
args = {}
handler_ = handler
@@ -304,6 +305,7 @@ class WebRTCConnectionMixin:
event_handler=cast(Callable, handler_),
set_additional_outputs=set_outputs,
mode=cast(Literal["send", "send-receive"], self.mode),
context=context,
**args,
)
elif self.modality == "audio-video" and track.kind == "video":
@@ -312,6 +314,7 @@ class WebRTCConnectionMixin:
event_handler=handler, # type: ignore
set_additional_outputs=set_outputs,
fps=cast(StreamHandlerImpl, handler).fps,
context=context,
)
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
eh = cast(StreamHandlerImpl, handler)
@@ -320,6 +323,7 @@ class WebRTCConnectionMixin:
relay.subscribe(track),
event_handler=eh,
set_additional_outputs=set_outputs,
context=context,
)
else:
raise ValueError("Modality must be either video, audio, or audio-video")
@@ -336,6 +340,7 @@ class WebRTCConnectionMixin:
elif self.mode == "send":
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
context = Context(webrtc_id=body["webrtc_id"])
if self.mode == "receive":
if self.modality == "video":
if isinstance(self.event_handler, VideoStreamHandler):
@@ -343,16 +348,19 @@ class WebRTCConnectionMixin:
cast(Callable, self.event_handler.callable),
set_additional_outputs=set_outputs,
fps=self.event_handler.fps,
context=context,
)
else:
cb = ServerToClientVideo(
cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
context=context,
)
elif self.modality == "audio":
cb = ServerToClientAudio(
cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
context=context,
)
else:
raise ValueError("Modality must be either video or audio")