diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index f1a72cf..8f0db37 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -16,10 +16,21 @@ from .utils import ( audio_to_file, audio_to_float32, ) -from .webrtc import AsyncStreamHandler, StreamHandler, WebRTC +from .webrtc import ( + AsyncAudioVideoStreamHandler, + AsyncStreamHandler, + AudioVideoStreamHandler, + StreamHandler, + WebRTC, + VideoEmitType, + AudioEmitType, +) __all__ = [ "AsyncStreamHandler", + "AudioVideoStreamHandler", + "AudioEmitType", + "AsyncAudioVideoStreamHandler", "AlgoOptions", "AdditionalOutputs", "aggregate_bytes_to_16bit", @@ -36,6 +47,7 @@ __all__ = [ "stt", "stt_for_chunks", "StreamHandler", + "VideoEmitType", "WebRTC", "WebRTCError", "Warning", diff --git a/backend/gradio_webrtc/reply_on_pause.py b/backend/gradio_webrtc/reply_on_pause.py index 13c7c1e..5733fd4 100644 --- a/backend/gradio_webrtc/reply_on_pause.py +++ b/backend/gradio_webrtc/reply_on_pause.py @@ -9,7 +9,6 @@ from typing import Any, Callable, Generator, Literal, Union, cast import numpy as np from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions -from gradio_webrtc.utils import AdditionalOutputs from gradio_webrtc.webrtc import EmitType, StreamHandler logger = getLogger(__name__) diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index 895655e..a2823ad 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -147,16 +147,18 @@ async def player_worker_decode( logger.debug( "received array with shape %s sample rate %s layout %s", - audio_array.shape, + audio_array.shape, # type: ignore sample_rate, - layout, + layout, # type: ignore ) - format = "s16" if audio_array.dtype == "int16" else "fltp" + format = "s16" if audio_array.dtype == "int16" else "fltp" # type: ignore # Convert to audio frame and resample # This runs in the same timeout context frame = av.AudioFrame.from_ndarray( # type: ignore - audio_array, format=format, layout=layout + audio_array, # type: ignore + format=format, + layout=layout, # type: ignore ) frame.sample_rate = sample_rate diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index 158872e..093efe9 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -4,11 +4,13 @@ from __future__ import annotations import asyncio import functools +import inspect import logging import threading import time import traceback from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Callable from typing import ( TYPE_CHECKING, @@ -40,6 +42,7 @@ from aiortc.mediastreams import MediaStreamError from gradio import wasm_utils from gradio.components.base import Component, server from gradio_client import handle_file +from numpy import typing as npt from gradio_webrtc.utils import ( AdditionalOutputs, @@ -61,6 +64,11 @@ if wasm_utils.IS_WASM: logger = logging.getLogger(__name__) +VideoEmitType = Union[ + AdditionalOutputs, tuple[npt.ArrayLike, AdditionalOutputs], npt.ArrayLike, None +] +VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType] + class VideoCallback(VideoStreamTrack): """ @@ -72,7 +80,7 @@ class VideoCallback(VideoStreamTrack): def __init__( self, track: MediaStreamTrack, - event_handler: Callable, + event_handler: VideoEventHandler, channel: DataChannel | None = None, set_additional_outputs: Callable | None = None, mode: Literal["send-receive", "send"] = "send-receive", @@ -86,6 +94,7 @@ class VideoCallback(VideoStreamTrack): self.thread_quit = asyncio.Event() self.mode = mode self.channel_set = asyncio.Event() + self.has_started = False def set_channel(self, channel: DataChannel): self.channel = channel @@ -132,7 +141,7 @@ class VideoCallback(VideoStreamTrack): if current_channel.get() != self.channel: current_channel.set(self.channel) - async def recv(self): + async def recv(self): # type: ignore try: try: frame = cast(VideoFrame, await self.track.recv()) @@ -142,7 +151,6 @@ class VideoCallback(VideoStreamTrack): await self.wait_for_channel() frame_array = frame.to_ndarray(format="bgr24") - if self.latest_args == "not_set": return frame @@ -253,6 +261,7 @@ EmitType: TypeAlias = Union[ tuple[tuple[int, np.ndarray], AdditionalOutputs], None, ] +AudioEmitType = EmitType class StreamHandler(StreamHandlerBase): @@ -282,19 +291,104 @@ class AsyncStreamHandler(StreamHandlerBase): StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler] +class AudioVideoStreamHandler(StreamHandlerBase): + @abstractmethod + def video_receive(self, frame: npt.NDArray) -> None: + pass + + @abstractmethod + def video_emit( + self, + ) -> VideoEmitType: + pass + + +class AsyncAudioVideoStreamHandler(StreamHandlerBase): + @abstractmethod + async def video_receive(self, frame: npt.NDArray) -> None: + pass + + @abstractmethod + async def video_emit( + self, + ) -> VideoEmitType: + pass + + +VideoStreamHandlerImpl = Union[AudioVideoStreamHandler, AsyncAudioVideoStreamHandler] +AudioVideoStreamHandlerImpl = Union[ + AudioVideoStreamHandler, AsyncAudioVideoStreamHandler +] +AsyncHandler = Union[AsyncStreamHandler, AsyncAudioVideoStreamHandler] + + +class VideoStreamHander(VideoCallback): + async def process_frames(self): + while not self.thread_quit.is_set(): + try: + await self.channel_set.wait() + frame = cast(VideoFrame, await self.track.recv()) + frame_array = frame.to_ndarray(format="bgr24") + handler = cast(VideoStreamHandlerImpl, self.event_handler) + if inspect.iscoroutinefunction(handler.video_receive): + await handler.video_receive(frame_array) + else: + handler.video_receive(frame_array) + except MediaStreamError: + self.stop() + + def start(self): + if not self.has_started: + asyncio.create_task(self.process_frames()) + self.has_started = True + + async def recv(self): # type: ignore + self.start() + try: + handler = cast(VideoStreamHandlerImpl, self.event_handler) + if inspect.iscoroutinefunction(handler.video_emit): + outputs = await handler.video_emit() + else: + outputs = handler.video_emit() + + array, outputs = split_output(outputs) + if ( + isinstance(outputs, AdditionalOutputs) + and self.set_additional_outputs + and self.channel + ): + self.set_additional_outputs(outputs) + self.channel.send("change") + if array is None and self.mode == "send": + return + + new_frame = self.array_to_frame(array) + + # Will probably have to give developer ability to set pts and time_base + pts, time_base = await self.next_timestamp() + new_frame.pts = pts + new_frame.time_base = time_base + + return new_frame + except Exception as e: + logger.debug("exception %s", e) + exec = traceback.format_exc() + logger.debug("traceback %s", exec) + + class AudioCallback(AudioStreamTrack): kind = "audio" def __init__( self, track: MediaStreamTrack, - event_handler: StreamHandlerImpl, + event_handler: StreamHandlerBase, channel: DataChannel | None = None, set_additional_outputs: Callable | None = None, ) -> None: super().__init__() self.track = track - self.event_handler = event_handler + self.event_handler = cast(StreamHandlerImpl, event_handler) self.current_timestamp = 0 self.latest_args: str | list[Any] = "not_set" self.queue = asyncio.Queue() @@ -322,7 +416,7 @@ class AudioCallback(AudioStreamTrack): frame = cast(AudioFrame, await self.track.recv()) for frame in self.event_handler.resample(frame): numpy_array = frame.to_ndarray() - if isinstance(self.event_handler, AsyncStreamHandler): + if isinstance(self.event_handler, AsyncHandler): await self.event_handler.receive( (frame.sample_rate, numpy_array) ) @@ -337,7 +431,7 @@ class AudioCallback(AudioStreamTrack): def start(self): if not self.has_started: loop = asyncio.get_running_loop() - if isinstance(self.event_handler, AsyncStreamHandler): + if isinstance(self.event_handler, AsyncHandler): callable = self.event_handler.emit else: callable = functools.partial( @@ -358,7 +452,7 @@ class AudioCallback(AudioStreamTrack): ) self.has_started = True - async def recv(self): + async def recv(self): # type: ignore try: if self.readyState != "live": raise MediaStreamError @@ -383,7 +477,7 @@ class AudioCallback(AudioStreamTrack): # control playback rate if self._start is None: - self._start = time.time() - data_time + self._start = time.time() - data_time # type: ignore else: wait = self._start + data_time - time.time() await asyncio.sleep(wait) @@ -434,7 +528,7 @@ class ServerToClientVideo(VideoStreamTrack): self.latest_args = list(args) self.args_set.set() - async def recv(self): + async def recv(self): # type: ignore try: pts, time_base = await self.next_timestamp() await self.args_set.wait() @@ -523,7 +617,7 @@ class ServerToClientAudio(AudioStreamTrack): ) self.has_started = True - async def recv(self): + async def recv(self): # type: ignore try: if self.readyState != "live": raise MediaStreamError @@ -539,7 +633,7 @@ class ServerToClientAudio(AudioStreamTrack): # control playback rate if data_time is not None: if self._start is None: - self._start = time.time() - data_time + self._start = time.time() - data_time # type: ignore else: wait = self._start + data_time - time.time() await asyncio.sleep(wait) @@ -576,10 +670,12 @@ class WebRTC(Component): pcs: set[RTCPeerConnection] = set([]) relay = MediaRelay() connections: dict[ - str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback - ] = {} + str, + list[VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback], + ] = defaultdict(list) data_channels: dict[str, DataChannel] = {} additional_outputs: dict[str, list[AdditionalOutputs]] = {} + handlers: dict[str, StreamHandlerBase | Callable] = {} EVENTS = ["tick", "state_change"] @@ -606,7 +702,7 @@ class WebRTC(Component): track_constraints: dict[str, Any] | None = None, time_limit: float | None = None, mode: Literal["send-receive", "receive", "send"] = "send-receive", - modality: Literal["video", "audio"] = "video", + modality: Literal["video", "audio", "audio-video"] = "video", rtp_params: dict[str, Any] | None = None, icon: str | None = None, icon_button_color: str | None = None, @@ -669,6 +765,23 @@ class WebRTC(Component): "height": {"ideal": 500}, "frameRate": {"ideal": 30}, } + if track_constraints is None and modality == "audio-video": + track_constraints = { + "video": { + "facingMode": "user", + "width": {"ideal": 500}, + "height": {"ideal": 500}, + "frameRate": {"ideal": 30}, + }, + "audio": { + "echoCancellation": True, + "noiseSuppression": {"exact": True}, + "autoGainControl": {"exact": True}, + "sampleRate": {"ideal": 24000}, + "sampleSize": {"ideal": 16}, + "channelCount": {"exact": 1}, + }, + } self.track_constraints = track_constraints self.event_handler: Callable | StreamHandler | None = None super().__init__( @@ -722,7 +835,8 @@ class WebRTC(Component): def set_input(self, webrtc_id: str, *args): if webrtc_id in self.connections: - self.connections[webrtc_id].set_args(list(args)) + for conn in self.connections[webrtc_id]: + conn.set_args(list(args)) def on_additional_outputs( self, @@ -767,7 +881,10 @@ class WebRTC(Component): def stream( self, - fn: Callable[..., Any] | StreamHandler | AsyncStreamHandler | None = None, + fn: Callable[..., Any] + | StreamHandlerImpl + | AudioVideoStreamHandlerImpl + | None = None, inputs: Block | Sequence[Block] | set[Block] | None = None, outputs: Block | Sequence[Block] | set[Block] | None = None, js: str | None = None, @@ -790,16 +907,16 @@ class WebRTC(Component): self.concurrency_limit = ( 1 if concurrency_limit in ["default", None] else concurrency_limit ) - self.event_handler = fn + self.event_handler = fn # type: ignore self.time_limit = time_limit if ( self.mode == "send-receive" - and self.modality == "audio" - and not isinstance(self.event_handler, (AsyncStreamHandler, StreamHandler)) + and self.modality in ["audio", "audio-video"] + and not isinstance(self.event_handler, StreamHandlerBase) ): raise ValueError( - "In the send-receive mode for audio, the event handler must be an instance of StreamHandler." + "In the send-receive mode for audio, the event handler must be an instance of StreamHandlerBase." ) if self.mode == "send-receive" or self.mode == "send": @@ -815,13 +932,23 @@ class WebRTC(Component): raise ValueError( "In the webrtc stream event, the only output component must be the WebRTC component." ) + for input_component in inputs[1:]: # type: ignore + if hasattr(input_component, "change"): + input_component.change( # type: ignore + self.set_input, + inputs=inputs, + outputs=None, + concurrency_id=concurrency_id, + concurrency_limit=None, + time_limit=None, + js=js, + ) return self.tick( # type: ignore self.set_input, inputs=inputs, outputs=None, concurrency_id=concurrency_id, concurrency_limit=None, - stream_every=0.5, time_limit=None, js=js, ) @@ -855,9 +982,11 @@ class WebRTC(Component): await pc.close() def clean_up(self, webrtc_id: str): - connection = self.connections.pop(webrtc_id, None) - if isinstance(connection, AudioCallback): - connection.event_handler.shutdown() + self.handlers.pop(webrtc_id, None) + connection = self.connections.pop(webrtc_id, []) + for conn in connection: + if isinstance(conn, AudioCallback): + conn.event_handler.shutdown() self.additional_outputs.pop(webrtc_id, None) self.data_channels.pop(webrtc_id, None) return connection @@ -874,6 +1003,13 @@ class WebRTC(Component): pc = RTCPeerConnection() self.pcs.add(pc) + if isinstance(self.event_handler, StreamHandlerBase): + handler = self.event_handler.copy() + else: + handler = cast(Callable, self.event_handler) + + self.handlers[body["webrtc_id"]] = handler + set_outputs = self.set_additional_outputs(body["webrtc_id"]) @pc.on("iceconnectionstatechange") @@ -891,7 +1027,8 @@ class WebRTC(Component): await pc.close() connection = self.clean_up(body["webrtc_id"]) if connection: - connection.stop() + for conn in connection: + conn.stop() self.pcs.discard(pc) if pc.connectionState == "connected": if self.time_limit is not None: @@ -900,28 +1037,38 @@ class WebRTC(Component): @pc.on("track") def on_track(track): relay = MediaRelay() - if self.modality == "video": + handler = self.handlers[body["webrtc_id"]] + + if self.modality == "video" and track.kind == "video": cb = VideoCallback( relay.subscribe(track), - event_handler=cast(Callable, self.event_handler), + event_handler=cast(VideoEventHandler, handler), set_additional_outputs=set_outputs, mode=cast(Literal["send", "send-receive"], self.mode), ) - elif self.modality == "audio": - handler = cast(StreamHandler, self.event_handler).copy() - handler._loop = asyncio.get_running_loop() + elif self.modality == "audio-video" and track.kind == "video": + cb = VideoStreamHander( + relay.subscribe(track), + event_handler=handler, # type: ignore + set_additional_outputs=set_outputs, + ) + elif self.modality in ["audio", "audio-video"] and track.kind == "audio": + eh = cast(StreamHandlerImpl, handler) + eh._loop = asyncio.get_running_loop() cb = AudioCallback( relay.subscribe(track), - event_handler=handler, + event_handler=eh, set_additional_outputs=set_outputs, ) else: - raise ValueError("Modality must be either video or audio") - self.connections[body["webrtc_id"]] = cb + raise ValueError("Modality must be either video, audio, or audio-video") + if body["webrtc_id"] not in self.connections: + self.connections[body["webrtc_id"]] = [] + + self.connections[body["webrtc_id"]].append(cb) if body["webrtc_id"] in self.data_channels: - self.connections[body["webrtc_id"]].set_channel( - self.data_channels[body["webrtc_id"]] - ) + for conn in self.connections[body["webrtc_id"]]: + conn.set_channel(self.data_channels[body["webrtc_id"]]) if self.mode == "send-receive": logger.debug("Adding track to peer connection %s", cb) pc.addTrack(cb) @@ -944,7 +1091,7 @@ class WebRTC(Component): logger.debug("Adding track to peer connection %s", cb) pc.addTrack(cb) - self.connections[body["webrtc_id"]] = cb + self.connections[body["webrtc_id"]].append(cb) cb.on("ended", lambda: self.clean_up(body["webrtc_id"])) @pc.on("datachannel") @@ -957,7 +1104,8 @@ class WebRTC(Component): while not self.connections.get(webrtc_id): await asyncio.sleep(0.05) logger.debug("setting channel for webrtc id %s", webrtc_id) - self.connections[webrtc_id].set_channel(channel) + for conn in self.connections[webrtc_id]: + conn.set_channel(channel) asyncio.create_task(set_channel(body["webrtc_id"])) diff --git a/frontend/Index.svelte b/frontend/Index.svelte index b3901ec..2ebfd69 100644 --- a/frontend/Index.svelte +++ b/frontend/Index.svelte @@ -30,7 +30,7 @@ export let gradio; export let rtc_configuration: Object; export let time_limit: number | null = null; - export let modality: "video" | "audio" = "video"; + export let modality: "video" | "audio" | "audio-video" = "video"; export let mode: "send-receive" | "receive" | "send" = "send-receive"; export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters; export let track_constraints: MediaTrackConstraints = {}; @@ -52,18 +52,18 @@ gradio.dispatch("error", detail)} /> - {:else if (mode === "send-receive" || mode == "send") && modality === "video"} + {:else if (mode === "send-receive" || mode == "send") && (modality === "video" || modality == "audio-video")}