mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Tidy up connection logic (#90)
* Add code: * code * code --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
@@ -70,7 +70,7 @@ class WebRTCConnectionMixin:
|
|||||||
data_channels: dict[str, DataChannel] = {}
|
data_channels: dict[str, DataChannel] = {}
|
||||||
additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue)
|
additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue)
|
||||||
handlers: dict[str, HandlerType | Callable] = {}
|
handlers: dict[str, HandlerType | Callable] = {}
|
||||||
|
connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event)
|
||||||
concurrency_limit: int | float
|
concurrency_limit: int | float
|
||||||
event_handler: HandlerType
|
event_handler: HandlerType
|
||||||
time_limit: float | int | None
|
time_limit: float | int | None
|
||||||
@@ -82,8 +82,24 @@ class WebRTCConnectionMixin:
|
|||||||
await asyncio.sleep(time_limit)
|
await asyncio.sleep(time_limit)
|
||||||
await pc.close()
|
await pc.close()
|
||||||
|
|
||||||
|
async def connection_timeout(
|
||||||
|
self,
|
||||||
|
pc: RTCPeerConnection,
|
||||||
|
webrtc_id: str,
|
||||||
|
time_limit: float,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self.connection_timeouts[webrtc_id].wait(), time_limit
|
||||||
|
)
|
||||||
|
except (asyncio.TimeoutError, TimeoutError):
|
||||||
|
await pc.close()
|
||||||
|
self.connection_timeouts[webrtc_id].clear()
|
||||||
|
self.clean_up(webrtc_id)
|
||||||
|
|
||||||
def clean_up(self, webrtc_id: str):
|
def clean_up(self, webrtc_id: str):
|
||||||
self.handlers.pop(webrtc_id, None)
|
self.handlers.pop(webrtc_id, None)
|
||||||
|
self.connection_timeouts.pop(webrtc_id, None)
|
||||||
connection = self.connections.pop(webrtc_id, [])
|
connection = self.connections.pop(webrtc_id, [])
|
||||||
for conn in connection:
|
for conn in connection:
|
||||||
if isinstance(conn, AudioCallback):
|
if isinstance(conn, AudioCallback):
|
||||||
@@ -132,7 +148,7 @@ class WebRTCConnectionMixin:
|
|||||||
logger.debug("Offer body %s", body)
|
logger.debug("Offer body %s", body)
|
||||||
if len(self.connections) >= cast(int, self.concurrency_limit):
|
if len(self.connections) >= cast(int, self.concurrency_limit):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=429,
|
status_code=200,
|
||||||
content={
|
content={
|
||||||
"status": "failed",
|
"status": "failed",
|
||||||
"meta": {
|
"meta": {
|
||||||
@@ -181,6 +197,7 @@ class WebRTCConnectionMixin:
|
|||||||
conn.stop()
|
conn.stop()
|
||||||
self.pcs.discard(pc)
|
self.pcs.discard(pc)
|
||||||
if pc.connectionState == "connected":
|
if pc.connectionState == "connected":
|
||||||
|
self.connection_timeouts[body["webrtc_id"]].set()
|
||||||
if self.time_limit is not None:
|
if self.time_limit is not None:
|
||||||
asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit))
|
asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit))
|
||||||
|
|
||||||
@@ -269,7 +286,7 @@ class WebRTCConnectionMixin:
|
|||||||
|
|
||||||
# handle offer
|
# handle offer
|
||||||
await pc.setRemoteDescription(offer)
|
await pc.setRemoteDescription(offer)
|
||||||
|
asyncio.create_task(self.connection_timeout(pc, body["webrtc_id"], 30))
|
||||||
# send answer
|
# send answer
|
||||||
answer = await pc.createAnswer()
|
answer = await pc.createAnswer()
|
||||||
await pc.setLocalDescription(answer) # type: ignore
|
await pc.setLocalDescription(answer) # type: ignore
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
from typing import Generator, Literal
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import numpy as np
|
||||||
from fastrtc import (
|
from fastrtc import (
|
||||||
Stream,
|
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
audio_to_float32,
|
|
||||||
ReplyOnPause,
|
ReplyOnPause,
|
||||||
|
Stream,
|
||||||
|
audio_to_float32,
|
||||||
get_twilio_turn_credentials,
|
get_twilio_turn_credentials,
|
||||||
)
|
)
|
||||||
from functools import lru_cache
|
|
||||||
import gradio as gr
|
|
||||||
from typing import Generator, Literal
|
|
||||||
from numpy.typing import NDArray
|
|
||||||
import numpy as np
|
|
||||||
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
|
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ async def stream_updates(webrtc_id: str):
|
|||||||
|
|
||||||
### Handling Errors
|
### Handling Errors
|
||||||
|
|
||||||
When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` route with a JSON response. If there are too many connections, the server will respond with a 429 error.
|
When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` route with a JSON response. If there are too many connections, the server will respond with a 200 error.
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -122,6 +122,8 @@ When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` rou
|
|||||||
|
|
||||||
Over `WebSocket`, the server will send the same message before closing the connection.
|
Over `WebSocket`, the server will send the same message before closing the connection.
|
||||||
|
|
||||||
|
!!! tip
|
||||||
|
The server will sends a 200 status code because otherwise the gradio client will not be able to process the json response and display the error.
|
||||||
|
|
||||||
<style>
|
<style>
|
||||||
.config-selector {
|
.config-selector {
|
||||||
|
|||||||
@@ -154,6 +154,13 @@
|
|||||||
_time_limit = null;
|
_time_limit = null;
|
||||||
stop(pc);
|
stop(pc);
|
||||||
break;
|
break;
|
||||||
|
case "failed":
|
||||||
|
console.info("failed");
|
||||||
|
stream_state = "closed";
|
||||||
|
_time_limit = null;
|
||||||
|
dispatch("error", "Connection failed!");
|
||||||
|
stop(pc);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -209,6 +216,7 @@
|
|||||||
})
|
})
|
||||||
.catch(() => {
|
.catch(() => {
|
||||||
console.info("catching");
|
console.info("catching");
|
||||||
|
clearTimeout(timeoutId);
|
||||||
stream_state = "closed";
|
stream_state = "closed";
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,6 +60,11 @@
|
|||||||
console.info("closed");
|
console.info("closed");
|
||||||
stop(pc);
|
stop(pc);
|
||||||
break;
|
break;
|
||||||
|
case "failed":
|
||||||
|
stream_state = "closed";
|
||||||
|
dispatch("error", "Connection failed!");
|
||||||
|
stop(pc);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,6 +47,11 @@
|
|||||||
case "disconnected":
|
case "disconnected":
|
||||||
stop(pc);
|
stop(pc);
|
||||||
break;
|
break;
|
||||||
|
case "failed":
|
||||||
|
stream_state = "closed";
|
||||||
|
dispatch("error", "Connection failed!");
|
||||||
|
stop(pc);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -155,6 +155,12 @@
|
|||||||
stop(pc);
|
stop(pc);
|
||||||
await access_webcam();
|
await access_webcam();
|
||||||
break;
|
break;
|
||||||
|
case "failed":
|
||||||
|
stream_state = "closed";
|
||||||
|
_time_limit = null;
|
||||||
|
dispatch("error", "Connection failed!");
|
||||||
|
stop(pc);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "fastrtc"
|
name = "fastrtc"
|
||||||
version = "0.0.6"
|
version = "0.0.8post1"
|
||||||
description = "The realtime communication library for Python"
|
description = "The realtime communication library for Python"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "apache-2.0"
|
license = "apache-2.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user