From 0c146ee45e4d99d93743168b0611e5d376976168 Mon Sep 17 00:00:00 2001 From: Freddy Boulton <41651716+freddyaboulton@users.noreply.github.com> Date: Fri, 30 May 2025 13:38:59 -0400 Subject: [PATCH] Pass Websocket to the context if available (#329) * Add code * Code * Fix * Add code --- backend/fastrtc/utils.py | 2 ++ backend/fastrtc/websocket.py | 32 ++++++++++++++++++-------------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/backend/fastrtc/utils.py b/backend/fastrtc/utils.py index 93b5728..d001da2 100644 --- a/backend/fastrtc/utils.py +++ b/backend/fastrtc/utils.py @@ -16,6 +16,7 @@ from typing import Any, Literal, Protocol, TypedDict, cast import av import librosa import numpy as np +from fastapi import WebSocket from numpy.typing import NDArray from pydub import AudioSegment @@ -67,6 +68,7 @@ current_channel: ContextVar[DataChannel | None] = ContextVar( @dataclass class Context: webrtc_id: str + websocket: WebSocket | None = None current_context: ContextVar[Context | None] = ContextVar( diff --git a/backend/fastrtc/websocket.py b/backend/fastrtc/websocket.py index f2b182d..0489ceb 100644 --- a/backend/fastrtc/websocket.py +++ b/backend/fastrtc/websocket.py @@ -5,9 +5,9 @@ import logging from collections.abc import Awaitable, Callable from typing import Any, cast -import anyio import librosa import numpy as np +from anyio.to_thread import run_sync from fastapi import WebSocket from fastapi.websockets import WebSocketDisconnect, WebSocketState @@ -52,9 +52,6 @@ def convert_to_mulaw( return audioop.lin2ulaw(audio_data, 2) # type: ignore -run_sync = anyio.to_thread.run_sync # type: ignore - - class WebSocketHandler: def __init__( self, @@ -94,21 +91,14 @@ class WebSocketHandler: await websocket.accept() loop = asyncio.get_running_loop() self.loop = loop - self.websocket = websocket self.data_channel = WebSocketDataChannel(websocket, loop) self.stream_handler._loop = loop self.stream_handler.set_channel(self.data_channel) self._emit_task = asyncio.create_task(self._emit_loop()) self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue()) self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop()) - if isinstance(self.stream_handler, AsyncStreamHandler): - start_up = self.stream_handler.start_up() - else: - start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore - was_disconnected = False - self.start_up_task = asyncio.create_task(start_up) try: while not self.quit.is_set(): if websocket.application_state != WebSocketState.CONNECTED: @@ -157,7 +147,17 @@ class WebSocketHandler: self.stream_id = cast(str, message["streamSid"]) else: self.stream_id = cast(str, message["websocket_id"]) - current_context.set(Context(webrtc_id=self.stream_id)) + self.websocket = websocket + current_context.set( + Context(webrtc_id=self.stream_id, websocket=websocket) + ) + if isinstance(self.stream_handler, AsyncStreamHandler): + start_up = self.stream_handler.start_up() + else: + start_up = run_sync(self.stream_handler.start_up) # type: ignore + + self.start_up_task = asyncio.create_task(start_up) + self.set_additional_outputs = self.set_additional_outputs_factory( self.stream_id ) @@ -189,11 +189,15 @@ class WebSocketHandler: self.clean_up(cast(str, self.stream_id)) def emit_with_context(self): - current_context.set(Context(webrtc_id=cast(str, self.stream_id))) + current_context.set( + Context(webrtc_id=cast(str, self.stream_id), websocket=self.websocket) + ) return self.stream_handler.emit() def receive_with_context(self, frame: tuple[int, np.ndarray]): - current_context.set(Context(webrtc_id=cast(str, self.stream_id))) + current_context.set( + Context(webrtc_id=cast(str, self.stream_id), websocket=self.websocket) + ) return self.stream_handler.receive(frame) async def _emit_to_queue(self):