Pass Websocket to the context if available (#329)

* Add code

* Code

* Fix

* Add code
This commit is contained in:
Freddy Boulton
2025-05-30 13:38:59 -04:00
committed by GitHub
parent 3fc258cd1b
commit 0c146ee45e
2 changed files with 20 additions and 14 deletions

View File

@@ -16,6 +16,7 @@ from typing import Any, Literal, Protocol, TypedDict, cast
import av import av
import librosa import librosa
import numpy as np import numpy as np
from fastapi import WebSocket
from numpy.typing import NDArray from numpy.typing import NDArray
from pydub import AudioSegment from pydub import AudioSegment
@@ -67,6 +68,7 @@ current_channel: ContextVar[DataChannel | None] = ContextVar(
@dataclass @dataclass
class Context: class Context:
webrtc_id: str webrtc_id: str
websocket: WebSocket | None = None
current_context: ContextVar[Context | None] = ContextVar( current_context: ContextVar[Context | None] = ContextVar(

View File

@@ -5,9 +5,9 @@ import logging
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any, cast from typing import Any, cast
import anyio
import librosa import librosa
import numpy as np import numpy as np
from anyio.to_thread import run_sync
from fastapi import WebSocket from fastapi import WebSocket
from fastapi.websockets import WebSocketDisconnect, WebSocketState from fastapi.websockets import WebSocketDisconnect, WebSocketState
@@ -52,9 +52,6 @@ def convert_to_mulaw(
return audioop.lin2ulaw(audio_data, 2) # type: ignore return audioop.lin2ulaw(audio_data, 2) # type: ignore
run_sync = anyio.to_thread.run_sync # type: ignore
class WebSocketHandler: class WebSocketHandler:
def __init__( def __init__(
self, self,
@@ -94,21 +91,14 @@ class WebSocketHandler:
await websocket.accept() await websocket.accept()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
self.loop = loop self.loop = loop
self.websocket = websocket
self.data_channel = WebSocketDataChannel(websocket, loop) self.data_channel = WebSocketDataChannel(websocket, loop)
self.stream_handler._loop = loop self.stream_handler._loop = loop
self.stream_handler.set_channel(self.data_channel) self.stream_handler.set_channel(self.data_channel)
self._emit_task = asyncio.create_task(self._emit_loop()) self._emit_task = asyncio.create_task(self._emit_loop())
self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue()) self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue())
self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop()) self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop())
if isinstance(self.stream_handler, AsyncStreamHandler):
start_up = self.stream_handler.start_up()
else:
start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore
was_disconnected = False was_disconnected = False
self.start_up_task = asyncio.create_task(start_up)
try: try:
while not self.quit.is_set(): while not self.quit.is_set():
if websocket.application_state != WebSocketState.CONNECTED: if websocket.application_state != WebSocketState.CONNECTED:
@@ -157,7 +147,17 @@ class WebSocketHandler:
self.stream_id = cast(str, message["streamSid"]) self.stream_id = cast(str, message["streamSid"])
else: else:
self.stream_id = cast(str, message["websocket_id"]) self.stream_id = cast(str, message["websocket_id"])
current_context.set(Context(webrtc_id=self.stream_id)) self.websocket = websocket
current_context.set(
Context(webrtc_id=self.stream_id, websocket=websocket)
)
if isinstance(self.stream_handler, AsyncStreamHandler):
start_up = self.stream_handler.start_up()
else:
start_up = run_sync(self.stream_handler.start_up) # type: ignore
self.start_up_task = asyncio.create_task(start_up)
self.set_additional_outputs = self.set_additional_outputs_factory( self.set_additional_outputs = self.set_additional_outputs_factory(
self.stream_id self.stream_id
) )
@@ -189,11 +189,15 @@ class WebSocketHandler:
self.clean_up(cast(str, self.stream_id)) self.clean_up(cast(str, self.stream_id))
def emit_with_context(self): def emit_with_context(self):
current_context.set(Context(webrtc_id=cast(str, self.stream_id))) current_context.set(
Context(webrtc_id=cast(str, self.stream_id), websocket=self.websocket)
)
return self.stream_handler.emit() return self.stream_handler.emit()
def receive_with_context(self, frame: tuple[int, np.ndarray]): def receive_with_context(self, frame: tuple[int, np.ndarray]):
current_context.set(Context(webrtc_id=cast(str, self.stream_id))) current_context.set(
Context(webrtc_id=cast(str, self.stream_id), websocket=self.websocket)
)
return self.stream_handler.receive(frame) return self.stream_handler.receive(frame)
async def _emit_to_queue(self): async def _emit_to_queue(self):