Merge pull request #13 from freddyaboulton/return-other-components

Be able to update other components
This commit is contained in:
Freddy Boulton
2024-10-28 10:08:35 -07:00
committed by GitHub
10 changed files with 225 additions and 18 deletions

View File

@@ -1,4 +1,5 @@
from .reply_on_pause import ReplyOnPause from .reply_on_pause import ReplyOnPause
from .utils import AdditionalOutputs
from .webrtc import StreamHandler, WebRTC from .webrtc import StreamHandler, WebRTC
__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"] __all__ = ["AdditionalOutputs", "ReplyOnPause", "StreamHandler", "WebRTC"]

View File

@@ -1,9 +1,10 @@
import asyncio import asyncio
import fractions import fractions
import logging import logging
from typing import Callable from typing import Any, Callable, Protocol, cast
import av import av
import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -11,10 +12,38 @@ logger = logging.getLogger(__name__)
AUDIO_PTIME = 0.020 AUDIO_PTIME = 0.020
class AdditionalOutputs:
def __init__(self, *args) -> None:
self.args = args
class DataChannel(Protocol):
def send(self, message: str) -> None: ...
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
if isinstance(data, tuple):
# handle the bare audio case
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
return data, None
if not len(data) == 2:
raise ValueError(
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
)
if not isinstance(data[-1], AdditionalOutputs):
raise ValueError(
"The last element of the tuple must be an instance of AdditionalOutputs."
)
return data[0], cast(AdditionalOutputs, data[1])
return data, None
async def player_worker_decode( async def player_worker_decode(
next_frame: Callable, next_frame: Callable,
queue: asyncio.Queue, queue: asyncio.Queue,
thread_quit: asyncio.Event, thread_quit: asyncio.Event,
channel: Callable[[], DataChannel | None] | None,
set_additional_outputs: Callable | None,
quit_on_none: bool = False, quit_on_none: bool = False,
sample_rate: int = 48000, sample_rate: int = 48000,
frame_size: int = int(48000 * AUDIO_PTIME), frame_size: int = int(48000 * AUDIO_PTIME),
@@ -31,7 +60,17 @@ async def player_worker_decode(
while not thread_quit.is_set(): while not thread_quit.is_set():
try: try:
# Get next frame # Get next frame
frame = await asyncio.wait_for(next_frame(), timeout=60) frame, outputs = split_output(
await asyncio.wait_for(next_frame(), timeout=60)
)
if (
isinstance(outputs, AdditionalOutputs)
and set_additional_outputs
and channel
and channel()
):
set_additional_outputs(outputs)
cast(DataChannel, channel()).send("change")
if frame is None: if frame is None:
if quit_on_none: if quit_on_none:
@@ -65,7 +104,7 @@ async def player_worker_decode(
processed_frame.time_base = audio_time_base processed_frame.time_base = audio_time_base
audio_samples += processed_frame.samples audio_samples += processed_frame.samples
await queue.put(processed_frame) await queue.put(processed_frame)
logger.debug("Queue size utils.py: %s", queue.qsize()) logger.debug("Queue size utils.py: %s", queue.qsize())
except (TimeoutError, asyncio.TimeoutError): except (TimeoutError, asyncio.TimeoutError):
logger.warning( logger.warning(

View File

@@ -11,7 +11,19 @@ import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast from typing import (
TYPE_CHECKING,
Any,
Callable,
Concatenate,
Generator,
Iterable,
Literal,
ParamSpec,
Sequence,
TypeVar,
cast,
)
import anyio.to_thread import anyio.to_thread
import av import av
@@ -29,7 +41,12 @@ from gradio import wasm_utils
from gradio.components.base import Component, server from gradio.components.base import Component, server
from gradio_client import handle_file from gradio_client import handle_file
from gradio_webrtc.utils import player_worker_decode from gradio_webrtc.utils import (
AdditionalOutputs,
DataChannel,
player_worker_decode,
split_output,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.blocks import Block from gradio.blocks import Block
@@ -55,11 +72,15 @@ class VideoCallback(VideoStreamTrack):
self, self,
track: MediaStreamTrack, track: MediaStreamTrack,
event_handler: Callable, event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None: ) -> None:
super().__init__() # don't forget this! super().__init__() # don't forget this!
self.track = track self.track = track
self.event_handler = event_handler self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set" self.latest_args: str | list[Any] = "not_set"
self.channel = channel
self.set_additional_outputs = set_additional_outputs
def add_frame_to_payload( def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None self, args: list[Any], frame: np.ndarray | None
@@ -88,7 +109,14 @@ class VideoCallback(VideoStreamTrack):
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array) args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array = self.event_handler(*args) array, outputs = split_output(self.event_handler(*args))
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send("change")
new_frame = self.array_to_frame(array) new_frame = self.array_to_frame(array)
if frame: if frame:
@@ -152,6 +180,8 @@ class AudioCallback(AudioStreamTrack):
self, self,
track: MediaStreamTrack, track: MediaStreamTrack,
event_handler: StreamHandler, event_handler: StreamHandler,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None: ) -> None:
self.track = track self.track = track
self.event_handler = event_handler self.event_handler = event_handler
@@ -162,6 +192,8 @@ class AudioCallback(AudioStreamTrack):
self._start: float | None = None self._start: float | None = None
self.has_started = False self.has_started = False
self.last_timestamp = 0 self.last_timestamp = 0
self.channel = channel
self.set_additional_outputs = set_additional_outputs
super().__init__() super().__init__()
async def process_input_frames(self) -> None: async def process_input_frames(self) -> None:
@@ -189,6 +221,8 @@ class AudioCallback(AudioStreamTrack):
callable, callable,
self.queue, self.queue,
self.thread_quit, self.thread_quit,
lambda: self.channel,
self.set_additional_outputs,
False, False,
self.event_handler.output_sample_rate, self.event_handler.output_sample_rate,
self.event_handler.output_frame_size, self.event_handler.output_frame_size,
@@ -242,12 +276,16 @@ class ServerToClientVideo(VideoStreamTrack):
def __init__( def __init__(
self, self,
event_handler: Callable, event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None: ) -> None:
super().__init__() # don't forget this! super().__init__() # don't forget this!
self.event_handler = event_handler self.event_handler = event_handler
self.args_set = asyncio.Event() self.args_set = asyncio.Event()
self.latest_args: str | list[Any] = "not_set" self.latest_args: str | list[Any] = "not_set"
self.generator: Generator[Any, None, Any] | None = None self.generator: Generator[Any, None, Any] | None = None
self.channel = channel
self.set_additional_outputs = set_additional_outputs
def array_to_frame(self, array: np.ndarray) -> VideoFrame: def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24") return VideoFrame.from_ndarray(array, format="bgr24")
@@ -262,7 +300,14 @@ class ServerToClientVideo(VideoStreamTrack):
) )
try: try:
next_array = next(self.generator) next_array, outputs = split_output(next(self.generator))
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send("change")
except StopIteration: except StopIteration:
self.stop() self.stop()
return return
@@ -283,6 +328,8 @@ class ServerToClientAudio(AudioStreamTrack):
def __init__( def __init__(
self, self,
event_handler: Callable, event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None: ) -> None:
self.generator: Generator[Any, None, Any] | None = None self.generator: Generator[Any, None, Any] | None = None
self.event_handler = event_handler self.event_handler = event_handler
@@ -291,6 +338,8 @@ class ServerToClientAudio(AudioStreamTrack):
self.args_set = threading.Event() self.args_set = threading.Event()
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.thread_quit = asyncio.Event() self.thread_quit = asyncio.Event()
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.has_started = False self.has_started = False
self._start: float | None = None self._start: float | None = None
super().__init__() super().__init__()
@@ -315,6 +364,8 @@ class ServerToClientAudio(AudioStreamTrack):
callable, callable,
self.queue, self.queue,
self.thread_quit, self.thread_quit,
lambda: self.channel,
self.set_additional_outputs,
True, True,
) )
) )
@@ -353,6 +404,12 @@ class ServerToClientAudio(AudioStreamTrack):
super().stop() super().stop()
# For the return type
R = TypeVar("R")
# For the parameter specification
P = ParamSpec("P")
class WebRTC(Component): class WebRTC(Component):
""" """
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output). Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
@@ -369,8 +426,10 @@ class WebRTC(Component):
connections: dict[ connections: dict[
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
] = {} ] = {}
data_channels: dict[str, DataChannel] = {}
additional_outputs: dict[str, AdditionalOutputs] = {}
EVENTS = ["tick"] EVENTS = ["tick", "state_change"]
def __init__( def __init__(
self, self,
@@ -470,6 +529,14 @@ class WebRTC(Component):
value=value, value=value,
) )
def set_additional_outputs(
self, webrtc_id: str
) -> Callable[[AdditionalOutputs], None]:
def set_outputs(outputs: AdditionalOutputs):
self.additional_outputs[webrtc_id] = outputs
return set_outputs
def preprocess(self, payload: str) -> str: def preprocess(self, payload: str) -> str:
""" """
Parameters: Parameters:
@@ -498,6 +565,38 @@ class WebRTC(Component):
self.connections[webrtc_id].latest_args = list(args) self.connections[webrtc_id].latest_args = list(args)
self.connections[webrtc_id].args_set.set() # type: ignore self.connections[webrtc_id].args_set.set() # type: ignore
def change(
self,
fn: Callable[Concatenate[P], R],
inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | set[Block] | None = None,
js: str | None = None,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
):
inputs = inputs or []
if inputs and not isinstance(inputs, Iterable):
inputs = [inputs]
inputs = list(inputs)
def handler(webrtc_id: str, *args):
if webrtc_id in self.additional_outputs:
return fn(*args, *self.additional_outputs[webrtc_id].args) # type: ignore
return (
tuple([None for _ in range(len(outputs))])
if isinstance(outputs, Iterable)
else None
)
return self.state_change( # type: ignore
fn=handler,
inputs=[self] + cast(list, inputs),
outputs=outputs,
js=js,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
)
def stream( def stream(
self, self,
fn: Callable[..., Any] | StreamHandler | None = None, fn: Callable[..., Any] | StreamHandler | None = None,
@@ -599,6 +698,8 @@ class WebRTC(Component):
pc = RTCPeerConnection() pc = RTCPeerConnection()
self.pcs.add(pc) self.pcs.add(pc)
set_outputs = self.set_additional_outputs(body["webrtc_id"])
@pc.on("iceconnectionstatechange") @pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange(): async def on_iceconnectionstatechange():
logger.debug("ICE connection state change %s", pc.iceConnectionState) logger.debug("ICE connection state change %s", pc.iceConnectionState)
@@ -627,27 +728,61 @@ class WebRTC(Component):
cb = VideoCallback( cb = VideoCallback(
relay.subscribe(track), relay.subscribe(track),
event_handler=cast(Callable, self.event_handler), event_handler=cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
) )
elif self.modality == "audio": elif self.modality == "audio":
cb = AudioCallback( cb = AudioCallback(
relay.subscribe(track), relay.subscribe(track),
event_handler=cast(StreamHandler, self.event_handler).copy(), event_handler=cast(StreamHandler, self.event_handler).copy(),
set_additional_outputs=set_outputs,
) )
self.connections[body["webrtc_id"]] = cb self.connections[body["webrtc_id"]] = cb
if body["webrtc_id"] in self.data_channels:
self.connections[body["webrtc_id"]].channel = self.data_channels[
body["webrtc_id"]
]
logger.debug("Adding track to peer connection %s", cb) logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb) pc.addTrack(cb)
if self.mode == "receive": if self.mode == "receive":
if self.modality == "video": if self.modality == "video":
cb = ServerToClientVideo(cast(Callable, self.event_handler)) cb = ServerToClientVideo(
cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
)
elif self.modality == "audio": elif self.modality == "audio":
cb = ServerToClientAudio(cast(Callable, self.event_handler)) cb = ServerToClientAudio(
cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
)
logger.debug("Adding track to peer connection %s", cb) logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb) pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None)) cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
@pc.on("datachannel")
def on_datachannel(channel):
print("data channel established")
logger.debug(f"Data channel established: {channel.label}")
self.data_channels[body["webrtc_id"]] = channel
async def set_channel(webrtc_id: str):
print("webrtc_id", webrtc_id)
while not self.connections.get(webrtc_id):
await asyncio.sleep(0.05)
print("setting channel")
self.connections[webrtc_id].channel = channel
asyncio.create_task(set_channel(body["webrtc_id"]))
@channel.on("message")
def on_message(message):
logger.debug(f"Received message: {message}")
if channel.readyState == "open":
channel.send(f"Server received: {message}")
# handle offer # handle offer
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)

View File

@@ -34,6 +34,10 @@
export let mode: "send-receive" | "receive" = "send-receive"; export let mode: "send-receive" | "receive" = "send-receive";
export let track_constraints: MediaTrackConstraints = {}; export let track_constraints: MediaTrackConstraints = {};
const on_change_cb = () => {
gradio.dispatch("state_change");
}
let dragging = false; let dragging = false;
$: console.log("value", value); $: console.log("value", value);
@@ -63,6 +67,7 @@
{#if mode == "receive" && modality === "video"} {#if mode == "receive" && modality === "video"}
<StaticVideo <StaticVideo
bind:value={value} bind:value={value}
{on_change_cb}
{label} {label}
{show_label} {show_label}
{server} {server}
@@ -73,6 +78,7 @@
{:else if mode == "receive" && modality === "audio"} {:else if mode == "receive" && modality === "audio"}
<StaticAudio <StaticAudio
bind:value={value} bind:value={value}
{on_change_cb}
{label} {label}
{show_label} {show_label}
{server} {server}
@@ -91,6 +97,7 @@
{server} {server}
{rtc_configuration} {rtc_configuration}
{time_limit} {time_limit}
{on_change_cb}
on:clear={() => gradio.dispatch("clear")} on:clear={() => gradio.dispatch("clear")}
on:play={() => gradio.dispatch("play")} on:play={() => gradio.dispatch("play")}
on:pause={() => gradio.dispatch("pause")} on:pause={() => gradio.dispatch("pause")}
@@ -109,6 +116,7 @@
{:else if mode === "send-receive" && modality === "audio"} {:else if mode === "send-receive" && modality === "audio"}
<InteractiveAudio <InteractiveAudio
bind:value={value} bind:value={value}
{on_change_cb}
{label} {label}
{show_label} {show_label}
{server} {server}

View File

@@ -25,9 +25,9 @@
export let i18n: I18nFormatter; export let i18n: I18nFormatter;
export let time_limit: number | null = null; export let time_limit: number | null = null;
export let track_constraints: MediaTrackConstraints = {}; export let track_constraints: MediaTrackConstraints = {};
let _time_limit: number | null = null; export let on_change_cb: () => void;
$: console.log("time_limit", time_limit); let _time_limit: number | null = null;
export let server: { export let server: {
offer: (body: any) => Promise<any>; offer: (body: any) => Promise<any>;
@@ -41,12 +41,15 @@
const dispatch = createEventDispatcher<{ const dispatch = createEventDispatcher<{
tick: undefined; tick: undefined;
state_change: undefined;
error: string error: string
play: undefined; play: undefined;
stop: undefined; stop: undefined;
}>(); }>();
onMount(() => { onMount(() => {
window.setInterval(() => { window.setInterval(() => {
if (stream_state == "open") { if (stream_state == "open") {
@@ -103,7 +106,7 @@
} }
if (stream == null) return; if (stream == null) return;
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => { start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => {
pc = connection; pc = connection;
}).catch(() => { }).catch(() => {
console.info("catching") console.info("catching")

View File

@@ -21,6 +21,7 @@
}; };
export let rtc_configuration: Object; export let rtc_configuration: Object;
export let track_constraints: MediaTrackConstraints = {}; export let track_constraints: MediaTrackConstraints = {};
export let on_change_cb: () => void;
const dispatch = createEventDispatcher<{ const dispatch = createEventDispatcher<{
change: FileData | null; change: FileData | null;
@@ -50,6 +51,7 @@
{include_audio} {include_audio}
{time_limit} {time_limit}
{track_constraints} {track_constraints}
{on_change_cb}
on:error on:error
on:start_recording on:start_recording
on:stop_recording on:stop_recording

View File

@@ -17,6 +17,7 @@
export let show_label = true; export let show_label = true;
export let rtc_configuration: Object | null = null; export let rtc_configuration: Object | null = null;
export let i18n: I18nFormatter; export let i18n: I18nFormatter;
export let on_change_cb: () => void;
export let server: { export let server: {
offer: (body: any) => Promise<any>; offer: (body: any) => Promise<any>;
@@ -68,7 +69,7 @@
} }
) )
let stream = null; let stream = null;
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => { start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => {
pc = connection; pc = connection;
}).catch(() => { }).catch(() => {
console.info("catching") console.info("catching")

View File

@@ -13,6 +13,7 @@
export let label: string | undefined = undefined; export let label: string | undefined = undefined;
export let show_label = true; export let show_label = true;
export let rtc_configuration: Object | null = null; export let rtc_configuration: Object | null = null;
export let on_change_cb: () => void;
export let server: { export let server: {
offer: (body: any) => Promise<any>; offer: (body: any) => Promise<any>;
}; };
@@ -59,7 +60,7 @@
} }
} }
) )
start(null, pc, video_element, server.offer, _webrtc_id).then((connection) => { start(null, pc, video_element, server.offer, _webrtc_id, "video", on_change_cb).then((connection) => {
pc = connection; pc = connection;
}).catch(() => { }).catch(() => {
console.log("catching") console.log("catching")

View File

@@ -24,6 +24,7 @@
let _time_limit: number | null = null; let _time_limit: number | null = null;
export let time_limit: number | null = null; export let time_limit: number | null = null;
let stream_state: "open" | "waiting" | "closed" = "closed"; let stream_state: "open" | "waiting" | "closed" = "closed";
export let on_change_cb: () => void;
const _webrtc_id = Math.random().toString(36).substring(2); const _webrtc_id = Math.random().toString(36).substring(2);
export const modify_stream: (state: "open" | "closed" | "waiting") => void = ( export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
@@ -139,7 +140,7 @@
) )
stream_state = "waiting" stream_state = "waiting"
webrtc_id = Math.random().toString(36).substring(2); webrtc_id = Math.random().toString(36).substring(2);
start(stream, pc, video_source, server.offer, webrtc_id).then((connection) => { start(stream, pc, video_source, server.offer, webrtc_id, "video", on_change_cb).then((connection) => {
pc = connection; pc = connection;
}).catch(() => { }).catch(() => {
console.info("catching") console.info("catching")

View File

@@ -44,8 +44,24 @@ export function createPeerConnection(pc, node) {
return pc; return pc;
} }
export async function start(stream, pc: RTCPeerConnection, node, server_fn, webrtc_id, modality: "video" | "audio" = "video") { export async function start(stream, pc: RTCPeerConnection, node, server_fn, webrtc_id,
modality: "video" | "audio" = "video", on_change_cb: () => void = () => {}) {
pc = createPeerConnection(pc, node); pc = createPeerConnection(pc, node);
const data_channel = pc.createDataChannel("text");
data_channel.onopen = () => {
console.debug("Data channel is open");
data_channel.send("handshake");
};
data_channel.onmessage = (event) => {
console.debug("Received message:", event.data);
if (event.data === "change") {
console.debug("Change event received");
on_change_cb();
}
};
if (stream) { if (stream) {
stream.getTracks().forEach((track) => { stream.getTracks().forEach((track) => {
console.debug("Track stream callback", track); console.debug("Track stream callback", track);