mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
add code
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from .reply_on_pause import ReplyOnPause
|
||||
from .utils import AdditionalOutputs
|
||||
from .webrtc import StreamHandler, WebRTC
|
||||
|
||||
__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"]
|
||||
__all__ = ["AdditionalOutputs", "ReplyOnPause", "StreamHandler", "WebRTC"]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import fractions
|
||||
import logging
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, Protocol, cast
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,10 +12,38 @@ logger = logging.getLogger(__name__)
|
||||
AUDIO_PTIME = 0.020
|
||||
|
||||
|
||||
class AdditionalOutputs:
|
||||
def __init__(self, *args) -> None:
|
||||
self.args = args
|
||||
|
||||
|
||||
class DataChannel(Protocol):
|
||||
def send(self, message: str) -> None: ...
|
||||
|
||||
|
||||
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||
if isinstance(data, tuple):
|
||||
# handle the bare audio case
|
||||
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
|
||||
return data, None
|
||||
if not len(data) == 2:
|
||||
raise ValueError(
|
||||
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
||||
)
|
||||
if not isinstance(data[-1], AdditionalOutputs):
|
||||
raise ValueError(
|
||||
"The last element of the tuple must be an instance of AdditionalOutputs."
|
||||
)
|
||||
return data[0], cast(AdditionalOutputs, data[1])
|
||||
return data, None
|
||||
|
||||
|
||||
async def player_worker_decode(
|
||||
next_frame: Callable,
|
||||
queue: asyncio.Queue,
|
||||
thread_quit: asyncio.Event,
|
||||
channel: Callable[[], DataChannel | None] | None,
|
||||
set_additional_outputs: Callable | None,
|
||||
quit_on_none: bool = False,
|
||||
sample_rate: int = 48000,
|
||||
frame_size: int = int(48000 * AUDIO_PTIME),
|
||||
@@ -31,7 +60,17 @@ async def player_worker_decode(
|
||||
while not thread_quit.is_set():
|
||||
try:
|
||||
# Get next frame
|
||||
frame = await asyncio.wait_for(next_frame(), timeout=60)
|
||||
frame, outputs = split_output(
|
||||
await asyncio.wait_for(next_frame(), timeout=60)
|
||||
)
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
and set_additional_outputs
|
||||
and channel
|
||||
and channel()
|
||||
):
|
||||
set_additional_outputs(outputs)
|
||||
cast(DataChannel, channel()).send("change")
|
||||
|
||||
if frame is None:
|
||||
if quit_on_none:
|
||||
@@ -65,7 +104,7 @@ async def player_worker_decode(
|
||||
processed_frame.time_base = audio_time_base
|
||||
audio_samples += processed_frame.samples
|
||||
await queue.put(processed_frame)
|
||||
logger.debug("Queue size utils.py: %s", queue.qsize())
|
||||
logger.debug("Queue size utils.py: %s", queue.qsize())
|
||||
|
||||
except (TimeoutError, asyncio.TimeoutError):
|
||||
logger.warning(
|
||||
|
||||
@@ -11,7 +11,19 @@ import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Concatenate,
|
||||
Generator,
|
||||
Iterable,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
import anyio.to_thread
|
||||
import av
|
||||
@@ -29,7 +41,12 @@ from gradio import wasm_utils
|
||||
from gradio.components.base import Component, server
|
||||
from gradio_client import handle_file
|
||||
|
||||
from gradio_webrtc.utils import player_worker_decode
|
||||
from gradio_webrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
DataChannel,
|
||||
player_worker_decode,
|
||||
split_output,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
@@ -55,11 +72,15 @@ class VideoCallback(VideoStreamTrack):
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: Callable,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
|
||||
def add_frame_to_payload(
|
||||
self, args: list[Any], frame: np.ndarray | None
|
||||
@@ -88,7 +109,14 @@ class VideoCallback(VideoStreamTrack):
|
||||
|
||||
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
||||
|
||||
array = self.event_handler(*args)
|
||||
array, outputs = split_output(self.event_handler(*args))
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
and self.set_additional_outputs
|
||||
and self.channel
|
||||
):
|
||||
self.set_additional_outputs(outputs)
|
||||
self.channel.send("change")
|
||||
|
||||
new_frame = self.array_to_frame(array)
|
||||
if frame:
|
||||
@@ -152,6 +180,8 @@ class AudioCallback(AudioStreamTrack):
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: StreamHandler,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
@@ -162,6 +192,8 @@ class AudioCallback(AudioStreamTrack):
|
||||
self._start: float | None = None
|
||||
self.has_started = False
|
||||
self.last_timestamp = 0
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
super().__init__()
|
||||
|
||||
async def process_input_frames(self) -> None:
|
||||
@@ -189,6 +221,8 @@ class AudioCallback(AudioStreamTrack):
|
||||
callable,
|
||||
self.queue,
|
||||
self.thread_quit,
|
||||
lambda: self.channel,
|
||||
self.set_additional_outputs,
|
||||
False,
|
||||
self.event_handler.output_sample_rate,
|
||||
self.event_handler.output_frame_size,
|
||||
@@ -242,12 +276,16 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
def __init__(
|
||||
self,
|
||||
event_handler: Callable,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
self.event_handler = event_handler
|
||||
self.args_set = asyncio.Event()
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
|
||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||
@@ -262,7 +300,14 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
)
|
||||
|
||||
try:
|
||||
next_array = next(self.generator)
|
||||
next_array, outputs = split_output(next(self.generator))
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
and self.set_additional_outputs
|
||||
and self.channel
|
||||
):
|
||||
self.set_additional_outputs(outputs)
|
||||
self.channel.send("change")
|
||||
except StopIteration:
|
||||
self.stop()
|
||||
return
|
||||
@@ -283,6 +328,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
def __init__(
|
||||
self,
|
||||
event_handler: Callable,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
self.event_handler = event_handler
|
||||
@@ -291,6 +338,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
self.args_set = threading.Event()
|
||||
self.queue = asyncio.Queue()
|
||||
self.thread_quit = asyncio.Event()
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.has_started = False
|
||||
self._start: float | None = None
|
||||
super().__init__()
|
||||
@@ -315,6 +364,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
callable,
|
||||
self.queue,
|
||||
self.thread_quit,
|
||||
lambda: self.channel,
|
||||
self.set_additional_outputs,
|
||||
True,
|
||||
)
|
||||
)
|
||||
@@ -353,6 +404,12 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
super().stop()
|
||||
|
||||
|
||||
# For the return type
|
||||
R = TypeVar("R")
|
||||
# For the parameter specification
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class WebRTC(Component):
|
||||
"""
|
||||
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
||||
@@ -369,8 +426,10 @@ class WebRTC(Component):
|
||||
connections: dict[
|
||||
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
|
||||
] = {}
|
||||
data_channels: dict[str, DataChannel] = {}
|
||||
additional_outputs: dict[str, AdditionalOutputs] = {}
|
||||
|
||||
EVENTS = ["tick"]
|
||||
EVENTS = ["tick", "state_change"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -470,6 +529,14 @@ class WebRTC(Component):
|
||||
value=value,
|
||||
)
|
||||
|
||||
def set_additional_outputs(
|
||||
self, webrtc_id: str
|
||||
) -> Callable[[AdditionalOutputs], None]:
|
||||
def set_outputs(outputs: AdditionalOutputs):
|
||||
self.additional_outputs[webrtc_id] = outputs
|
||||
|
||||
return set_outputs
|
||||
|
||||
def preprocess(self, payload: str) -> str:
|
||||
"""
|
||||
Parameters:
|
||||
@@ -498,6 +565,38 @@ class WebRTC(Component):
|
||||
self.connections[webrtc_id].latest_args = list(args)
|
||||
self.connections[webrtc_id].args_set.set() # type: ignore
|
||||
|
||||
def change(
|
||||
self,
|
||||
fn: Callable[Concatenate[P], R],
|
||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
js: str | None = None,
|
||||
concurrency_limit: int | None | Literal["default"] = "default",
|
||||
concurrency_id: str | None = None,
|
||||
):
|
||||
inputs = inputs or []
|
||||
if inputs and not isinstance(inputs, Iterable):
|
||||
inputs = [inputs]
|
||||
inputs = list(inputs)
|
||||
|
||||
def handler(webrtc_id: str, *args):
|
||||
if webrtc_id in self.additional_outputs:
|
||||
return fn(*args, *self.additional_outputs[webrtc_id].args) # type: ignore
|
||||
return (
|
||||
tuple([None for _ in range(len(outputs))])
|
||||
if isinstance(outputs, Iterable)
|
||||
else None
|
||||
)
|
||||
|
||||
return self.state_change( # type: ignore
|
||||
fn=handler,
|
||||
inputs=[self] + cast(list, inputs),
|
||||
outputs=outputs,
|
||||
js=js,
|
||||
concurrency_limit=concurrency_limit,
|
||||
concurrency_id=concurrency_id,
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
fn: Callable[..., Any] | StreamHandler | None = None,
|
||||
@@ -599,6 +698,8 @@ class WebRTC(Component):
|
||||
pc = RTCPeerConnection()
|
||||
self.pcs.add(pc)
|
||||
|
||||
set_outputs = self.set_additional_outputs(body["webrtc_id"])
|
||||
|
||||
@pc.on("iceconnectionstatechange")
|
||||
async def on_iceconnectionstatechange():
|
||||
logger.debug("ICE connection state change %s", pc.iceConnectionState)
|
||||
@@ -627,27 +728,61 @@ class WebRTC(Component):
|
||||
cb = VideoCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
elif self.modality == "audio":
|
||||
cb = AudioCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(StreamHandler, self.event_handler).copy(),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
if body["webrtc_id"] in self.data_channels:
|
||||
self.connections[body["webrtc_id"]].channel = self.data_channels[
|
||||
body["webrtc_id"]
|
||||
]
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
|
||||
if self.mode == "receive":
|
||||
if self.modality == "video":
|
||||
cb = ServerToClientVideo(cast(Callable, self.event_handler))
|
||||
cb = ServerToClientVideo(
|
||||
cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
elif self.modality == "audio":
|
||||
cb = ServerToClientAudio(cast(Callable, self.event_handler))
|
||||
cb = ServerToClientAudio(
|
||||
cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
|
||||
|
||||
@pc.on("datachannel")
|
||||
def on_datachannel(channel):
|
||||
print("data channel established")
|
||||
logger.debug(f"Data channel established: {channel.label}")
|
||||
|
||||
self.data_channels[body["webrtc_id"]] = channel
|
||||
|
||||
async def set_channel(webrtc_id: str):
|
||||
print("webrtc_id", webrtc_id)
|
||||
while not self.connections.get(webrtc_id):
|
||||
await asyncio.sleep(0.05)
|
||||
print("setting channel")
|
||||
self.connections[webrtc_id].channel = channel
|
||||
|
||||
asyncio.create_task(set_channel(body["webrtc_id"]))
|
||||
|
||||
@channel.on("message")
|
||||
def on_message(message):
|
||||
logger.debug(f"Received message: {message}")
|
||||
if channel.readyState == "open":
|
||||
channel.send(f"Server received: {message}")
|
||||
|
||||
# handle offer
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user