From c45febf3bf75adb5e33ef4ae1496854637fec6da Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Fri, 20 Dec 2024 12:46:17 -0500 Subject: [PATCH] Async stream handler support (#43) * async stream handler * Add code --- backend/gradio_webrtc/__init__.py | 14 +++- backend/gradio_webrtc/reply_on_pause.py | 44 +++-------- backend/gradio_webrtc/reply_on_stopwords.py | 4 +- backend/gradio_webrtc/utils.py | 40 ++++++++++ backend/gradio_webrtc/webrtc.py | 87 ++++++++++++++++----- pyproject.toml | 2 +- 6 files changed, 133 insertions(+), 58 deletions(-) diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index 98783db..2963e16 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -6,12 +6,22 @@ from .credentials import ( from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions from .reply_on_stopwords import ReplyOnStopWords from .speech_to_text import stt, stt_for_chunks -from .utils import AdditionalOutputs, audio_to_bytes, audio_to_file, audio_to_float32 -from .webrtc import StreamHandler, WebRTC +from .utils import ( + AdditionalOutputs, + aggregate_bytes_to_16bit, + async_aggregate_bytes_to_16bit, + audio_to_bytes, + audio_to_file, + audio_to_float32, +) +from .webrtc import AsyncStreamHandler, StreamHandler, WebRTC __all__ = [ + "AsyncStreamHandler", "AlgoOptions", "AdditionalOutputs", + "aggregate_bytes_to_16bit", + "async_aggregate_bytes_to_16bit", "audio_to_bytes", "audio_to_file", "audio_to_float32", diff --git a/backend/gradio_webrtc/reply_on_pause.py b/backend/gradio_webrtc/reply_on_pause.py index 5cb425f..13c7c1e 100644 --- a/backend/gradio_webrtc/reply_on_pause.py +++ b/backend/gradio_webrtc/reply_on_pause.py @@ -10,7 +10,7 @@ import numpy as np from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions from gradio_webrtc.utils import AdditionalOutputs -from gradio_webrtc.webrtc import StreamHandler +from gradio_webrtc.webrtc import EmitType, StreamHandler logger = getLogger(__name__) @@ -47,25 +47,11 @@ ReplyFnGenerator = Union[ # For two arguments Callable[ [tuple[int, np.ndarray], list[dict[Any, Any]]], - Generator[ - tuple[int, np.ndarray] - | tuple[int, np.ndarray, Literal["mono", "stereo"]] - | AdditionalOutputs - | tuple[tuple[int, np.ndarray], AdditionalOutputs], - None, - None, - ], + Generator[EmitType, None, None], ], Callable[ [tuple[int, np.ndarray]], - Generator[ - tuple[int, np.ndarray] - | tuple[int, np.ndarray, Literal["mono", "stereo"]] - | AdditionalOutputs - | tuple[tuple[int, np.ndarray], AdditionalOutputs], - None, - None, - ], + Generator[EmitType, None, None], ], ] @@ -99,11 +85,9 @@ class ReplyOnPause(StreamHandler): self.is_async = inspect.isasyncgenfunction(fn) self.event = Event() self.state = AppState() - self.generator = None + self.generator: Generator[EmitType, None, None] | None = None self.model_options = model_options self.algo_options = algo_options or AlgoOptions() - self.latest_args: list[Any] = [] - self.args_set = Event() @property def _needs_additional_inputs(self) -> bool: @@ -168,23 +152,12 @@ class ReplyOnPause(StreamHandler): self.event.set() def reset(self): - self.args_set.clear() + super().reset() self.generator = None self.event.clear() self.state = AppState() - def set_args(self, args: list[Any]): - super().set_args(args) - self.args_set.set() - - async def fetch_args( - self, - ): - if self.channel: - self.channel.send("tick") - logger.debug("Sent tick") - - async def async_iterate(self, generator) -> Any: + async def async_iterate(self, generator) -> EmitType: return await anext(generator) def emit(self): @@ -193,8 +166,9 @@ class ReplyOnPause(StreamHandler): else: if not self.generator: if self._needs_additional_inputs and not self.args_set.is_set(): - asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop) - self.args_set.wait() + asyncio.run_coroutine_threadsafe( + self.wait_for_args(), self.loop + ).result() logger.debug("Creating generator") audio = cast(np.ndarray, self.state.stream).reshape(1, -1) if self._needs_additional_inputs: diff --git a/backend/gradio_webrtc/reply_on_stopwords.py b/backend/gradio_webrtc/reply_on_stopwords.py index a391e17..9ebc782 100644 --- a/backend/gradio_webrtc/reply_on_stopwords.py +++ b/backend/gradio_webrtc/reply_on_stopwords.py @@ -71,7 +71,7 @@ class ReplyOnStopWords(ReplyOnPause): def send_stopword(self): asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop) - def determine_pause( + def determine_pause( # type: ignore self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState ) -> bool: """Take in the stream, determine if a pause happened""" @@ -128,7 +128,7 @@ class ReplyOnStopWords(ReplyOnPause): return False def reset(self): - self.args_set.clear() + super().reset() self.generator = None self.event.clear() self.state = ReplyOnStopWordsState() diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index bb93032..b150490 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -218,3 +218,43 @@ def audio_to_float32(audio: tuple[int, np.ndarray]) -> np.ndarray: >>> audio_float32 = audio_to_float32(audio_tuple) """ return audio[1].astype(np.float32) / 32768.0 + + +def aggregate_bytes_to_16bit(chunks_iterator): + leftover = b"" # Store incomplete bytes between chunks + + for chunk in chunks_iterator: + # Combine with any leftover bytes from previous chunk + current_bytes = leftover + chunk + + # Calculate complete samples + n_complete_samples = len(current_bytes) // 2 # int16 = 2 bytes + bytes_to_process = n_complete_samples * 2 + + # Split into complete samples and leftover + to_process = current_bytes[:bytes_to_process] + leftover = current_bytes[bytes_to_process:] + + if to_process: # Only yield if we have complete samples + audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1) + yield audio_array + + +async def async_aggregate_bytes_to_16bit(chunks_iterator): + leftover = b"" # Store incomplete bytes between chunks + + async for chunk in chunks_iterator: + # Combine with any leftover bytes from previous chunk + current_bytes = leftover + chunk + + # Calculate complete samples + n_complete_samples = len(current_bytes) // 2 # int16 = 2 bytes + bytes_to_process = n_complete_samples * 2 + + # Split into complete samples and leftover + to_process = current_bytes[:bytes_to_process] + leftover = current_bytes[bytes_to_process:] + + if to_process: # Only yield if we have complete samples + audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1) + yield audio_array diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index bbd6674..778db02 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -19,7 +19,9 @@ from typing import ( Literal, ParamSpec, Sequence, + TypeAlias, TypeVar, + Union, cast, ) @@ -161,7 +163,7 @@ class VideoCallback(VideoStreamTrack): logger.debug("traceback %s", exec) -class StreamHandler(ABC): +class StreamHandlerBase(ABC): def __init__( self, expected_layout: Literal["mono", "stereo"] = "mono", @@ -173,10 +175,11 @@ class StreamHandler(ABC): self.output_sample_rate = output_sample_rate self.output_frame_size = output_frame_size self.input_sample_rate = input_sample_rate - self.latest_args: str | list[Any] = "not_set" + self.latest_args: list[Any] = [] self._resampler = None self._channel: DataChannel | None = None self._loop: asyncio.AbstractEventLoop + self.args_set = asyncio.Event() @property def loop(self) -> asyncio.AbstractEventLoop: @@ -189,15 +192,30 @@ class StreamHandler(ABC): def set_channel(self, channel: DataChannel): self._channel = channel + async def fetch_args( + self, + ): + if self.channel: + self.channel.send("tick") + logger.debug("Sent tick") + + async def wait_for_args(self): + await self.fetch_args() + await self.args_set.wait() + def set_args(self, args: list[Any]): logger.debug("setting args in audio callback %s", args) self.latest_args = ["__webrtc_value__"] + list(args) + self.args_set.set() + + def reset(self): + self.args_set.clear() def shutdown(self): pass @abstractmethod - def copy(self) -> "StreamHandler": + def copy(self) -> "StreamHandlerBase": pass def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]: @@ -210,6 +228,17 @@ class StreamHandler(ABC): ) yield from self._resampler.resample(frame) + +EmitType: TypeAlias = Union[ + tuple[int, np.ndarray], + tuple[int, np.ndarray, Literal["mono", "stereo"]], + AdditionalOutputs, + tuple[tuple[int, np.ndarray], AdditionalOutputs], + None, +] + + +class StreamHandler(StreamHandlerBase): @abstractmethod def receive(self, frame: tuple[int, np.ndarray]) -> None: pass @@ -217,22 +246,32 @@ class StreamHandler(ABC): @abstractmethod def emit( self, - ) -> ( - tuple[int, np.ndarray] - | AdditionalOutputs - | None - | tuple[tuple[int, np.ndarray], AdditionalOutputs] - ): + ) -> EmitType: pass +class AsyncStreamHandler(StreamHandlerBase): + @abstractmethod + async def receive(self, frame: tuple[int, np.ndarray]) -> None: + pass + + @abstractmethod + async def emit( + self, + ) -> EmitType: + pass + + +StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler] + + class AudioCallback(AudioStreamTrack): kind = "audio" def __init__( self, track: MediaStreamTrack, - event_handler: StreamHandler, + event_handler: StreamHandlerImpl, channel: DataChannel | None = None, set_additional_outputs: Callable | None = None, ) -> None: @@ -262,9 +301,14 @@ class AudioCallback(AudioStreamTrack): frame = cast(AudioFrame, await self.track.recv()) for frame in self.event_handler.resample(frame): numpy_array = frame.to_ndarray() - await anyio.to_thread.run_sync( - self.event_handler.receive, (frame.sample_rate, numpy_array) - ) + if isinstance(self.event_handler, AsyncStreamHandler): + await self.event_handler.receive( + (frame.sample_rate, numpy_array) + ) + else: + await anyio.to_thread.run_sync( + self.event_handler.receive, (frame.sample_rate, numpy_array) + ) except MediaStreamError: logger.debug("MediaStreamError in process_input_frames") break @@ -272,9 +316,12 @@ class AudioCallback(AudioStreamTrack): def start(self): if not self.has_started: loop = asyncio.get_running_loop() - callable = functools.partial( - loop.run_in_executor, None, self.event_handler.emit - ) + if isinstance(self.event_handler, AsyncStreamHandler): + callable = self.event_handler.emit + else: + callable = functools.partial( + loop.run_in_executor, None, self.event_handler.emit + ) asyncio.create_task(self.process_input_frames()) asyncio.create_task( player_worker_decode( @@ -692,7 +739,7 @@ class WebRTC(Component): def stream( self, - fn: Callable[..., Any] | StreamHandler | None = None, + fn: Callable[..., Any] | StreamHandler | AsyncStreamHandler | None = None, inputs: Block | Sequence[Block] | set[Block] | None = None, outputs: Block | Sequence[Block] | set[Block] | None = None, js: str | None = None, @@ -721,7 +768,7 @@ class WebRTC(Component): if ( self.mode == "send-receive" and self.modality == "audio" - and not isinstance(self.event_handler, StreamHandler) + and not isinstance(self.event_handler, (AsyncStreamHandler, StreamHandler)) ): raise ValueError( "In the send-receive mode for audio, the event handler must be an instance of StreamHandler." @@ -840,6 +887,8 @@ class WebRTC(Component): event_handler=handler, set_additional_outputs=set_outputs, ) + else: + raise ValueError("Modality must be either video or audio") self.connections[body["webrtc_id"]] = cb if body["webrtc_id"] in self.data_channels: self.connections[body["webrtc_id"]].set_channel( @@ -862,6 +911,8 @@ class WebRTC(Component): cast(Callable, self.event_handler), set_additional_outputs=set_outputs, ) + else: + raise ValueError("Modality must be either video or audio") logger.debug("Adding track to peer connection %s", cb) pc.addTrack(cb) diff --git a/pyproject.toml b/pyproject.toml index a8e039e..7f5d68a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" [project] name = "gradio_webrtc" -version = "0.0.23" +version = "0.0.24" description = "Stream images in realtime with webrtc" readme = "README.md" license = "apache-2.0"