mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Audio in only (#15)
* Audio + Video / test Audio * Add code * Fix demo * support additional inputs * Add code * Add code
This commit is contained in:
@@ -2,11 +2,14 @@ from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from logging import getLogger
|
||||
from threading import Event
|
||||
from typing import Callable, Generator, Literal, cast
|
||||
import inspect
|
||||
from typing import Any, Callable, Generator, Literal, Union, cast
|
||||
import asyncio
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
|
||||
from gradio_webrtc.utils import AdditionalOutputs
|
||||
from gradio_webrtc.webrtc import StreamHandler
|
||||
|
||||
logger = getLogger(__name__)
|
||||
@@ -40,12 +43,29 @@ class AppState:
|
||||
buffer: np.ndarray | None = None
|
||||
|
||||
|
||||
ReplyFnGenerator = Callable[
|
||||
[tuple[int, np.ndarray]],
|
||||
Generator[
|
||||
tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]],
|
||||
None,
|
||||
None,
|
||||
ReplyFnGenerator = Union[
|
||||
# For two arguments
|
||||
Callable[
|
||||
[tuple[int, np.ndarray], list[dict[Any, Any]]],
|
||||
Generator[
|
||||
tuple[int, np.ndarray]
|
||||
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
|
||||
| AdditionalOutputs
|
||||
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
||||
None,
|
||||
None,
|
||||
],
|
||||
],
|
||||
Callable[
|
||||
[tuple[int, np.ndarray]],
|
||||
Generator[
|
||||
tuple[int, np.ndarray]
|
||||
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
|
||||
| AdditionalOutputs
|
||||
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
||||
None,
|
||||
None,
|
||||
],
|
||||
],
|
||||
]
|
||||
|
||||
@@ -71,6 +91,12 @@ class ReplyOnPause(StreamHandler):
|
||||
self.generator = None
|
||||
self.model_options = model_options
|
||||
self.algo_options = algo_options or AlgoOptions()
|
||||
self.latest_args: list[Any] = []
|
||||
self.args_set = Event()
|
||||
|
||||
@property
|
||||
def _needs_additional_inputs(self) -> bool:
|
||||
return len(inspect.signature(self.fn).parameters) > 1
|
||||
|
||||
def copy(self):
|
||||
return ReplyOnPause(
|
||||
@@ -130,17 +156,38 @@ class ReplyOnPause(StreamHandler):
|
||||
self.event.set()
|
||||
|
||||
def reset(self):
|
||||
self.args_set.clear()
|
||||
self.generator = None
|
||||
self.event.clear()
|
||||
self.state = AppState()
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
super().set_args(args)
|
||||
self.args_set.set()
|
||||
|
||||
async def fetch_args(
|
||||
self,
|
||||
):
|
||||
if self.channel:
|
||||
self.channel.send("tick")
|
||||
logger.debug("Sent tick")
|
||||
|
||||
def emit(self):
|
||||
if not self.event.is_set():
|
||||
return None
|
||||
else:
|
||||
if not self.generator:
|
||||
if self._needs_additional_inputs and not self.args_set.is_set():
|
||||
asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop)
|
||||
self.args_set.wait()
|
||||
logger.debug("Creating generator")
|
||||
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
|
||||
self.generator = self.fn((self.state.sampling_rate, audio))
|
||||
if self._needs_additional_inputs:
|
||||
self.latest_args[0] = (self.state.sampling_rate, audio)
|
||||
self.generator = self.fn(*self.latest_args)
|
||||
else:
|
||||
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
|
||||
logger.debug("Latest args: %s", self.latest_args)
|
||||
self.state.responding = True
|
||||
try:
|
||||
return next(self.generator)
|
||||
|
||||
Reference in New Issue
Block a user