mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
add code (#223)
This commit is contained in:
@@ -38,10 +38,12 @@ from numpy import typing as npt
|
||||
from fastrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
CloseStream,
|
||||
Context,
|
||||
DataChannel,
|
||||
WebRTCError,
|
||||
create_message,
|
||||
current_channel,
|
||||
current_context,
|
||||
player_worker_decode,
|
||||
split_output,
|
||||
)
|
||||
@@ -83,6 +85,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: VideoEventHandler,
|
||||
context: Context,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
mode: Literal["send-receive", "send"] = "send-receive",
|
||||
@@ -104,10 +107,12 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.skip_frames = skip_frames
|
||||
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
|
||||
self.latest_frame = None
|
||||
self.context = context
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
current_channel.set(channel)
|
||||
current_context.set(self.context)
|
||||
self.channel_set.set()
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
@@ -145,6 +150,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.thread_quit.set()
|
||||
|
||||
async def wait_for_channel(self):
|
||||
current_context.set(self.context)
|
||||
if not self.channel_set.is_set():
|
||||
await self.channel_set.wait()
|
||||
if current_channel.get() != self.channel:
|
||||
@@ -486,6 +492,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: StreamHandlerBase,
|
||||
context: Context,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
@@ -502,6 +509,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
self.last_timestamp = 0
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.context = context
|
||||
|
||||
def clear_queue(self):
|
||||
logger.debug("clearing queue")
|
||||
@@ -514,6 +522,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
self._start = None
|
||||
|
||||
async def wait_for_channel(self):
|
||||
current_context.set(self.context)
|
||||
if not self.event_handler.channel_set.is_set():
|
||||
await self.event_handler.channel_set.wait()
|
||||
if current_channel.get() != self.event_handler.channel:
|
||||
@@ -532,6 +541,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
|
||||
def event_handler_emit(self) -> EmitType:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
current_context.set(self.context)
|
||||
return cast(Callable, self.event_handler.emit)()
|
||||
|
||||
async def process_input_frames(self) -> None:
|
||||
@@ -649,6 +659,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
def __init__(
|
||||
self,
|
||||
event_handler: Callable,
|
||||
context: Context,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
fps: int = 30,
|
||||
@@ -662,6 +673,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.fps = fps
|
||||
self.frame_ptime = 1.0 / fps
|
||||
self.context = context
|
||||
|
||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||
@@ -693,6 +705,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
pts, time_base = await self.next_timestamp()
|
||||
await self.args_set.wait()
|
||||
current_channel.set(self.channel)
|
||||
current_context.set(self.context)
|
||||
if self.generator is None:
|
||||
self.generator = cast(
|
||||
Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
||||
@@ -736,6 +749,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
def __init__(
|
||||
self,
|
||||
event_handler: Callable,
|
||||
context: Context,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
@@ -751,6 +765,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.has_started = False
|
||||
self._start: float | None = None
|
||||
self.context = context
|
||||
super().__init__()
|
||||
|
||||
def clear_queue(self):
|
||||
@@ -766,6 +781,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
self.args_set.set()
|
||||
|
||||
def next(self) -> tuple[int, np.ndarray] | None:
|
||||
current_context.set(self.context)
|
||||
self.args_set.wait()
|
||||
current_channel.set(self.channel)
|
||||
if self.generator is None:
|
||||
|
||||
Reference in New Issue
Block a user