This commit is contained in:
freddyaboulton
2024-10-28 09:59:08 -07:00
parent c051736fbb
commit d1c43edcd4
10 changed files with 225 additions and 18 deletions

View File

@@ -1,4 +1,5 @@
from .reply_on_pause import ReplyOnPause
from .utils import AdditionalOutputs
from .webrtc import StreamHandler, WebRTC
__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"]
__all__ = ["AdditionalOutputs", "ReplyOnPause", "StreamHandler", "WebRTC"]

View File

@@ -1,9 +1,10 @@
import asyncio
import fractions
import logging
from typing import Callable
from typing import Any, Callable, Protocol, cast
import av
import numpy as np
logger = logging.getLogger(__name__)
@@ -11,10 +12,38 @@ logger = logging.getLogger(__name__)
AUDIO_PTIME = 0.020
class AdditionalOutputs:
def __init__(self, *args) -> None:
self.args = args
class DataChannel(Protocol):
def send(self, message: str) -> None: ...
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
if isinstance(data, tuple):
# handle the bare audio case
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
return data, None
if not len(data) == 2:
raise ValueError(
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
)
if not isinstance(data[-1], AdditionalOutputs):
raise ValueError(
"The last element of the tuple must be an instance of AdditionalOutputs."
)
return data[0], cast(AdditionalOutputs, data[1])
return data, None
async def player_worker_decode(
next_frame: Callable,
queue: asyncio.Queue,
thread_quit: asyncio.Event,
channel: Callable[[], DataChannel | None] | None,
set_additional_outputs: Callable | None,
quit_on_none: bool = False,
sample_rate: int = 48000,
frame_size: int = int(48000 * AUDIO_PTIME),
@@ -31,7 +60,17 @@ async def player_worker_decode(
while not thread_quit.is_set():
try:
# Get next frame
frame = await asyncio.wait_for(next_frame(), timeout=60)
frame, outputs = split_output(
await asyncio.wait_for(next_frame(), timeout=60)
)
if (
isinstance(outputs, AdditionalOutputs)
and set_additional_outputs
and channel
and channel()
):
set_additional_outputs(outputs)
cast(DataChannel, channel()).send("change")
if frame is None:
if quit_on_none:
@@ -65,7 +104,7 @@ async def player_worker_decode(
processed_frame.time_base = audio_time_base
audio_samples += processed_frame.samples
await queue.put(processed_frame)
logger.debug("Queue size utils.py: %s", queue.qsize())
logger.debug("Queue size utils.py: %s", queue.qsize())
except (TimeoutError, asyncio.TimeoutError):
logger.warning(

View File

@@ -11,7 +11,19 @@ import traceback
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Concatenate,
Generator,
Iterable,
Literal,
ParamSpec,
Sequence,
TypeVar,
cast,
)
import anyio.to_thread
import av
@@ -29,7 +41,12 @@ from gradio import wasm_utils
from gradio.components.base import Component, server
from gradio_client import handle_file
from gradio_webrtc.utils import player_worker_decode
from gradio_webrtc.utils import (
AdditionalOutputs,
DataChannel,
player_worker_decode,
split_output,
)
if TYPE_CHECKING:
from gradio.blocks import Block
@@ -55,11 +72,15 @@ class VideoCallback(VideoStreamTrack):
self,
track: MediaStreamTrack,
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
super().__init__() # don't forget this!
self.track = track
self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set"
self.channel = channel
self.set_additional_outputs = set_additional_outputs
def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None
@@ -88,7 +109,14 @@ class VideoCallback(VideoStreamTrack):
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array = self.event_handler(*args)
array, outputs = split_output(self.event_handler(*args))
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send("change")
new_frame = self.array_to_frame(array)
if frame:
@@ -152,6 +180,8 @@ class AudioCallback(AudioStreamTrack):
self,
track: MediaStreamTrack,
event_handler: StreamHandler,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
self.track = track
self.event_handler = event_handler
@@ -162,6 +192,8 @@ class AudioCallback(AudioStreamTrack):
self._start: float | None = None
self.has_started = False
self.last_timestamp = 0
self.channel = channel
self.set_additional_outputs = set_additional_outputs
super().__init__()
async def process_input_frames(self) -> None:
@@ -189,6 +221,8 @@ class AudioCallback(AudioStreamTrack):
callable,
self.queue,
self.thread_quit,
lambda: self.channel,
self.set_additional_outputs,
False,
self.event_handler.output_sample_rate,
self.event_handler.output_frame_size,
@@ -242,12 +276,16 @@ class ServerToClientVideo(VideoStreamTrack):
def __init__(
self,
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
super().__init__() # don't forget this!
self.event_handler = event_handler
self.args_set = asyncio.Event()
self.latest_args: str | list[Any] = "not_set"
self.generator: Generator[Any, None, Any] | None = None
self.channel = channel
self.set_additional_outputs = set_additional_outputs
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
@@ -262,7 +300,14 @@ class ServerToClientVideo(VideoStreamTrack):
)
try:
next_array = next(self.generator)
next_array, outputs = split_output(next(self.generator))
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send("change")
except StopIteration:
self.stop()
return
@@ -283,6 +328,8 @@ class ServerToClientAudio(AudioStreamTrack):
def __init__(
self,
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
self.generator: Generator[Any, None, Any] | None = None
self.event_handler = event_handler
@@ -291,6 +338,8 @@ class ServerToClientAudio(AudioStreamTrack):
self.args_set = threading.Event()
self.queue = asyncio.Queue()
self.thread_quit = asyncio.Event()
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.has_started = False
self._start: float | None = None
super().__init__()
@@ -315,6 +364,8 @@ class ServerToClientAudio(AudioStreamTrack):
callable,
self.queue,
self.thread_quit,
lambda: self.channel,
self.set_additional_outputs,
True,
)
)
@@ -353,6 +404,12 @@ class ServerToClientAudio(AudioStreamTrack):
super().stop()
# For the return type
R = TypeVar("R")
# For the parameter specification
P = ParamSpec("P")
class WebRTC(Component):
"""
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
@@ -369,8 +426,10 @@ class WebRTC(Component):
connections: dict[
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
] = {}
data_channels: dict[str, DataChannel] = {}
additional_outputs: dict[str, AdditionalOutputs] = {}
EVENTS = ["tick"]
EVENTS = ["tick", "state_change"]
def __init__(
self,
@@ -470,6 +529,14 @@ class WebRTC(Component):
value=value,
)
def set_additional_outputs(
self, webrtc_id: str
) -> Callable[[AdditionalOutputs], None]:
def set_outputs(outputs: AdditionalOutputs):
self.additional_outputs[webrtc_id] = outputs
return set_outputs
def preprocess(self, payload: str) -> str:
"""
Parameters:
@@ -498,6 +565,38 @@ class WebRTC(Component):
self.connections[webrtc_id].latest_args = list(args)
self.connections[webrtc_id].args_set.set() # type: ignore
def change(
self,
fn: Callable[Concatenate[P], R],
inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | set[Block] | None = None,
js: str | None = None,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
):
inputs = inputs or []
if inputs and not isinstance(inputs, Iterable):
inputs = [inputs]
inputs = list(inputs)
def handler(webrtc_id: str, *args):
if webrtc_id in self.additional_outputs:
return fn(*args, *self.additional_outputs[webrtc_id].args) # type: ignore
return (
tuple([None for _ in range(len(outputs))])
if isinstance(outputs, Iterable)
else None
)
return self.state_change( # type: ignore
fn=handler,
inputs=[self] + cast(list, inputs),
outputs=outputs,
js=js,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
)
def stream(
self,
fn: Callable[..., Any] | StreamHandler | None = None,
@@ -599,6 +698,8 @@ class WebRTC(Component):
pc = RTCPeerConnection()
self.pcs.add(pc)
set_outputs = self.set_additional_outputs(body["webrtc_id"])
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
logger.debug("ICE connection state change %s", pc.iceConnectionState)
@@ -627,27 +728,61 @@ class WebRTC(Component):
cb = VideoCallback(
relay.subscribe(track),
event_handler=cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
)
elif self.modality == "audio":
cb = AudioCallback(
relay.subscribe(track),
event_handler=cast(StreamHandler, self.event_handler).copy(),
set_additional_outputs=set_outputs,
)
self.connections[body["webrtc_id"]] = cb
if body["webrtc_id"] in self.data_channels:
self.connections[body["webrtc_id"]].channel = self.data_channels[
body["webrtc_id"]
]
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
if self.mode == "receive":
if self.modality == "video":
cb = ServerToClientVideo(cast(Callable, self.event_handler))
cb = ServerToClientVideo(
cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
)
elif self.modality == "audio":
cb = ServerToClientAudio(cast(Callable, self.event_handler))
cb = ServerToClientAudio(
cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
)
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
@pc.on("datachannel")
def on_datachannel(channel):
print("data channel established")
logger.debug(f"Data channel established: {channel.label}")
self.data_channels[body["webrtc_id"]] = channel
async def set_channel(webrtc_id: str):
print("webrtc_id", webrtc_id)
while not self.connections.get(webrtc_id):
await asyncio.sleep(0.05)
print("setting channel")
self.connections[webrtc_id].channel = channel
asyncio.create_task(set_channel(body["webrtc_id"]))
@channel.on("message")
def on_message(message):
logger.debug(f"Received message: {message}")
if channel.readyState == "open":
channel.send(f"Server received: {message}")
# handle offer
await pc.setRemoteDescription(offer)

View File

@@ -34,6 +34,10 @@
export let mode: "send-receive" | "receive" = "send-receive";
export let track_constraints: MediaTrackConstraints = {};
const on_change_cb = () => {
gradio.dispatch("state_change");
}
let dragging = false;
$: console.log("value", value);
@@ -63,6 +67,7 @@
{#if mode == "receive" && modality === "video"}
<StaticVideo
bind:value={value}
{on_change_cb}
{label}
{show_label}
{server}
@@ -73,6 +78,7 @@
{:else if mode == "receive" && modality === "audio"}
<StaticAudio
bind:value={value}
{on_change_cb}
{label}
{show_label}
{server}
@@ -91,6 +97,7 @@
{server}
{rtc_configuration}
{time_limit}
{on_change_cb}
on:clear={() => gradio.dispatch("clear")}
on:play={() => gradio.dispatch("play")}
on:pause={() => gradio.dispatch("pause")}
@@ -109,6 +116,7 @@
{:else if mode === "send-receive" && modality === "audio"}
<InteractiveAudio
bind:value={value}
{on_change_cb}
{label}
{show_label}
{server}

View File

@@ -25,9 +25,9 @@
export let i18n: I18nFormatter;
export let time_limit: number | null = null;
export let track_constraints: MediaTrackConstraints = {};
let _time_limit: number | null = null;
export let on_change_cb: () => void;
$: console.log("time_limit", time_limit);
let _time_limit: number | null = null;
export let server: {
offer: (body: any) => Promise<any>;
@@ -41,12 +41,15 @@
const dispatch = createEventDispatcher<{
tick: undefined;
state_change: undefined;
error: string
play: undefined;
stop: undefined;
}>();
onMount(() => {
window.setInterval(() => {
if (stream_state == "open") {
@@ -103,7 +106,7 @@
}
if (stream == null) return;
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => {
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => {
pc = connection;
}).catch(() => {
console.info("catching")

View File

@@ -21,6 +21,7 @@
};
export let rtc_configuration: Object;
export let track_constraints: MediaTrackConstraints = {};
export let on_change_cb: () => void;
const dispatch = createEventDispatcher<{
change: FileData | null;
@@ -50,6 +51,7 @@
{include_audio}
{time_limit}
{track_constraints}
{on_change_cb}
on:error
on:start_recording
on:stop_recording

View File

@@ -17,6 +17,7 @@
export let show_label = true;
export let rtc_configuration: Object | null = null;
export let i18n: I18nFormatter;
export let on_change_cb: () => void;
export let server: {
offer: (body: any) => Promise<any>;
@@ -68,7 +69,7 @@
}
)
let stream = null;
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => {
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => {
pc = connection;
}).catch(() => {
console.info("catching")

View File

@@ -13,6 +13,7 @@
export let label: string | undefined = undefined;
export let show_label = true;
export let rtc_configuration: Object | null = null;
export let on_change_cb: () => void;
export let server: {
offer: (body: any) => Promise<any>;
};
@@ -59,7 +60,7 @@
}
}
)
start(null, pc, video_element, server.offer, _webrtc_id).then((connection) => {
start(null, pc, video_element, server.offer, _webrtc_id, "video", on_change_cb).then((connection) => {
pc = connection;
}).catch(() => {
console.log("catching")

View File

@@ -24,6 +24,7 @@
let _time_limit: number | null = null;
export let time_limit: number | null = null;
let stream_state: "open" | "waiting" | "closed" = "closed";
export let on_change_cb: () => void;
const _webrtc_id = Math.random().toString(36).substring(2);
export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
@@ -139,7 +140,7 @@
)
stream_state = "waiting"
webrtc_id = Math.random().toString(36).substring(2);
start(stream, pc, video_source, server.offer, webrtc_id).then((connection) => {
start(stream, pc, video_source, server.offer, webrtc_id, "video", on_change_cb).then((connection) => {
pc = connection;
}).catch(() => {
console.info("catching")

View File

@@ -44,8 +44,24 @@ export function createPeerConnection(pc, node) {
return pc;
}
export async function start(stream, pc: RTCPeerConnection, node, server_fn, webrtc_id, modality: "video" | "audio" = "video") {
export async function start(stream, pc: RTCPeerConnection, node, server_fn, webrtc_id,
modality: "video" | "audio" = "video", on_change_cb: () => void = () => {}) {
pc = createPeerConnection(pc, node);
const data_channel = pc.createDataChannel("text");
data_channel.onopen = () => {
console.debug("Data channel is open");
data_channel.send("handshake");
};
data_channel.onmessage = (event) => {
console.debug("Received message:", event.data);
if (event.data === "change") {
console.debug("Change event received");
on_change_cb();
}
};
if (stream) {
stream.getTracks().forEach((track) => {
console.debug("Track stream callback", track);