mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
code (#48)
This commit is contained in:
@@ -16,10 +16,21 @@ from .utils import (
|
|||||||
audio_to_file,
|
audio_to_file,
|
||||||
audio_to_float32,
|
audio_to_float32,
|
||||||
)
|
)
|
||||||
from .webrtc import AsyncStreamHandler, StreamHandler, WebRTC
|
from .webrtc import (
|
||||||
|
AsyncAudioVideoStreamHandler,
|
||||||
|
AsyncStreamHandler,
|
||||||
|
AudioVideoStreamHandler,
|
||||||
|
StreamHandler,
|
||||||
|
WebRTC,
|
||||||
|
VideoEmitType,
|
||||||
|
AudioEmitType,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AsyncStreamHandler",
|
"AsyncStreamHandler",
|
||||||
|
"AudioVideoStreamHandler",
|
||||||
|
"AudioEmitType",
|
||||||
|
"AsyncAudioVideoStreamHandler",
|
||||||
"AlgoOptions",
|
"AlgoOptions",
|
||||||
"AdditionalOutputs",
|
"AdditionalOutputs",
|
||||||
"aggregate_bytes_to_16bit",
|
"aggregate_bytes_to_16bit",
|
||||||
@@ -36,6 +47,7 @@ __all__ = [
|
|||||||
"stt",
|
"stt",
|
||||||
"stt_for_chunks",
|
"stt_for_chunks",
|
||||||
"StreamHandler",
|
"StreamHandler",
|
||||||
|
"VideoEmitType",
|
||||||
"WebRTC",
|
"WebRTC",
|
||||||
"WebRTCError",
|
"WebRTCError",
|
||||||
"Warning",
|
"Warning",
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from typing import Any, Callable, Generator, Literal, Union, cast
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
|
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
|
||||||
from gradio_webrtc.utils import AdditionalOutputs
|
|
||||||
from gradio_webrtc.webrtc import EmitType, StreamHandler
|
from gradio_webrtc.webrtc import EmitType, StreamHandler
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|||||||
@@ -147,16 +147,18 @@ async def player_worker_decode(
|
|||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"received array with shape %s sample rate %s layout %s",
|
"received array with shape %s sample rate %s layout %s",
|
||||||
audio_array.shape,
|
audio_array.shape, # type: ignore
|
||||||
sample_rate,
|
sample_rate,
|
||||||
layout,
|
layout, # type: ignore
|
||||||
)
|
)
|
||||||
format = "s16" if audio_array.dtype == "int16" else "fltp"
|
format = "s16" if audio_array.dtype == "int16" else "fltp" # type: ignore
|
||||||
|
|
||||||
# Convert to audio frame and resample
|
# Convert to audio frame and resample
|
||||||
# This runs in the same timeout context
|
# This runs in the same timeout context
|
||||||
frame = av.AudioFrame.from_ndarray( # type: ignore
|
frame = av.AudioFrame.from_ndarray( # type: ignore
|
||||||
audio_array, format=format, layout=layout
|
audio_array, # type: ignore
|
||||||
|
format=format,
|
||||||
|
layout=layout, # type: ignore
|
||||||
)
|
)
|
||||||
frame.sample_rate = sample_rate
|
frame.sample_rate = sample_rate
|
||||||
|
|
||||||
|
|||||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@@ -40,6 +42,7 @@ from aiortc.mediastreams import MediaStreamError
|
|||||||
from gradio import wasm_utils
|
from gradio import wasm_utils
|
||||||
from gradio.components.base import Component, server
|
from gradio.components.base import Component, server
|
||||||
from gradio_client import handle_file
|
from gradio_client import handle_file
|
||||||
|
from numpy import typing as npt
|
||||||
|
|
||||||
from gradio_webrtc.utils import (
|
from gradio_webrtc.utils import (
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
@@ -61,6 +64,11 @@ if wasm_utils.IS_WASM:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VideoEmitType = Union[
|
||||||
|
AdditionalOutputs, tuple[npt.ArrayLike, AdditionalOutputs], npt.ArrayLike, None
|
||||||
|
]
|
||||||
|
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
|
||||||
|
|
||||||
|
|
||||||
class VideoCallback(VideoStreamTrack):
|
class VideoCallback(VideoStreamTrack):
|
||||||
"""
|
"""
|
||||||
@@ -72,7 +80,7 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
track: MediaStreamTrack,
|
track: MediaStreamTrack,
|
||||||
event_handler: Callable,
|
event_handler: VideoEventHandler,
|
||||||
channel: DataChannel | None = None,
|
channel: DataChannel | None = None,
|
||||||
set_additional_outputs: Callable | None = None,
|
set_additional_outputs: Callable | None = None,
|
||||||
mode: Literal["send-receive", "send"] = "send-receive",
|
mode: Literal["send-receive", "send"] = "send-receive",
|
||||||
@@ -86,6 +94,7 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
self.thread_quit = asyncio.Event()
|
self.thread_quit = asyncio.Event()
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.channel_set = asyncio.Event()
|
self.channel_set = asyncio.Event()
|
||||||
|
self.has_started = False
|
||||||
|
|
||||||
def set_channel(self, channel: DataChannel):
|
def set_channel(self, channel: DataChannel):
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
@@ -132,7 +141,7 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
if current_channel.get() != self.channel:
|
if current_channel.get() != self.channel:
|
||||||
current_channel.set(self.channel)
|
current_channel.set(self.channel)
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self): # type: ignore
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
frame = cast(VideoFrame, await self.track.recv())
|
frame = cast(VideoFrame, await self.track.recv())
|
||||||
@@ -142,7 +151,6 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
|
|
||||||
await self.wait_for_channel()
|
await self.wait_for_channel()
|
||||||
frame_array = frame.to_ndarray(format="bgr24")
|
frame_array = frame.to_ndarray(format="bgr24")
|
||||||
|
|
||||||
if self.latest_args == "not_set":
|
if self.latest_args == "not_set":
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
@@ -253,6 +261,7 @@ EmitType: TypeAlias = Union[
|
|||||||
tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
||||||
None,
|
None,
|
||||||
]
|
]
|
||||||
|
AudioEmitType = EmitType
|
||||||
|
|
||||||
|
|
||||||
class StreamHandler(StreamHandlerBase):
|
class StreamHandler(StreamHandlerBase):
|
||||||
@@ -282,19 +291,104 @@ class AsyncStreamHandler(StreamHandlerBase):
|
|||||||
StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler]
|
StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler]
|
||||||
|
|
||||||
|
|
||||||
|
class AudioVideoStreamHandler(StreamHandlerBase):
|
||||||
|
@abstractmethod
|
||||||
|
def video_receive(self, frame: npt.NDArray) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def video_emit(
|
||||||
|
self,
|
||||||
|
) -> VideoEmitType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncAudioVideoStreamHandler(StreamHandlerBase):
|
||||||
|
@abstractmethod
|
||||||
|
async def video_receive(self, frame: npt.NDArray) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def video_emit(
|
||||||
|
self,
|
||||||
|
) -> VideoEmitType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
VideoStreamHandlerImpl = Union[AudioVideoStreamHandler, AsyncAudioVideoStreamHandler]
|
||||||
|
AudioVideoStreamHandlerImpl = Union[
|
||||||
|
AudioVideoStreamHandler, AsyncAudioVideoStreamHandler
|
||||||
|
]
|
||||||
|
AsyncHandler = Union[AsyncStreamHandler, AsyncAudioVideoStreamHandler]
|
||||||
|
|
||||||
|
|
||||||
|
class VideoStreamHander(VideoCallback):
|
||||||
|
async def process_frames(self):
|
||||||
|
while not self.thread_quit.is_set():
|
||||||
|
try:
|
||||||
|
await self.channel_set.wait()
|
||||||
|
frame = cast(VideoFrame, await self.track.recv())
|
||||||
|
frame_array = frame.to_ndarray(format="bgr24")
|
||||||
|
handler = cast(VideoStreamHandlerImpl, self.event_handler)
|
||||||
|
if inspect.iscoroutinefunction(handler.video_receive):
|
||||||
|
await handler.video_receive(frame_array)
|
||||||
|
else:
|
||||||
|
handler.video_receive(frame_array)
|
||||||
|
except MediaStreamError:
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
if not self.has_started:
|
||||||
|
asyncio.create_task(self.process_frames())
|
||||||
|
self.has_started = True
|
||||||
|
|
||||||
|
async def recv(self): # type: ignore
|
||||||
|
self.start()
|
||||||
|
try:
|
||||||
|
handler = cast(VideoStreamHandlerImpl, self.event_handler)
|
||||||
|
if inspect.iscoroutinefunction(handler.video_emit):
|
||||||
|
outputs = await handler.video_emit()
|
||||||
|
else:
|
||||||
|
outputs = handler.video_emit()
|
||||||
|
|
||||||
|
array, outputs = split_output(outputs)
|
||||||
|
if (
|
||||||
|
isinstance(outputs, AdditionalOutputs)
|
||||||
|
and self.set_additional_outputs
|
||||||
|
and self.channel
|
||||||
|
):
|
||||||
|
self.set_additional_outputs(outputs)
|
||||||
|
self.channel.send("change")
|
||||||
|
if array is None and self.mode == "send":
|
||||||
|
return
|
||||||
|
|
||||||
|
new_frame = self.array_to_frame(array)
|
||||||
|
|
||||||
|
# Will probably have to give developer ability to set pts and time_base
|
||||||
|
pts, time_base = await self.next_timestamp()
|
||||||
|
new_frame.pts = pts
|
||||||
|
new_frame.time_base = time_base
|
||||||
|
|
||||||
|
return new_frame
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("exception %s", e)
|
||||||
|
exec = traceback.format_exc()
|
||||||
|
logger.debug("traceback %s", exec)
|
||||||
|
|
||||||
|
|
||||||
class AudioCallback(AudioStreamTrack):
|
class AudioCallback(AudioStreamTrack):
|
||||||
kind = "audio"
|
kind = "audio"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
track: MediaStreamTrack,
|
track: MediaStreamTrack,
|
||||||
event_handler: StreamHandlerImpl,
|
event_handler: StreamHandlerBase,
|
||||||
channel: DataChannel | None = None,
|
channel: DataChannel | None = None,
|
||||||
set_additional_outputs: Callable | None = None,
|
set_additional_outputs: Callable | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.track = track
|
self.track = track
|
||||||
self.event_handler = event_handler
|
self.event_handler = cast(StreamHandlerImpl, event_handler)
|
||||||
self.current_timestamp = 0
|
self.current_timestamp = 0
|
||||||
self.latest_args: str | list[Any] = "not_set"
|
self.latest_args: str | list[Any] = "not_set"
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
@@ -322,7 +416,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
frame = cast(AudioFrame, await self.track.recv())
|
frame = cast(AudioFrame, await self.track.recv())
|
||||||
for frame in self.event_handler.resample(frame):
|
for frame in self.event_handler.resample(frame):
|
||||||
numpy_array = frame.to_ndarray()
|
numpy_array = frame.to_ndarray()
|
||||||
if isinstance(self.event_handler, AsyncStreamHandler):
|
if isinstance(self.event_handler, AsyncHandler):
|
||||||
await self.event_handler.receive(
|
await self.event_handler.receive(
|
||||||
(frame.sample_rate, numpy_array)
|
(frame.sample_rate, numpy_array)
|
||||||
)
|
)
|
||||||
@@ -337,7 +431,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
def start(self):
|
def start(self):
|
||||||
if not self.has_started:
|
if not self.has_started:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
if isinstance(self.event_handler, AsyncStreamHandler):
|
if isinstance(self.event_handler, AsyncHandler):
|
||||||
callable = self.event_handler.emit
|
callable = self.event_handler.emit
|
||||||
else:
|
else:
|
||||||
callable = functools.partial(
|
callable = functools.partial(
|
||||||
@@ -358,7 +452,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
)
|
)
|
||||||
self.has_started = True
|
self.has_started = True
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self): # type: ignore
|
||||||
try:
|
try:
|
||||||
if self.readyState != "live":
|
if self.readyState != "live":
|
||||||
raise MediaStreamError
|
raise MediaStreamError
|
||||||
@@ -383,7 +477,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
|
|
||||||
# control playback rate
|
# control playback rate
|
||||||
if self._start is None:
|
if self._start is None:
|
||||||
self._start = time.time() - data_time
|
self._start = time.time() - data_time # type: ignore
|
||||||
else:
|
else:
|
||||||
wait = self._start + data_time - time.time()
|
wait = self._start + data_time - time.time()
|
||||||
await asyncio.sleep(wait)
|
await asyncio.sleep(wait)
|
||||||
@@ -434,7 +528,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
self.latest_args = list(args)
|
self.latest_args = list(args)
|
||||||
self.args_set.set()
|
self.args_set.set()
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self): # type: ignore
|
||||||
try:
|
try:
|
||||||
pts, time_base = await self.next_timestamp()
|
pts, time_base = await self.next_timestamp()
|
||||||
await self.args_set.wait()
|
await self.args_set.wait()
|
||||||
@@ -523,7 +617,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
)
|
)
|
||||||
self.has_started = True
|
self.has_started = True
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self): # type: ignore
|
||||||
try:
|
try:
|
||||||
if self.readyState != "live":
|
if self.readyState != "live":
|
||||||
raise MediaStreamError
|
raise MediaStreamError
|
||||||
@@ -539,7 +633,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
# control playback rate
|
# control playback rate
|
||||||
if data_time is not None:
|
if data_time is not None:
|
||||||
if self._start is None:
|
if self._start is None:
|
||||||
self._start = time.time() - data_time
|
self._start = time.time() - data_time # type: ignore
|
||||||
else:
|
else:
|
||||||
wait = self._start + data_time - time.time()
|
wait = self._start + data_time - time.time()
|
||||||
await asyncio.sleep(wait)
|
await asyncio.sleep(wait)
|
||||||
@@ -576,10 +670,12 @@ class WebRTC(Component):
|
|||||||
pcs: set[RTCPeerConnection] = set([])
|
pcs: set[RTCPeerConnection] = set([])
|
||||||
relay = MediaRelay()
|
relay = MediaRelay()
|
||||||
connections: dict[
|
connections: dict[
|
||||||
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
|
str,
|
||||||
] = {}
|
list[VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback],
|
||||||
|
] = defaultdict(list)
|
||||||
data_channels: dict[str, DataChannel] = {}
|
data_channels: dict[str, DataChannel] = {}
|
||||||
additional_outputs: dict[str, list[AdditionalOutputs]] = {}
|
additional_outputs: dict[str, list[AdditionalOutputs]] = {}
|
||||||
|
handlers: dict[str, StreamHandlerBase | Callable] = {}
|
||||||
|
|
||||||
EVENTS = ["tick", "state_change"]
|
EVENTS = ["tick", "state_change"]
|
||||||
|
|
||||||
@@ -606,7 +702,7 @@ class WebRTC(Component):
|
|||||||
track_constraints: dict[str, Any] | None = None,
|
track_constraints: dict[str, Any] | None = None,
|
||||||
time_limit: float | None = None,
|
time_limit: float | None = None,
|
||||||
mode: Literal["send-receive", "receive", "send"] = "send-receive",
|
mode: Literal["send-receive", "receive", "send"] = "send-receive",
|
||||||
modality: Literal["video", "audio"] = "video",
|
modality: Literal["video", "audio", "audio-video"] = "video",
|
||||||
rtp_params: dict[str, Any] | None = None,
|
rtp_params: dict[str, Any] | None = None,
|
||||||
icon: str | None = None,
|
icon: str | None = None,
|
||||||
icon_button_color: str | None = None,
|
icon_button_color: str | None = None,
|
||||||
@@ -669,6 +765,23 @@ class WebRTC(Component):
|
|||||||
"height": {"ideal": 500},
|
"height": {"ideal": 500},
|
||||||
"frameRate": {"ideal": 30},
|
"frameRate": {"ideal": 30},
|
||||||
}
|
}
|
||||||
|
if track_constraints is None and modality == "audio-video":
|
||||||
|
track_constraints = {
|
||||||
|
"video": {
|
||||||
|
"facingMode": "user",
|
||||||
|
"width": {"ideal": 500},
|
||||||
|
"height": {"ideal": 500},
|
||||||
|
"frameRate": {"ideal": 30},
|
||||||
|
},
|
||||||
|
"audio": {
|
||||||
|
"echoCancellation": True,
|
||||||
|
"noiseSuppression": {"exact": True},
|
||||||
|
"autoGainControl": {"exact": True},
|
||||||
|
"sampleRate": {"ideal": 24000},
|
||||||
|
"sampleSize": {"ideal": 16},
|
||||||
|
"channelCount": {"exact": 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
self.track_constraints = track_constraints
|
self.track_constraints = track_constraints
|
||||||
self.event_handler: Callable | StreamHandler | None = None
|
self.event_handler: Callable | StreamHandler | None = None
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -722,7 +835,8 @@ class WebRTC(Component):
|
|||||||
|
|
||||||
def set_input(self, webrtc_id: str, *args):
|
def set_input(self, webrtc_id: str, *args):
|
||||||
if webrtc_id in self.connections:
|
if webrtc_id in self.connections:
|
||||||
self.connections[webrtc_id].set_args(list(args))
|
for conn in self.connections[webrtc_id]:
|
||||||
|
conn.set_args(list(args))
|
||||||
|
|
||||||
def on_additional_outputs(
|
def on_additional_outputs(
|
||||||
self,
|
self,
|
||||||
@@ -767,7 +881,10 @@ class WebRTC(Component):
|
|||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
fn: Callable[..., Any] | StreamHandler | AsyncStreamHandler | None = None,
|
fn: Callable[..., Any]
|
||||||
|
| StreamHandlerImpl
|
||||||
|
| AudioVideoStreamHandlerImpl
|
||||||
|
| None = None,
|
||||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||||
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||||
js: str | None = None,
|
js: str | None = None,
|
||||||
@@ -790,16 +907,16 @@ class WebRTC(Component):
|
|||||||
self.concurrency_limit = (
|
self.concurrency_limit = (
|
||||||
1 if concurrency_limit in ["default", None] else concurrency_limit
|
1 if concurrency_limit in ["default", None] else concurrency_limit
|
||||||
)
|
)
|
||||||
self.event_handler = fn
|
self.event_handler = fn # type: ignore
|
||||||
self.time_limit = time_limit
|
self.time_limit = time_limit
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.mode == "send-receive"
|
self.mode == "send-receive"
|
||||||
and self.modality == "audio"
|
and self.modality in ["audio", "audio-video"]
|
||||||
and not isinstance(self.event_handler, (AsyncStreamHandler, StreamHandler))
|
and not isinstance(self.event_handler, StreamHandlerBase)
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
|
"In the send-receive mode for audio, the event handler must be an instance of StreamHandlerBase."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.mode == "send-receive" or self.mode == "send":
|
if self.mode == "send-receive" or self.mode == "send":
|
||||||
@@ -815,13 +932,23 @@ class WebRTC(Component):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"In the webrtc stream event, the only output component must be the WebRTC component."
|
"In the webrtc stream event, the only output component must be the WebRTC component."
|
||||||
)
|
)
|
||||||
|
for input_component in inputs[1:]: # type: ignore
|
||||||
|
if hasattr(input_component, "change"):
|
||||||
|
input_component.change( # type: ignore
|
||||||
|
self.set_input,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=None,
|
||||||
|
concurrency_id=concurrency_id,
|
||||||
|
concurrency_limit=None,
|
||||||
|
time_limit=None,
|
||||||
|
js=js,
|
||||||
|
)
|
||||||
return self.tick( # type: ignore
|
return self.tick( # type: ignore
|
||||||
self.set_input,
|
self.set_input,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs=None,
|
outputs=None,
|
||||||
concurrency_id=concurrency_id,
|
concurrency_id=concurrency_id,
|
||||||
concurrency_limit=None,
|
concurrency_limit=None,
|
||||||
stream_every=0.5,
|
|
||||||
time_limit=None,
|
time_limit=None,
|
||||||
js=js,
|
js=js,
|
||||||
)
|
)
|
||||||
@@ -855,9 +982,11 @@ class WebRTC(Component):
|
|||||||
await pc.close()
|
await pc.close()
|
||||||
|
|
||||||
def clean_up(self, webrtc_id: str):
|
def clean_up(self, webrtc_id: str):
|
||||||
connection = self.connections.pop(webrtc_id, None)
|
self.handlers.pop(webrtc_id, None)
|
||||||
if isinstance(connection, AudioCallback):
|
connection = self.connections.pop(webrtc_id, [])
|
||||||
connection.event_handler.shutdown()
|
for conn in connection:
|
||||||
|
if isinstance(conn, AudioCallback):
|
||||||
|
conn.event_handler.shutdown()
|
||||||
self.additional_outputs.pop(webrtc_id, None)
|
self.additional_outputs.pop(webrtc_id, None)
|
||||||
self.data_channels.pop(webrtc_id, None)
|
self.data_channels.pop(webrtc_id, None)
|
||||||
return connection
|
return connection
|
||||||
@@ -874,6 +1003,13 @@ class WebRTC(Component):
|
|||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
self.pcs.add(pc)
|
self.pcs.add(pc)
|
||||||
|
|
||||||
|
if isinstance(self.event_handler, StreamHandlerBase):
|
||||||
|
handler = self.event_handler.copy()
|
||||||
|
else:
|
||||||
|
handler = cast(Callable, self.event_handler)
|
||||||
|
|
||||||
|
self.handlers[body["webrtc_id"]] = handler
|
||||||
|
|
||||||
set_outputs = self.set_additional_outputs(body["webrtc_id"])
|
set_outputs = self.set_additional_outputs(body["webrtc_id"])
|
||||||
|
|
||||||
@pc.on("iceconnectionstatechange")
|
@pc.on("iceconnectionstatechange")
|
||||||
@@ -891,7 +1027,8 @@ class WebRTC(Component):
|
|||||||
await pc.close()
|
await pc.close()
|
||||||
connection = self.clean_up(body["webrtc_id"])
|
connection = self.clean_up(body["webrtc_id"])
|
||||||
if connection:
|
if connection:
|
||||||
connection.stop()
|
for conn in connection:
|
||||||
|
conn.stop()
|
||||||
self.pcs.discard(pc)
|
self.pcs.discard(pc)
|
||||||
if pc.connectionState == "connected":
|
if pc.connectionState == "connected":
|
||||||
if self.time_limit is not None:
|
if self.time_limit is not None:
|
||||||
@@ -900,28 +1037,38 @@ class WebRTC(Component):
|
|||||||
@pc.on("track")
|
@pc.on("track")
|
||||||
def on_track(track):
|
def on_track(track):
|
||||||
relay = MediaRelay()
|
relay = MediaRelay()
|
||||||
if self.modality == "video":
|
handler = self.handlers[body["webrtc_id"]]
|
||||||
|
|
||||||
|
if self.modality == "video" and track.kind == "video":
|
||||||
cb = VideoCallback(
|
cb = VideoCallback(
|
||||||
relay.subscribe(track),
|
relay.subscribe(track),
|
||||||
event_handler=cast(Callable, self.event_handler),
|
event_handler=cast(VideoEventHandler, handler),
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
mode=cast(Literal["send", "send-receive"], self.mode),
|
mode=cast(Literal["send", "send-receive"], self.mode),
|
||||||
)
|
)
|
||||||
elif self.modality == "audio":
|
elif self.modality == "audio-video" and track.kind == "video":
|
||||||
handler = cast(StreamHandler, self.event_handler).copy()
|
cb = VideoStreamHander(
|
||||||
handler._loop = asyncio.get_running_loop()
|
relay.subscribe(track),
|
||||||
|
event_handler=handler, # type: ignore
|
||||||
|
set_additional_outputs=set_outputs,
|
||||||
|
)
|
||||||
|
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
|
||||||
|
eh = cast(StreamHandlerImpl, handler)
|
||||||
|
eh._loop = asyncio.get_running_loop()
|
||||||
cb = AudioCallback(
|
cb = AudioCallback(
|
||||||
relay.subscribe(track),
|
relay.subscribe(track),
|
||||||
event_handler=handler,
|
event_handler=eh,
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Modality must be either video or audio")
|
raise ValueError("Modality must be either video, audio, or audio-video")
|
||||||
self.connections[body["webrtc_id"]] = cb
|
if body["webrtc_id"] not in self.connections:
|
||||||
|
self.connections[body["webrtc_id"]] = []
|
||||||
|
|
||||||
|
self.connections[body["webrtc_id"]].append(cb)
|
||||||
if body["webrtc_id"] in self.data_channels:
|
if body["webrtc_id"] in self.data_channels:
|
||||||
self.connections[body["webrtc_id"]].set_channel(
|
for conn in self.connections[body["webrtc_id"]]:
|
||||||
self.data_channels[body["webrtc_id"]]
|
conn.set_channel(self.data_channels[body["webrtc_id"]])
|
||||||
)
|
|
||||||
if self.mode == "send-receive":
|
if self.mode == "send-receive":
|
||||||
logger.debug("Adding track to peer connection %s", cb)
|
logger.debug("Adding track to peer connection %s", cb)
|
||||||
pc.addTrack(cb)
|
pc.addTrack(cb)
|
||||||
@@ -944,7 +1091,7 @@ class WebRTC(Component):
|
|||||||
|
|
||||||
logger.debug("Adding track to peer connection %s", cb)
|
logger.debug("Adding track to peer connection %s", cb)
|
||||||
pc.addTrack(cb)
|
pc.addTrack(cb)
|
||||||
self.connections[body["webrtc_id"]] = cb
|
self.connections[body["webrtc_id"]].append(cb)
|
||||||
cb.on("ended", lambda: self.clean_up(body["webrtc_id"]))
|
cb.on("ended", lambda: self.clean_up(body["webrtc_id"]))
|
||||||
|
|
||||||
@pc.on("datachannel")
|
@pc.on("datachannel")
|
||||||
@@ -957,7 +1104,8 @@ class WebRTC(Component):
|
|||||||
while not self.connections.get(webrtc_id):
|
while not self.connections.get(webrtc_id):
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
logger.debug("setting channel for webrtc id %s", webrtc_id)
|
logger.debug("setting channel for webrtc id %s", webrtc_id)
|
||||||
self.connections[webrtc_id].set_channel(channel)
|
for conn in self.connections[webrtc_id]:
|
||||||
|
conn.set_channel(channel)
|
||||||
|
|
||||||
asyncio.create_task(set_channel(body["webrtc_id"]))
|
asyncio.create_task(set_channel(body["webrtc_id"]))
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@
|
|||||||
export let gradio;
|
export let gradio;
|
||||||
export let rtc_configuration: Object;
|
export let rtc_configuration: Object;
|
||||||
export let time_limit: number | null = null;
|
export let time_limit: number | null = null;
|
||||||
export let modality: "video" | "audio" = "video";
|
export let modality: "video" | "audio" | "audio-video" = "video";
|
||||||
export let mode: "send-receive" | "receive" | "send" = "send-receive";
|
export let mode: "send-receive" | "receive" | "send" = "send-receive";
|
||||||
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
||||||
export let track_constraints: MediaTrackConstraints = {};
|
export let track_constraints: MediaTrackConstraints = {};
|
||||||
@@ -52,18 +52,18 @@
|
|||||||
</script>
|
</script>
|
||||||
|
|
||||||
<Block
|
<Block
|
||||||
{visible}
|
{visible}
|
||||||
variant={"solid"}
|
variant={"solid"}
|
||||||
border_mode={dragging ? "focus" : "base"}
|
border_mode={dragging ? "focus" : "base"}
|
||||||
padding={false}
|
padding={false}
|
||||||
{elem_id}
|
{elem_id}
|
||||||
{elem_classes}
|
{elem_classes}
|
||||||
{height}
|
{height}
|
||||||
{width}
|
{width}
|
||||||
{container}
|
{container}
|
||||||
{scale}
|
{scale}
|
||||||
{min_width}
|
{min_width}
|
||||||
allow_overflow={false}
|
allow_overflow={false}
|
||||||
>
|
>
|
||||||
<StatusTracker
|
<StatusTracker
|
||||||
autoscroll={gradio.autoscroll}
|
autoscroll={gradio.autoscroll}
|
||||||
@@ -99,13 +99,13 @@
|
|||||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||||
|
|
||||||
/>
|
/>
|
||||||
{:else if (mode === "send-receive" || mode == "send") && modality === "video"}
|
{:else if (mode === "send-receive" || mode == "send") && (modality === "video" || modality == "audio-video")}
|
||||||
<Video
|
<Video
|
||||||
bind:value={value}
|
bind:value={value}
|
||||||
{label}
|
{label}
|
||||||
{show_label}
|
{show_label}
|
||||||
active_source={"webcam"}
|
active_source={"webcam"}
|
||||||
include_audio={false}
|
include_audio={modality === "audio-video"}
|
||||||
{server}
|
{server}
|
||||||
{rtc_configuration}
|
{rtc_configuration}
|
||||||
{time_limit}
|
{time_limit}
|
||||||
@@ -113,6 +113,9 @@
|
|||||||
{track_constraints}
|
{track_constraints}
|
||||||
{rtp_params}
|
{rtp_params}
|
||||||
{on_change_cb}
|
{on_change_cb}
|
||||||
|
{icon}
|
||||||
|
{icon_button_color}
|
||||||
|
{pulse_color}
|
||||||
on:clear={() => gradio.dispatch("clear")}
|
on:clear={() => gradio.dispatch("clear")}
|
||||||
on:play={() => gradio.dispatch("play")}
|
on:play={() => gradio.dispatch("play")}
|
||||||
on:pause={() => gradio.dispatch("pause")}
|
on:pause={() => gradio.dispatch("pause")}
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { onDestroy } from 'svelte';
|
import { onDestroy } from 'svelte';
|
||||||
|
import type {ComponentType} from 'svelte';
|
||||||
|
|
||||||
|
import PulsingIcon from './PulsingIcon.svelte';
|
||||||
|
|
||||||
export let numBars = 16;
|
export let numBars = 16;
|
||||||
export let stream_state: "open" | "closed" | "waiting" = "closed";
|
export let stream_state: "open" | "closed" | "waiting" = "closed";
|
||||||
export let audio_source_callback: () => MediaStream;
|
export let audio_source_callback: () => MediaStream;
|
||||||
export let icon: string | undefined = undefined;
|
export let icon: string | undefined | ComponentType = undefined;
|
||||||
export let icon_button_color: string = "var(--color-accent)";
|
export let icon_button_color: string = "var(--color-accent)";
|
||||||
export let pulse_color: string = "var(--color-accent)";
|
export let pulse_color: string = "var(--color-accent)";
|
||||||
|
|
||||||
@@ -13,7 +16,6 @@
|
|||||||
let dataArray: Uint8Array;
|
let dataArray: Uint8Array;
|
||||||
let animationId: number;
|
let animationId: number;
|
||||||
let pulseScale = 1;
|
let pulseScale = 1;
|
||||||
let pulseIntensity = 0;
|
|
||||||
|
|
||||||
$: containerWidth = icon
|
$: containerWidth = icon
|
||||||
? "128px"
|
? "128px"
|
||||||
@@ -47,53 +49,31 @@
|
|||||||
function updateVisualization() {
|
function updateVisualization() {
|
||||||
analyser.getByteFrequencyData(dataArray);
|
analyser.getByteFrequencyData(dataArray);
|
||||||
|
|
||||||
if (icon) {
|
|
||||||
// Calculate average amplitude for pulse effect
|
|
||||||
const average = Array.from(dataArray).reduce((a, b) => a + b, 0) / dataArray.length;
|
|
||||||
const normalizedAverage = average / 255;
|
|
||||||
pulseScale = 1 + (normalizedAverage * 0.15);
|
|
||||||
pulseIntensity = normalizedAverage;
|
|
||||||
} else {
|
|
||||||
// Update bars
|
// Update bars
|
||||||
const bars = document.querySelectorAll('.gradio-webrtc-waveContainer .gradio-webrtc-box');
|
const bars = document.querySelectorAll('.gradio-webrtc-waveContainer .gradio-webrtc-box');
|
||||||
for (let i = 0; i < bars.length; i++) {
|
for (let i = 0; i < bars.length; i++) {
|
||||||
const barHeight = (dataArray[i] / 255) * 2;
|
const barHeight = (dataArray[i] / 255) * 2;
|
||||||
bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`;
|
bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
animationId = requestAnimationFrame(updateVisualization);
|
animationId = requestAnimationFrame(updateVisualization);
|
||||||
}
|
}
|
||||||
|
|
||||||
$: maxPulseScale = 1 + (pulseIntensity * 10); // Scale from 1x to 3x based on intensity
|
|
||||||
|
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div class="gradio-webrtc-waveContainer">
|
<div class="gradio-webrtc-waveContainer">
|
||||||
{#if icon}
|
{#if icon}
|
||||||
<div class="gradio-webrtc-icon-container">
|
<div class="gradio-webrtc-icon-container">
|
||||||
{#if pulseIntensity > 0}
|
|
||||||
{#each Array(3) as _, i}
|
|
||||||
<div
|
|
||||||
class="pulse-ring"
|
|
||||||
style:background={pulse_color}
|
|
||||||
style:animation-delay={`${i * 0.4}s`}
|
|
||||||
style:--max-scale={maxPulseScale}
|
|
||||||
style:opacity={0.5 * pulseIntensity}
|
|
||||||
/>
|
|
||||||
{/each}
|
|
||||||
{/if}
|
|
||||||
|
|
||||||
<div
|
<div
|
||||||
class="gradio-webrtc-icon"
|
class="gradio-webrtc-icon"
|
||||||
style:transform={`scale(${pulseScale})`}
|
style:transform={`scale(${pulseScale})`}
|
||||||
style:background={icon_button_color}
|
style:background={icon_button_color}
|
||||||
>
|
>
|
||||||
<img
|
<PulsingIcon
|
||||||
src={icon}
|
{stream_state}
|
||||||
alt="Audio visualization icon"
|
{pulse_color}
|
||||||
class="icon-image"
|
{icon}
|
||||||
/>
|
{icon_button_color}
|
||||||
|
{audio_source_callback}/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{:else}
|
{:else}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { createEventDispatcher } from "svelte";
|
import { createEventDispatcher } from "svelte";
|
||||||
|
import type { ComponentType } from "svelte";
|
||||||
import type { FileData, Client } from "@gradio/client";
|
import type { FileData, Client } from "@gradio/client";
|
||||||
import { BlockLabel } from "@gradio/atoms";
|
import { BlockLabel } from "@gradio/atoms";
|
||||||
import Webcam from "./Webcam.svelte";
|
import Webcam from "./Webcam.svelte";
|
||||||
@@ -24,6 +25,9 @@
|
|||||||
export let mode: "send" | "send-receive";
|
export let mode: "send" | "send-receive";
|
||||||
export let on_change_cb: (msg: "change" | "tick") => void;
|
export let on_change_cb: (msg: "change" | "tick") => void;
|
||||||
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
||||||
|
export let icon: string | undefined | ComponentType = undefined;
|
||||||
|
export let icon_button_color: string = "var(--color-accent)";
|
||||||
|
export let pulse_color: string = "var(--color-accent)";
|
||||||
|
|
||||||
const dispatch = createEventDispatcher<{
|
const dispatch = createEventDispatcher<{
|
||||||
change: FileData | null;
|
change: FileData | null;
|
||||||
@@ -56,6 +60,9 @@
|
|||||||
{mode}
|
{mode}
|
||||||
{rtp_params}
|
{rtp_params}
|
||||||
{on_change_cb}
|
{on_change_cb}
|
||||||
|
{icon}
|
||||||
|
{icon_button_color}
|
||||||
|
{pulse_color}
|
||||||
on:error
|
on:error
|
||||||
on:start_recording
|
on:start_recording
|
||||||
on:stop_recording
|
on:stop_recording
|
||||||
|
|||||||
151
frontend/shared/PulsingIcon.svelte
Normal file
151
frontend/shared/PulsingIcon.svelte
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { onDestroy } from 'svelte';
|
||||||
|
import type {ComponentType} from 'svelte';
|
||||||
|
|
||||||
|
export let stream_state: "open" | "closed" | "waiting" = "closed";
|
||||||
|
export let audio_source_callback: () => MediaStream;
|
||||||
|
export let icon: string | ComponentType = undefined;
|
||||||
|
export let icon_button_color: string = "var(--color-accent)";
|
||||||
|
export let pulse_color: string = "var(--color-accent)";
|
||||||
|
|
||||||
|
let audioContext: AudioContext;
|
||||||
|
let analyser: AnalyserNode;
|
||||||
|
let dataArray: Uint8Array;
|
||||||
|
let animationId: number;
|
||||||
|
let pulseScale = 1;
|
||||||
|
let pulseIntensity = 0;
|
||||||
|
|
||||||
|
$: if(stream_state === "open") setupAudioContext();
|
||||||
|
|
||||||
|
onDestroy(() => {
|
||||||
|
if (animationId) {
|
||||||
|
cancelAnimationFrame(animationId);
|
||||||
|
}
|
||||||
|
if (audioContext) {
|
||||||
|
audioContext.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
function setupAudioContext() {
|
||||||
|
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
|
analyser = audioContext.createAnalyser();
|
||||||
|
const source = audioContext.createMediaStreamSource(audio_source_callback());
|
||||||
|
|
||||||
|
source.connect(analyser);
|
||||||
|
|
||||||
|
analyser.fftSize = 64;
|
||||||
|
analyser.smoothingTimeConstant = 0.8;
|
||||||
|
dataArray = new Uint8Array(analyser.frequencyBinCount);
|
||||||
|
|
||||||
|
updateVisualization();
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateVisualization() {
|
||||||
|
|
||||||
|
analyser.getByteFrequencyData(dataArray);
|
||||||
|
|
||||||
|
// Calculate average amplitude for pulse effect
|
||||||
|
const average = Array.from(dataArray).reduce((a, b) => a + b, 0) / dataArray.length;
|
||||||
|
const normalizedAverage = average / 255;
|
||||||
|
pulseScale = 1 + (normalizedAverage * 0.15);
|
||||||
|
pulseIntensity = normalizedAverage;
|
||||||
|
animationId = requestAnimationFrame(updateVisualization);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
$: maxPulseScale = 1 + (pulseIntensity * 10); // Scale from 1x to 3x based on intensity
|
||||||
|
|
||||||
|
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="gradio-webrtc-icon-wrapper">
|
||||||
|
<div class="gradio-webrtc-pulsing-icon-container">
|
||||||
|
{#if pulseIntensity > 0}
|
||||||
|
{#each Array(3) as _, i}
|
||||||
|
<div
|
||||||
|
class="pulse-ring"
|
||||||
|
style:background={pulse_color}
|
||||||
|
style:animation-delay={`${i * 0.4}s`}
|
||||||
|
style:--max-scale={maxPulseScale}
|
||||||
|
style:opacity={0.5 * pulseIntensity}
|
||||||
|
/>
|
||||||
|
{/each}
|
||||||
|
{/if}
|
||||||
|
|
||||||
|
<div
|
||||||
|
class="gradio-webrtc-pulsing-icon"
|
||||||
|
style:transform={`scale(${pulseScale})`}
|
||||||
|
style:background={icon_button_color}
|
||||||
|
>
|
||||||
|
{#if typeof icon === "string"}
|
||||||
|
<img
|
||||||
|
src={icon}
|
||||||
|
alt="Audio visualization icon"
|
||||||
|
class="icon-image"
|
||||||
|
/>
|
||||||
|
{:else}
|
||||||
|
<svelte:component this={icon} />
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<style>
|
||||||
|
.gradio-webrtc-icon-wrapper {
|
||||||
|
position: relative;
|
||||||
|
display: flex;
|
||||||
|
max-height: 128px;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gradio-webrtc-pulsing-icon-container {
|
||||||
|
position: relative;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gradio-webrtc-pulsing-icon {
|
||||||
|
position: relative;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
border-radius: 50%;
|
||||||
|
transition: transform 0.1s ease;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
z-index: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
.icon-image {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
object-fit: contain;
|
||||||
|
filter: brightness(0) invert(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.pulse-ring {
|
||||||
|
position: absolute;
|
||||||
|
top: 50%;
|
||||||
|
left: 50%;
|
||||||
|
transform: translate(-50%, -50%);
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
border-radius: 50%;
|
||||||
|
animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||||
|
opacity: 0.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
0% {
|
||||||
|
transform: translate(-50%, -50%) scale(1);
|
||||||
|
opacity: 0.5;
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: translate(-50%, -50%) scale(var(--max-scale, 3));
|
||||||
|
opacity: 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -98,7 +98,7 @@
|
|||||||
/>
|
/>
|
||||||
<audio
|
<audio
|
||||||
class="standard-player"
|
class="standard-player"
|
||||||
class:hidden={value === "__webrtc_value__"}
|
class:hidden={true}
|
||||||
on:load
|
on:load
|
||||||
bind:this={audio_player}
|
bind:this={audio_player}
|
||||||
on:ended={() => dispatch("stop")}
|
on:ended={() => dispatch("stop")}
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { createEventDispatcher, onMount } from "svelte";
|
import { createEventDispatcher, onMount } from "svelte";
|
||||||
|
import type { ComponentType } from "svelte";
|
||||||
import {
|
import {
|
||||||
Circle,
|
Circle,
|
||||||
Square,
|
Square,
|
||||||
DropdownArrow,
|
DropdownArrow,
|
||||||
Spinner
|
Spinner,
|
||||||
|
Microphone as Mic
|
||||||
} from "@gradio/icons";
|
} from "@gradio/icons";
|
||||||
import type { I18nFormatter } from "@gradio/utils";
|
import type { I18nFormatter } from "@gradio/utils";
|
||||||
import { StreamingBar } from "@gradio/statustracker";
|
import { StreamingBar } from "@gradio/statustracker";
|
||||||
@@ -15,8 +17,8 @@
|
|||||||
get_video_stream,
|
get_video_stream,
|
||||||
set_available_devices
|
set_available_devices
|
||||||
} from "./stream_utils";
|
} from "./stream_utils";
|
||||||
|
|
||||||
import { start, stop } from "./webrtc_utils";
|
import { start, stop } from "./webrtc_utils";
|
||||||
|
import PulsingIcon from "./PulsingIcon.svelte";
|
||||||
|
|
||||||
let video_source: HTMLVideoElement;
|
let video_source: HTMLVideoElement;
|
||||||
let available_video_devices: MediaDeviceInfo[] = [];
|
let available_video_devices: MediaDeviceInfo[] = [];
|
||||||
@@ -28,6 +30,9 @@
|
|||||||
export let mode: "send-receive" | "send";
|
export let mode: "send-receive" | "send";
|
||||||
const _webrtc_id = Math.random().toString(36).substring(2);
|
const _webrtc_id = Math.random().toString(36).substring(2);
|
||||||
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
||||||
|
export let icon: string | undefined | ComponentType = undefined;
|
||||||
|
export let icon_button_color: string = "var(--color-accent)";
|
||||||
|
export let pulse_color: string = "var(--color-accent)";
|
||||||
|
|
||||||
export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
|
export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
|
||||||
state: "open" | "closed" | "waiting"
|
state: "open" | "closed" | "waiting"
|
||||||
@@ -156,14 +161,13 @@
|
|||||||
_time_limit = null;
|
_time_limit = null;
|
||||||
await access_webcam();
|
await access_webcam();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
window.setInterval(() => {
|
// window.setInterval(() => {
|
||||||
if (stream_state == "open") {
|
// if (stream_state == "open") {
|
||||||
dispatch("tick");
|
// dispatch("tick");
|
||||||
}
|
// }
|
||||||
}, stream_every * 1000);
|
// }, stream_every * 1000);
|
||||||
|
|
||||||
let options_open = false;
|
let options_open = false;
|
||||||
|
|
||||||
@@ -192,16 +196,29 @@
|
|||||||
event.stopPropagation();
|
event.stopPropagation();
|
||||||
options_open = false;
|
options_open = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const audio_source_callback = () => video_source.srcObject as MediaStream;
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div class="wrap">
|
<div class="wrap">
|
||||||
<StreamingBar time_limit={_time_limit} />
|
<StreamingBar time_limit={_time_limit} />
|
||||||
|
{#if stream_state === "open" && include_audio}
|
||||||
|
<div class="audio-indicator">
|
||||||
|
<PulsingIcon
|
||||||
|
stream_state={stream_state}
|
||||||
|
audio_source_callback={audio_source_callback}
|
||||||
|
icon={icon || Mic}
|
||||||
|
icon_button_color={icon_button_color}
|
||||||
|
pulse_color={pulse_color}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
<!-- svelte-ignore a11y-media-has-caption -->
|
<!-- svelte-ignore a11y-media-has-caption -->
|
||||||
<!-- need to suppress for video streaming https://github.com/sveltejs/svelte/issues/5967 -->
|
<!-- need to suppress for video streaming https://github.com/sveltejs/svelte/issues/5967 -->
|
||||||
<video
|
<video
|
||||||
bind:this={video_source}
|
bind:this={video_source}
|
||||||
class:hide={!webcam_accessed}
|
class:hide={!webcam_accessed}
|
||||||
class:flip={(stream_state != "open")}
|
class:flip={(stream_state != "open") || (stream_state === "open" && include_audio)}
|
||||||
autoplay={true}
|
autoplay={true}
|
||||||
playsinline={true}
|
playsinline={true}
|
||||||
/>
|
/>
|
||||||
@@ -324,6 +341,15 @@
|
|||||||
justify-content: space-evenly;
|
justify-content: space-evenly;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.audio-indicator {
|
||||||
|
position: absolute;
|
||||||
|
top: var(--size-2);
|
||||||
|
right: var(--size-2);
|
||||||
|
z-index: var(--layer-2);
|
||||||
|
height: var(--size-5);
|
||||||
|
width: var(--size-5);
|
||||||
|
}
|
||||||
|
|
||||||
@media (--screen-md) {
|
@media (--screen-md) {
|
||||||
button {
|
button {
|
||||||
bottom: var(--size-4);
|
bottom: var(--size-4);
|
||||||
|
|||||||
@@ -68,14 +68,14 @@ export async function start(
|
|||||||
try {
|
try {
|
||||||
event_json = JSON.parse(event.data);
|
event_json = JSON.parse(event.data);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.debug("Error parsing JSON")
|
console.debug("Error parsing JSON");
|
||||||
}
|
}
|
||||||
console.log("event_json", event_json);
|
console.log("event_json", event_json);
|
||||||
if (
|
if (
|
||||||
event.data === "change" ||
|
event.data === "change" ||
|
||||||
event.data === "tick" ||
|
event.data === "tick" ||
|
||||||
event.data === "stopword" ||
|
event.data === "stopword" ||
|
||||||
event_json?.type === "warning" ||
|
event_json?.type === "warning" ||
|
||||||
event_json?.type === "error"
|
event_json?.type === "error"
|
||||||
) {
|
) {
|
||||||
console.debug(`${event.data} event received`);
|
console.debug(`${event.data} event received`);
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "gradio_webrtc"
|
name = "gradio_webrtc"
|
||||||
version = "0.0.27"
|
version = "0.0.28"
|
||||||
description = "Stream images in realtime with webrtc"
|
description = "Stream images in realtime with webrtc"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "apache-2.0"
|
license = "apache-2.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user