This commit is contained in:
Freddy Boulton
2025-01-10 17:14:47 -05:00
committed by GitHub
parent b64e019323
commit 4d16634307
12 changed files with 431 additions and 103 deletions

View File

@@ -16,10 +16,21 @@ from .utils import (
audio_to_file, audio_to_file,
audio_to_float32, audio_to_float32,
) )
from .webrtc import AsyncStreamHandler, StreamHandler, WebRTC from .webrtc import (
AsyncAudioVideoStreamHandler,
AsyncStreamHandler,
AudioVideoStreamHandler,
StreamHandler,
WebRTC,
VideoEmitType,
AudioEmitType,
)
__all__ = [ __all__ = [
"AsyncStreamHandler", "AsyncStreamHandler",
"AudioVideoStreamHandler",
"AudioEmitType",
"AsyncAudioVideoStreamHandler",
"AlgoOptions", "AlgoOptions",
"AdditionalOutputs", "AdditionalOutputs",
"aggregate_bytes_to_16bit", "aggregate_bytes_to_16bit",
@@ -36,6 +47,7 @@ __all__ = [
"stt", "stt",
"stt_for_chunks", "stt_for_chunks",
"StreamHandler", "StreamHandler",
"VideoEmitType",
"WebRTC", "WebRTC",
"WebRTCError", "WebRTCError",
"Warning", "Warning",

View File

@@ -9,7 +9,6 @@ from typing import Any, Callable, Generator, Literal, Union, cast
import numpy as np import numpy as np
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
from gradio_webrtc.utils import AdditionalOutputs
from gradio_webrtc.webrtc import EmitType, StreamHandler from gradio_webrtc.webrtc import EmitType, StreamHandler
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@@ -147,16 +147,18 @@ async def player_worker_decode(
logger.debug( logger.debug(
"received array with shape %s sample rate %s layout %s", "received array with shape %s sample rate %s layout %s",
audio_array.shape, audio_array.shape, # type: ignore
sample_rate, sample_rate,
layout, layout, # type: ignore
) )
format = "s16" if audio_array.dtype == "int16" else "fltp" format = "s16" if audio_array.dtype == "int16" else "fltp" # type: ignore
# Convert to audio frame and resample # Convert to audio frame and resample
# This runs in the same timeout context # This runs in the same timeout context
frame = av.AudioFrame.from_ndarray( # type: ignore frame = av.AudioFrame.from_ndarray( # type: ignore
audio_array, format=format, layout=layout audio_array, # type: ignore
format=format,
layout=layout, # type: ignore
) )
frame.sample_rate = sample_rate frame.sample_rate = sample_rate

View File

@@ -4,11 +4,13 @@ from __future__ import annotations
import asyncio import asyncio
import functools import functools
import inspect
import logging import logging
import threading import threading
import time import time
import traceback import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@@ -40,6 +42,7 @@ from aiortc.mediastreams import MediaStreamError
from gradio import wasm_utils from gradio import wasm_utils
from gradio.components.base import Component, server from gradio.components.base import Component, server
from gradio_client import handle_file from gradio_client import handle_file
from numpy import typing as npt
from gradio_webrtc.utils import ( from gradio_webrtc.utils import (
AdditionalOutputs, AdditionalOutputs,
@@ -61,6 +64,11 @@ if wasm_utils.IS_WASM:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VideoEmitType = Union[
AdditionalOutputs, tuple[npt.ArrayLike, AdditionalOutputs], npt.ArrayLike, None
]
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
class VideoCallback(VideoStreamTrack): class VideoCallback(VideoStreamTrack):
""" """
@@ -72,7 +80,7 @@ class VideoCallback(VideoStreamTrack):
def __init__( def __init__(
self, self,
track: MediaStreamTrack, track: MediaStreamTrack,
event_handler: Callable, event_handler: VideoEventHandler,
channel: DataChannel | None = None, channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None, set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive", mode: Literal["send-receive", "send"] = "send-receive",
@@ -86,6 +94,7 @@ class VideoCallback(VideoStreamTrack):
self.thread_quit = asyncio.Event() self.thread_quit = asyncio.Event()
self.mode = mode self.mode = mode
self.channel_set = asyncio.Event() self.channel_set = asyncio.Event()
self.has_started = False
def set_channel(self, channel: DataChannel): def set_channel(self, channel: DataChannel):
self.channel = channel self.channel = channel
@@ -132,7 +141,7 @@ class VideoCallback(VideoStreamTrack):
if current_channel.get() != self.channel: if current_channel.get() != self.channel:
current_channel.set(self.channel) current_channel.set(self.channel)
async def recv(self): async def recv(self): # type: ignore
try: try:
try: try:
frame = cast(VideoFrame, await self.track.recv()) frame = cast(VideoFrame, await self.track.recv())
@@ -142,7 +151,6 @@ class VideoCallback(VideoStreamTrack):
await self.wait_for_channel() await self.wait_for_channel()
frame_array = frame.to_ndarray(format="bgr24") frame_array = frame.to_ndarray(format="bgr24")
if self.latest_args == "not_set": if self.latest_args == "not_set":
return frame return frame
@@ -253,6 +261,7 @@ EmitType: TypeAlias = Union[
tuple[tuple[int, np.ndarray], AdditionalOutputs], tuple[tuple[int, np.ndarray], AdditionalOutputs],
None, None,
] ]
AudioEmitType = EmitType
class StreamHandler(StreamHandlerBase): class StreamHandler(StreamHandlerBase):
@@ -282,19 +291,104 @@ class AsyncStreamHandler(StreamHandlerBase):
StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler] StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler]
class AudioVideoStreamHandler(StreamHandlerBase):
@abstractmethod
def video_receive(self, frame: npt.NDArray) -> None:
pass
@abstractmethod
def video_emit(
self,
) -> VideoEmitType:
pass
class AsyncAudioVideoStreamHandler(StreamHandlerBase):
@abstractmethod
async def video_receive(self, frame: npt.NDArray) -> None:
pass
@abstractmethod
async def video_emit(
self,
) -> VideoEmitType:
pass
VideoStreamHandlerImpl = Union[AudioVideoStreamHandler, AsyncAudioVideoStreamHandler]
AudioVideoStreamHandlerImpl = Union[
AudioVideoStreamHandler, AsyncAudioVideoStreamHandler
]
AsyncHandler = Union[AsyncStreamHandler, AsyncAudioVideoStreamHandler]
class VideoStreamHander(VideoCallback):
async def process_frames(self):
while not self.thread_quit.is_set():
try:
await self.channel_set.wait()
frame = cast(VideoFrame, await self.track.recv())
frame_array = frame.to_ndarray(format="bgr24")
handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_receive):
await handler.video_receive(frame_array)
else:
handler.video_receive(frame_array)
except MediaStreamError:
self.stop()
def start(self):
if not self.has_started:
asyncio.create_task(self.process_frames())
self.has_started = True
async def recv(self): # type: ignore
self.start()
try:
handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_emit):
outputs = await handler.video_emit()
else:
outputs = handler.video_emit()
array, outputs = split_output(outputs)
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send("change")
if array is None and self.mode == "send":
return
new_frame = self.array_to_frame(array)
# Will probably have to give developer ability to set pts and time_base
pts, time_base = await self.next_timestamp()
new_frame.pts = pts
new_frame.time_base = time_base
return new_frame
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
class AudioCallback(AudioStreamTrack): class AudioCallback(AudioStreamTrack):
kind = "audio" kind = "audio"
def __init__( def __init__(
self, self,
track: MediaStreamTrack, track: MediaStreamTrack,
event_handler: StreamHandlerImpl, event_handler: StreamHandlerBase,
channel: DataChannel | None = None, channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None, set_additional_outputs: Callable | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.track = track self.track = track
self.event_handler = event_handler self.event_handler = cast(StreamHandlerImpl, event_handler)
self.current_timestamp = 0 self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set" self.latest_args: str | list[Any] = "not_set"
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
@@ -322,7 +416,7 @@ class AudioCallback(AudioStreamTrack):
frame = cast(AudioFrame, await self.track.recv()) frame = cast(AudioFrame, await self.track.recv())
for frame in self.event_handler.resample(frame): for frame in self.event_handler.resample(frame):
numpy_array = frame.to_ndarray() numpy_array = frame.to_ndarray()
if isinstance(self.event_handler, AsyncStreamHandler): if isinstance(self.event_handler, AsyncHandler):
await self.event_handler.receive( await self.event_handler.receive(
(frame.sample_rate, numpy_array) (frame.sample_rate, numpy_array)
) )
@@ -337,7 +431,7 @@ class AudioCallback(AudioStreamTrack):
def start(self): def start(self):
if not self.has_started: if not self.has_started:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
if isinstance(self.event_handler, AsyncStreamHandler): if isinstance(self.event_handler, AsyncHandler):
callable = self.event_handler.emit callable = self.event_handler.emit
else: else:
callable = functools.partial( callable = functools.partial(
@@ -358,7 +452,7 @@ class AudioCallback(AudioStreamTrack):
) )
self.has_started = True self.has_started = True
async def recv(self): async def recv(self): # type: ignore
try: try:
if self.readyState != "live": if self.readyState != "live":
raise MediaStreamError raise MediaStreamError
@@ -383,7 +477,7 @@ class AudioCallback(AudioStreamTrack):
# control playback rate # control playback rate
if self._start is None: if self._start is None:
self._start = time.time() - data_time self._start = time.time() - data_time # type: ignore
else: else:
wait = self._start + data_time - time.time() wait = self._start + data_time - time.time()
await asyncio.sleep(wait) await asyncio.sleep(wait)
@@ -434,7 +528,7 @@ class ServerToClientVideo(VideoStreamTrack):
self.latest_args = list(args) self.latest_args = list(args)
self.args_set.set() self.args_set.set()
async def recv(self): async def recv(self): # type: ignore
try: try:
pts, time_base = await self.next_timestamp() pts, time_base = await self.next_timestamp()
await self.args_set.wait() await self.args_set.wait()
@@ -523,7 +617,7 @@ class ServerToClientAudio(AudioStreamTrack):
) )
self.has_started = True self.has_started = True
async def recv(self): async def recv(self): # type: ignore
try: try:
if self.readyState != "live": if self.readyState != "live":
raise MediaStreamError raise MediaStreamError
@@ -539,7 +633,7 @@ class ServerToClientAudio(AudioStreamTrack):
# control playback rate # control playback rate
if data_time is not None: if data_time is not None:
if self._start is None: if self._start is None:
self._start = time.time() - data_time self._start = time.time() - data_time # type: ignore
else: else:
wait = self._start + data_time - time.time() wait = self._start + data_time - time.time()
await asyncio.sleep(wait) await asyncio.sleep(wait)
@@ -576,10 +670,12 @@ class WebRTC(Component):
pcs: set[RTCPeerConnection] = set([]) pcs: set[RTCPeerConnection] = set([])
relay = MediaRelay() relay = MediaRelay()
connections: dict[ connections: dict[
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback str,
] = {} list[VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback],
] = defaultdict(list)
data_channels: dict[str, DataChannel] = {} data_channels: dict[str, DataChannel] = {}
additional_outputs: dict[str, list[AdditionalOutputs]] = {} additional_outputs: dict[str, list[AdditionalOutputs]] = {}
handlers: dict[str, StreamHandlerBase | Callable] = {}
EVENTS = ["tick", "state_change"] EVENTS = ["tick", "state_change"]
@@ -606,7 +702,7 @@ class WebRTC(Component):
track_constraints: dict[str, Any] | None = None, track_constraints: dict[str, Any] | None = None,
time_limit: float | None = None, time_limit: float | None = None,
mode: Literal["send-receive", "receive", "send"] = "send-receive", mode: Literal["send-receive", "receive", "send"] = "send-receive",
modality: Literal["video", "audio"] = "video", modality: Literal["video", "audio", "audio-video"] = "video",
rtp_params: dict[str, Any] | None = None, rtp_params: dict[str, Any] | None = None,
icon: str | None = None, icon: str | None = None,
icon_button_color: str | None = None, icon_button_color: str | None = None,
@@ -669,6 +765,23 @@ class WebRTC(Component):
"height": {"ideal": 500}, "height": {"ideal": 500},
"frameRate": {"ideal": 30}, "frameRate": {"ideal": 30},
} }
if track_constraints is None and modality == "audio-video":
track_constraints = {
"video": {
"facingMode": "user",
"width": {"ideal": 500},
"height": {"ideal": 500},
"frameRate": {"ideal": 30},
},
"audio": {
"echoCancellation": True,
"noiseSuppression": {"exact": True},
"autoGainControl": {"exact": True},
"sampleRate": {"ideal": 24000},
"sampleSize": {"ideal": 16},
"channelCount": {"exact": 1},
},
}
self.track_constraints = track_constraints self.track_constraints = track_constraints
self.event_handler: Callable | StreamHandler | None = None self.event_handler: Callable | StreamHandler | None = None
super().__init__( super().__init__(
@@ -722,7 +835,8 @@ class WebRTC(Component):
def set_input(self, webrtc_id: str, *args): def set_input(self, webrtc_id: str, *args):
if webrtc_id in self.connections: if webrtc_id in self.connections:
self.connections[webrtc_id].set_args(list(args)) for conn in self.connections[webrtc_id]:
conn.set_args(list(args))
def on_additional_outputs( def on_additional_outputs(
self, self,
@@ -767,7 +881,10 @@ class WebRTC(Component):
def stream( def stream(
self, self,
fn: Callable[..., Any] | StreamHandler | AsyncStreamHandler | None = None, fn: Callable[..., Any]
| StreamHandlerImpl
| AudioVideoStreamHandlerImpl
| None = None,
inputs: Block | Sequence[Block] | set[Block] | None = None, inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | set[Block] | None = None, outputs: Block | Sequence[Block] | set[Block] | None = None,
js: str | None = None, js: str | None = None,
@@ -790,16 +907,16 @@ class WebRTC(Component):
self.concurrency_limit = ( self.concurrency_limit = (
1 if concurrency_limit in ["default", None] else concurrency_limit 1 if concurrency_limit in ["default", None] else concurrency_limit
) )
self.event_handler = fn self.event_handler = fn # type: ignore
self.time_limit = time_limit self.time_limit = time_limit
if ( if (
self.mode == "send-receive" self.mode == "send-receive"
and self.modality == "audio" and self.modality in ["audio", "audio-video"]
and not isinstance(self.event_handler, (AsyncStreamHandler, StreamHandler)) and not isinstance(self.event_handler, StreamHandlerBase)
): ):
raise ValueError( raise ValueError(
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler." "In the send-receive mode for audio, the event handler must be an instance of StreamHandlerBase."
) )
if self.mode == "send-receive" or self.mode == "send": if self.mode == "send-receive" or self.mode == "send":
@@ -815,13 +932,23 @@ class WebRTC(Component):
raise ValueError( raise ValueError(
"In the webrtc stream event, the only output component must be the WebRTC component." "In the webrtc stream event, the only output component must be the WebRTC component."
) )
for input_component in inputs[1:]: # type: ignore
if hasattr(input_component, "change"):
input_component.change( # type: ignore
self.set_input,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
concurrency_limit=None,
time_limit=None,
js=js,
)
return self.tick( # type: ignore return self.tick( # type: ignore
self.set_input, self.set_input,
inputs=inputs, inputs=inputs,
outputs=None, outputs=None,
concurrency_id=concurrency_id, concurrency_id=concurrency_id,
concurrency_limit=None, concurrency_limit=None,
stream_every=0.5,
time_limit=None, time_limit=None,
js=js, js=js,
) )
@@ -855,9 +982,11 @@ class WebRTC(Component):
await pc.close() await pc.close()
def clean_up(self, webrtc_id: str): def clean_up(self, webrtc_id: str):
connection = self.connections.pop(webrtc_id, None) self.handlers.pop(webrtc_id, None)
if isinstance(connection, AudioCallback): connection = self.connections.pop(webrtc_id, [])
connection.event_handler.shutdown() for conn in connection:
if isinstance(conn, AudioCallback):
conn.event_handler.shutdown()
self.additional_outputs.pop(webrtc_id, None) self.additional_outputs.pop(webrtc_id, None)
self.data_channels.pop(webrtc_id, None) self.data_channels.pop(webrtc_id, None)
return connection return connection
@@ -874,6 +1003,13 @@ class WebRTC(Component):
pc = RTCPeerConnection() pc = RTCPeerConnection()
self.pcs.add(pc) self.pcs.add(pc)
if isinstance(self.event_handler, StreamHandlerBase):
handler = self.event_handler.copy()
else:
handler = cast(Callable, self.event_handler)
self.handlers[body["webrtc_id"]] = handler
set_outputs = self.set_additional_outputs(body["webrtc_id"]) set_outputs = self.set_additional_outputs(body["webrtc_id"])
@pc.on("iceconnectionstatechange") @pc.on("iceconnectionstatechange")
@@ -891,7 +1027,8 @@ class WebRTC(Component):
await pc.close() await pc.close()
connection = self.clean_up(body["webrtc_id"]) connection = self.clean_up(body["webrtc_id"])
if connection: if connection:
connection.stop() for conn in connection:
conn.stop()
self.pcs.discard(pc) self.pcs.discard(pc)
if pc.connectionState == "connected": if pc.connectionState == "connected":
if self.time_limit is not None: if self.time_limit is not None:
@@ -900,28 +1037,38 @@ class WebRTC(Component):
@pc.on("track") @pc.on("track")
def on_track(track): def on_track(track):
relay = MediaRelay() relay = MediaRelay()
if self.modality == "video": handler = self.handlers[body["webrtc_id"]]
if self.modality == "video" and track.kind == "video":
cb = VideoCallback( cb = VideoCallback(
relay.subscribe(track), relay.subscribe(track),
event_handler=cast(Callable, self.event_handler), event_handler=cast(VideoEventHandler, handler),
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
mode=cast(Literal["send", "send-receive"], self.mode), mode=cast(Literal["send", "send-receive"], self.mode),
) )
elif self.modality == "audio": elif self.modality == "audio-video" and track.kind == "video":
handler = cast(StreamHandler, self.event_handler).copy() cb = VideoStreamHander(
handler._loop = asyncio.get_running_loop() relay.subscribe(track),
event_handler=handler, # type: ignore
set_additional_outputs=set_outputs,
)
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
eh = cast(StreamHandlerImpl, handler)
eh._loop = asyncio.get_running_loop()
cb = AudioCallback( cb = AudioCallback(
relay.subscribe(track), relay.subscribe(track),
event_handler=handler, event_handler=eh,
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
) )
else: else:
raise ValueError("Modality must be either video or audio") raise ValueError("Modality must be either video, audio, or audio-video")
self.connections[body["webrtc_id"]] = cb if body["webrtc_id"] not in self.connections:
self.connections[body["webrtc_id"]] = []
self.connections[body["webrtc_id"]].append(cb)
if body["webrtc_id"] in self.data_channels: if body["webrtc_id"] in self.data_channels:
self.connections[body["webrtc_id"]].set_channel( for conn in self.connections[body["webrtc_id"]]:
self.data_channels[body["webrtc_id"]] conn.set_channel(self.data_channels[body["webrtc_id"]])
)
if self.mode == "send-receive": if self.mode == "send-receive":
logger.debug("Adding track to peer connection %s", cb) logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb) pc.addTrack(cb)
@@ -944,7 +1091,7 @@ class WebRTC(Component):
logger.debug("Adding track to peer connection %s", cb) logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb) pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb self.connections[body["webrtc_id"]].append(cb)
cb.on("ended", lambda: self.clean_up(body["webrtc_id"])) cb.on("ended", lambda: self.clean_up(body["webrtc_id"]))
@pc.on("datachannel") @pc.on("datachannel")
@@ -957,7 +1104,8 @@ class WebRTC(Component):
while not self.connections.get(webrtc_id): while not self.connections.get(webrtc_id):
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
logger.debug("setting channel for webrtc id %s", webrtc_id) logger.debug("setting channel for webrtc id %s", webrtc_id)
self.connections[webrtc_id].set_channel(channel) for conn in self.connections[webrtc_id]:
conn.set_channel(channel)
asyncio.create_task(set_channel(body["webrtc_id"])) asyncio.create_task(set_channel(body["webrtc_id"]))

View File

@@ -30,7 +30,7 @@
export let gradio; export let gradio;
export let rtc_configuration: Object; export let rtc_configuration: Object;
export let time_limit: number | null = null; export let time_limit: number | null = null;
export let modality: "video" | "audio" = "video"; export let modality: "video" | "audio" | "audio-video" = "video";
export let mode: "send-receive" | "receive" | "send" = "send-receive"; export let mode: "send-receive" | "receive" | "send" = "send-receive";
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters; export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
export let track_constraints: MediaTrackConstraints = {}; export let track_constraints: MediaTrackConstraints = {};
@@ -52,18 +52,18 @@
</script> </script>
<Block <Block
{visible} {visible}
variant={"solid"} variant={"solid"}
border_mode={dragging ? "focus" : "base"} border_mode={dragging ? "focus" : "base"}
padding={false} padding={false}
{elem_id} {elem_id}
{elem_classes} {elem_classes}
{height} {height}
{width} {width}
{container} {container}
{scale} {scale}
{min_width} {min_width}
allow_overflow={false} allow_overflow={false}
> >
<StatusTracker <StatusTracker
autoscroll={gradio.autoscroll} autoscroll={gradio.autoscroll}
@@ -99,13 +99,13 @@
on:error={({ detail }) => gradio.dispatch("error", detail)} on:error={({ detail }) => gradio.dispatch("error", detail)}
/> />
{:else if (mode === "send-receive" || mode == "send") && modality === "video"} {:else if (mode === "send-receive" || mode == "send") && (modality === "video" || modality == "audio-video")}
<Video <Video
bind:value={value} bind:value={value}
{label} {label}
{show_label} {show_label}
active_source={"webcam"} active_source={"webcam"}
include_audio={false} include_audio={modality === "audio-video"}
{server} {server}
{rtc_configuration} {rtc_configuration}
{time_limit} {time_limit}
@@ -113,6 +113,9 @@
{track_constraints} {track_constraints}
{rtp_params} {rtp_params}
{on_change_cb} {on_change_cb}
{icon}
{icon_button_color}
{pulse_color}
on:clear={() => gradio.dispatch("clear")} on:clear={() => gradio.dispatch("clear")}
on:play={() => gradio.dispatch("play")} on:play={() => gradio.dispatch("play")}
on:pause={() => gradio.dispatch("pause")} on:pause={() => gradio.dispatch("pause")}

View File

@@ -1,10 +1,13 @@
<script lang="ts"> <script lang="ts">
import { onDestroy } from 'svelte'; import { onDestroy } from 'svelte';
import type {ComponentType} from 'svelte';
import PulsingIcon from './PulsingIcon.svelte';
export let numBars = 16; export let numBars = 16;
export let stream_state: "open" | "closed" | "waiting" = "closed"; export let stream_state: "open" | "closed" | "waiting" = "closed";
export let audio_source_callback: () => MediaStream; export let audio_source_callback: () => MediaStream;
export let icon: string | undefined = undefined; export let icon: string | undefined | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)"; export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)"; export let pulse_color: string = "var(--color-accent)";
@@ -13,7 +16,6 @@
let dataArray: Uint8Array; let dataArray: Uint8Array;
let animationId: number; let animationId: number;
let pulseScale = 1; let pulseScale = 1;
let pulseIntensity = 0;
$: containerWidth = icon $: containerWidth = icon
? "128px" ? "128px"
@@ -47,53 +49,31 @@
function updateVisualization() { function updateVisualization() {
analyser.getByteFrequencyData(dataArray); analyser.getByteFrequencyData(dataArray);
if (icon) {
// Calculate average amplitude for pulse effect
const average = Array.from(dataArray).reduce((a, b) => a + b, 0) / dataArray.length;
const normalizedAverage = average / 255;
pulseScale = 1 + (normalizedAverage * 0.15);
pulseIntensity = normalizedAverage;
} else {
// Update bars // Update bars
const bars = document.querySelectorAll('.gradio-webrtc-waveContainer .gradio-webrtc-box'); const bars = document.querySelectorAll('.gradio-webrtc-waveContainer .gradio-webrtc-box');
for (let i = 0; i < bars.length; i++) { for (let i = 0; i < bars.length; i++) {
const barHeight = (dataArray[i] / 255) * 2; const barHeight = (dataArray[i] / 255) * 2;
bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`; bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`;
} }
}
animationId = requestAnimationFrame(updateVisualization); animationId = requestAnimationFrame(updateVisualization);
} }
$: maxPulseScale = 1 + (pulseIntensity * 10); // Scale from 1x to 3x based on intensity
</script> </script>
<div class="gradio-webrtc-waveContainer"> <div class="gradio-webrtc-waveContainer">
{#if icon} {#if icon}
<div class="gradio-webrtc-icon-container"> <div class="gradio-webrtc-icon-container">
{#if pulseIntensity > 0}
{#each Array(3) as _, i}
<div
class="pulse-ring"
style:background={pulse_color}
style:animation-delay={`${i * 0.4}s`}
style:--max-scale={maxPulseScale}
style:opacity={0.5 * pulseIntensity}
/>
{/each}
{/if}
<div <div
class="gradio-webrtc-icon" class="gradio-webrtc-icon"
style:transform={`scale(${pulseScale})`} style:transform={`scale(${pulseScale})`}
style:background={icon_button_color} style:background={icon_button_color}
> >
<img <PulsingIcon
src={icon} {stream_state}
alt="Audio visualization icon" {pulse_color}
class="icon-image" {icon}
/> {icon_button_color}
{audio_source_callback}/>
</div> </div>
</div> </div>
{:else} {:else}

View File

@@ -1,5 +1,6 @@
<script lang="ts"> <script lang="ts">
import { createEventDispatcher } from "svelte"; import { createEventDispatcher } from "svelte";
import type { ComponentType } from "svelte";
import type { FileData, Client } from "@gradio/client"; import type { FileData, Client } from "@gradio/client";
import { BlockLabel } from "@gradio/atoms"; import { BlockLabel } from "@gradio/atoms";
import Webcam from "./Webcam.svelte"; import Webcam from "./Webcam.svelte";
@@ -24,6 +25,9 @@
export let mode: "send" | "send-receive"; export let mode: "send" | "send-receive";
export let on_change_cb: (msg: "change" | "tick") => void; export let on_change_cb: (msg: "change" | "tick") => void;
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters; export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
export let icon: string | undefined | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
const dispatch = createEventDispatcher<{ const dispatch = createEventDispatcher<{
change: FileData | null; change: FileData | null;
@@ -56,6 +60,9 @@
{mode} {mode}
{rtp_params} {rtp_params}
{on_change_cb} {on_change_cb}
{icon}
{icon_button_color}
{pulse_color}
on:error on:error
on:start_recording on:start_recording
on:stop_recording on:stop_recording

View File

@@ -0,0 +1,151 @@
<script lang="ts">
import { onDestroy } from 'svelte';
import type {ComponentType} from 'svelte';
export let stream_state: "open" | "closed" | "waiting" = "closed";
export let audio_source_callback: () => MediaStream;
export let icon: string | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
let audioContext: AudioContext;
let analyser: AnalyserNode;
let dataArray: Uint8Array;
let animationId: number;
let pulseScale = 1;
let pulseIntensity = 0;
$: if(stream_state === "open") setupAudioContext();
onDestroy(() => {
if (animationId) {
cancelAnimationFrame(animationId);
}
if (audioContext) {
audioContext.close();
}
});
function setupAudioContext() {
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
const source = audioContext.createMediaStreamSource(audio_source_callback());
source.connect(analyser);
analyser.fftSize = 64;
analyser.smoothingTimeConstant = 0.8;
dataArray = new Uint8Array(analyser.frequencyBinCount);
updateVisualization();
}
function updateVisualization() {
analyser.getByteFrequencyData(dataArray);
// Calculate average amplitude for pulse effect
const average = Array.from(dataArray).reduce((a, b) => a + b, 0) / dataArray.length;
const normalizedAverage = average / 255;
pulseScale = 1 + (normalizedAverage * 0.15);
pulseIntensity = normalizedAverage;
animationId = requestAnimationFrame(updateVisualization);
}
$: maxPulseScale = 1 + (pulseIntensity * 10); // Scale from 1x to 3x based on intensity
</script>
<div class="gradio-webrtc-icon-wrapper">
<div class="gradio-webrtc-pulsing-icon-container">
{#if pulseIntensity > 0}
{#each Array(3) as _, i}
<div
class="pulse-ring"
style:background={pulse_color}
style:animation-delay={`${i * 0.4}s`}
style:--max-scale={maxPulseScale}
style:opacity={0.5 * pulseIntensity}
/>
{/each}
{/if}
<div
class="gradio-webrtc-pulsing-icon"
style:transform={`scale(${pulseScale})`}
style:background={icon_button_color}
>
{#if typeof icon === "string"}
<img
src={icon}
alt="Audio visualization icon"
class="icon-image"
/>
{:else}
<svelte:component this={icon} />
{/if}
</div>
</div>
</div>
<style>
.gradio-webrtc-icon-wrapper {
position: relative;
display: flex;
max-height: 128px;
justify-content: center;
align-items: center;
}
.gradio-webrtc-pulsing-icon-container {
position: relative;
width: 100%;
height: 100%;
display: flex;
justify-content: center;
align-items: center;
}
.gradio-webrtc-pulsing-icon {
position: relative;
width: 100%;
height: 100%;
border-radius: 50%;
transition: transform 0.1s ease;
display: flex;
justify-content: center;
align-items: center;
z-index: 2;
}
.icon-image {
width: 100%;
height: 100%;
object-fit: contain;
filter: brightness(0) invert(1);
}
.pulse-ring {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
width: 100%;
height: 100%;
border-radius: 50%;
animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
opacity: 0.5;
}
@keyframes pulse {
0% {
transform: translate(-50%, -50%) scale(1);
opacity: 0.5;
}
100% {
transform: translate(-50%, -50%) scale(var(--max-scale, 3));
opacity: 0;
}
}

View File

@@ -98,7 +98,7 @@
/> />
<audio <audio
class="standard-player" class="standard-player"
class:hidden={value === "__webrtc_value__"} class:hidden={true}
on:load on:load
bind:this={audio_player} bind:this={audio_player}
on:ended={() => dispatch("stop")} on:ended={() => dispatch("stop")}

View File

@@ -1,10 +1,12 @@
<script lang="ts"> <script lang="ts">
import { createEventDispatcher, onMount } from "svelte"; import { createEventDispatcher, onMount } from "svelte";
import type { ComponentType } from "svelte";
import { import {
Circle, Circle,
Square, Square,
DropdownArrow, DropdownArrow,
Spinner Spinner,
Microphone as Mic
} from "@gradio/icons"; } from "@gradio/icons";
import type { I18nFormatter } from "@gradio/utils"; import type { I18nFormatter } from "@gradio/utils";
import { StreamingBar } from "@gradio/statustracker"; import { StreamingBar } from "@gradio/statustracker";
@@ -15,8 +17,8 @@
get_video_stream, get_video_stream,
set_available_devices set_available_devices
} from "./stream_utils"; } from "./stream_utils";
import { start, stop } from "./webrtc_utils"; import { start, stop } from "./webrtc_utils";
import PulsingIcon from "./PulsingIcon.svelte";
let video_source: HTMLVideoElement; let video_source: HTMLVideoElement;
let available_video_devices: MediaDeviceInfo[] = []; let available_video_devices: MediaDeviceInfo[] = [];
@@ -28,6 +30,9 @@
export let mode: "send-receive" | "send"; export let mode: "send-receive" | "send";
const _webrtc_id = Math.random().toString(36).substring(2); const _webrtc_id = Math.random().toString(36).substring(2);
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters; export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
export let icon: string | undefined | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
export const modify_stream: (state: "open" | "closed" | "waiting") => void = ( export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
state: "open" | "closed" | "waiting" state: "open" | "closed" | "waiting"
@@ -156,14 +161,13 @@
_time_limit = null; _time_limit = null;
await access_webcam(); await access_webcam();
} }
} }
window.setInterval(() => { // window.setInterval(() => {
if (stream_state == "open") { // if (stream_state == "open") {
dispatch("tick"); // dispatch("tick");
} // }
}, stream_every * 1000); // }, stream_every * 1000);
let options_open = false; let options_open = false;
@@ -192,16 +196,29 @@
event.stopPropagation(); event.stopPropagation();
options_open = false; options_open = false;
} }
const audio_source_callback = () => video_source.srcObject as MediaStream;
</script> </script>
<div class="wrap"> <div class="wrap">
<StreamingBar time_limit={_time_limit} /> <StreamingBar time_limit={_time_limit} />
{#if stream_state === "open" && include_audio}
<div class="audio-indicator">
<PulsingIcon
stream_state={stream_state}
audio_source_callback={audio_source_callback}
icon={icon || Mic}
icon_button_color={icon_button_color}
pulse_color={pulse_color}
/>
</div>
{/if}
<!-- svelte-ignore a11y-media-has-caption --> <!-- svelte-ignore a11y-media-has-caption -->
<!-- need to suppress for video streaming https://github.com/sveltejs/svelte/issues/5967 --> <!-- need to suppress for video streaming https://github.com/sveltejs/svelte/issues/5967 -->
<video <video
bind:this={video_source} bind:this={video_source}
class:hide={!webcam_accessed} class:hide={!webcam_accessed}
class:flip={(stream_state != "open")} class:flip={(stream_state != "open") || (stream_state === "open" && include_audio)}
autoplay={true} autoplay={true}
playsinline={true} playsinline={true}
/> />
@@ -324,6 +341,15 @@
justify-content: space-evenly; justify-content: space-evenly;
} }
.audio-indicator {
position: absolute;
top: var(--size-2);
right: var(--size-2);
z-index: var(--layer-2);
height: var(--size-5);
width: var(--size-5);
}
@media (--screen-md) { @media (--screen-md) {
button { button {
bottom: var(--size-4); bottom: var(--size-4);

View File

@@ -68,14 +68,14 @@ export async function start(
try { try {
event_json = JSON.parse(event.data); event_json = JSON.parse(event.data);
} catch (e) { } catch (e) {
console.debug("Error parsing JSON") console.debug("Error parsing JSON");
} }
console.log("event_json", event_json); console.log("event_json", event_json);
if ( if (
event.data === "change" || event.data === "change" ||
event.data === "tick" || event.data === "tick" ||
event.data === "stopword" || event.data === "stopword" ||
event_json?.type === "warning" || event_json?.type === "warning" ||
event_json?.type === "error" event_json?.type === "error"
) { ) {
console.debug(`${event.data} event received`); console.debug(`${event.data} event received`);

View File

@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "gradio_webrtc" name = "gradio_webrtc"
version = "0.0.27" version = "0.0.28"
description = "Stream images in realtime with webrtc" description = "Stream images in realtime with webrtc"
readme = "README.md" readme = "README.md"
license = "apache-2.0" license = "apache-2.0"