mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 01:49:23 +08:00
Improve error handling for websockets (#238)
* Improve error handling for websockets * Correctly run clean_up
This commit is contained in:
@@ -8,6 +8,7 @@ import anyio
|
||||
import librosa
|
||||
import numpy as np
|
||||
from fastapi import WebSocket
|
||||
from fastapi.websockets import WebSocketDisconnect, WebSocketState
|
||||
|
||||
from .tracks import AsyncStreamHandler, StreamHandlerImpl
|
||||
from .utils import AdditionalOutputs, DataChannel, split_output
|
||||
@@ -97,9 +98,15 @@ class WebSocketHandler:
|
||||
else:
|
||||
start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore
|
||||
|
||||
was_disconnected = False
|
||||
|
||||
self.start_up_task = asyncio.create_task(start_up)
|
||||
try:
|
||||
while not self.quit.is_set():
|
||||
if websocket.application_state != WebSocketState.CONNECTED:
|
||||
was_disconnected = True
|
||||
break
|
||||
|
||||
message = await websocket.receive_json()
|
||||
|
||||
if message["event"] == "media":
|
||||
@@ -117,15 +124,23 @@ class WebSocketHandler:
|
||||
target_sr=self.stream_handler.input_sample_rate,
|
||||
)
|
||||
audio_array = (audio_array * 32768).astype(np.int16)
|
||||
if isinstance(self.stream_handler, AsyncStreamHandler):
|
||||
await self.stream_handler.receive(
|
||||
(self.stream_handler.input_sample_rate, audio_array)
|
||||
)
|
||||
else:
|
||||
await run_sync(
|
||||
self.stream_handler.receive,
|
||||
(self.stream_handler.input_sample_rate, audio_array),
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(self.stream_handler, AsyncStreamHandler):
|
||||
await self.stream_handler.receive(
|
||||
(self.stream_handler.input_sample_rate, audio_array)
|
||||
)
|
||||
else:
|
||||
await run_sync(
|
||||
self.stream_handler.receive,
|
||||
(self.stream_handler.input_sample_rate, audio_array),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.debug("Error in websocket handler %s", e)
|
||||
|
||||
elif message["event"] == "start":
|
||||
if self.stream_handler.phone_mode:
|
||||
@@ -138,17 +153,13 @@ class WebSocketHandler:
|
||||
await self.set_handler(self.stream_id, self)
|
||||
elif message["event"] == "stop":
|
||||
self.quit.set()
|
||||
self.clean_up(cast(str, self.stream_id))
|
||||
return
|
||||
return # Still runs the `finally` block
|
||||
elif message["event"] == "ping":
|
||||
await websocket.send_json({"event": "pong"})
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.debug("Error in websocket handler %s", e)
|
||||
except WebSocketDisconnect:
|
||||
# Surprisingly, this leaves `websocket.application_state` as CONNECTED
|
||||
# in the `finally` block, so we use this variable
|
||||
was_disconnected = True
|
||||
finally:
|
||||
if self._emit_task:
|
||||
self._emit_task.cancel()
|
||||
@@ -156,7 +167,11 @@ class WebSocketHandler:
|
||||
self._emit_to_queue_task.cancel()
|
||||
if self.start_up_task:
|
||||
self.start_up_task.cancel()
|
||||
await websocket.close()
|
||||
|
||||
if not was_disconnected:
|
||||
await websocket.close()
|
||||
|
||||
self.clean_up(cast(str, self.stream_id))
|
||||
|
||||
async def _emit_to_queue(self):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user