This commit is contained in:
freddyaboulton
2024-10-28 09:59:08 -07:00
parent c051736fbb
commit d1c43edcd4
10 changed files with 225 additions and 18 deletions

View File

@@ -1,4 +1,5 @@
from .reply_on_pause import ReplyOnPause
from .utils import AdditionalOutputs
from .webrtc import StreamHandler, WebRTC
__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"]
__all__ = ["AdditionalOutputs", "ReplyOnPause", "StreamHandler", "WebRTC"]

View File

@@ -1,9 +1,10 @@
import asyncio
import fractions
import logging
from typing import Callable
from typing import Any, Callable, Protocol, cast
import av
import numpy as np
logger = logging.getLogger(__name__)
@@ -11,10 +12,38 @@ logger = logging.getLogger(__name__)
AUDIO_PTIME = 0.020
class AdditionalOutputs:
def __init__(self, *args) -> None:
self.args = args
class DataChannel(Protocol):
def send(self, message: str) -> None: ...
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
if isinstance(data, tuple):
# handle the bare audio case
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
return data, None
if not len(data) == 2:
raise ValueError(
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
)
if not isinstance(data[-1], AdditionalOutputs):
raise ValueError(
"The last element of the tuple must be an instance of AdditionalOutputs."
)
return data[0], cast(AdditionalOutputs, data[1])
return data, None
async def player_worker_decode(
next_frame: Callable,
queue: asyncio.Queue,
thread_quit: asyncio.Event,
channel: Callable[[], DataChannel | None] | None,
set_additional_outputs: Callable | None,
quit_on_none: bool = False,
sample_rate: int = 48000,
frame_size: int = int(48000 * AUDIO_PTIME),
@@ -31,7 +60,17 @@ async def player_worker_decode(
while not thread_quit.is_set():
try:
# Get next frame
frame = await asyncio.wait_for(next_frame(), timeout=60)
frame, outputs = split_output(
await asyncio.wait_for(next_frame(), timeout=60)
)
if (
isinstance(outputs, AdditionalOutputs)
and set_additional_outputs
and channel
and channel()
):
set_additional_outputs(outputs)
cast(DataChannel, channel()).send("change")
if frame is None:
if quit_on_none:
@@ -65,7 +104,7 @@ async def player_worker_decode(
processed_frame.time_base = audio_time_base
audio_samples += processed_frame.samples
await queue.put(processed_frame)
logger.debug("Queue size utils.py: %s", queue.qsize())
logger.debug("Queue size utils.py: %s", queue.qsize())
except (TimeoutError, asyncio.TimeoutError):
logger.warning(

View File

@@ -11,7 +11,19 @@ import traceback
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Concatenate,
Generator,
Iterable,
Literal,
ParamSpec,
Sequence,
TypeVar,
cast,
)
import anyio.to_thread
import av
@@ -29,7 +41,12 @@ from gradio import wasm_utils
from gradio.components.base import Component, server
from gradio_client import handle_file
from gradio_webrtc.utils import player_worker_decode
from gradio_webrtc.utils import (
AdditionalOutputs,
DataChannel,
player_worker_decode,
split_output,
)
if TYPE_CHECKING:
from gradio.blocks import Block
@@ -55,11 +72,15 @@ class VideoCallback(VideoStreamTrack):
self,
track: MediaStreamTrack,
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
super().__init__() # don't forget this!
self.track = track
self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set"
self.channel = channel
self.set_additional_outputs = set_additional_outputs
def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None
@@ -88,7 +109,14 @@ class VideoCallback(VideoStreamTrack):
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array = self.event_handler(*args)
array, outputs = split_output(self.event_handler(*args))
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send("change")
new_frame = self.array_to_frame(array)
if frame:
@@ -152,6 +180,8 @@ class AudioCallback(AudioStreamTrack):
self,
track: MediaStreamTrack,
event_handler: StreamHandler,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
self.track = track
self.event_handler = event_handler
@@ -162,6 +192,8 @@ class AudioCallback(AudioStreamTrack):
self._start: float | None = None
self.has_started = False
self.last_timestamp = 0
self.channel = channel
self.set_additional_outputs = set_additional_outputs
super().__init__()
async def process_input_frames(self) -> None:
@@ -189,6 +221,8 @@ class AudioCallback(AudioStreamTrack):
callable,
self.queue,
self.thread_quit,
lambda: self.channel,
self.set_additional_outputs,
False,
self.event_handler.output_sample_rate,
self.event_handler.output_frame_size,
@@ -242,12 +276,16 @@ class ServerToClientVideo(VideoStreamTrack):
def __init__(
self,
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
super().__init__() # don't forget this!
self.event_handler = event_handler
self.args_set = asyncio.Event()
self.latest_args: str | list[Any] = "not_set"
self.generator: Generator[Any, None, Any] | None = None
self.channel = channel
self.set_additional_outputs = set_additional_outputs
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
@@ -262,7 +300,14 @@ class ServerToClientVideo(VideoStreamTrack):
)
try:
next_array = next(self.generator)
next_array, outputs = split_output(next(self.generator))
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send("change")
except StopIteration:
self.stop()
return
@@ -283,6 +328,8 @@ class ServerToClientAudio(AudioStreamTrack):
def __init__(
self,
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
self.generator: Generator[Any, None, Any] | None = None
self.event_handler = event_handler
@@ -291,6 +338,8 @@ class ServerToClientAudio(AudioStreamTrack):
self.args_set = threading.Event()
self.queue = asyncio.Queue()
self.thread_quit = asyncio.Event()
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.has_started = False
self._start: float | None = None
super().__init__()
@@ -315,6 +364,8 @@ class ServerToClientAudio(AudioStreamTrack):
callable,
self.queue,
self.thread_quit,
lambda: self.channel,
self.set_additional_outputs,
True,
)
)
@@ -353,6 +404,12 @@ class ServerToClientAudio(AudioStreamTrack):
super().stop()
# For the return type
R = TypeVar("R")
# For the parameter specification
P = ParamSpec("P")
class WebRTC(Component):
"""
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
@@ -369,8 +426,10 @@ class WebRTC(Component):
connections: dict[
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
] = {}
data_channels: dict[str, DataChannel] = {}
additional_outputs: dict[str, AdditionalOutputs] = {}
EVENTS = ["tick"]
EVENTS = ["tick", "state_change"]
def __init__(
self,
@@ -470,6 +529,14 @@ class WebRTC(Component):
value=value,
)
def set_additional_outputs(
self, webrtc_id: str
) -> Callable[[AdditionalOutputs], None]:
def set_outputs(outputs: AdditionalOutputs):
self.additional_outputs[webrtc_id] = outputs
return set_outputs
def preprocess(self, payload: str) -> str:
"""
Parameters:
@@ -498,6 +565,38 @@ class WebRTC(Component):
self.connections[webrtc_id].latest_args = list(args)
self.connections[webrtc_id].args_set.set() # type: ignore
def change(
self,
fn: Callable[Concatenate[P], R],
inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | set[Block] | None = None,
js: str | None = None,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
):
inputs = inputs or []
if inputs and not isinstance(inputs, Iterable):
inputs = [inputs]
inputs = list(inputs)
def handler(webrtc_id: str, *args):
if webrtc_id in self.additional_outputs:
return fn(*args, *self.additional_outputs[webrtc_id].args) # type: ignore
return (
tuple([None for _ in range(len(outputs))])
if isinstance(outputs, Iterable)
else None
)
return self.state_change( # type: ignore
fn=handler,
inputs=[self] + cast(list, inputs),
outputs=outputs,
js=js,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
)
def stream(
self,
fn: Callable[..., Any] | StreamHandler | None = None,
@@ -599,6 +698,8 @@ class WebRTC(Component):
pc = RTCPeerConnection()
self.pcs.add(pc)
set_outputs = self.set_additional_outputs(body["webrtc_id"])
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
logger.debug("ICE connection state change %s", pc.iceConnectionState)
@@ -627,27 +728,61 @@ class WebRTC(Component):
cb = VideoCallback(
relay.subscribe(track),
event_handler=cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
)
elif self.modality == "audio":
cb = AudioCallback(
relay.subscribe(track),
event_handler=cast(StreamHandler, self.event_handler).copy(),
set_additional_outputs=set_outputs,
)
self.connections[body["webrtc_id"]] = cb
if body["webrtc_id"] in self.data_channels:
self.connections[body["webrtc_id"]].channel = self.data_channels[
body["webrtc_id"]
]
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
if self.mode == "receive":
if self.modality == "video":
cb = ServerToClientVideo(cast(Callable, self.event_handler))
cb = ServerToClientVideo(
cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
)
elif self.modality == "audio":
cb = ServerToClientAudio(cast(Callable, self.event_handler))
cb = ServerToClientAudio(
cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
)
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
@pc.on("datachannel")
def on_datachannel(channel):
print("data channel established")
logger.debug(f"Data channel established: {channel.label}")
self.data_channels[body["webrtc_id"]] = channel
async def set_channel(webrtc_id: str):
print("webrtc_id", webrtc_id)
while not self.connections.get(webrtc_id):
await asyncio.sleep(0.05)
print("setting channel")
self.connections[webrtc_id].channel = channel
asyncio.create_task(set_channel(body["webrtc_id"]))
@channel.on("message")
def on_message(message):
logger.debug(f"Received message: {message}")
if channel.readyState == "open":
channel.send(f"Server received: {message}")
# handle offer
await pc.setRemoteDescription(offer)