mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
code (#48)
This commit is contained in:
@@ -16,10 +16,21 @@ from .utils import (
|
||||
audio_to_file,
|
||||
audio_to_float32,
|
||||
)
|
||||
from .webrtc import AsyncStreamHandler, StreamHandler, WebRTC
|
||||
from .webrtc import (
|
||||
AsyncAudioVideoStreamHandler,
|
||||
AsyncStreamHandler,
|
||||
AudioVideoStreamHandler,
|
||||
StreamHandler,
|
||||
WebRTC,
|
||||
VideoEmitType,
|
||||
AudioEmitType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AsyncStreamHandler",
|
||||
"AudioVideoStreamHandler",
|
||||
"AudioEmitType",
|
||||
"AsyncAudioVideoStreamHandler",
|
||||
"AlgoOptions",
|
||||
"AdditionalOutputs",
|
||||
"aggregate_bytes_to_16bit",
|
||||
@@ -36,6 +47,7 @@ __all__ = [
|
||||
"stt",
|
||||
"stt_for_chunks",
|
||||
"StreamHandler",
|
||||
"VideoEmitType",
|
||||
"WebRTC",
|
||||
"WebRTCError",
|
||||
"Warning",
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any, Callable, Generator, Literal, Union, cast
|
||||
import numpy as np
|
||||
|
||||
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
|
||||
from gradio_webrtc.utils import AdditionalOutputs
|
||||
from gradio_webrtc.webrtc import EmitType, StreamHandler
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@@ -147,16 +147,18 @@ async def player_worker_decode(
|
||||
|
||||
logger.debug(
|
||||
"received array with shape %s sample rate %s layout %s",
|
||||
audio_array.shape,
|
||||
audio_array.shape, # type: ignore
|
||||
sample_rate,
|
||||
layout,
|
||||
layout, # type: ignore
|
||||
)
|
||||
format = "s16" if audio_array.dtype == "int16" else "fltp"
|
||||
format = "s16" if audio_array.dtype == "int16" else "fltp" # type: ignore
|
||||
|
||||
# Convert to audio frame and resample
|
||||
# This runs in the same timeout context
|
||||
frame = av.AudioFrame.from_ndarray( # type: ignore
|
||||
audio_array, format=format, layout=layout
|
||||
audio_array, # type: ignore
|
||||
format=format,
|
||||
layout=layout, # type: ignore
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
|
||||
|
||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -40,6 +42,7 @@ from aiortc.mediastreams import MediaStreamError
|
||||
from gradio import wasm_utils
|
||||
from gradio.components.base import Component, server
|
||||
from gradio_client import handle_file
|
||||
from numpy import typing as npt
|
||||
|
||||
from gradio_webrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
@@ -61,6 +64,11 @@ if wasm_utils.IS_WASM:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VideoEmitType = Union[
|
||||
AdditionalOutputs, tuple[npt.ArrayLike, AdditionalOutputs], npt.ArrayLike, None
|
||||
]
|
||||
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
|
||||
|
||||
|
||||
class VideoCallback(VideoStreamTrack):
|
||||
"""
|
||||
@@ -72,7 +80,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
def __init__(
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: Callable,
|
||||
event_handler: VideoEventHandler,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
mode: Literal["send-receive", "send"] = "send-receive",
|
||||
@@ -86,6 +94,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.thread_quit = asyncio.Event()
|
||||
self.mode = mode
|
||||
self.channel_set = asyncio.Event()
|
||||
self.has_started = False
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
@@ -132,7 +141,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
if current_channel.get() != self.channel:
|
||||
current_channel.set(self.channel)
|
||||
|
||||
async def recv(self):
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
try:
|
||||
frame = cast(VideoFrame, await self.track.recv())
|
||||
@@ -142,7 +151,6 @@ class VideoCallback(VideoStreamTrack):
|
||||
|
||||
await self.wait_for_channel()
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
|
||||
if self.latest_args == "not_set":
|
||||
return frame
|
||||
|
||||
@@ -253,6 +261,7 @@ EmitType: TypeAlias = Union[
|
||||
tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
||||
None,
|
||||
]
|
||||
AudioEmitType = EmitType
|
||||
|
||||
|
||||
class StreamHandler(StreamHandlerBase):
|
||||
@@ -282,19 +291,104 @@ class AsyncStreamHandler(StreamHandlerBase):
|
||||
StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler]
|
||||
|
||||
|
||||
class AudioVideoStreamHandler(StreamHandlerBase):
|
||||
@abstractmethod
|
||||
def video_receive(self, frame: npt.NDArray) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def video_emit(
|
||||
self,
|
||||
) -> VideoEmitType:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncAudioVideoStreamHandler(StreamHandlerBase):
|
||||
@abstractmethod
|
||||
async def video_receive(self, frame: npt.NDArray) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def video_emit(
|
||||
self,
|
||||
) -> VideoEmitType:
|
||||
pass
|
||||
|
||||
|
||||
VideoStreamHandlerImpl = Union[AudioVideoStreamHandler, AsyncAudioVideoStreamHandler]
|
||||
AudioVideoStreamHandlerImpl = Union[
|
||||
AudioVideoStreamHandler, AsyncAudioVideoStreamHandler
|
||||
]
|
||||
AsyncHandler = Union[AsyncStreamHandler, AsyncAudioVideoStreamHandler]
|
||||
|
||||
|
||||
class VideoStreamHander(VideoCallback):
|
||||
async def process_frames(self):
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
await self.channel_set.wait()
|
||||
frame = cast(VideoFrame, await self.track.recv())
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
handler = cast(VideoStreamHandlerImpl, self.event_handler)
|
||||
if inspect.iscoroutinefunction(handler.video_receive):
|
||||
await handler.video_receive(frame_array)
|
||||
else:
|
||||
handler.video_receive(frame_array)
|
||||
except MediaStreamError:
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
if not self.has_started:
|
||||
asyncio.create_task(self.process_frames())
|
||||
self.has_started = True
|
||||
|
||||
async def recv(self): # type: ignore
|
||||
self.start()
|
||||
try:
|
||||
handler = cast(VideoStreamHandlerImpl, self.event_handler)
|
||||
if inspect.iscoroutinefunction(handler.video_emit):
|
||||
outputs = await handler.video_emit()
|
||||
else:
|
||||
outputs = handler.video_emit()
|
||||
|
||||
array, outputs = split_output(outputs)
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
and self.set_additional_outputs
|
||||
and self.channel
|
||||
):
|
||||
self.set_additional_outputs(outputs)
|
||||
self.channel.send("change")
|
||||
if array is None and self.mode == "send":
|
||||
return
|
||||
|
||||
new_frame = self.array_to_frame(array)
|
||||
|
||||
# Will probably have to give developer ability to set pts and time_base
|
||||
pts, time_base = await self.next_timestamp()
|
||||
new_frame.pts = pts
|
||||
new_frame.time_base = time_base
|
||||
|
||||
return new_frame
|
||||
except Exception as e:
|
||||
logger.debug("exception %s", e)
|
||||
exec = traceback.format_exc()
|
||||
logger.debug("traceback %s", exec)
|
||||
|
||||
|
||||
class AudioCallback(AudioStreamTrack):
|
||||
kind = "audio"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: StreamHandlerImpl,
|
||||
event_handler: StreamHandlerBase,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
self.event_handler = cast(StreamHandlerImpl, event_handler)
|
||||
self.current_timestamp = 0
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.queue = asyncio.Queue()
|
||||
@@ -322,7 +416,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
frame = cast(AudioFrame, await self.track.recv())
|
||||
for frame in self.event_handler.resample(frame):
|
||||
numpy_array = frame.to_ndarray()
|
||||
if isinstance(self.event_handler, AsyncStreamHandler):
|
||||
if isinstance(self.event_handler, AsyncHandler):
|
||||
await self.event_handler.receive(
|
||||
(frame.sample_rate, numpy_array)
|
||||
)
|
||||
@@ -337,7 +431,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
def start(self):
|
||||
if not self.has_started:
|
||||
loop = asyncio.get_running_loop()
|
||||
if isinstance(self.event_handler, AsyncStreamHandler):
|
||||
if isinstance(self.event_handler, AsyncHandler):
|
||||
callable = self.event_handler.emit
|
||||
else:
|
||||
callable = functools.partial(
|
||||
@@ -358,7 +452,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
)
|
||||
self.has_started = True
|
||||
|
||||
async def recv(self):
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
@@ -383,7 +477,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
|
||||
# control playback rate
|
||||
if self._start is None:
|
||||
self._start = time.time() - data_time
|
||||
self._start = time.time() - data_time # type: ignore
|
||||
else:
|
||||
wait = self._start + data_time - time.time()
|
||||
await asyncio.sleep(wait)
|
||||
@@ -434,7 +528,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.latest_args = list(args)
|
||||
self.args_set.set()
|
||||
|
||||
async def recv(self):
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
pts, time_base = await self.next_timestamp()
|
||||
await self.args_set.wait()
|
||||
@@ -523,7 +617,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
)
|
||||
self.has_started = True
|
||||
|
||||
async def recv(self):
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
@@ -539,7 +633,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
# control playback rate
|
||||
if data_time is not None:
|
||||
if self._start is None:
|
||||
self._start = time.time() - data_time
|
||||
self._start = time.time() - data_time # type: ignore
|
||||
else:
|
||||
wait = self._start + data_time - time.time()
|
||||
await asyncio.sleep(wait)
|
||||
@@ -576,10 +670,12 @@ class WebRTC(Component):
|
||||
pcs: set[RTCPeerConnection] = set([])
|
||||
relay = MediaRelay()
|
||||
connections: dict[
|
||||
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
|
||||
] = {}
|
||||
str,
|
||||
list[VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback],
|
||||
] = defaultdict(list)
|
||||
data_channels: dict[str, DataChannel] = {}
|
||||
additional_outputs: dict[str, list[AdditionalOutputs]] = {}
|
||||
handlers: dict[str, StreamHandlerBase | Callable] = {}
|
||||
|
||||
EVENTS = ["tick", "state_change"]
|
||||
|
||||
@@ -606,7 +702,7 @@ class WebRTC(Component):
|
||||
track_constraints: dict[str, Any] | None = None,
|
||||
time_limit: float | None = None,
|
||||
mode: Literal["send-receive", "receive", "send"] = "send-receive",
|
||||
modality: Literal["video", "audio"] = "video",
|
||||
modality: Literal["video", "audio", "audio-video"] = "video",
|
||||
rtp_params: dict[str, Any] | None = None,
|
||||
icon: str | None = None,
|
||||
icon_button_color: str | None = None,
|
||||
@@ -669,6 +765,23 @@ class WebRTC(Component):
|
||||
"height": {"ideal": 500},
|
||||
"frameRate": {"ideal": 30},
|
||||
}
|
||||
if track_constraints is None and modality == "audio-video":
|
||||
track_constraints = {
|
||||
"video": {
|
||||
"facingMode": "user",
|
||||
"width": {"ideal": 500},
|
||||
"height": {"ideal": 500},
|
||||
"frameRate": {"ideal": 30},
|
||||
},
|
||||
"audio": {
|
||||
"echoCancellation": True,
|
||||
"noiseSuppression": {"exact": True},
|
||||
"autoGainControl": {"exact": True},
|
||||
"sampleRate": {"ideal": 24000},
|
||||
"sampleSize": {"ideal": 16},
|
||||
"channelCount": {"exact": 1},
|
||||
},
|
||||
}
|
||||
self.track_constraints = track_constraints
|
||||
self.event_handler: Callable | StreamHandler | None = None
|
||||
super().__init__(
|
||||
@@ -722,7 +835,8 @@ class WebRTC(Component):
|
||||
|
||||
def set_input(self, webrtc_id: str, *args):
|
||||
if webrtc_id in self.connections:
|
||||
self.connections[webrtc_id].set_args(list(args))
|
||||
for conn in self.connections[webrtc_id]:
|
||||
conn.set_args(list(args))
|
||||
|
||||
def on_additional_outputs(
|
||||
self,
|
||||
@@ -767,7 +881,10 @@ class WebRTC(Component):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
fn: Callable[..., Any] | StreamHandler | AsyncStreamHandler | None = None,
|
||||
fn: Callable[..., Any]
|
||||
| StreamHandlerImpl
|
||||
| AudioVideoStreamHandlerImpl
|
||||
| None = None,
|
||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
js: str | None = None,
|
||||
@@ -790,16 +907,16 @@ class WebRTC(Component):
|
||||
self.concurrency_limit = (
|
||||
1 if concurrency_limit in ["default", None] else concurrency_limit
|
||||
)
|
||||
self.event_handler = fn
|
||||
self.event_handler = fn # type: ignore
|
||||
self.time_limit = time_limit
|
||||
|
||||
if (
|
||||
self.mode == "send-receive"
|
||||
and self.modality == "audio"
|
||||
and not isinstance(self.event_handler, (AsyncStreamHandler, StreamHandler))
|
||||
and self.modality in ["audio", "audio-video"]
|
||||
and not isinstance(self.event_handler, StreamHandlerBase)
|
||||
):
|
||||
raise ValueError(
|
||||
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
|
||||
"In the send-receive mode for audio, the event handler must be an instance of StreamHandlerBase."
|
||||
)
|
||||
|
||||
if self.mode == "send-receive" or self.mode == "send":
|
||||
@@ -815,13 +932,23 @@ class WebRTC(Component):
|
||||
raise ValueError(
|
||||
"In the webrtc stream event, the only output component must be the WebRTC component."
|
||||
)
|
||||
for input_component in inputs[1:]: # type: ignore
|
||||
if hasattr(input_component, "change"):
|
||||
input_component.change( # type: ignore
|
||||
self.set_input,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
concurrency_limit=None,
|
||||
time_limit=None,
|
||||
js=js,
|
||||
)
|
||||
return self.tick( # type: ignore
|
||||
self.set_input,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
concurrency_limit=None,
|
||||
stream_every=0.5,
|
||||
time_limit=None,
|
||||
js=js,
|
||||
)
|
||||
@@ -855,9 +982,11 @@ class WebRTC(Component):
|
||||
await pc.close()
|
||||
|
||||
def clean_up(self, webrtc_id: str):
|
||||
connection = self.connections.pop(webrtc_id, None)
|
||||
if isinstance(connection, AudioCallback):
|
||||
connection.event_handler.shutdown()
|
||||
self.handlers.pop(webrtc_id, None)
|
||||
connection = self.connections.pop(webrtc_id, [])
|
||||
for conn in connection:
|
||||
if isinstance(conn, AudioCallback):
|
||||
conn.event_handler.shutdown()
|
||||
self.additional_outputs.pop(webrtc_id, None)
|
||||
self.data_channels.pop(webrtc_id, None)
|
||||
return connection
|
||||
@@ -874,6 +1003,13 @@ class WebRTC(Component):
|
||||
pc = RTCPeerConnection()
|
||||
self.pcs.add(pc)
|
||||
|
||||
if isinstance(self.event_handler, StreamHandlerBase):
|
||||
handler = self.event_handler.copy()
|
||||
else:
|
||||
handler = cast(Callable, self.event_handler)
|
||||
|
||||
self.handlers[body["webrtc_id"]] = handler
|
||||
|
||||
set_outputs = self.set_additional_outputs(body["webrtc_id"])
|
||||
|
||||
@pc.on("iceconnectionstatechange")
|
||||
@@ -891,7 +1027,8 @@ class WebRTC(Component):
|
||||
await pc.close()
|
||||
connection = self.clean_up(body["webrtc_id"])
|
||||
if connection:
|
||||
connection.stop()
|
||||
for conn in connection:
|
||||
conn.stop()
|
||||
self.pcs.discard(pc)
|
||||
if pc.connectionState == "connected":
|
||||
if self.time_limit is not None:
|
||||
@@ -900,28 +1037,38 @@ class WebRTC(Component):
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
relay = MediaRelay()
|
||||
if self.modality == "video":
|
||||
handler = self.handlers[body["webrtc_id"]]
|
||||
|
||||
if self.modality == "video" and track.kind == "video":
|
||||
cb = VideoCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(Callable, self.event_handler),
|
||||
event_handler=cast(VideoEventHandler, handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
mode=cast(Literal["send", "send-receive"], self.mode),
|
||||
)
|
||||
elif self.modality == "audio":
|
||||
handler = cast(StreamHandler, self.event_handler).copy()
|
||||
handler._loop = asyncio.get_running_loop()
|
||||
elif self.modality == "audio-video" and track.kind == "video":
|
||||
cb = VideoStreamHander(
|
||||
relay.subscribe(track),
|
||||
event_handler=handler, # type: ignore
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
|
||||
eh = cast(StreamHandlerImpl, handler)
|
||||
eh._loop = asyncio.get_running_loop()
|
||||
cb = AudioCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=handler,
|
||||
event_handler=eh,
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Modality must be either video or audio")
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
raise ValueError("Modality must be either video, audio, or audio-video")
|
||||
if body["webrtc_id"] not in self.connections:
|
||||
self.connections[body["webrtc_id"]] = []
|
||||
|
||||
self.connections[body["webrtc_id"]].append(cb)
|
||||
if body["webrtc_id"] in self.data_channels:
|
||||
self.connections[body["webrtc_id"]].set_channel(
|
||||
self.data_channels[body["webrtc_id"]]
|
||||
)
|
||||
for conn in self.connections[body["webrtc_id"]]:
|
||||
conn.set_channel(self.data_channels[body["webrtc_id"]])
|
||||
if self.mode == "send-receive":
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
@@ -944,7 +1091,7 @@ class WebRTC(Component):
|
||||
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
self.connections[body["webrtc_id"]].append(cb)
|
||||
cb.on("ended", lambda: self.clean_up(body["webrtc_id"]))
|
||||
|
||||
@pc.on("datachannel")
|
||||
@@ -957,7 +1104,8 @@ class WebRTC(Component):
|
||||
while not self.connections.get(webrtc_id):
|
||||
await asyncio.sleep(0.05)
|
||||
logger.debug("setting channel for webrtc id %s", webrtc_id)
|
||||
self.connections[webrtc_id].set_channel(channel)
|
||||
for conn in self.connections[webrtc_id]:
|
||||
conn.set_channel(channel)
|
||||
|
||||
asyncio.create_task(set_channel(body["webrtc_id"]))
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
export let gradio;
|
||||
export let rtc_configuration: Object;
|
||||
export let time_limit: number | null = null;
|
||||
export let modality: "video" | "audio" = "video";
|
||||
export let modality: "video" | "audio" | "audio-video" = "video";
|
||||
export let mode: "send-receive" | "receive" | "send" = "send-receive";
|
||||
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
||||
export let track_constraints: MediaTrackConstraints = {};
|
||||
@@ -52,18 +52,18 @@
|
||||
</script>
|
||||
|
||||
<Block
|
||||
{visible}
|
||||
variant={"solid"}
|
||||
border_mode={dragging ? "focus" : "base"}
|
||||
padding={false}
|
||||
{elem_id}
|
||||
{elem_classes}
|
||||
{height}
|
||||
{width}
|
||||
{container}
|
||||
{scale}
|
||||
{min_width}
|
||||
allow_overflow={false}
|
||||
{visible}
|
||||
variant={"solid"}
|
||||
border_mode={dragging ? "focus" : "base"}
|
||||
padding={false}
|
||||
{elem_id}
|
||||
{elem_classes}
|
||||
{height}
|
||||
{width}
|
||||
{container}
|
||||
{scale}
|
||||
{min_width}
|
||||
allow_overflow={false}
|
||||
>
|
||||
<StatusTracker
|
||||
autoscroll={gradio.autoscroll}
|
||||
@@ -99,13 +99,13 @@
|
||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||
|
||||
/>
|
||||
{:else if (mode === "send-receive" || mode == "send") && modality === "video"}
|
||||
{:else if (mode === "send-receive" || mode == "send") && (modality === "video" || modality == "audio-video")}
|
||||
<Video
|
||||
bind:value={value}
|
||||
{label}
|
||||
{show_label}
|
||||
active_source={"webcam"}
|
||||
include_audio={false}
|
||||
include_audio={modality === "audio-video"}
|
||||
{server}
|
||||
{rtc_configuration}
|
||||
{time_limit}
|
||||
@@ -113,6 +113,9 @@
|
||||
{track_constraints}
|
||||
{rtp_params}
|
||||
{on_change_cb}
|
||||
{icon}
|
||||
{icon_button_color}
|
||||
{pulse_color}
|
||||
on:clear={() => gradio.dispatch("clear")}
|
||||
on:play={() => gradio.dispatch("play")}
|
||||
on:pause={() => gradio.dispatch("pause")}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
<script lang="ts">
|
||||
import { onDestroy } from 'svelte';
|
||||
import type {ComponentType} from 'svelte';
|
||||
|
||||
import PulsingIcon from './PulsingIcon.svelte';
|
||||
|
||||
export let numBars = 16;
|
||||
export let stream_state: "open" | "closed" | "waiting" = "closed";
|
||||
export let audio_source_callback: () => MediaStream;
|
||||
export let icon: string | undefined = undefined;
|
||||
export let icon: string | undefined | ComponentType = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
|
||||
@@ -13,7 +16,6 @@
|
||||
let dataArray: Uint8Array;
|
||||
let animationId: number;
|
||||
let pulseScale = 1;
|
||||
let pulseIntensity = 0;
|
||||
|
||||
$: containerWidth = icon
|
||||
? "128px"
|
||||
@@ -47,53 +49,31 @@
|
||||
function updateVisualization() {
|
||||
analyser.getByteFrequencyData(dataArray);
|
||||
|
||||
if (icon) {
|
||||
// Calculate average amplitude for pulse effect
|
||||
const average = Array.from(dataArray).reduce((a, b) => a + b, 0) / dataArray.length;
|
||||
const normalizedAverage = average / 255;
|
||||
pulseScale = 1 + (normalizedAverage * 0.15);
|
||||
pulseIntensity = normalizedAverage;
|
||||
} else {
|
||||
// Update bars
|
||||
const bars = document.querySelectorAll('.gradio-webrtc-waveContainer .gradio-webrtc-box');
|
||||
for (let i = 0; i < bars.length; i++) {
|
||||
const barHeight = (dataArray[i] / 255) * 2;
|
||||
bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`;
|
||||
}
|
||||
}
|
||||
|
||||
animationId = requestAnimationFrame(updateVisualization);
|
||||
}
|
||||
|
||||
$: maxPulseScale = 1 + (pulseIntensity * 10); // Scale from 1x to 3x based on intensity
|
||||
|
||||
</script>
|
||||
|
||||
<div class="gradio-webrtc-waveContainer">
|
||||
{#if icon}
|
||||
<div class="gradio-webrtc-icon-container">
|
||||
{#if pulseIntensity > 0}
|
||||
{#each Array(3) as _, i}
|
||||
<div
|
||||
class="pulse-ring"
|
||||
style:background={pulse_color}
|
||||
style:animation-delay={`${i * 0.4}s`}
|
||||
style:--max-scale={maxPulseScale}
|
||||
style:opacity={0.5 * pulseIntensity}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
|
||||
<div
|
||||
class="gradio-webrtc-icon"
|
||||
style:transform={`scale(${pulseScale})`}
|
||||
style:background={icon_button_color}
|
||||
>
|
||||
<img
|
||||
src={icon}
|
||||
alt="Audio visualization icon"
|
||||
class="icon-image"
|
||||
/>
|
||||
<PulsingIcon
|
||||
{stream_state}
|
||||
{pulse_color}
|
||||
{icon}
|
||||
{icon_button_color}
|
||||
{audio_source_callback}/>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import { createEventDispatcher } from "svelte";
|
||||
import type { ComponentType } from "svelte";
|
||||
import type { FileData, Client } from "@gradio/client";
|
||||
import { BlockLabel } from "@gradio/atoms";
|
||||
import Webcam from "./Webcam.svelte";
|
||||
@@ -24,6 +25,9 @@
|
||||
export let mode: "send" | "send-receive";
|
||||
export let on_change_cb: (msg: "change" | "tick") => void;
|
||||
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
||||
export let icon: string | undefined | ComponentType = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
|
||||
const dispatch = createEventDispatcher<{
|
||||
change: FileData | null;
|
||||
@@ -56,6 +60,9 @@
|
||||
{mode}
|
||||
{rtp_params}
|
||||
{on_change_cb}
|
||||
{icon}
|
||||
{icon_button_color}
|
||||
{pulse_color}
|
||||
on:error
|
||||
on:start_recording
|
||||
on:stop_recording
|
||||
|
||||
151
frontend/shared/PulsingIcon.svelte
Normal file
151
frontend/shared/PulsingIcon.svelte
Normal file
@@ -0,0 +1,151 @@
|
||||
<script lang="ts">
|
||||
import { onDestroy } from 'svelte';
|
||||
import type {ComponentType} from 'svelte';
|
||||
|
||||
export let stream_state: "open" | "closed" | "waiting" = "closed";
|
||||
export let audio_source_callback: () => MediaStream;
|
||||
export let icon: string | ComponentType = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
|
||||
let audioContext: AudioContext;
|
||||
let analyser: AnalyserNode;
|
||||
let dataArray: Uint8Array;
|
||||
let animationId: number;
|
||||
let pulseScale = 1;
|
||||
let pulseIntensity = 0;
|
||||
|
||||
$: if(stream_state === "open") setupAudioContext();
|
||||
|
||||
onDestroy(() => {
|
||||
if (animationId) {
|
||||
cancelAnimationFrame(animationId);
|
||||
}
|
||||
if (audioContext) {
|
||||
audioContext.close();
|
||||
}
|
||||
});
|
||||
|
||||
function setupAudioContext() {
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
analyser = audioContext.createAnalyser();
|
||||
const source = audioContext.createMediaStreamSource(audio_source_callback());
|
||||
|
||||
source.connect(analyser);
|
||||
|
||||
analyser.fftSize = 64;
|
||||
analyser.smoothingTimeConstant = 0.8;
|
||||
dataArray = new Uint8Array(analyser.frequencyBinCount);
|
||||
|
||||
updateVisualization();
|
||||
}
|
||||
|
||||
function updateVisualization() {
|
||||
|
||||
analyser.getByteFrequencyData(dataArray);
|
||||
|
||||
// Calculate average amplitude for pulse effect
|
||||
const average = Array.from(dataArray).reduce((a, b) => a + b, 0) / dataArray.length;
|
||||
const normalizedAverage = average / 255;
|
||||
pulseScale = 1 + (normalizedAverage * 0.15);
|
||||
pulseIntensity = normalizedAverage;
|
||||
animationId = requestAnimationFrame(updateVisualization);
|
||||
|
||||
}
|
||||
|
||||
$: maxPulseScale = 1 + (pulseIntensity * 10); // Scale from 1x to 3x based on intensity
|
||||
|
||||
|
||||
</script>
|
||||
|
||||
<div class="gradio-webrtc-icon-wrapper">
|
||||
<div class="gradio-webrtc-pulsing-icon-container">
|
||||
{#if pulseIntensity > 0}
|
||||
{#each Array(3) as _, i}
|
||||
<div
|
||||
class="pulse-ring"
|
||||
style:background={pulse_color}
|
||||
style:animation-delay={`${i * 0.4}s`}
|
||||
style:--max-scale={maxPulseScale}
|
||||
style:opacity={0.5 * pulseIntensity}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
|
||||
<div
|
||||
class="gradio-webrtc-pulsing-icon"
|
||||
style:transform={`scale(${pulseScale})`}
|
||||
style:background={icon_button_color}
|
||||
>
|
||||
{#if typeof icon === "string"}
|
||||
<img
|
||||
src={icon}
|
||||
alt="Audio visualization icon"
|
||||
class="icon-image"
|
||||
/>
|
||||
{:else}
|
||||
<svelte:component this={icon} />
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.gradio-webrtc-icon-wrapper {
|
||||
position: relative;
|
||||
display: flex;
|
||||
max-height: 128px;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.gradio-webrtc-pulsing-icon-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.gradio-webrtc-pulsing-icon {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
border-radius: 50%;
|
||||
transition: transform 0.1s ease;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
z-index: 2;
|
||||
}
|
||||
|
||||
.icon-image {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: contain;
|
||||
filter: brightness(0) invert(1);
|
||||
}
|
||||
|
||||
.pulse-ring {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
border-radius: 50%;
|
||||
animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% {
|
||||
transform: translate(-50%, -50%) scale(1);
|
||||
opacity: 0.5;
|
||||
}
|
||||
100% {
|
||||
transform: translate(-50%, -50%) scale(var(--max-scale, 3));
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
||||
@@ -98,7 +98,7 @@
|
||||
/>
|
||||
<audio
|
||||
class="standard-player"
|
||||
class:hidden={value === "__webrtc_value__"}
|
||||
class:hidden={true}
|
||||
on:load
|
||||
bind:this={audio_player}
|
||||
on:ended={() => dispatch("stop")}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
<script lang="ts">
|
||||
import { createEventDispatcher, onMount } from "svelte";
|
||||
import type { ComponentType } from "svelte";
|
||||
import {
|
||||
Circle,
|
||||
Square,
|
||||
DropdownArrow,
|
||||
Spinner
|
||||
Spinner,
|
||||
Microphone as Mic
|
||||
} from "@gradio/icons";
|
||||
import type { I18nFormatter } from "@gradio/utils";
|
||||
import { StreamingBar } from "@gradio/statustracker";
|
||||
@@ -15,8 +17,8 @@
|
||||
get_video_stream,
|
||||
set_available_devices
|
||||
} from "./stream_utils";
|
||||
|
||||
import { start, stop } from "./webrtc_utils";
|
||||
import PulsingIcon from "./PulsingIcon.svelte";
|
||||
|
||||
let video_source: HTMLVideoElement;
|
||||
let available_video_devices: MediaDeviceInfo[] = [];
|
||||
@@ -28,6 +30,9 @@
|
||||
export let mode: "send-receive" | "send";
|
||||
const _webrtc_id = Math.random().toString(36).substring(2);
|
||||
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
||||
export let icon: string | undefined | ComponentType = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
|
||||
export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
|
||||
state: "open" | "closed" | "waiting"
|
||||
@@ -156,14 +161,13 @@
|
||||
_time_limit = null;
|
||||
await access_webcam();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
window.setInterval(() => {
|
||||
if (stream_state == "open") {
|
||||
dispatch("tick");
|
||||
}
|
||||
}, stream_every * 1000);
|
||||
// window.setInterval(() => {
|
||||
// if (stream_state == "open") {
|
||||
// dispatch("tick");
|
||||
// }
|
||||
// }, stream_every * 1000);
|
||||
|
||||
let options_open = false;
|
||||
|
||||
@@ -192,16 +196,29 @@
|
||||
event.stopPropagation();
|
||||
options_open = false;
|
||||
}
|
||||
|
||||
const audio_source_callback = () => video_source.srcObject as MediaStream;
|
||||
</script>
|
||||
|
||||
<div class="wrap">
|
||||
<StreamingBar time_limit={_time_limit} />
|
||||
{#if stream_state === "open" && include_audio}
|
||||
<div class="audio-indicator">
|
||||
<PulsingIcon
|
||||
stream_state={stream_state}
|
||||
audio_source_callback={audio_source_callback}
|
||||
icon={icon || Mic}
|
||||
icon_button_color={icon_button_color}
|
||||
pulse_color={pulse_color}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
<!-- svelte-ignore a11y-media-has-caption -->
|
||||
<!-- need to suppress for video streaming https://github.com/sveltejs/svelte/issues/5967 -->
|
||||
<video
|
||||
bind:this={video_source}
|
||||
class:hide={!webcam_accessed}
|
||||
class:flip={(stream_state != "open")}
|
||||
class:flip={(stream_state != "open") || (stream_state === "open" && include_audio)}
|
||||
autoplay={true}
|
||||
playsinline={true}
|
||||
/>
|
||||
@@ -324,6 +341,15 @@
|
||||
justify-content: space-evenly;
|
||||
}
|
||||
|
||||
.audio-indicator {
|
||||
position: absolute;
|
||||
top: var(--size-2);
|
||||
right: var(--size-2);
|
||||
z-index: var(--layer-2);
|
||||
height: var(--size-5);
|
||||
width: var(--size-5);
|
||||
}
|
||||
|
||||
@media (--screen-md) {
|
||||
button {
|
||||
bottom: var(--size-4);
|
||||
|
||||
@@ -68,14 +68,14 @@ export async function start(
|
||||
try {
|
||||
event_json = JSON.parse(event.data);
|
||||
} catch (e) {
|
||||
console.debug("Error parsing JSON")
|
||||
console.debug("Error parsing JSON");
|
||||
}
|
||||
console.log("event_json", event_json);
|
||||
if (
|
||||
event.data === "change" ||
|
||||
event.data === "tick" ||
|
||||
event.data === "stopword" ||
|
||||
event_json?.type === "warning" ||
|
||||
event_json?.type === "warning" ||
|
||||
event_json?.type === "error"
|
||||
) {
|
||||
console.debug(`${event.data} event received`);
|
||||
|
||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "gradio_webrtc"
|
||||
version = "0.0.27"
|
||||
version = "0.0.28"
|
||||
description = "Stream images in realtime with webrtc"
|
||||
readme = "README.md"
|
||||
license = "apache-2.0"
|
||||
|
||||
Reference in New Issue
Block a user