Tidy up connection logic (#90)

* Add code:

* code

* code

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
Freddy Boulton
2025-02-26 18:21:26 -05:00
committed by GitHub
parent e44341d781
commit 43e42c1b22
8 changed files with 56 additions and 12 deletions

View File

@@ -70,7 +70,7 @@ class WebRTCConnectionMixin:
data_channels: dict[str, DataChannel] = {}
additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue)
handlers: dict[str, HandlerType | Callable] = {}
connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event)
concurrency_limit: int | float
event_handler: HandlerType
time_limit: float | int | None
@@ -82,8 +82,24 @@ class WebRTCConnectionMixin:
await asyncio.sleep(time_limit)
await pc.close()
async def connection_timeout(
self,
pc: RTCPeerConnection,
webrtc_id: str,
time_limit: float,
):
try:
await asyncio.wait_for(
self.connection_timeouts[webrtc_id].wait(), time_limit
)
except (asyncio.TimeoutError, TimeoutError):
await pc.close()
self.connection_timeouts[webrtc_id].clear()
self.clean_up(webrtc_id)
def clean_up(self, webrtc_id: str):
self.handlers.pop(webrtc_id, None)
self.connection_timeouts.pop(webrtc_id, None)
connection = self.connections.pop(webrtc_id, [])
for conn in connection:
if isinstance(conn, AudioCallback):
@@ -132,7 +148,7 @@ class WebRTCConnectionMixin:
logger.debug("Offer body %s", body)
if len(self.connections) >= cast(int, self.concurrency_limit):
return JSONResponse(
status_code=429,
status_code=200,
content={
"status": "failed",
"meta": {
@@ -181,6 +197,7 @@ class WebRTCConnectionMixin:
conn.stop()
self.pcs.discard(pc)
if pc.connectionState == "connected":
self.connection_timeouts[body["webrtc_id"]].set()
if self.time_limit is not None:
asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit))
@@ -269,7 +286,7 @@ class WebRTCConnectionMixin:
# handle offer
await pc.setRemoteDescription(offer)
asyncio.create_task(self.connection_timeout(pc, body["webrtc_id"], 30))
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer) # type: ignore

View File

@@ -1,16 +1,17 @@
from functools import lru_cache
from typing import Generator, Literal
import gradio as gr
import numpy as np
from fastrtc import (
Stream,
AdditionalOutputs,
audio_to_float32,
ReplyOnPause,
Stream,
audio_to_float32,
get_twilio_turn_credentials,
)
from functools import lru_cache
import gradio as gr
from typing import Generator, Literal
from numpy.typing import NDArray
import numpy as np
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
from numpy.typing import NDArray
@lru_cache(maxsize=None)

View File

@@ -109,7 +109,7 @@ async def stream_updates(webrtc_id: str):
### Handling Errors
When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` route with a JSON response. If there are too many connections, the server will respond with a 429 error.
When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` route with a JSON response. If there are too many connections, the server will respond with a 200 error.
```json
{
@@ -122,6 +122,8 @@ When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` rou
Over `WebSocket`, the server will send the same message before closing the connection.
!!! tip
The server will sends a 200 status code because otherwise the gradio client will not be able to process the json response and display the error.
<style>
.config-selector {

View File

@@ -154,6 +154,13 @@
_time_limit = null;
stop(pc);
break;
case "failed":
console.info("failed");
stream_state = "closed";
_time_limit = null;
dispatch("error", "Connection failed!");
stop(pc);
break;
default:
break;
}
@@ -209,6 +216,7 @@
})
.catch(() => {
console.info("catching");
clearTimeout(timeoutId);
stream_state = "closed";
});
}

View File

@@ -60,6 +60,11 @@
console.info("closed");
stop(pc);
break;
case "failed":
stream_state = "closed";
dispatch("error", "Connection failed!");
stop(pc);
break;
default:
break;
}

View File

@@ -47,6 +47,11 @@
case "disconnected":
stop(pc);
break;
case "failed":
stream_state = "closed";
dispatch("error", "Connection failed!");
stop(pc);
break;
default:
break;
}

View File

@@ -155,6 +155,12 @@
stop(pc);
await access_webcam();
break;
case "failed":
stream_state = "closed";
_time_limit = null;
dispatch("error", "Connection failed!");
stop(pc);
break;
default:
break;
}

View File

@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
[project]
name = "fastrtc"
version = "0.0.6"
version = "0.0.8post1"
description = "The realtime communication library for Python"
readme = "README.md"
license = "apache-2.0"