mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
add code (#223)
This commit is contained in:
@@ -34,6 +34,7 @@ from .utils import (
|
|||||||
audio_to_file,
|
audio_to_file,
|
||||||
audio_to_float32,
|
audio_to_float32,
|
||||||
audio_to_int16,
|
audio_to_int16,
|
||||||
|
get_current_context,
|
||||||
wait_for_item,
|
wait_for_item,
|
||||||
)
|
)
|
||||||
from .webrtc import (
|
from .webrtc import (
|
||||||
@@ -77,4 +78,5 @@ __all__ = [
|
|||||||
"SileroVadOptions",
|
"SileroVadOptions",
|
||||||
"VideoStreamHandler",
|
"VideoStreamHandler",
|
||||||
"CloseStream",
|
"CloseStream",
|
||||||
|
"get_current_context",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -38,10 +38,12 @@ from numpy import typing as npt
|
|||||||
from fastrtc.utils import (
|
from fastrtc.utils import (
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
CloseStream,
|
CloseStream,
|
||||||
|
Context,
|
||||||
DataChannel,
|
DataChannel,
|
||||||
WebRTCError,
|
WebRTCError,
|
||||||
create_message,
|
create_message,
|
||||||
current_channel,
|
current_channel,
|
||||||
|
current_context,
|
||||||
player_worker_decode,
|
player_worker_decode,
|
||||||
split_output,
|
split_output,
|
||||||
)
|
)
|
||||||
@@ -83,6 +85,7 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
self,
|
self,
|
||||||
track: MediaStreamTrack,
|
track: MediaStreamTrack,
|
||||||
event_handler: VideoEventHandler,
|
event_handler: VideoEventHandler,
|
||||||
|
context: Context,
|
||||||
channel: DataChannel | None = None,
|
channel: DataChannel | None = None,
|
||||||
set_additional_outputs: Callable | None = None,
|
set_additional_outputs: Callable | None = None,
|
||||||
mode: Literal["send-receive", "send"] = "send-receive",
|
mode: Literal["send-receive", "send"] = "send-receive",
|
||||||
@@ -104,10 +107,12 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
self.skip_frames = skip_frames
|
self.skip_frames = skip_frames
|
||||||
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
|
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
|
||||||
self.latest_frame = None
|
self.latest_frame = None
|
||||||
|
self.context = context
|
||||||
|
|
||||||
def set_channel(self, channel: DataChannel):
|
def set_channel(self, channel: DataChannel):
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
current_channel.set(channel)
|
current_channel.set(channel)
|
||||||
|
current_context.set(self.context)
|
||||||
self.channel_set.set()
|
self.channel_set.set()
|
||||||
|
|
||||||
def set_args(self, args: list[Any]):
|
def set_args(self, args: list[Any]):
|
||||||
@@ -145,6 +150,7 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
self.thread_quit.set()
|
self.thread_quit.set()
|
||||||
|
|
||||||
async def wait_for_channel(self):
|
async def wait_for_channel(self):
|
||||||
|
current_context.set(self.context)
|
||||||
if not self.channel_set.is_set():
|
if not self.channel_set.is_set():
|
||||||
await self.channel_set.wait()
|
await self.channel_set.wait()
|
||||||
if current_channel.get() != self.channel:
|
if current_channel.get() != self.channel:
|
||||||
@@ -486,6 +492,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
self,
|
self,
|
||||||
track: MediaStreamTrack,
|
track: MediaStreamTrack,
|
||||||
event_handler: StreamHandlerBase,
|
event_handler: StreamHandlerBase,
|
||||||
|
context: Context,
|
||||||
channel: DataChannel | None = None,
|
channel: DataChannel | None = None,
|
||||||
set_additional_outputs: Callable | None = None,
|
set_additional_outputs: Callable | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -502,6 +509,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
self.last_timestamp = 0
|
self.last_timestamp = 0
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.set_additional_outputs = set_additional_outputs
|
self.set_additional_outputs = set_additional_outputs
|
||||||
|
self.context = context
|
||||||
|
|
||||||
def clear_queue(self):
|
def clear_queue(self):
|
||||||
logger.debug("clearing queue")
|
logger.debug("clearing queue")
|
||||||
@@ -514,6 +522,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
self._start = None
|
self._start = None
|
||||||
|
|
||||||
async def wait_for_channel(self):
|
async def wait_for_channel(self):
|
||||||
|
current_context.set(self.context)
|
||||||
if not self.event_handler.channel_set.is_set():
|
if not self.event_handler.channel_set.is_set():
|
||||||
await self.event_handler.channel_set.wait()
|
await self.event_handler.channel_set.wait()
|
||||||
if current_channel.get() != self.event_handler.channel:
|
if current_channel.get() != self.event_handler.channel:
|
||||||
@@ -532,6 +541,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
|
|
||||||
def event_handler_emit(self) -> EmitType:
|
def event_handler_emit(self) -> EmitType:
|
||||||
current_channel.set(self.event_handler.channel)
|
current_channel.set(self.event_handler.channel)
|
||||||
|
current_context.set(self.context)
|
||||||
return cast(Callable, self.event_handler.emit)()
|
return cast(Callable, self.event_handler.emit)()
|
||||||
|
|
||||||
async def process_input_frames(self) -> None:
|
async def process_input_frames(self) -> None:
|
||||||
@@ -649,6 +659,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
event_handler: Callable,
|
event_handler: Callable,
|
||||||
|
context: Context,
|
||||||
channel: DataChannel | None = None,
|
channel: DataChannel | None = None,
|
||||||
set_additional_outputs: Callable | None = None,
|
set_additional_outputs: Callable | None = None,
|
||||||
fps: int = 30,
|
fps: int = 30,
|
||||||
@@ -662,6 +673,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
self.set_additional_outputs = set_additional_outputs
|
self.set_additional_outputs = set_additional_outputs
|
||||||
self.fps = fps
|
self.fps = fps
|
||||||
self.frame_ptime = 1.0 / fps
|
self.frame_ptime = 1.0 / fps
|
||||||
|
self.context = context
|
||||||
|
|
||||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||||
@@ -693,6 +705,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
pts, time_base = await self.next_timestamp()
|
pts, time_base = await self.next_timestamp()
|
||||||
await self.args_set.wait()
|
await self.args_set.wait()
|
||||||
current_channel.set(self.channel)
|
current_channel.set(self.channel)
|
||||||
|
current_context.set(self.context)
|
||||||
if self.generator is None:
|
if self.generator is None:
|
||||||
self.generator = cast(
|
self.generator = cast(
|
||||||
Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
||||||
@@ -736,6 +749,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
event_handler: Callable,
|
event_handler: Callable,
|
||||||
|
context: Context,
|
||||||
channel: DataChannel | None = None,
|
channel: DataChannel | None = None,
|
||||||
set_additional_outputs: Callable | None = None,
|
set_additional_outputs: Callable | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -751,6 +765,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
self.set_additional_outputs = set_additional_outputs
|
self.set_additional_outputs = set_additional_outputs
|
||||||
self.has_started = False
|
self.has_started = False
|
||||||
self._start: float | None = None
|
self._start: float | None = None
|
||||||
|
self.context = context
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def clear_queue(self):
|
def clear_queue(self):
|
||||||
@@ -766,6 +781,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
self.args_set.set()
|
self.args_set.set()
|
||||||
|
|
||||||
def next(self) -> tuple[int, np.ndarray] | None:
|
def next(self) -> tuple[int, np.ndarray] | None:
|
||||||
|
current_context.set(self.context)
|
||||||
self.args_set.wait()
|
self.args_set.wait()
|
||||||
current_channel.set(self.channel)
|
current_channel.set(self.channel)
|
||||||
if self.generator is None:
|
if self.generator is None:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import logging
|
|||||||
import tempfile
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
|
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
|
||||||
|
|
||||||
import av
|
import av
|
||||||
@@ -61,6 +62,22 @@ current_channel: ContextVar[DataChannel | None] = ContextVar(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Context:
|
||||||
|
webrtc_id: str
|
||||||
|
|
||||||
|
|
||||||
|
current_context: ContextVar[Context | None] = ContextVar(
|
||||||
|
"current_context", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_context() -> Context:
|
||||||
|
if not (ctx := current_context.get()):
|
||||||
|
raise RuntimeError("No context found")
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
def _send_log(message: str, type: str) -> None:
|
def _send_log(message: str, type: str) -> None:
|
||||||
async def _send(channel: DataChannel) -> None:
|
async def _send(channel: DataChannel) -> None:
|
||||||
channel.send(
|
channel.send(
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from fastrtc.tracks import (
|
|||||||
)
|
)
|
||||||
from fastrtc.utils import (
|
from fastrtc.utils import (
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
|
Context,
|
||||||
create_message,
|
create_message,
|
||||||
webrtc_error_handler,
|
webrtc_error_handler,
|
||||||
)
|
)
|
||||||
@@ -291,7 +292,7 @@ class WebRTCConnectionMixin:
|
|||||||
def _(track):
|
def _(track):
|
||||||
relay = MediaRelay()
|
relay = MediaRelay()
|
||||||
handler = self.handlers[body["webrtc_id"]]
|
handler = self.handlers[body["webrtc_id"]]
|
||||||
|
context = Context(webrtc_id=body["webrtc_id"])
|
||||||
if self.modality == "video" and track.kind == "video":
|
if self.modality == "video" and track.kind == "video":
|
||||||
args = {}
|
args = {}
|
||||||
handler_ = handler
|
handler_ = handler
|
||||||
@@ -304,6 +305,7 @@ class WebRTCConnectionMixin:
|
|||||||
event_handler=cast(Callable, handler_),
|
event_handler=cast(Callable, handler_),
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
mode=cast(Literal["send", "send-receive"], self.mode),
|
mode=cast(Literal["send", "send-receive"], self.mode),
|
||||||
|
context=context,
|
||||||
**args,
|
**args,
|
||||||
)
|
)
|
||||||
elif self.modality == "audio-video" and track.kind == "video":
|
elif self.modality == "audio-video" and track.kind == "video":
|
||||||
@@ -312,6 +314,7 @@ class WebRTCConnectionMixin:
|
|||||||
event_handler=handler, # type: ignore
|
event_handler=handler, # type: ignore
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
fps=cast(StreamHandlerImpl, handler).fps,
|
fps=cast(StreamHandlerImpl, handler).fps,
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
|
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
|
||||||
eh = cast(StreamHandlerImpl, handler)
|
eh = cast(StreamHandlerImpl, handler)
|
||||||
@@ -320,6 +323,7 @@ class WebRTCConnectionMixin:
|
|||||||
relay.subscribe(track),
|
relay.subscribe(track),
|
||||||
event_handler=eh,
|
event_handler=eh,
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Modality must be either video, audio, or audio-video")
|
raise ValueError("Modality must be either video, audio, or audio-video")
|
||||||
@@ -336,6 +340,7 @@ class WebRTCConnectionMixin:
|
|||||||
elif self.mode == "send":
|
elif self.mode == "send":
|
||||||
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
|
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
|
||||||
|
|
||||||
|
context = Context(webrtc_id=body["webrtc_id"])
|
||||||
if self.mode == "receive":
|
if self.mode == "receive":
|
||||||
if self.modality == "video":
|
if self.modality == "video":
|
||||||
if isinstance(self.event_handler, VideoStreamHandler):
|
if isinstance(self.event_handler, VideoStreamHandler):
|
||||||
@@ -343,16 +348,19 @@ class WebRTCConnectionMixin:
|
|||||||
cast(Callable, self.event_handler.callable),
|
cast(Callable, self.event_handler.callable),
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
fps=self.event_handler.fps,
|
fps=self.event_handler.fps,
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cb = ServerToClientVideo(
|
cb = ServerToClientVideo(
|
||||||
cast(Callable, self.event_handler),
|
cast(Callable, self.event_handler),
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
elif self.modality == "audio":
|
elif self.modality == "audio":
|
||||||
cb = ServerToClientAudio(
|
cb = ServerToClientAudio(
|
||||||
cast(Callable, self.event_handler),
|
cast(Callable, self.event_handler),
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Modality must be either video or audio")
|
raise ValueError("Modality must be either video or audio")
|
||||||
|
|||||||
Reference in New Issue
Block a user