Audio in only (#15)

* Audio + Video / test Audio

* Add code

* Fix demo

* support additional inputs

* Add code

* Add code
This commit is contained in:
Freddy Boulton
2024-10-30 13:08:09 -04:00
committed by GitHub
parent 2068b91854
commit 3bf4a437fb
29 changed files with 1613 additions and 416 deletions

View File

@@ -2,11 +2,14 @@ from dataclasses import dataclass
from functools import lru_cache
from logging import getLogger
from threading import Event
from typing import Callable, Generator, Literal, cast
import inspect
from typing import Any, Callable, Generator, Literal, Union, cast
import asyncio
import numpy as np
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
from gradio_webrtc.utils import AdditionalOutputs
from gradio_webrtc.webrtc import StreamHandler
logger = getLogger(__name__)
@@ -40,12 +43,29 @@ class AppState:
buffer: np.ndarray | None = None
ReplyFnGenerator = Callable[
[tuple[int, np.ndarray]],
Generator[
tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]],
None,
None,
ReplyFnGenerator = Union[
# For two arguments
Callable[
[tuple[int, np.ndarray], list[dict[Any, Any]]],
Generator[
tuple[int, np.ndarray]
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
| AdditionalOutputs
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
None,
None,
],
],
Callable[
[tuple[int, np.ndarray]],
Generator[
tuple[int, np.ndarray]
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
| AdditionalOutputs
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
None,
None,
],
],
]
@@ -71,6 +91,12 @@ class ReplyOnPause(StreamHandler):
self.generator = None
self.model_options = model_options
self.algo_options = algo_options or AlgoOptions()
self.latest_args: list[Any] = []
self.args_set = Event()
@property
def _needs_additional_inputs(self) -> bool:
return len(inspect.signature(self.fn).parameters) > 1
def copy(self):
return ReplyOnPause(
@@ -130,17 +156,38 @@ class ReplyOnPause(StreamHandler):
self.event.set()
def reset(self):
self.args_set.clear()
self.generator = None
self.event.clear()
self.state = AppState()
def set_args(self, args: list[Any]):
super().set_args(args)
self.args_set.set()
async def fetch_args(
self,
):
if self.channel:
self.channel.send("tick")
logger.debug("Sent tick")
def emit(self):
if not self.event.is_set():
return None
else:
if not self.generator:
if self._needs_additional_inputs and not self.args_set.is_set():
asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop)
self.args_set.wait()
logger.debug("Creating generator")
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
self.generator = self.fn((self.state.sampling_rate, audio))
if self._needs_additional_inputs:
self.latest_args[0] = (self.state.sampling_rate, audio)
self.generator = self.fn(*self.latest_args)
else:
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
logger.debug("Latest args: %s", self.latest_args)
self.state.responding = True
try:
return next(self.generator)