mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Tidy up connection logic (#90)
* Add code: * code * code --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
@@ -70,7 +70,7 @@ class WebRTCConnectionMixin:
|
||||
data_channels: dict[str, DataChannel] = {}
|
||||
additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue)
|
||||
handlers: dict[str, HandlerType | Callable] = {}
|
||||
|
||||
connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event)
|
||||
concurrency_limit: int | float
|
||||
event_handler: HandlerType
|
||||
time_limit: float | int | None
|
||||
@@ -82,8 +82,24 @@ class WebRTCConnectionMixin:
|
||||
await asyncio.sleep(time_limit)
|
||||
await pc.close()
|
||||
|
||||
async def connection_timeout(
|
||||
self,
|
||||
pc: RTCPeerConnection,
|
||||
webrtc_id: str,
|
||||
time_limit: float,
|
||||
):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.connection_timeouts[webrtc_id].wait(), time_limit
|
||||
)
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
await pc.close()
|
||||
self.connection_timeouts[webrtc_id].clear()
|
||||
self.clean_up(webrtc_id)
|
||||
|
||||
def clean_up(self, webrtc_id: str):
|
||||
self.handlers.pop(webrtc_id, None)
|
||||
self.connection_timeouts.pop(webrtc_id, None)
|
||||
connection = self.connections.pop(webrtc_id, [])
|
||||
for conn in connection:
|
||||
if isinstance(conn, AudioCallback):
|
||||
@@ -132,7 +148,7 @@ class WebRTCConnectionMixin:
|
||||
logger.debug("Offer body %s", body)
|
||||
if len(self.connections) >= cast(int, self.concurrency_limit):
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
status_code=200,
|
||||
content={
|
||||
"status": "failed",
|
||||
"meta": {
|
||||
@@ -181,6 +197,7 @@ class WebRTCConnectionMixin:
|
||||
conn.stop()
|
||||
self.pcs.discard(pc)
|
||||
if pc.connectionState == "connected":
|
||||
self.connection_timeouts[body["webrtc_id"]].set()
|
||||
if self.time_limit is not None:
|
||||
asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit))
|
||||
|
||||
@@ -269,7 +286,7 @@ class WebRTCConnectionMixin:
|
||||
|
||||
# handle offer
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
asyncio.create_task(self.connection_timeout(pc, body["webrtc_id"], 30))
|
||||
# send answer
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer) # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user