Tidy up connection logic (#90)

* Add code:

* code

* code

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
Freddy Boulton
2025-02-26 18:21:26 -05:00
committed by GitHub
parent e44341d781
commit 43e42c1b22
8 changed files with 56 additions and 12 deletions

View File

@@ -70,7 +70,7 @@ class WebRTCConnectionMixin:
data_channels: dict[str, DataChannel] = {} data_channels: dict[str, DataChannel] = {}
additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue) additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue)
handlers: dict[str, HandlerType | Callable] = {} handlers: dict[str, HandlerType | Callable] = {}
connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event)
concurrency_limit: int | float concurrency_limit: int | float
event_handler: HandlerType event_handler: HandlerType
time_limit: float | int | None time_limit: float | int | None
@@ -82,8 +82,24 @@ class WebRTCConnectionMixin:
await asyncio.sleep(time_limit) await asyncio.sleep(time_limit)
await pc.close() await pc.close()
async def connection_timeout(
self,
pc: RTCPeerConnection,
webrtc_id: str,
time_limit: float,
):
try:
await asyncio.wait_for(
self.connection_timeouts[webrtc_id].wait(), time_limit
)
except (asyncio.TimeoutError, TimeoutError):
await pc.close()
self.connection_timeouts[webrtc_id].clear()
self.clean_up(webrtc_id)
def clean_up(self, webrtc_id: str): def clean_up(self, webrtc_id: str):
self.handlers.pop(webrtc_id, None) self.handlers.pop(webrtc_id, None)
self.connection_timeouts.pop(webrtc_id, None)
connection = self.connections.pop(webrtc_id, []) connection = self.connections.pop(webrtc_id, [])
for conn in connection: for conn in connection:
if isinstance(conn, AudioCallback): if isinstance(conn, AudioCallback):
@@ -132,7 +148,7 @@ class WebRTCConnectionMixin:
logger.debug("Offer body %s", body) logger.debug("Offer body %s", body)
if len(self.connections) >= cast(int, self.concurrency_limit): if len(self.connections) >= cast(int, self.concurrency_limit):
return JSONResponse( return JSONResponse(
status_code=429, status_code=200,
content={ content={
"status": "failed", "status": "failed",
"meta": { "meta": {
@@ -181,6 +197,7 @@ class WebRTCConnectionMixin:
conn.stop() conn.stop()
self.pcs.discard(pc) self.pcs.discard(pc)
if pc.connectionState == "connected": if pc.connectionState == "connected":
self.connection_timeouts[body["webrtc_id"]].set()
if self.time_limit is not None: if self.time_limit is not None:
asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit)) asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit))
@@ -269,7 +286,7 @@ class WebRTCConnectionMixin:
# handle offer # handle offer
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)
asyncio.create_task(self.connection_timeout(pc, body["webrtc_id"], 30))
# send answer # send answer
answer = await pc.createAnswer() answer = await pc.createAnswer()
await pc.setLocalDescription(answer) # type: ignore await pc.setLocalDescription(answer) # type: ignore

View File

@@ -1,16 +1,17 @@
from functools import lru_cache
from typing import Generator, Literal
import gradio as gr
import numpy as np
from fastrtc import ( from fastrtc import (
Stream,
AdditionalOutputs, AdditionalOutputs,
audio_to_float32,
ReplyOnPause, ReplyOnPause,
Stream,
audio_to_float32,
get_twilio_turn_credentials, get_twilio_turn_credentials,
) )
from functools import lru_cache
import gradio as gr
from typing import Generator, Literal
from numpy.typing import NDArray
import numpy as np
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
from numpy.typing import NDArray
@lru_cache(maxsize=None) @lru_cache(maxsize=None)

View File

@@ -109,7 +109,7 @@ async def stream_updates(webrtc_id: str):
### Handling Errors ### Handling Errors
When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` route with a JSON response. If there are too many connections, the server will respond with a 429 error. When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` route with a JSON response. If there are too many connections, the server will respond with a 200 error.
```json ```json
{ {
@@ -122,6 +122,8 @@ When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` rou
Over `WebSocket`, the server will send the same message before closing the connection. Over `WebSocket`, the server will send the same message before closing the connection.
!!! tip
The server will sends a 200 status code because otherwise the gradio client will not be able to process the json response and display the error.
<style> <style>
.config-selector { .config-selector {

View File

@@ -154,6 +154,13 @@
_time_limit = null; _time_limit = null;
stop(pc); stop(pc);
break; break;
case "failed":
console.info("failed");
stream_state = "closed";
_time_limit = null;
dispatch("error", "Connection failed!");
stop(pc);
break;
default: default:
break; break;
} }
@@ -209,6 +216,7 @@
}) })
.catch(() => { .catch(() => {
console.info("catching"); console.info("catching");
clearTimeout(timeoutId);
stream_state = "closed"; stream_state = "closed";
}); });
} }

View File

@@ -60,6 +60,11 @@
console.info("closed"); console.info("closed");
stop(pc); stop(pc);
break; break;
case "failed":
stream_state = "closed";
dispatch("error", "Connection failed!");
stop(pc);
break;
default: default:
break; break;
} }

View File

@@ -47,6 +47,11 @@
case "disconnected": case "disconnected":
stop(pc); stop(pc);
break; break;
case "failed":
stream_state = "closed";
dispatch("error", "Connection failed!");
stop(pc);
break;
default: default:
break; break;
} }

View File

@@ -155,6 +155,12 @@
stop(pc); stop(pc);
await access_webcam(); await access_webcam();
break; break;
case "failed":
stream_state = "closed";
_time_limit = null;
dispatch("error", "Connection failed!");
stop(pc);
break;
default: default:
break; break;
} }

View File

@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "fastrtc" name = "fastrtc"
version = "0.0.6" version = "0.0.8post1"
description = "The realtime communication library for Python" description = "The realtime communication library for Python"
readme = "README.md" readme = "README.md"
license = "apache-2.0" license = "apache-2.0"