mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 09:29:23 +08:00
add code (#223)
This commit is contained in:
@@ -34,6 +34,7 @@ from .utils import (
|
||||
audio_to_file,
|
||||
audio_to_float32,
|
||||
audio_to_int16,
|
||||
get_current_context,
|
||||
wait_for_item,
|
||||
)
|
||||
from .webrtc import (
|
||||
@@ -77,4 +78,5 @@ __all__ = [
|
||||
"SileroVadOptions",
|
||||
"VideoStreamHandler",
|
||||
"CloseStream",
|
||||
"get_current_context",
|
||||
]
|
||||
|
||||
@@ -38,10 +38,12 @@ from numpy import typing as npt
|
||||
from fastrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
CloseStream,
|
||||
Context,
|
||||
DataChannel,
|
||||
WebRTCError,
|
||||
create_message,
|
||||
current_channel,
|
||||
current_context,
|
||||
player_worker_decode,
|
||||
split_output,
|
||||
)
|
||||
@@ -83,6 +85,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: VideoEventHandler,
|
||||
context: Context,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
mode: Literal["send-receive", "send"] = "send-receive",
|
||||
@@ -104,10 +107,12 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.skip_frames = skip_frames
|
||||
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
|
||||
self.latest_frame = None
|
||||
self.context = context
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
current_channel.set(channel)
|
||||
current_context.set(self.context)
|
||||
self.channel_set.set()
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
@@ -145,6 +150,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.thread_quit.set()
|
||||
|
||||
async def wait_for_channel(self):
|
||||
current_context.set(self.context)
|
||||
if not self.channel_set.is_set():
|
||||
await self.channel_set.wait()
|
||||
if current_channel.get() != self.channel:
|
||||
@@ -486,6 +492,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: StreamHandlerBase,
|
||||
context: Context,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
@@ -502,6 +509,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
self.last_timestamp = 0
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.context = context
|
||||
|
||||
def clear_queue(self):
|
||||
logger.debug("clearing queue")
|
||||
@@ -514,6 +522,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
self._start = None
|
||||
|
||||
async def wait_for_channel(self):
|
||||
current_context.set(self.context)
|
||||
if not self.event_handler.channel_set.is_set():
|
||||
await self.event_handler.channel_set.wait()
|
||||
if current_channel.get() != self.event_handler.channel:
|
||||
@@ -532,6 +541,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
|
||||
def event_handler_emit(self) -> EmitType:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
current_context.set(self.context)
|
||||
return cast(Callable, self.event_handler.emit)()
|
||||
|
||||
async def process_input_frames(self) -> None:
|
||||
@@ -649,6 +659,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
def __init__(
|
||||
self,
|
||||
event_handler: Callable,
|
||||
context: Context,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
fps: int = 30,
|
||||
@@ -662,6 +673,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.fps = fps
|
||||
self.frame_ptime = 1.0 / fps
|
||||
self.context = context
|
||||
|
||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||
@@ -693,6 +705,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
pts, time_base = await self.next_timestamp()
|
||||
await self.args_set.wait()
|
||||
current_channel.set(self.channel)
|
||||
current_context.set(self.context)
|
||||
if self.generator is None:
|
||||
self.generator = cast(
|
||||
Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
||||
@@ -736,6 +749,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
def __init__(
|
||||
self,
|
||||
event_handler: Callable,
|
||||
context: Context,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
@@ -751,6 +765,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.has_started = False
|
||||
self._start: float | None = None
|
||||
self.context = context
|
||||
super().__init__()
|
||||
|
||||
def clear_queue(self):
|
||||
@@ -766,6 +781,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
self.args_set.set()
|
||||
|
||||
def next(self) -> tuple[int, np.ndarray] | None:
|
||||
current_context.set(self.context)
|
||||
self.args_set.wait()
|
||||
current_channel.set(self.channel)
|
||||
if self.generator is None:
|
||||
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
import tempfile
|
||||
import traceback
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
|
||||
|
||||
import av
|
||||
@@ -61,6 +62,22 @@ current_channel: ContextVar[DataChannel | None] = ContextVar(
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
webrtc_id: str
|
||||
|
||||
|
||||
current_context: ContextVar[Context | None] = ContextVar(
|
||||
"current_context", default=None
|
||||
)
|
||||
|
||||
|
||||
def get_current_context() -> Context:
|
||||
if not (ctx := current_context.get()):
|
||||
raise RuntimeError("No context found")
|
||||
return ctx
|
||||
|
||||
|
||||
def _send_log(message: str, type: str) -> None:
|
||||
async def _send(channel: DataChannel) -> None:
|
||||
channel.send(
|
||||
|
||||
@@ -38,6 +38,7 @@ from fastrtc.tracks import (
|
||||
)
|
||||
from fastrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
Context,
|
||||
create_message,
|
||||
webrtc_error_handler,
|
||||
)
|
||||
@@ -291,7 +292,7 @@ class WebRTCConnectionMixin:
|
||||
def _(track):
|
||||
relay = MediaRelay()
|
||||
handler = self.handlers[body["webrtc_id"]]
|
||||
|
||||
context = Context(webrtc_id=body["webrtc_id"])
|
||||
if self.modality == "video" and track.kind == "video":
|
||||
args = {}
|
||||
handler_ = handler
|
||||
@@ -304,6 +305,7 @@ class WebRTCConnectionMixin:
|
||||
event_handler=cast(Callable, handler_),
|
||||
set_additional_outputs=set_outputs,
|
||||
mode=cast(Literal["send", "send-receive"], self.mode),
|
||||
context=context,
|
||||
**args,
|
||||
)
|
||||
elif self.modality == "audio-video" and track.kind == "video":
|
||||
@@ -312,6 +314,7 @@ class WebRTCConnectionMixin:
|
||||
event_handler=handler, # type: ignore
|
||||
set_additional_outputs=set_outputs,
|
||||
fps=cast(StreamHandlerImpl, handler).fps,
|
||||
context=context,
|
||||
)
|
||||
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
|
||||
eh = cast(StreamHandlerImpl, handler)
|
||||
@@ -320,6 +323,7 @@ class WebRTCConnectionMixin:
|
||||
relay.subscribe(track),
|
||||
event_handler=eh,
|
||||
set_additional_outputs=set_outputs,
|
||||
context=context,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Modality must be either video, audio, or audio-video")
|
||||
@@ -336,6 +340,7 @@ class WebRTCConnectionMixin:
|
||||
elif self.mode == "send":
|
||||
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
|
||||
|
||||
context = Context(webrtc_id=body["webrtc_id"])
|
||||
if self.mode == "receive":
|
||||
if self.modality == "video":
|
||||
if isinstance(self.event_handler, VideoStreamHandler):
|
||||
@@ -343,16 +348,19 @@ class WebRTCConnectionMixin:
|
||||
cast(Callable, self.event_handler.callable),
|
||||
set_additional_outputs=set_outputs,
|
||||
fps=self.event_handler.fps,
|
||||
context=context,
|
||||
)
|
||||
else:
|
||||
cb = ServerToClientVideo(
|
||||
cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
context=context,
|
||||
)
|
||||
elif self.modality == "audio":
|
||||
cb = ServerToClientAudio(
|
||||
cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
context=context,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Modality must be either video or audio")
|
||||
|
||||
Reference in New Issue
Block a user