mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Merge pull request #13 from freddyaboulton/return-other-components
Be able to update other components
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
from .reply_on_pause import ReplyOnPause
|
from .reply_on_pause import ReplyOnPause
|
||||||
|
from .utils import AdditionalOutputs
|
||||||
from .webrtc import StreamHandler, WebRTC
|
from .webrtc import StreamHandler, WebRTC
|
||||||
|
|
||||||
__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"]
|
__all__ = ["AdditionalOutputs", "ReplyOnPause", "StreamHandler", "WebRTC"]
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import fractions
|
import fractions
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable
|
from typing import Any, Callable, Protocol, cast
|
||||||
|
|
||||||
import av
|
import av
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -11,10 +12,38 @@ logger = logging.getLogger(__name__)
|
|||||||
AUDIO_PTIME = 0.020
|
AUDIO_PTIME = 0.020
|
||||||
|
|
||||||
|
|
||||||
|
class AdditionalOutputs:
|
||||||
|
def __init__(self, *args) -> None:
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
|
||||||
|
class DataChannel(Protocol):
|
||||||
|
def send(self, message: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||||
|
if isinstance(data, tuple):
|
||||||
|
# handle the bare audio case
|
||||||
|
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
|
||||||
|
return data, None
|
||||||
|
if not len(data) == 2:
|
||||||
|
raise ValueError(
|
||||||
|
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
||||||
|
)
|
||||||
|
if not isinstance(data[-1], AdditionalOutputs):
|
||||||
|
raise ValueError(
|
||||||
|
"The last element of the tuple must be an instance of AdditionalOutputs."
|
||||||
|
)
|
||||||
|
return data[0], cast(AdditionalOutputs, data[1])
|
||||||
|
return data, None
|
||||||
|
|
||||||
|
|
||||||
async def player_worker_decode(
|
async def player_worker_decode(
|
||||||
next_frame: Callable,
|
next_frame: Callable,
|
||||||
queue: asyncio.Queue,
|
queue: asyncio.Queue,
|
||||||
thread_quit: asyncio.Event,
|
thread_quit: asyncio.Event,
|
||||||
|
channel: Callable[[], DataChannel | None] | None,
|
||||||
|
set_additional_outputs: Callable | None,
|
||||||
quit_on_none: bool = False,
|
quit_on_none: bool = False,
|
||||||
sample_rate: int = 48000,
|
sample_rate: int = 48000,
|
||||||
frame_size: int = int(48000 * AUDIO_PTIME),
|
frame_size: int = int(48000 * AUDIO_PTIME),
|
||||||
@@ -31,7 +60,17 @@ async def player_worker_decode(
|
|||||||
while not thread_quit.is_set():
|
while not thread_quit.is_set():
|
||||||
try:
|
try:
|
||||||
# Get next frame
|
# Get next frame
|
||||||
frame = await asyncio.wait_for(next_frame(), timeout=60)
|
frame, outputs = split_output(
|
||||||
|
await asyncio.wait_for(next_frame(), timeout=60)
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
isinstance(outputs, AdditionalOutputs)
|
||||||
|
and set_additional_outputs
|
||||||
|
and channel
|
||||||
|
and channel()
|
||||||
|
):
|
||||||
|
set_additional_outputs(outputs)
|
||||||
|
cast(DataChannel, channel()).send("change")
|
||||||
|
|
||||||
if frame is None:
|
if frame is None:
|
||||||
if quit_on_none:
|
if quit_on_none:
|
||||||
@@ -65,7 +104,7 @@ async def player_worker_decode(
|
|||||||
processed_frame.time_base = audio_time_base
|
processed_frame.time_base = audio_time_base
|
||||||
audio_samples += processed_frame.samples
|
audio_samples += processed_frame.samples
|
||||||
await queue.put(processed_frame)
|
await queue.put(processed_frame)
|
||||||
logger.debug("Queue size utils.py: %s", queue.qsize())
|
logger.debug("Queue size utils.py: %s", queue.qsize())
|
||||||
|
|
||||||
except (TimeoutError, asyncio.TimeoutError):
|
except (TimeoutError, asyncio.TimeoutError):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -11,7 +11,19 @@ import traceback
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Concatenate,
|
||||||
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
Literal,
|
||||||
|
ParamSpec,
|
||||||
|
Sequence,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import anyio.to_thread
|
import anyio.to_thread
|
||||||
import av
|
import av
|
||||||
@@ -29,7 +41,12 @@ from gradio import wasm_utils
|
|||||||
from gradio.components.base import Component, server
|
from gradio.components.base import Component, server
|
||||||
from gradio_client import handle_file
|
from gradio_client import handle_file
|
||||||
|
|
||||||
from gradio_webrtc.utils import player_worker_decode
|
from gradio_webrtc.utils import (
|
||||||
|
AdditionalOutputs,
|
||||||
|
DataChannel,
|
||||||
|
player_worker_decode,
|
||||||
|
split_output,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.blocks import Block
|
from gradio.blocks import Block
|
||||||
@@ -55,11 +72,15 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
self,
|
self,
|
||||||
track: MediaStreamTrack,
|
track: MediaStreamTrack,
|
||||||
event_handler: Callable,
|
event_handler: Callable,
|
||||||
|
channel: DataChannel | None = None,
|
||||||
|
set_additional_outputs: Callable | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__() # don't forget this!
|
super().__init__() # don't forget this!
|
||||||
self.track = track
|
self.track = track
|
||||||
self.event_handler = event_handler
|
self.event_handler = event_handler
|
||||||
self.latest_args: str | list[Any] = "not_set"
|
self.latest_args: str | list[Any] = "not_set"
|
||||||
|
self.channel = channel
|
||||||
|
self.set_additional_outputs = set_additional_outputs
|
||||||
|
|
||||||
def add_frame_to_payload(
|
def add_frame_to_payload(
|
||||||
self, args: list[Any], frame: np.ndarray | None
|
self, args: list[Any], frame: np.ndarray | None
|
||||||
@@ -88,7 +109,14 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
|
|
||||||
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
||||||
|
|
||||||
array = self.event_handler(*args)
|
array, outputs = split_output(self.event_handler(*args))
|
||||||
|
if (
|
||||||
|
isinstance(outputs, AdditionalOutputs)
|
||||||
|
and self.set_additional_outputs
|
||||||
|
and self.channel
|
||||||
|
):
|
||||||
|
self.set_additional_outputs(outputs)
|
||||||
|
self.channel.send("change")
|
||||||
|
|
||||||
new_frame = self.array_to_frame(array)
|
new_frame = self.array_to_frame(array)
|
||||||
if frame:
|
if frame:
|
||||||
@@ -152,6 +180,8 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
self,
|
self,
|
||||||
track: MediaStreamTrack,
|
track: MediaStreamTrack,
|
||||||
event_handler: StreamHandler,
|
event_handler: StreamHandler,
|
||||||
|
channel: DataChannel | None = None,
|
||||||
|
set_additional_outputs: Callable | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.track = track
|
self.track = track
|
||||||
self.event_handler = event_handler
|
self.event_handler = event_handler
|
||||||
@@ -162,6 +192,8 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
self._start: float | None = None
|
self._start: float | None = None
|
||||||
self.has_started = False
|
self.has_started = False
|
||||||
self.last_timestamp = 0
|
self.last_timestamp = 0
|
||||||
|
self.channel = channel
|
||||||
|
self.set_additional_outputs = set_additional_outputs
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
async def process_input_frames(self) -> None:
|
async def process_input_frames(self) -> None:
|
||||||
@@ -189,6 +221,8 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
callable,
|
callable,
|
||||||
self.queue,
|
self.queue,
|
||||||
self.thread_quit,
|
self.thread_quit,
|
||||||
|
lambda: self.channel,
|
||||||
|
self.set_additional_outputs,
|
||||||
False,
|
False,
|
||||||
self.event_handler.output_sample_rate,
|
self.event_handler.output_sample_rate,
|
||||||
self.event_handler.output_frame_size,
|
self.event_handler.output_frame_size,
|
||||||
@@ -242,12 +276,16 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
event_handler: Callable,
|
event_handler: Callable,
|
||||||
|
channel: DataChannel | None = None,
|
||||||
|
set_additional_outputs: Callable | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__() # don't forget this!
|
super().__init__() # don't forget this!
|
||||||
self.event_handler = event_handler
|
self.event_handler = event_handler
|
||||||
self.args_set = asyncio.Event()
|
self.args_set = asyncio.Event()
|
||||||
self.latest_args: str | list[Any] = "not_set"
|
self.latest_args: str | list[Any] = "not_set"
|
||||||
self.generator: Generator[Any, None, Any] | None = None
|
self.generator: Generator[Any, None, Any] | None = None
|
||||||
|
self.channel = channel
|
||||||
|
self.set_additional_outputs = set_additional_outputs
|
||||||
|
|
||||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||||
@@ -262,7 +300,14 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
next_array = next(self.generator)
|
next_array, outputs = split_output(next(self.generator))
|
||||||
|
if (
|
||||||
|
isinstance(outputs, AdditionalOutputs)
|
||||||
|
and self.set_additional_outputs
|
||||||
|
and self.channel
|
||||||
|
):
|
||||||
|
self.set_additional_outputs(outputs)
|
||||||
|
self.channel.send("change")
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.stop()
|
self.stop()
|
||||||
return
|
return
|
||||||
@@ -283,6 +328,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
event_handler: Callable,
|
event_handler: Callable,
|
||||||
|
channel: DataChannel | None = None,
|
||||||
|
set_additional_outputs: Callable | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.generator: Generator[Any, None, Any] | None = None
|
self.generator: Generator[Any, None, Any] | None = None
|
||||||
self.event_handler = event_handler
|
self.event_handler = event_handler
|
||||||
@@ -291,6 +338,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
self.args_set = threading.Event()
|
self.args_set = threading.Event()
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
self.thread_quit = asyncio.Event()
|
self.thread_quit = asyncio.Event()
|
||||||
|
self.channel = channel
|
||||||
|
self.set_additional_outputs = set_additional_outputs
|
||||||
self.has_started = False
|
self.has_started = False
|
||||||
self._start: float | None = None
|
self._start: float | None = None
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -315,6 +364,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
callable,
|
callable,
|
||||||
self.queue,
|
self.queue,
|
||||||
self.thread_quit,
|
self.thread_quit,
|
||||||
|
lambda: self.channel,
|
||||||
|
self.set_additional_outputs,
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -353,6 +404,12 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
super().stop()
|
super().stop()
|
||||||
|
|
||||||
|
|
||||||
|
# For the return type
|
||||||
|
R = TypeVar("R")
|
||||||
|
# For the parameter specification
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
class WebRTC(Component):
|
class WebRTC(Component):
|
||||||
"""
|
"""
|
||||||
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
||||||
@@ -369,8 +426,10 @@ class WebRTC(Component):
|
|||||||
connections: dict[
|
connections: dict[
|
||||||
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
|
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
|
||||||
] = {}
|
] = {}
|
||||||
|
data_channels: dict[str, DataChannel] = {}
|
||||||
|
additional_outputs: dict[str, AdditionalOutputs] = {}
|
||||||
|
|
||||||
EVENTS = ["tick"]
|
EVENTS = ["tick", "state_change"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -470,6 +529,14 @@ class WebRTC(Component):
|
|||||||
value=value,
|
value=value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_additional_outputs(
|
||||||
|
self, webrtc_id: str
|
||||||
|
) -> Callable[[AdditionalOutputs], None]:
|
||||||
|
def set_outputs(outputs: AdditionalOutputs):
|
||||||
|
self.additional_outputs[webrtc_id] = outputs
|
||||||
|
|
||||||
|
return set_outputs
|
||||||
|
|
||||||
def preprocess(self, payload: str) -> str:
|
def preprocess(self, payload: str) -> str:
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
@@ -498,6 +565,38 @@ class WebRTC(Component):
|
|||||||
self.connections[webrtc_id].latest_args = list(args)
|
self.connections[webrtc_id].latest_args = list(args)
|
||||||
self.connections[webrtc_id].args_set.set() # type: ignore
|
self.connections[webrtc_id].args_set.set() # type: ignore
|
||||||
|
|
||||||
|
def change(
|
||||||
|
self,
|
||||||
|
fn: Callable[Concatenate[P], R],
|
||||||
|
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||||
|
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||||
|
js: str | None = None,
|
||||||
|
concurrency_limit: int | None | Literal["default"] = "default",
|
||||||
|
concurrency_id: str | None = None,
|
||||||
|
):
|
||||||
|
inputs = inputs or []
|
||||||
|
if inputs and not isinstance(inputs, Iterable):
|
||||||
|
inputs = [inputs]
|
||||||
|
inputs = list(inputs)
|
||||||
|
|
||||||
|
def handler(webrtc_id: str, *args):
|
||||||
|
if webrtc_id in self.additional_outputs:
|
||||||
|
return fn(*args, *self.additional_outputs[webrtc_id].args) # type: ignore
|
||||||
|
return (
|
||||||
|
tuple([None for _ in range(len(outputs))])
|
||||||
|
if isinstance(outputs, Iterable)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.state_change( # type: ignore
|
||||||
|
fn=handler,
|
||||||
|
inputs=[self] + cast(list, inputs),
|
||||||
|
outputs=outputs,
|
||||||
|
js=js,
|
||||||
|
concurrency_limit=concurrency_limit,
|
||||||
|
concurrency_id=concurrency_id,
|
||||||
|
)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
fn: Callable[..., Any] | StreamHandler | None = None,
|
fn: Callable[..., Any] | StreamHandler | None = None,
|
||||||
@@ -599,6 +698,8 @@ class WebRTC(Component):
|
|||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
self.pcs.add(pc)
|
self.pcs.add(pc)
|
||||||
|
|
||||||
|
set_outputs = self.set_additional_outputs(body["webrtc_id"])
|
||||||
|
|
||||||
@pc.on("iceconnectionstatechange")
|
@pc.on("iceconnectionstatechange")
|
||||||
async def on_iceconnectionstatechange():
|
async def on_iceconnectionstatechange():
|
||||||
logger.debug("ICE connection state change %s", pc.iceConnectionState)
|
logger.debug("ICE connection state change %s", pc.iceConnectionState)
|
||||||
@@ -627,27 +728,61 @@ class WebRTC(Component):
|
|||||||
cb = VideoCallback(
|
cb = VideoCallback(
|
||||||
relay.subscribe(track),
|
relay.subscribe(track),
|
||||||
event_handler=cast(Callable, self.event_handler),
|
event_handler=cast(Callable, self.event_handler),
|
||||||
|
set_additional_outputs=set_outputs,
|
||||||
)
|
)
|
||||||
elif self.modality == "audio":
|
elif self.modality == "audio":
|
||||||
cb = AudioCallback(
|
cb = AudioCallback(
|
||||||
relay.subscribe(track),
|
relay.subscribe(track),
|
||||||
event_handler=cast(StreamHandler, self.event_handler).copy(),
|
event_handler=cast(StreamHandler, self.event_handler).copy(),
|
||||||
|
set_additional_outputs=set_outputs,
|
||||||
)
|
)
|
||||||
self.connections[body["webrtc_id"]] = cb
|
self.connections[body["webrtc_id"]] = cb
|
||||||
|
if body["webrtc_id"] in self.data_channels:
|
||||||
|
self.connections[body["webrtc_id"]].channel = self.data_channels[
|
||||||
|
body["webrtc_id"]
|
||||||
|
]
|
||||||
logger.debug("Adding track to peer connection %s", cb)
|
logger.debug("Adding track to peer connection %s", cb)
|
||||||
pc.addTrack(cb)
|
pc.addTrack(cb)
|
||||||
|
|
||||||
if self.mode == "receive":
|
if self.mode == "receive":
|
||||||
if self.modality == "video":
|
if self.modality == "video":
|
||||||
cb = ServerToClientVideo(cast(Callable, self.event_handler))
|
cb = ServerToClientVideo(
|
||||||
|
cast(Callable, self.event_handler),
|
||||||
|
set_additional_outputs=set_outputs,
|
||||||
|
)
|
||||||
elif self.modality == "audio":
|
elif self.modality == "audio":
|
||||||
cb = ServerToClientAudio(cast(Callable, self.event_handler))
|
cb = ServerToClientAudio(
|
||||||
|
cast(Callable, self.event_handler),
|
||||||
|
set_additional_outputs=set_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug("Adding track to peer connection %s", cb)
|
logger.debug("Adding track to peer connection %s", cb)
|
||||||
pc.addTrack(cb)
|
pc.addTrack(cb)
|
||||||
self.connections[body["webrtc_id"]] = cb
|
self.connections[body["webrtc_id"]] = cb
|
||||||
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
|
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
|
||||||
|
|
||||||
|
@pc.on("datachannel")
|
||||||
|
def on_datachannel(channel):
|
||||||
|
print("data channel established")
|
||||||
|
logger.debug(f"Data channel established: {channel.label}")
|
||||||
|
|
||||||
|
self.data_channels[body["webrtc_id"]] = channel
|
||||||
|
|
||||||
|
async def set_channel(webrtc_id: str):
|
||||||
|
print("webrtc_id", webrtc_id)
|
||||||
|
while not self.connections.get(webrtc_id):
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
print("setting channel")
|
||||||
|
self.connections[webrtc_id].channel = channel
|
||||||
|
|
||||||
|
asyncio.create_task(set_channel(body["webrtc_id"]))
|
||||||
|
|
||||||
|
@channel.on("message")
|
||||||
|
def on_message(message):
|
||||||
|
logger.debug(f"Received message: {message}")
|
||||||
|
if channel.readyState == "open":
|
||||||
|
channel.send(f"Server received: {message}")
|
||||||
|
|
||||||
# handle offer
|
# handle offer
|
||||||
await pc.setRemoteDescription(offer)
|
await pc.setRemoteDescription(offer)
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,10 @@
|
|||||||
export let mode: "send-receive" | "receive" = "send-receive";
|
export let mode: "send-receive" | "receive" = "send-receive";
|
||||||
export let track_constraints: MediaTrackConstraints = {};
|
export let track_constraints: MediaTrackConstraints = {};
|
||||||
|
|
||||||
|
const on_change_cb = () => {
|
||||||
|
gradio.dispatch("state_change");
|
||||||
|
}
|
||||||
|
|
||||||
let dragging = false;
|
let dragging = false;
|
||||||
|
|
||||||
$: console.log("value", value);
|
$: console.log("value", value);
|
||||||
@@ -63,6 +67,7 @@
|
|||||||
{#if mode == "receive" && modality === "video"}
|
{#if mode == "receive" && modality === "video"}
|
||||||
<StaticVideo
|
<StaticVideo
|
||||||
bind:value={value}
|
bind:value={value}
|
||||||
|
{on_change_cb}
|
||||||
{label}
|
{label}
|
||||||
{show_label}
|
{show_label}
|
||||||
{server}
|
{server}
|
||||||
@@ -73,6 +78,7 @@
|
|||||||
{:else if mode == "receive" && modality === "audio"}
|
{:else if mode == "receive" && modality === "audio"}
|
||||||
<StaticAudio
|
<StaticAudio
|
||||||
bind:value={value}
|
bind:value={value}
|
||||||
|
{on_change_cb}
|
||||||
{label}
|
{label}
|
||||||
{show_label}
|
{show_label}
|
||||||
{server}
|
{server}
|
||||||
@@ -91,6 +97,7 @@
|
|||||||
{server}
|
{server}
|
||||||
{rtc_configuration}
|
{rtc_configuration}
|
||||||
{time_limit}
|
{time_limit}
|
||||||
|
{on_change_cb}
|
||||||
on:clear={() => gradio.dispatch("clear")}
|
on:clear={() => gradio.dispatch("clear")}
|
||||||
on:play={() => gradio.dispatch("play")}
|
on:play={() => gradio.dispatch("play")}
|
||||||
on:pause={() => gradio.dispatch("pause")}
|
on:pause={() => gradio.dispatch("pause")}
|
||||||
@@ -109,6 +116,7 @@
|
|||||||
{:else if mode === "send-receive" && modality === "audio"}
|
{:else if mode === "send-receive" && modality === "audio"}
|
||||||
<InteractiveAudio
|
<InteractiveAudio
|
||||||
bind:value={value}
|
bind:value={value}
|
||||||
|
{on_change_cb}
|
||||||
{label}
|
{label}
|
||||||
{show_label}
|
{show_label}
|
||||||
{server}
|
{server}
|
||||||
|
|||||||
@@ -25,9 +25,9 @@
|
|||||||
export let i18n: I18nFormatter;
|
export let i18n: I18nFormatter;
|
||||||
export let time_limit: number | null = null;
|
export let time_limit: number | null = null;
|
||||||
export let track_constraints: MediaTrackConstraints = {};
|
export let track_constraints: MediaTrackConstraints = {};
|
||||||
let _time_limit: number | null = null;
|
export let on_change_cb: () => void;
|
||||||
|
|
||||||
$: console.log("time_limit", time_limit);
|
let _time_limit: number | null = null;
|
||||||
|
|
||||||
export let server: {
|
export let server: {
|
||||||
offer: (body: any) => Promise<any>;
|
offer: (body: any) => Promise<any>;
|
||||||
@@ -41,12 +41,15 @@
|
|||||||
|
|
||||||
const dispatch = createEventDispatcher<{
|
const dispatch = createEventDispatcher<{
|
||||||
tick: undefined;
|
tick: undefined;
|
||||||
|
state_change: undefined;
|
||||||
error: string
|
error: string
|
||||||
play: undefined;
|
play: undefined;
|
||||||
stop: undefined;
|
stop: undefined;
|
||||||
}>();
|
}>();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
onMount(() => {
|
onMount(() => {
|
||||||
window.setInterval(() => {
|
window.setInterval(() => {
|
||||||
if (stream_state == "open") {
|
if (stream_state == "open") {
|
||||||
@@ -103,7 +106,7 @@
|
|||||||
}
|
}
|
||||||
if (stream == null) return;
|
if (stream == null) return;
|
||||||
|
|
||||||
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => {
|
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => {
|
||||||
pc = connection;
|
pc = connection;
|
||||||
}).catch(() => {
|
}).catch(() => {
|
||||||
console.info("catching")
|
console.info("catching")
|
||||||
|
|||||||
@@ -21,6 +21,7 @@
|
|||||||
};
|
};
|
||||||
export let rtc_configuration: Object;
|
export let rtc_configuration: Object;
|
||||||
export let track_constraints: MediaTrackConstraints = {};
|
export let track_constraints: MediaTrackConstraints = {};
|
||||||
|
export let on_change_cb: () => void;
|
||||||
|
|
||||||
const dispatch = createEventDispatcher<{
|
const dispatch = createEventDispatcher<{
|
||||||
change: FileData | null;
|
change: FileData | null;
|
||||||
@@ -50,6 +51,7 @@
|
|||||||
{include_audio}
|
{include_audio}
|
||||||
{time_limit}
|
{time_limit}
|
||||||
{track_constraints}
|
{track_constraints}
|
||||||
|
{on_change_cb}
|
||||||
on:error
|
on:error
|
||||||
on:start_recording
|
on:start_recording
|
||||||
on:stop_recording
|
on:stop_recording
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
export let show_label = true;
|
export let show_label = true;
|
||||||
export let rtc_configuration: Object | null = null;
|
export let rtc_configuration: Object | null = null;
|
||||||
export let i18n: I18nFormatter;
|
export let i18n: I18nFormatter;
|
||||||
|
export let on_change_cb: () => void;
|
||||||
|
|
||||||
export let server: {
|
export let server: {
|
||||||
offer: (body: any) => Promise<any>;
|
offer: (body: any) => Promise<any>;
|
||||||
@@ -68,7 +69,7 @@
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
let stream = null;
|
let stream = null;
|
||||||
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => {
|
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => {
|
||||||
pc = connection;
|
pc = connection;
|
||||||
}).catch(() => {
|
}).catch(() => {
|
||||||
console.info("catching")
|
console.info("catching")
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
export let label: string | undefined = undefined;
|
export let label: string | undefined = undefined;
|
||||||
export let show_label = true;
|
export let show_label = true;
|
||||||
export let rtc_configuration: Object | null = null;
|
export let rtc_configuration: Object | null = null;
|
||||||
|
export let on_change_cb: () => void;
|
||||||
export let server: {
|
export let server: {
|
||||||
offer: (body: any) => Promise<any>;
|
offer: (body: any) => Promise<any>;
|
||||||
};
|
};
|
||||||
@@ -59,7 +60,7 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
start(null, pc, video_element, server.offer, _webrtc_id).then((connection) => {
|
start(null, pc, video_element, server.offer, _webrtc_id, "video", on_change_cb).then((connection) => {
|
||||||
pc = connection;
|
pc = connection;
|
||||||
}).catch(() => {
|
}).catch(() => {
|
||||||
console.log("catching")
|
console.log("catching")
|
||||||
|
|||||||
@@ -24,6 +24,7 @@
|
|||||||
let _time_limit: number | null = null;
|
let _time_limit: number | null = null;
|
||||||
export let time_limit: number | null = null;
|
export let time_limit: number | null = null;
|
||||||
let stream_state: "open" | "waiting" | "closed" = "closed";
|
let stream_state: "open" | "waiting" | "closed" = "closed";
|
||||||
|
export let on_change_cb: () => void;
|
||||||
const _webrtc_id = Math.random().toString(36).substring(2);
|
const _webrtc_id = Math.random().toString(36).substring(2);
|
||||||
|
|
||||||
export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
|
export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
|
||||||
@@ -139,7 +140,7 @@
|
|||||||
)
|
)
|
||||||
stream_state = "waiting"
|
stream_state = "waiting"
|
||||||
webrtc_id = Math.random().toString(36).substring(2);
|
webrtc_id = Math.random().toString(36).substring(2);
|
||||||
start(stream, pc, video_source, server.offer, webrtc_id).then((connection) => {
|
start(stream, pc, video_source, server.offer, webrtc_id, "video", on_change_cb).then((connection) => {
|
||||||
pc = connection;
|
pc = connection;
|
||||||
}).catch(() => {
|
}).catch(() => {
|
||||||
console.info("catching")
|
console.info("catching")
|
||||||
|
|||||||
@@ -44,8 +44,24 @@ export function createPeerConnection(pc, node) {
|
|||||||
return pc;
|
return pc;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function start(stream, pc: RTCPeerConnection, node, server_fn, webrtc_id, modality: "video" | "audio" = "video") {
|
export async function start(stream, pc: RTCPeerConnection, node, server_fn, webrtc_id,
|
||||||
|
modality: "video" | "audio" = "video", on_change_cb: () => void = () => {}) {
|
||||||
pc = createPeerConnection(pc, node);
|
pc = createPeerConnection(pc, node);
|
||||||
|
const data_channel = pc.createDataChannel("text");
|
||||||
|
|
||||||
|
data_channel.onopen = () => {
|
||||||
|
console.debug("Data channel is open");
|
||||||
|
data_channel.send("handshake");
|
||||||
|
};
|
||||||
|
|
||||||
|
data_channel.onmessage = (event) => {
|
||||||
|
console.debug("Received message:", event.data);
|
||||||
|
if (event.data === "change") {
|
||||||
|
console.debug("Change event received");
|
||||||
|
on_change_cb();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if (stream) {
|
if (stream) {
|
||||||
stream.getTracks().forEach((track) => {
|
stream.getTracks().forEach((track) => {
|
||||||
console.debug("Track stream callback", track);
|
console.debug("Track stream callback", track);
|
||||||
|
|||||||
Reference in New Issue
Block a user