mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Pass Websocket to the context if available (#329)
* Add code * Code * Fix * Add code
This commit is contained in:
@@ -16,6 +16,7 @@ from typing import Any, Literal, Protocol, TypedDict, cast
|
|||||||
import av
|
import av
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from fastapi import WebSocket
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
|
||||||
@@ -67,6 +68,7 @@ current_channel: ContextVar[DataChannel | None] = ContextVar(
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Context:
|
class Context:
|
||||||
webrtc_id: str
|
webrtc_id: str
|
||||||
|
websocket: WebSocket | None = None
|
||||||
|
|
||||||
|
|
||||||
current_context: ContextVar[Context | None] = ContextVar(
|
current_context: ContextVar[Context | None] = ContextVar(
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import logging
|
|||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import anyio
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from anyio.to_thread import run_sync
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
from fastapi.websockets import WebSocketDisconnect, WebSocketState
|
from fastapi.websockets import WebSocketDisconnect, WebSocketState
|
||||||
|
|
||||||
@@ -52,9 +52,6 @@ def convert_to_mulaw(
|
|||||||
return audioop.lin2ulaw(audio_data, 2) # type: ignore
|
return audioop.lin2ulaw(audio_data, 2) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
run_sync = anyio.to_thread.run_sync # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketHandler:
|
class WebSocketHandler:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -94,21 +91,14 @@ class WebSocketHandler:
|
|||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.websocket = websocket
|
|
||||||
self.data_channel = WebSocketDataChannel(websocket, loop)
|
self.data_channel = WebSocketDataChannel(websocket, loop)
|
||||||
self.stream_handler._loop = loop
|
self.stream_handler._loop = loop
|
||||||
self.stream_handler.set_channel(self.data_channel)
|
self.stream_handler.set_channel(self.data_channel)
|
||||||
self._emit_task = asyncio.create_task(self._emit_loop())
|
self._emit_task = asyncio.create_task(self._emit_loop())
|
||||||
self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue())
|
self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue())
|
||||||
self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop())
|
self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop())
|
||||||
if isinstance(self.stream_handler, AsyncStreamHandler):
|
|
||||||
start_up = self.stream_handler.start_up()
|
|
||||||
else:
|
|
||||||
start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore
|
|
||||||
|
|
||||||
was_disconnected = False
|
was_disconnected = False
|
||||||
|
|
||||||
self.start_up_task = asyncio.create_task(start_up)
|
|
||||||
try:
|
try:
|
||||||
while not self.quit.is_set():
|
while not self.quit.is_set():
|
||||||
if websocket.application_state != WebSocketState.CONNECTED:
|
if websocket.application_state != WebSocketState.CONNECTED:
|
||||||
@@ -157,7 +147,17 @@ class WebSocketHandler:
|
|||||||
self.stream_id = cast(str, message["streamSid"])
|
self.stream_id = cast(str, message["streamSid"])
|
||||||
else:
|
else:
|
||||||
self.stream_id = cast(str, message["websocket_id"])
|
self.stream_id = cast(str, message["websocket_id"])
|
||||||
current_context.set(Context(webrtc_id=self.stream_id))
|
self.websocket = websocket
|
||||||
|
current_context.set(
|
||||||
|
Context(webrtc_id=self.stream_id, websocket=websocket)
|
||||||
|
)
|
||||||
|
if isinstance(self.stream_handler, AsyncStreamHandler):
|
||||||
|
start_up = self.stream_handler.start_up()
|
||||||
|
else:
|
||||||
|
start_up = run_sync(self.stream_handler.start_up) # type: ignore
|
||||||
|
|
||||||
|
self.start_up_task = asyncio.create_task(start_up)
|
||||||
|
|
||||||
self.set_additional_outputs = self.set_additional_outputs_factory(
|
self.set_additional_outputs = self.set_additional_outputs_factory(
|
||||||
self.stream_id
|
self.stream_id
|
||||||
)
|
)
|
||||||
@@ -189,11 +189,15 @@ class WebSocketHandler:
|
|||||||
self.clean_up(cast(str, self.stream_id))
|
self.clean_up(cast(str, self.stream_id))
|
||||||
|
|
||||||
def emit_with_context(self):
|
def emit_with_context(self):
|
||||||
current_context.set(Context(webrtc_id=cast(str, self.stream_id)))
|
current_context.set(
|
||||||
|
Context(webrtc_id=cast(str, self.stream_id), websocket=self.websocket)
|
||||||
|
)
|
||||||
return self.stream_handler.emit()
|
return self.stream_handler.emit()
|
||||||
|
|
||||||
def receive_with_context(self, frame: tuple[int, np.ndarray]):
|
def receive_with_context(self, frame: tuple[int, np.ndarray]):
|
||||||
current_context.set(Context(webrtc_id=cast(str, self.stream_id)))
|
current_context.set(
|
||||||
|
Context(webrtc_id=cast(str, self.stream_id), websocket=self.websocket)
|
||||||
|
)
|
||||||
return self.stream_handler.receive(frame)
|
return self.stream_handler.receive(frame)
|
||||||
|
|
||||||
async def _emit_to_queue(self):
|
async def _emit_to_queue(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user