Improve error handling for websockets (#238)

* Improve error handling for websockets

* Correctly run clean_up
This commit is contained in:
Václav Volhejn
2025-04-02 22:29:03 +02:00
committed by GitHub
parent 06885d06c4
commit cc8d82f233

View File

@@ -8,6 +8,7 @@ import anyio
import librosa
import numpy as np
from fastapi import WebSocket
from fastapi.websockets import WebSocketDisconnect, WebSocketState
from .tracks import AsyncStreamHandler, StreamHandlerImpl
from .utils import AdditionalOutputs, DataChannel, split_output
@@ -97,9 +98,15 @@ class WebSocketHandler:
else:
start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore
was_disconnected = False
self.start_up_task = asyncio.create_task(start_up)
try:
while not self.quit.is_set():
if websocket.application_state != WebSocketState.CONNECTED:
was_disconnected = True
break
message = await websocket.receive_json()
if message["event"] == "media":
@@ -117,15 +124,23 @@ class WebSocketHandler:
target_sr=self.stream_handler.input_sample_rate,
)
audio_array = (audio_array * 32768).astype(np.int16)
if isinstance(self.stream_handler, AsyncStreamHandler):
await self.stream_handler.receive(
(self.stream_handler.input_sample_rate, audio_array)
)
else:
await run_sync(
self.stream_handler.receive,
(self.stream_handler.input_sample_rate, audio_array),
)
try:
if isinstance(self.stream_handler, AsyncStreamHandler):
await self.stream_handler.receive(
(self.stream_handler.input_sample_rate, audio_array)
)
else:
await run_sync(
self.stream_handler.receive,
(self.stream_handler.input_sample_rate, audio_array),
)
except Exception as e:
print(e)
import traceback
traceback.print_exc()
logger.debug("Error in websocket handler %s", e)
elif message["event"] == "start":
if self.stream_handler.phone_mode:
@@ -138,17 +153,13 @@ class WebSocketHandler:
await self.set_handler(self.stream_id, self)
elif message["event"] == "stop":
self.quit.set()
self.clean_up(cast(str, self.stream_id))
return
return # Still runs the `finally` block
elif message["event"] == "ping":
await websocket.send_json({"event": "pong"})
except Exception as e:
print(e)
import traceback
traceback.print_exc()
logger.debug("Error in websocket handler %s", e)
except WebSocketDisconnect:
# Surprisingly, this leaves `websocket.application_state` as CONNECTED
# in the `finally` block, so we use this variable
was_disconnected = True
finally:
if self._emit_task:
self._emit_task.cancel()
@@ -156,7 +167,11 @@ class WebSocketHandler:
self._emit_to_queue_task.cancel()
if self.start_up_task:
self.start_up_task.cancel()
await websocket.close()
if not was_disconnected:
await websocket.close()
self.clean_up(cast(str, self.stream_id))
async def _emit_to_queue(self):
try: