Pass Websocket to the context if available (#329)

* Add code

* Code

* Fix

* Add code
This commit is contained in:
Freddy Boulton
2025-05-30 13:38:59 -04:00
committed by GitHub
parent 3fc258cd1b
commit 0c146ee45e
2 changed files with 20 additions and 14 deletions

View File

@@ -16,6 +16,7 @@ from typing import Any, Literal, Protocol, TypedDict, cast
import av
import librosa
import numpy as np
from fastapi import WebSocket
from numpy.typing import NDArray
from pydub import AudioSegment
@@ -67,6 +68,7 @@ current_channel: ContextVar[DataChannel | None] = ContextVar(
@dataclass
class Context:
webrtc_id: str
websocket: WebSocket | None = None
current_context: ContextVar[Context | None] = ContextVar(

View File

@@ -5,9 +5,9 @@ import logging
from collections.abc import Awaitable, Callable
from typing import Any, cast
import anyio
import librosa
import numpy as np
from anyio.to_thread import run_sync
from fastapi import WebSocket
from fastapi.websockets import WebSocketDisconnect, WebSocketState
@@ -52,9 +52,6 @@ def convert_to_mulaw(
return audioop.lin2ulaw(audio_data, 2) # type: ignore
run_sync = anyio.to_thread.run_sync # type: ignore
class WebSocketHandler:
def __init__(
self,
@@ -94,21 +91,14 @@ class WebSocketHandler:
await websocket.accept()
loop = asyncio.get_running_loop()
self.loop = loop
self.websocket = websocket
self.data_channel = WebSocketDataChannel(websocket, loop)
self.stream_handler._loop = loop
self.stream_handler.set_channel(self.data_channel)
self._emit_task = asyncio.create_task(self._emit_loop())
self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue())
self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop())
if isinstance(self.stream_handler, AsyncStreamHandler):
start_up = self.stream_handler.start_up()
else:
start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore
was_disconnected = False
self.start_up_task = asyncio.create_task(start_up)
try:
while not self.quit.is_set():
if websocket.application_state != WebSocketState.CONNECTED:
@@ -157,7 +147,17 @@ class WebSocketHandler:
self.stream_id = cast(str, message["streamSid"])
else:
self.stream_id = cast(str, message["websocket_id"])
current_context.set(Context(webrtc_id=self.stream_id))
self.websocket = websocket
current_context.set(
Context(webrtc_id=self.stream_id, websocket=websocket)
)
if isinstance(self.stream_handler, AsyncStreamHandler):
start_up = self.stream_handler.start_up()
else:
start_up = run_sync(self.stream_handler.start_up) # type: ignore
self.start_up_task = asyncio.create_task(start_up)
self.set_additional_outputs = self.set_additional_outputs_factory(
self.stream_id
)
@@ -189,11 +189,15 @@ class WebSocketHandler:
self.clean_up(cast(str, self.stream_id))
def emit_with_context(self):
current_context.set(Context(webrtc_id=cast(str, self.stream_id)))
current_context.set(
Context(webrtc_id=cast(str, self.stream_id), websocket=self.websocket)
)
return self.stream_handler.emit()
def receive_with_context(self, frame: tuple[int, np.ndarray]):
current_context.set(Context(webrtc_id=cast(str, self.stream_id)))
current_context.set(
Context(webrtc_id=cast(str, self.stream_id), websocket=self.websocket)
)
return self.stream_handler.receive(frame)
async def _emit_to_queue(self):