This commit is contained in:
Freddy Boulton
2025-03-28 21:12:58 -04:00
committed by GitHub
parent 8ed27fba78
commit f742c93235
4 changed files with 44 additions and 1 deletions

View File

@@ -38,10 +38,12 @@ from numpy import typing as npt
from fastrtc.utils import (
AdditionalOutputs,
CloseStream,
Context,
DataChannel,
WebRTCError,
create_message,
current_channel,
current_context,
player_worker_decode,
split_output,
)
@@ -83,6 +85,7 @@ class VideoCallback(VideoStreamTrack):
self,
track: MediaStreamTrack,
event_handler: VideoEventHandler,
context: Context,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive",
@@ -104,10 +107,12 @@ class VideoCallback(VideoStreamTrack):
self.skip_frames = skip_frames
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
self.latest_frame = None
self.context = context
def set_channel(self, channel: DataChannel):
self.channel = channel
current_channel.set(channel)
current_context.set(self.context)
self.channel_set.set()
def set_args(self, args: list[Any]):
@@ -145,6 +150,7 @@ class VideoCallback(VideoStreamTrack):
self.thread_quit.set()
async def wait_for_channel(self):
current_context.set(self.context)
if not self.channel_set.is_set():
await self.channel_set.wait()
if current_channel.get() != self.channel:
@@ -486,6 +492,7 @@ class AudioCallback(AudioStreamTrack):
self,
track: MediaStreamTrack,
event_handler: StreamHandlerBase,
context: Context,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
@@ -502,6 +509,7 @@ class AudioCallback(AudioStreamTrack):
self.last_timestamp = 0
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.context = context
def clear_queue(self):
logger.debug("clearing queue")
@@ -514,6 +522,7 @@ class AudioCallback(AudioStreamTrack):
self._start = None
async def wait_for_channel(self):
current_context.set(self.context)
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
@@ -532,6 +541,7 @@ class AudioCallback(AudioStreamTrack):
def event_handler_emit(self) -> EmitType:
current_channel.set(self.event_handler.channel)
current_context.set(self.context)
return cast(Callable, self.event_handler.emit)()
async def process_input_frames(self) -> None:
@@ -649,6 +659,7 @@ class ServerToClientVideo(VideoStreamTrack):
def __init__(
self,
event_handler: Callable,
context: Context,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
fps: int = 30,
@@ -662,6 +673,7 @@ class ServerToClientVideo(VideoStreamTrack):
self.set_additional_outputs = set_additional_outputs
self.fps = fps
self.frame_ptime = 1.0 / fps
self.context = context
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
@@ -693,6 +705,7 @@ class ServerToClientVideo(VideoStreamTrack):
pts, time_base = await self.next_timestamp()
await self.args_set.wait()
current_channel.set(self.channel)
current_context.set(self.context)
if self.generator is None:
self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args)
@@ -736,6 +749,7 @@ class ServerToClientAudio(AudioStreamTrack):
def __init__(
self,
event_handler: Callable,
context: Context,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
@@ -751,6 +765,7 @@ class ServerToClientAudio(AudioStreamTrack):
self.set_additional_outputs = set_additional_outputs
self.has_started = False
self._start: float | None = None
self.context = context
super().__init__()
def clear_queue(self):
@@ -766,6 +781,7 @@ class ServerToClientAudio(AudioStreamTrack):
self.args_set.set()
def next(self) -> tuple[int, np.ndarray] | None:
current_context.set(self.context)
self.args_set.wait()
current_channel.set(self.channel)
if self.generator is None: