Async stream handler support (#43)

* async stream handler

* Add code
This commit is contained in:
Freddy Boulton
2024-12-20 12:46:17 -05:00
committed by GitHub
parent 8a5c1f1bb3
commit c45febf3bf
6 changed files with 133 additions and 58 deletions

View File

@@ -10,7 +10,7 @@ import numpy as np
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
from gradio_webrtc.utils import AdditionalOutputs
from gradio_webrtc.webrtc import StreamHandler
from gradio_webrtc.webrtc import EmitType, StreamHandler
logger = getLogger(__name__)
@@ -47,25 +47,11 @@ ReplyFnGenerator = Union[
# For two arguments
Callable[
[tuple[int, np.ndarray], list[dict[Any, Any]]],
Generator[
tuple[int, np.ndarray]
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
| AdditionalOutputs
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
None,
None,
],
Generator[EmitType, None, None],
],
Callable[
[tuple[int, np.ndarray]],
Generator[
tuple[int, np.ndarray]
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
| AdditionalOutputs
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
None,
None,
],
Generator[EmitType, None, None],
],
]
@@ -99,11 +85,9 @@ class ReplyOnPause(StreamHandler):
self.is_async = inspect.isasyncgenfunction(fn)
self.event = Event()
self.state = AppState()
self.generator = None
self.generator: Generator[EmitType, None, None] | None = None
self.model_options = model_options
self.algo_options = algo_options or AlgoOptions()
self.latest_args: list[Any] = []
self.args_set = Event()
@property
def _needs_additional_inputs(self) -> bool:
@@ -168,23 +152,12 @@ class ReplyOnPause(StreamHandler):
self.event.set()
def reset(self):
self.args_set.clear()
super().reset()
self.generator = None
self.event.clear()
self.state = AppState()
def set_args(self, args: list[Any]):
super().set_args(args)
self.args_set.set()
async def fetch_args(
self,
):
if self.channel:
self.channel.send("tick")
logger.debug("Sent tick")
async def async_iterate(self, generator) -> Any:
async def async_iterate(self, generator) -> EmitType:
return await anext(generator)
def emit(self):
@@ -193,8 +166,9 @@ class ReplyOnPause(StreamHandler):
else:
if not self.generator:
if self._needs_additional_inputs and not self.args_set.is_set():
asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop)
self.args_set.wait()
asyncio.run_coroutine_threadsafe(
self.wait_for_args(), self.loop
).result()
logger.debug("Creating generator")
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
if self._needs_additional_inputs: