From cc8d82f23384ab8865dc10c0194da4477b81fca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=A1clav=20Volhejn?= <8401624+vvolhejn@users.noreply.github.com> Date: Wed, 2 Apr 2025 22:29:03 +0200 Subject: [PATCH] Improve error handling for websockets (#238) * Improve error handling for websockets * Correctly run clean_up --- backend/fastrtc/websocket.py | 53 +++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/backend/fastrtc/websocket.py b/backend/fastrtc/websocket.py index 4d182b9..5d015de 100644 --- a/backend/fastrtc/websocket.py +++ b/backend/fastrtc/websocket.py @@ -8,6 +8,7 @@ import anyio import librosa import numpy as np from fastapi import WebSocket +from fastapi.websockets import WebSocketDisconnect, WebSocketState from .tracks import AsyncStreamHandler, StreamHandlerImpl from .utils import AdditionalOutputs, DataChannel, split_output @@ -97,9 +98,15 @@ class WebSocketHandler: else: start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore + was_disconnected = False + self.start_up_task = asyncio.create_task(start_up) try: while not self.quit.is_set(): + if websocket.application_state != WebSocketState.CONNECTED: + was_disconnected = True + break + message = await websocket.receive_json() if message["event"] == "media": @@ -117,15 +124,23 @@ class WebSocketHandler: target_sr=self.stream_handler.input_sample_rate, ) audio_array = (audio_array * 32768).astype(np.int16) - if isinstance(self.stream_handler, AsyncStreamHandler): - await self.stream_handler.receive( - (self.stream_handler.input_sample_rate, audio_array) - ) - else: - await run_sync( - self.stream_handler.receive, - (self.stream_handler.input_sample_rate, audio_array), - ) + + try: + if isinstance(self.stream_handler, AsyncStreamHandler): + await self.stream_handler.receive( + (self.stream_handler.input_sample_rate, audio_array) + ) + else: + await run_sync( + self.stream_handler.receive, + (self.stream_handler.input_sample_rate, audio_array), + ) + except Exception as e: + print(e) + import traceback + + traceback.print_exc() + logger.debug("Error in websocket handler %s", e) elif message["event"] == "start": if self.stream_handler.phone_mode: @@ -138,17 +153,13 @@ class WebSocketHandler: await self.set_handler(self.stream_id, self) elif message["event"] == "stop": self.quit.set() - self.clean_up(cast(str, self.stream_id)) - return + return # Still runs the `finally` block elif message["event"] == "ping": await websocket.send_json({"event": "pong"}) - - except Exception as e: - print(e) - import traceback - - traceback.print_exc() - logger.debug("Error in websocket handler %s", e) + except WebSocketDisconnect: + # Surprisingly, this leaves `websocket.application_state` as CONNECTED + # in the `finally` block, so we use this variable + was_disconnected = True finally: if self._emit_task: self._emit_task.cancel() @@ -156,7 +167,11 @@ class WebSocketHandler: self._emit_to_queue_task.cancel() if self.start_up_task: self.start_up_task.cancel() - await websocket.close() + + if not was_disconnected: + await websocket.close() + + self.clean_up(cast(str, self.stream_id)) async def _emit_to_queue(self): try: