mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
Pass Websocket to the context if available (#329)
* Add code * Code * Fix * Add code
This commit is contained in:
@@ -16,6 +16,7 @@ from typing import Any, Literal, Protocol, TypedDict, cast
|
||||
import av
|
||||
import librosa
|
||||
import numpy as np
|
||||
from fastapi import WebSocket
|
||||
from numpy.typing import NDArray
|
||||
from pydub import AudioSegment
|
||||
|
||||
@@ -67,6 +68,7 @@ current_channel: ContextVar[DataChannel | None] = ContextVar(
|
||||
@dataclass
|
||||
class Context:
|
||||
webrtc_id: str
|
||||
websocket: WebSocket | None = None
|
||||
|
||||
|
||||
current_context: ContextVar[Context | None] = ContextVar(
|
||||
|
||||
@@ -5,9 +5,9 @@ import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import anyio
|
||||
import librosa
|
||||
import numpy as np
|
||||
from anyio.to_thread import run_sync
|
||||
from fastapi import WebSocket
|
||||
from fastapi.websockets import WebSocketDisconnect, WebSocketState
|
||||
|
||||
@@ -52,9 +52,6 @@ def convert_to_mulaw(
|
||||
return audioop.lin2ulaw(audio_data, 2) # type: ignore
|
||||
|
||||
|
||||
run_sync = anyio.to_thread.run_sync # type: ignore
|
||||
|
||||
|
||||
class WebSocketHandler:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -94,21 +91,14 @@ class WebSocketHandler:
|
||||
await websocket.accept()
|
||||
loop = asyncio.get_running_loop()
|
||||
self.loop = loop
|
||||
self.websocket = websocket
|
||||
self.data_channel = WebSocketDataChannel(websocket, loop)
|
||||
self.stream_handler._loop = loop
|
||||
self.stream_handler.set_channel(self.data_channel)
|
||||
self._emit_task = asyncio.create_task(self._emit_loop())
|
||||
self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue())
|
||||
self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop())
|
||||
if isinstance(self.stream_handler, AsyncStreamHandler):
|
||||
start_up = self.stream_handler.start_up()
|
||||
else:
|
||||
start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore
|
||||
|
||||
was_disconnected = False
|
||||
|
||||
self.start_up_task = asyncio.create_task(start_up)
|
||||
try:
|
||||
while not self.quit.is_set():
|
||||
if websocket.application_state != WebSocketState.CONNECTED:
|
||||
@@ -157,7 +147,17 @@ class WebSocketHandler:
|
||||
self.stream_id = cast(str, message["streamSid"])
|
||||
else:
|
||||
self.stream_id = cast(str, message["websocket_id"])
|
||||
current_context.set(Context(webrtc_id=self.stream_id))
|
||||
self.websocket = websocket
|
||||
current_context.set(
|
||||
Context(webrtc_id=self.stream_id, websocket=websocket)
|
||||
)
|
||||
if isinstance(self.stream_handler, AsyncStreamHandler):
|
||||
start_up = self.stream_handler.start_up()
|
||||
else:
|
||||
start_up = run_sync(self.stream_handler.start_up) # type: ignore
|
||||
|
||||
self.start_up_task = asyncio.create_task(start_up)
|
||||
|
||||
self.set_additional_outputs = self.set_additional_outputs_factory(
|
||||
self.stream_id
|
||||
)
|
||||
@@ -189,11 +189,15 @@ class WebSocketHandler:
|
||||
self.clean_up(cast(str, self.stream_id))
|
||||
|
||||
def emit_with_context(self):
|
||||
current_context.set(Context(webrtc_id=cast(str, self.stream_id)))
|
||||
current_context.set(
|
||||
Context(webrtc_id=cast(str, self.stream_id), websocket=self.websocket)
|
||||
)
|
||||
return self.stream_handler.emit()
|
||||
|
||||
def receive_with_context(self, frame: tuple[int, np.ndarray]):
|
||||
current_context.set(Context(webrtc_id=cast(str, self.stream_id)))
|
||||
current_context.set(
|
||||
Context(webrtc_id=cast(str, self.stream_id), websocket=self.websocket)
|
||||
)
|
||||
return self.stream_handler.receive(frame)
|
||||
|
||||
async def _emit_to_queue(self):
|
||||
|
||||
Reference in New Issue
Block a user