Improve error handling for websockets (#238)

* Improve error handling for websockets

* Correctly run clean_up
This commit is contained in:
Václav Volhejn
2025-04-02 22:29:03 +02:00
committed by GitHub
parent 06885d06c4
commit cc8d82f233

View File

@@ -8,6 +8,7 @@ import anyio
import librosa import librosa
import numpy as np import numpy as np
from fastapi import WebSocket from fastapi import WebSocket
from fastapi.websockets import WebSocketDisconnect, WebSocketState
from .tracks import AsyncStreamHandler, StreamHandlerImpl from .tracks import AsyncStreamHandler, StreamHandlerImpl
from .utils import AdditionalOutputs, DataChannel, split_output from .utils import AdditionalOutputs, DataChannel, split_output
@@ -97,9 +98,15 @@ class WebSocketHandler:
else: else:
start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore
was_disconnected = False
self.start_up_task = asyncio.create_task(start_up) self.start_up_task = asyncio.create_task(start_up)
try: try:
while not self.quit.is_set(): while not self.quit.is_set():
if websocket.application_state != WebSocketState.CONNECTED:
was_disconnected = True
break
message = await websocket.receive_json() message = await websocket.receive_json()
if message["event"] == "media": if message["event"] == "media":
@@ -117,6 +124,8 @@ class WebSocketHandler:
target_sr=self.stream_handler.input_sample_rate, target_sr=self.stream_handler.input_sample_rate,
) )
audio_array = (audio_array * 32768).astype(np.int16) audio_array = (audio_array * 32768).astype(np.int16)
try:
if isinstance(self.stream_handler, AsyncStreamHandler): if isinstance(self.stream_handler, AsyncStreamHandler):
await self.stream_handler.receive( await self.stream_handler.receive(
(self.stream_handler.input_sample_rate, audio_array) (self.stream_handler.input_sample_rate, audio_array)
@@ -126,6 +135,12 @@ class WebSocketHandler:
self.stream_handler.receive, self.stream_handler.receive,
(self.stream_handler.input_sample_rate, audio_array), (self.stream_handler.input_sample_rate, audio_array),
) )
except Exception as e:
print(e)
import traceback
traceback.print_exc()
logger.debug("Error in websocket handler %s", e)
elif message["event"] == "start": elif message["event"] == "start":
if self.stream_handler.phone_mode: if self.stream_handler.phone_mode:
@@ -138,17 +153,13 @@ class WebSocketHandler:
await self.set_handler(self.stream_id, self) await self.set_handler(self.stream_id, self)
elif message["event"] == "stop": elif message["event"] == "stop":
self.quit.set() self.quit.set()
self.clean_up(cast(str, self.stream_id)) return # Still runs the `finally` block
return
elif message["event"] == "ping": elif message["event"] == "ping":
await websocket.send_json({"event": "pong"}) await websocket.send_json({"event": "pong"})
except WebSocketDisconnect:
except Exception as e: # Surprisingly, this leaves `websocket.application_state` as CONNECTED
print(e) # in the `finally` block, so we use this variable
import traceback was_disconnected = True
traceback.print_exc()
logger.debug("Error in websocket handler %s", e)
finally: finally:
if self._emit_task: if self._emit_task:
self._emit_task.cancel() self._emit_task.cancel()
@@ -156,8 +167,12 @@ class WebSocketHandler:
self._emit_to_queue_task.cancel() self._emit_to_queue_task.cancel()
if self.start_up_task: if self.start_up_task:
self.start_up_task.cancel() self.start_up_task.cancel()
if not was_disconnected:
await websocket.close() await websocket.close()
self.clean_up(cast(str, self.stream_id))
async def _emit_to_queue(self): async def _emit_to_queue(self):
try: try:
while not self.quit.is_set(): while not self.quit.is_set():