Async stream handler support (#43)

* async stream handler

* Add code
This commit is contained in:
Freddy Boulton
2024-12-20 12:46:17 -05:00
committed by GitHub
parent 8a5c1f1bb3
commit c45febf3bf
6 changed files with 133 additions and 58 deletions

View File

@@ -6,12 +6,22 @@ from .credentials import (
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions
from .reply_on_stopwords import ReplyOnStopWords from .reply_on_stopwords import ReplyOnStopWords
from .speech_to_text import stt, stt_for_chunks from .speech_to_text import stt, stt_for_chunks
from .utils import AdditionalOutputs, audio_to_bytes, audio_to_file, audio_to_float32 from .utils import (
from .webrtc import StreamHandler, WebRTC AdditionalOutputs,
aggregate_bytes_to_16bit,
async_aggregate_bytes_to_16bit,
audio_to_bytes,
audio_to_file,
audio_to_float32,
)
from .webrtc import AsyncStreamHandler, StreamHandler, WebRTC
__all__ = [ __all__ = [
"AsyncStreamHandler",
"AlgoOptions", "AlgoOptions",
"AdditionalOutputs", "AdditionalOutputs",
"aggregate_bytes_to_16bit",
"async_aggregate_bytes_to_16bit",
"audio_to_bytes", "audio_to_bytes",
"audio_to_file", "audio_to_file",
"audio_to_float32", "audio_to_float32",

View File

@@ -10,7 +10,7 @@ import numpy as np
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
from gradio_webrtc.utils import AdditionalOutputs from gradio_webrtc.utils import AdditionalOutputs
from gradio_webrtc.webrtc import StreamHandler from gradio_webrtc.webrtc import EmitType, StreamHandler
logger = getLogger(__name__) logger = getLogger(__name__)
@@ -47,25 +47,11 @@ ReplyFnGenerator = Union[
# For two arguments # For two arguments
Callable[ Callable[
[tuple[int, np.ndarray], list[dict[Any, Any]]], [tuple[int, np.ndarray], list[dict[Any, Any]]],
Generator[ Generator[EmitType, None, None],
tuple[int, np.ndarray]
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
| AdditionalOutputs
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
None,
None,
],
], ],
Callable[ Callable[
[tuple[int, np.ndarray]], [tuple[int, np.ndarray]],
Generator[ Generator[EmitType, None, None],
tuple[int, np.ndarray]
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
| AdditionalOutputs
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
None,
None,
],
], ],
] ]
@@ -99,11 +85,9 @@ class ReplyOnPause(StreamHandler):
self.is_async = inspect.isasyncgenfunction(fn) self.is_async = inspect.isasyncgenfunction(fn)
self.event = Event() self.event = Event()
self.state = AppState() self.state = AppState()
self.generator = None self.generator: Generator[EmitType, None, None] | None = None
self.model_options = model_options self.model_options = model_options
self.algo_options = algo_options or AlgoOptions() self.algo_options = algo_options or AlgoOptions()
self.latest_args: list[Any] = []
self.args_set = Event()
@property @property
def _needs_additional_inputs(self) -> bool: def _needs_additional_inputs(self) -> bool:
@@ -168,23 +152,12 @@ class ReplyOnPause(StreamHandler):
self.event.set() self.event.set()
def reset(self): def reset(self):
self.args_set.clear() super().reset()
self.generator = None self.generator = None
self.event.clear() self.event.clear()
self.state = AppState() self.state = AppState()
def set_args(self, args: list[Any]): async def async_iterate(self, generator) -> EmitType:
super().set_args(args)
self.args_set.set()
async def fetch_args(
self,
):
if self.channel:
self.channel.send("tick")
logger.debug("Sent tick")
async def async_iterate(self, generator) -> Any:
return await anext(generator) return await anext(generator)
def emit(self): def emit(self):
@@ -193,8 +166,9 @@ class ReplyOnPause(StreamHandler):
else: else:
if not self.generator: if not self.generator:
if self._needs_additional_inputs and not self.args_set.is_set(): if self._needs_additional_inputs and not self.args_set.is_set():
asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop) asyncio.run_coroutine_threadsafe(
self.args_set.wait() self.wait_for_args(), self.loop
).result()
logger.debug("Creating generator") logger.debug("Creating generator")
audio = cast(np.ndarray, self.state.stream).reshape(1, -1) audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
if self._needs_additional_inputs: if self._needs_additional_inputs:

View File

@@ -71,7 +71,7 @@ class ReplyOnStopWords(ReplyOnPause):
def send_stopword(self): def send_stopword(self):
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop) asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
def determine_pause( def determine_pause( # type: ignore
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
) -> bool: ) -> bool:
"""Take in the stream, determine if a pause happened""" """Take in the stream, determine if a pause happened"""
@@ -128,7 +128,7 @@ class ReplyOnStopWords(ReplyOnPause):
return False return False
def reset(self): def reset(self):
self.args_set.clear() super().reset()
self.generator = None self.generator = None
self.event.clear() self.event.clear()
self.state = ReplyOnStopWordsState() self.state = ReplyOnStopWordsState()

View File

@@ -218,3 +218,43 @@ def audio_to_float32(audio: tuple[int, np.ndarray]) -> np.ndarray:
>>> audio_float32 = audio_to_float32(audio_tuple) >>> audio_float32 = audio_to_float32(audio_tuple)
""" """
return audio[1].astype(np.float32) / 32768.0 return audio[1].astype(np.float32) / 32768.0
def aggregate_bytes_to_16bit(chunks_iterator):
leftover = b"" # Store incomplete bytes between chunks
for chunk in chunks_iterator:
# Combine with any leftover bytes from previous chunk
current_bytes = leftover + chunk
# Calculate complete samples
n_complete_samples = len(current_bytes) // 2 # int16 = 2 bytes
bytes_to_process = n_complete_samples * 2
# Split into complete samples and leftover
to_process = current_bytes[:bytes_to_process]
leftover = current_bytes[bytes_to_process:]
if to_process: # Only yield if we have complete samples
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
yield audio_array
async def async_aggregate_bytes_to_16bit(chunks_iterator):
leftover = b"" # Store incomplete bytes between chunks
async for chunk in chunks_iterator:
# Combine with any leftover bytes from previous chunk
current_bytes = leftover + chunk
# Calculate complete samples
n_complete_samples = len(current_bytes) // 2 # int16 = 2 bytes
bytes_to_process = n_complete_samples * 2
# Split into complete samples and leftover
to_process = current_bytes[:bytes_to_process]
leftover = current_bytes[bytes_to_process:]
if to_process: # Only yield if we have complete samples
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
yield audio_array

View File

@@ -19,7 +19,9 @@ from typing import (
Literal, Literal,
ParamSpec, ParamSpec,
Sequence, Sequence,
TypeAlias,
TypeVar, TypeVar,
Union,
cast, cast,
) )
@@ -161,7 +163,7 @@ class VideoCallback(VideoStreamTrack):
logger.debug("traceback %s", exec) logger.debug("traceback %s", exec)
class StreamHandler(ABC): class StreamHandlerBase(ABC):
def __init__( def __init__(
self, self,
expected_layout: Literal["mono", "stereo"] = "mono", expected_layout: Literal["mono", "stereo"] = "mono",
@@ -173,10 +175,11 @@ class StreamHandler(ABC):
self.output_sample_rate = output_sample_rate self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size self.output_frame_size = output_frame_size
self.input_sample_rate = input_sample_rate self.input_sample_rate = input_sample_rate
self.latest_args: str | list[Any] = "not_set" self.latest_args: list[Any] = []
self._resampler = None self._resampler = None
self._channel: DataChannel | None = None self._channel: DataChannel | None = None
self._loop: asyncio.AbstractEventLoop self._loop: asyncio.AbstractEventLoop
self.args_set = asyncio.Event()
@property @property
def loop(self) -> asyncio.AbstractEventLoop: def loop(self) -> asyncio.AbstractEventLoop:
@@ -189,15 +192,30 @@ class StreamHandler(ABC):
def set_channel(self, channel: DataChannel): def set_channel(self, channel: DataChannel):
self._channel = channel self._channel = channel
async def fetch_args(
self,
):
if self.channel:
self.channel.send("tick")
logger.debug("Sent tick")
async def wait_for_args(self):
await self.fetch_args()
await self.args_set.wait()
def set_args(self, args: list[Any]): def set_args(self, args: list[Any]):
logger.debug("setting args in audio callback %s", args) logger.debug("setting args in audio callback %s", args)
self.latest_args = ["__webrtc_value__"] + list(args) self.latest_args = ["__webrtc_value__"] + list(args)
self.args_set.set()
def reset(self):
self.args_set.clear()
def shutdown(self): def shutdown(self):
pass pass
@abstractmethod @abstractmethod
def copy(self) -> "StreamHandler": def copy(self) -> "StreamHandlerBase":
pass pass
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]: def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
@@ -210,6 +228,17 @@ class StreamHandler(ABC):
) )
yield from self._resampler.resample(frame) yield from self._resampler.resample(frame)
EmitType: TypeAlias = Union[
tuple[int, np.ndarray],
tuple[int, np.ndarray, Literal["mono", "stereo"]],
AdditionalOutputs,
tuple[tuple[int, np.ndarray], AdditionalOutputs],
None,
]
class StreamHandler(StreamHandlerBase):
@abstractmethod @abstractmethod
def receive(self, frame: tuple[int, np.ndarray]) -> None: def receive(self, frame: tuple[int, np.ndarray]) -> None:
pass pass
@@ -217,22 +246,32 @@ class StreamHandler(ABC):
@abstractmethod @abstractmethod
def emit( def emit(
self, self,
) -> ( ) -> EmitType:
tuple[int, np.ndarray]
| AdditionalOutputs
| None
| tuple[tuple[int, np.ndarray], AdditionalOutputs]
):
pass pass
class AsyncStreamHandler(StreamHandlerBase):
@abstractmethod
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
pass
@abstractmethod
async def emit(
self,
) -> EmitType:
pass
StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler]
class AudioCallback(AudioStreamTrack): class AudioCallback(AudioStreamTrack):
kind = "audio" kind = "audio"
def __init__( def __init__(
self, self,
track: MediaStreamTrack, track: MediaStreamTrack,
event_handler: StreamHandler, event_handler: StreamHandlerImpl,
channel: DataChannel | None = None, channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None, set_additional_outputs: Callable | None = None,
) -> None: ) -> None:
@@ -262,9 +301,14 @@ class AudioCallback(AudioStreamTrack):
frame = cast(AudioFrame, await self.track.recv()) frame = cast(AudioFrame, await self.track.recv())
for frame in self.event_handler.resample(frame): for frame in self.event_handler.resample(frame):
numpy_array = frame.to_ndarray() numpy_array = frame.to_ndarray()
await anyio.to_thread.run_sync( if isinstance(self.event_handler, AsyncStreamHandler):
self.event_handler.receive, (frame.sample_rate, numpy_array) await self.event_handler.receive(
) (frame.sample_rate, numpy_array)
)
else:
await anyio.to_thread.run_sync(
self.event_handler.receive, (frame.sample_rate, numpy_array)
)
except MediaStreamError: except MediaStreamError:
logger.debug("MediaStreamError in process_input_frames") logger.debug("MediaStreamError in process_input_frames")
break break
@@ -272,9 +316,12 @@ class AudioCallback(AudioStreamTrack):
def start(self): def start(self):
if not self.has_started: if not self.has_started:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
callable = functools.partial( if isinstance(self.event_handler, AsyncStreamHandler):
loop.run_in_executor, None, self.event_handler.emit callable = self.event_handler.emit
) else:
callable = functools.partial(
loop.run_in_executor, None, self.event_handler.emit
)
asyncio.create_task(self.process_input_frames()) asyncio.create_task(self.process_input_frames())
asyncio.create_task( asyncio.create_task(
player_worker_decode( player_worker_decode(
@@ -692,7 +739,7 @@ class WebRTC(Component):
def stream( def stream(
self, self,
fn: Callable[..., Any] | StreamHandler | None = None, fn: Callable[..., Any] | StreamHandler | AsyncStreamHandler | None = None,
inputs: Block | Sequence[Block] | set[Block] | None = None, inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | set[Block] | None = None, outputs: Block | Sequence[Block] | set[Block] | None = None,
js: str | None = None, js: str | None = None,
@@ -721,7 +768,7 @@ class WebRTC(Component):
if ( if (
self.mode == "send-receive" self.mode == "send-receive"
and self.modality == "audio" and self.modality == "audio"
and not isinstance(self.event_handler, StreamHandler) and not isinstance(self.event_handler, (AsyncStreamHandler, StreamHandler))
): ):
raise ValueError( raise ValueError(
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler." "In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
@@ -840,6 +887,8 @@ class WebRTC(Component):
event_handler=handler, event_handler=handler,
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
) )
else:
raise ValueError("Modality must be either video or audio")
self.connections[body["webrtc_id"]] = cb self.connections[body["webrtc_id"]] = cb
if body["webrtc_id"] in self.data_channels: if body["webrtc_id"] in self.data_channels:
self.connections[body["webrtc_id"]].set_channel( self.connections[body["webrtc_id"]].set_channel(
@@ -862,6 +911,8 @@ class WebRTC(Component):
cast(Callable, self.event_handler), cast(Callable, self.event_handler),
set_additional_outputs=set_outputs, set_additional_outputs=set_outputs,
) )
else:
raise ValueError("Modality must be either video or audio")
logger.debug("Adding track to peer connection %s", cb) logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb) pc.addTrack(cb)

View File

@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "gradio_webrtc" name = "gradio_webrtc"
version = "0.0.23" version = "0.0.24"
description = "Stream images in realtime with webrtc" description = "Stream images in realtime with webrtc"
readme = "README.md" readme = "README.md"
license = "apache-2.0" license = "apache-2.0"