mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 01:49:23 +08:00
add code
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from .reply_on_pause import ReplyOnPause
|
||||
from .utils import AdditionalOutputs
|
||||
from .webrtc import StreamHandler, WebRTC
|
||||
|
||||
__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"]
|
||||
__all__ = ["AdditionalOutputs", "ReplyOnPause", "StreamHandler", "WebRTC"]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import fractions
|
||||
import logging
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, Protocol, cast
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,10 +12,38 @@ logger = logging.getLogger(__name__)
|
||||
AUDIO_PTIME = 0.020
|
||||
|
||||
|
||||
class AdditionalOutputs:
|
||||
def __init__(self, *args) -> None:
|
||||
self.args = args
|
||||
|
||||
|
||||
class DataChannel(Protocol):
|
||||
def send(self, message: str) -> None: ...
|
||||
|
||||
|
||||
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||
if isinstance(data, tuple):
|
||||
# handle the bare audio case
|
||||
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
|
||||
return data, None
|
||||
if not len(data) == 2:
|
||||
raise ValueError(
|
||||
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
||||
)
|
||||
if not isinstance(data[-1], AdditionalOutputs):
|
||||
raise ValueError(
|
||||
"The last element of the tuple must be an instance of AdditionalOutputs."
|
||||
)
|
||||
return data[0], cast(AdditionalOutputs, data[1])
|
||||
return data, None
|
||||
|
||||
|
||||
async def player_worker_decode(
|
||||
next_frame: Callable,
|
||||
queue: asyncio.Queue,
|
||||
thread_quit: asyncio.Event,
|
||||
channel: Callable[[], DataChannel | None] | None,
|
||||
set_additional_outputs: Callable | None,
|
||||
quit_on_none: bool = False,
|
||||
sample_rate: int = 48000,
|
||||
frame_size: int = int(48000 * AUDIO_PTIME),
|
||||
@@ -31,7 +60,17 @@ async def player_worker_decode(
|
||||
while not thread_quit.is_set():
|
||||
try:
|
||||
# Get next frame
|
||||
frame = await asyncio.wait_for(next_frame(), timeout=60)
|
||||
frame, outputs = split_output(
|
||||
await asyncio.wait_for(next_frame(), timeout=60)
|
||||
)
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
and set_additional_outputs
|
||||
and channel
|
||||
and channel()
|
||||
):
|
||||
set_additional_outputs(outputs)
|
||||
cast(DataChannel, channel()).send("change")
|
||||
|
||||
if frame is None:
|
||||
if quit_on_none:
|
||||
@@ -65,7 +104,7 @@ async def player_worker_decode(
|
||||
processed_frame.time_base = audio_time_base
|
||||
audio_samples += processed_frame.samples
|
||||
await queue.put(processed_frame)
|
||||
logger.debug("Queue size utils.py: %s", queue.qsize())
|
||||
logger.debug("Queue size utils.py: %s", queue.qsize())
|
||||
|
||||
except (TimeoutError, asyncio.TimeoutError):
|
||||
logger.warning(
|
||||
|
||||
@@ -11,7 +11,19 @@ import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Concatenate,
|
||||
Generator,
|
||||
Iterable,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
import anyio.to_thread
|
||||
import av
|
||||
@@ -29,7 +41,12 @@ from gradio import wasm_utils
|
||||
from gradio.components.base import Component, server
|
||||
from gradio_client import handle_file
|
||||
|
||||
from gradio_webrtc.utils import player_worker_decode
|
||||
from gradio_webrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
DataChannel,
|
||||
player_worker_decode,
|
||||
split_output,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
@@ -55,11 +72,15 @@ class VideoCallback(VideoStreamTrack):
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: Callable,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
|
||||
def add_frame_to_payload(
|
||||
self, args: list[Any], frame: np.ndarray | None
|
||||
@@ -88,7 +109,14 @@ class VideoCallback(VideoStreamTrack):
|
||||
|
||||
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
||||
|
||||
array = self.event_handler(*args)
|
||||
array, outputs = split_output(self.event_handler(*args))
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
and self.set_additional_outputs
|
||||
and self.channel
|
||||
):
|
||||
self.set_additional_outputs(outputs)
|
||||
self.channel.send("change")
|
||||
|
||||
new_frame = self.array_to_frame(array)
|
||||
if frame:
|
||||
@@ -152,6 +180,8 @@ class AudioCallback(AudioStreamTrack):
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: StreamHandler,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
@@ -162,6 +192,8 @@ class AudioCallback(AudioStreamTrack):
|
||||
self._start: float | None = None
|
||||
self.has_started = False
|
||||
self.last_timestamp = 0
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
super().__init__()
|
||||
|
||||
async def process_input_frames(self) -> None:
|
||||
@@ -189,6 +221,8 @@ class AudioCallback(AudioStreamTrack):
|
||||
callable,
|
||||
self.queue,
|
||||
self.thread_quit,
|
||||
lambda: self.channel,
|
||||
self.set_additional_outputs,
|
||||
False,
|
||||
self.event_handler.output_sample_rate,
|
||||
self.event_handler.output_frame_size,
|
||||
@@ -242,12 +276,16 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
def __init__(
|
||||
self,
|
||||
event_handler: Callable,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
self.event_handler = event_handler
|
||||
self.args_set = asyncio.Event()
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
|
||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||
@@ -262,7 +300,14 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
)
|
||||
|
||||
try:
|
||||
next_array = next(self.generator)
|
||||
next_array, outputs = split_output(next(self.generator))
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
and self.set_additional_outputs
|
||||
and self.channel
|
||||
):
|
||||
self.set_additional_outputs(outputs)
|
||||
self.channel.send("change")
|
||||
except StopIteration:
|
||||
self.stop()
|
||||
return
|
||||
@@ -283,6 +328,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
def __init__(
|
||||
self,
|
||||
event_handler: Callable,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
self.event_handler = event_handler
|
||||
@@ -291,6 +338,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
self.args_set = threading.Event()
|
||||
self.queue = asyncio.Queue()
|
||||
self.thread_quit = asyncio.Event()
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.has_started = False
|
||||
self._start: float | None = None
|
||||
super().__init__()
|
||||
@@ -315,6 +364,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
callable,
|
||||
self.queue,
|
||||
self.thread_quit,
|
||||
lambda: self.channel,
|
||||
self.set_additional_outputs,
|
||||
True,
|
||||
)
|
||||
)
|
||||
@@ -353,6 +404,12 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
super().stop()
|
||||
|
||||
|
||||
# For the return type
|
||||
R = TypeVar("R")
|
||||
# For the parameter specification
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class WebRTC(Component):
|
||||
"""
|
||||
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
||||
@@ -369,8 +426,10 @@ class WebRTC(Component):
|
||||
connections: dict[
|
||||
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
|
||||
] = {}
|
||||
data_channels: dict[str, DataChannel] = {}
|
||||
additional_outputs: dict[str, AdditionalOutputs] = {}
|
||||
|
||||
EVENTS = ["tick"]
|
||||
EVENTS = ["tick", "state_change"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -470,6 +529,14 @@ class WebRTC(Component):
|
||||
value=value,
|
||||
)
|
||||
|
||||
def set_additional_outputs(
|
||||
self, webrtc_id: str
|
||||
) -> Callable[[AdditionalOutputs], None]:
|
||||
def set_outputs(outputs: AdditionalOutputs):
|
||||
self.additional_outputs[webrtc_id] = outputs
|
||||
|
||||
return set_outputs
|
||||
|
||||
def preprocess(self, payload: str) -> str:
|
||||
"""
|
||||
Parameters:
|
||||
@@ -498,6 +565,38 @@ class WebRTC(Component):
|
||||
self.connections[webrtc_id].latest_args = list(args)
|
||||
self.connections[webrtc_id].args_set.set() # type: ignore
|
||||
|
||||
def change(
|
||||
self,
|
||||
fn: Callable[Concatenate[P], R],
|
||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
js: str | None = None,
|
||||
concurrency_limit: int | None | Literal["default"] = "default",
|
||||
concurrency_id: str | None = None,
|
||||
):
|
||||
inputs = inputs or []
|
||||
if inputs and not isinstance(inputs, Iterable):
|
||||
inputs = [inputs]
|
||||
inputs = list(inputs)
|
||||
|
||||
def handler(webrtc_id: str, *args):
|
||||
if webrtc_id in self.additional_outputs:
|
||||
return fn(*args, *self.additional_outputs[webrtc_id].args) # type: ignore
|
||||
return (
|
||||
tuple([None for _ in range(len(outputs))])
|
||||
if isinstance(outputs, Iterable)
|
||||
else None
|
||||
)
|
||||
|
||||
return self.state_change( # type: ignore
|
||||
fn=handler,
|
||||
inputs=[self] + cast(list, inputs),
|
||||
outputs=outputs,
|
||||
js=js,
|
||||
concurrency_limit=concurrency_limit,
|
||||
concurrency_id=concurrency_id,
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
fn: Callable[..., Any] | StreamHandler | None = None,
|
||||
@@ -599,6 +698,8 @@ class WebRTC(Component):
|
||||
pc = RTCPeerConnection()
|
||||
self.pcs.add(pc)
|
||||
|
||||
set_outputs = self.set_additional_outputs(body["webrtc_id"])
|
||||
|
||||
@pc.on("iceconnectionstatechange")
|
||||
async def on_iceconnectionstatechange():
|
||||
logger.debug("ICE connection state change %s", pc.iceConnectionState)
|
||||
@@ -627,27 +728,61 @@ class WebRTC(Component):
|
||||
cb = VideoCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
elif self.modality == "audio":
|
||||
cb = AudioCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(StreamHandler, self.event_handler).copy(),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
if body["webrtc_id"] in self.data_channels:
|
||||
self.connections[body["webrtc_id"]].channel = self.data_channels[
|
||||
body["webrtc_id"]
|
||||
]
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
|
||||
if self.mode == "receive":
|
||||
if self.modality == "video":
|
||||
cb = ServerToClientVideo(cast(Callable, self.event_handler))
|
||||
cb = ServerToClientVideo(
|
||||
cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
elif self.modality == "audio":
|
||||
cb = ServerToClientAudio(cast(Callable, self.event_handler))
|
||||
cb = ServerToClientAudio(
|
||||
cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
|
||||
|
||||
@pc.on("datachannel")
|
||||
def on_datachannel(channel):
|
||||
print("data channel established")
|
||||
logger.debug(f"Data channel established: {channel.label}")
|
||||
|
||||
self.data_channels[body["webrtc_id"]] = channel
|
||||
|
||||
async def set_channel(webrtc_id: str):
|
||||
print("webrtc_id", webrtc_id)
|
||||
while not self.connections.get(webrtc_id):
|
||||
await asyncio.sleep(0.05)
|
||||
print("setting channel")
|
||||
self.connections[webrtc_id].channel = channel
|
||||
|
||||
asyncio.create_task(set_channel(body["webrtc_id"]))
|
||||
|
||||
@channel.on("message")
|
||||
def on_message(message):
|
||||
logger.debug(f"Received message: {message}")
|
||||
if channel.readyState == "open":
|
||||
channel.send(f"Server received: {message}")
|
||||
|
||||
# handle offer
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
|
||||
@@ -34,6 +34,10 @@
|
||||
export let mode: "send-receive" | "receive" = "send-receive";
|
||||
export let track_constraints: MediaTrackConstraints = {};
|
||||
|
||||
const on_change_cb = () => {
|
||||
gradio.dispatch("state_change");
|
||||
}
|
||||
|
||||
let dragging = false;
|
||||
|
||||
$: console.log("value", value);
|
||||
@@ -63,6 +67,7 @@
|
||||
{#if mode == "receive" && modality === "video"}
|
||||
<StaticVideo
|
||||
bind:value={value}
|
||||
{on_change_cb}
|
||||
{label}
|
||||
{show_label}
|
||||
{server}
|
||||
@@ -73,6 +78,7 @@
|
||||
{:else if mode == "receive" && modality === "audio"}
|
||||
<StaticAudio
|
||||
bind:value={value}
|
||||
{on_change_cb}
|
||||
{label}
|
||||
{show_label}
|
||||
{server}
|
||||
@@ -91,6 +97,7 @@
|
||||
{server}
|
||||
{rtc_configuration}
|
||||
{time_limit}
|
||||
{on_change_cb}
|
||||
on:clear={() => gradio.dispatch("clear")}
|
||||
on:play={() => gradio.dispatch("play")}
|
||||
on:pause={() => gradio.dispatch("pause")}
|
||||
@@ -109,6 +116,7 @@
|
||||
{:else if mode === "send-receive" && modality === "audio"}
|
||||
<InteractiveAudio
|
||||
bind:value={value}
|
||||
{on_change_cb}
|
||||
{label}
|
||||
{show_label}
|
||||
{server}
|
||||
|
||||
@@ -25,9 +25,9 @@
|
||||
export let i18n: I18nFormatter;
|
||||
export let time_limit: number | null = null;
|
||||
export let track_constraints: MediaTrackConstraints = {};
|
||||
let _time_limit: number | null = null;
|
||||
export let on_change_cb: () => void;
|
||||
|
||||
$: console.log("time_limit", time_limit);
|
||||
let _time_limit: number | null = null;
|
||||
|
||||
export let server: {
|
||||
offer: (body: any) => Promise<any>;
|
||||
@@ -41,12 +41,15 @@
|
||||
|
||||
const dispatch = createEventDispatcher<{
|
||||
tick: undefined;
|
||||
state_change: undefined;
|
||||
error: string
|
||||
play: undefined;
|
||||
stop: undefined;
|
||||
}>();
|
||||
|
||||
|
||||
|
||||
|
||||
onMount(() => {
|
||||
window.setInterval(() => {
|
||||
if (stream_state == "open") {
|
||||
@@ -103,7 +106,7 @@
|
||||
}
|
||||
if (stream == null) return;
|
||||
|
||||
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => {
|
||||
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => {
|
||||
pc = connection;
|
||||
}).catch(() => {
|
||||
console.info("catching")
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
};
|
||||
export let rtc_configuration: Object;
|
||||
export let track_constraints: MediaTrackConstraints = {};
|
||||
export let on_change_cb: () => void;
|
||||
|
||||
const dispatch = createEventDispatcher<{
|
||||
change: FileData | null;
|
||||
@@ -50,6 +51,7 @@
|
||||
{include_audio}
|
||||
{time_limit}
|
||||
{track_constraints}
|
||||
{on_change_cb}
|
||||
on:error
|
||||
on:start_recording
|
||||
on:stop_recording
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
export let show_label = true;
|
||||
export let rtc_configuration: Object | null = null;
|
||||
export let i18n: I18nFormatter;
|
||||
export let on_change_cb: () => void;
|
||||
|
||||
export let server: {
|
||||
offer: (body: any) => Promise<any>;
|
||||
@@ -68,7 +69,7 @@
|
||||
}
|
||||
)
|
||||
let stream = null;
|
||||
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => {
|
||||
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => {
|
||||
pc = connection;
|
||||
}).catch(() => {
|
||||
console.info("catching")
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
export let label: string | undefined = undefined;
|
||||
export let show_label = true;
|
||||
export let rtc_configuration: Object | null = null;
|
||||
export let on_change_cb: () => void;
|
||||
export let server: {
|
||||
offer: (body: any) => Promise<any>;
|
||||
};
|
||||
@@ -59,7 +60,7 @@
|
||||
}
|
||||
}
|
||||
)
|
||||
start(null, pc, video_element, server.offer, _webrtc_id).then((connection) => {
|
||||
start(null, pc, video_element, server.offer, _webrtc_id, "video", on_change_cb).then((connection) => {
|
||||
pc = connection;
|
||||
}).catch(() => {
|
||||
console.log("catching")
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
let _time_limit: number | null = null;
|
||||
export let time_limit: number | null = null;
|
||||
let stream_state: "open" | "waiting" | "closed" = "closed";
|
||||
export let on_change_cb: () => void;
|
||||
const _webrtc_id = Math.random().toString(36).substring(2);
|
||||
|
||||
export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
|
||||
@@ -139,7 +140,7 @@
|
||||
)
|
||||
stream_state = "waiting"
|
||||
webrtc_id = Math.random().toString(36).substring(2);
|
||||
start(stream, pc, video_source, server.offer, webrtc_id).then((connection) => {
|
||||
start(stream, pc, video_source, server.offer, webrtc_id, "video", on_change_cb).then((connection) => {
|
||||
pc = connection;
|
||||
}).catch(() => {
|
||||
console.info("catching")
|
||||
|
||||
@@ -44,8 +44,24 @@ export function createPeerConnection(pc, node) {
|
||||
return pc;
|
||||
}
|
||||
|
||||
export async function start(stream, pc: RTCPeerConnection, node, server_fn, webrtc_id, modality: "video" | "audio" = "video") {
|
||||
export async function start(stream, pc: RTCPeerConnection, node, server_fn, webrtc_id,
|
||||
modality: "video" | "audio" = "video", on_change_cb: () => void = () => {}) {
|
||||
pc = createPeerConnection(pc, node);
|
||||
const data_channel = pc.createDataChannel("text");
|
||||
|
||||
data_channel.onopen = () => {
|
||||
console.debug("Data channel is open");
|
||||
data_channel.send("handshake");
|
||||
};
|
||||
|
||||
data_channel.onmessage = (event) => {
|
||||
console.debug("Received message:", event.data);
|
||||
if (event.data === "change") {
|
||||
console.debug("Change event received");
|
||||
on_change_cb();
|
||||
}
|
||||
};
|
||||
|
||||
if (stream) {
|
||||
stream.getTracks().forEach((track) => {
|
||||
console.debug("Track stream callback", track);
|
||||
|
||||
Reference in New Issue
Block a user