mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add text mode (#321)
* Pretty good spot * Working draft * Fix other mode * Add js to git * Working * Add code * fix * Fix * Add code * Fix submit race condition * demo * fix * Fix * Fix
This commit is contained in:
@@ -11,7 +11,7 @@ from numpy.typing import NDArray
|
||||
|
||||
from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model
|
||||
from .tracks import EmitType, StreamHandler
|
||||
from .utils import AdditionalOutputs, create_message, split_output
|
||||
from .utils import AdditionalOutputs, WebRTCData, create_message, split_output
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@@ -67,6 +67,14 @@ ReplyFnGenerator = (
|
||||
[tuple[int, NDArray[np.int16]], Any],
|
||||
AsyncGenerator[EmitType, None],
|
||||
]
|
||||
| Callable[
|
||||
[WebRTCData],
|
||||
Generator[EmitType, None, None],
|
||||
]
|
||||
| Callable[
|
||||
[WebRTCData, Any],
|
||||
AsyncGenerator[EmitType, None],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -115,6 +123,7 @@ class ReplyOnPause(StreamHandler):
|
||||
output_frame_size: int | None = None, # Deprecated
|
||||
input_sample_rate: int = 48000,
|
||||
model: PauseDetectionModel | None = None,
|
||||
needs_args: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the ReplyOnPause handler.
|
||||
@@ -132,6 +141,7 @@ class ReplyOnPause(StreamHandler):
|
||||
output_frame_size: Deprecated.
|
||||
input_sample_rate: The expected sample rate of incoming audio.
|
||||
model: An optional pre-initialized VAD model instance.
|
||||
needs_args: Whether the reply function expects additional arguments.
|
||||
"""
|
||||
super().__init__(
|
||||
expected_layout,
|
||||
@@ -152,11 +162,12 @@ class ReplyOnPause(StreamHandler):
|
||||
self.model_options = model_options
|
||||
self.algo_options = algo_options or AlgoOptions()
|
||||
self.startup_fn = startup_fn
|
||||
self.needs_args = needs_args
|
||||
|
||||
@property
|
||||
def _needs_additional_inputs(self) -> bool:
|
||||
"""Checks if the reply function `fn` expects additional arguments."""
|
||||
return len(inspect.signature(self.fn).parameters) > 1
|
||||
return len(inspect.signature(self.fn).parameters) > 1 or self.needs_args
|
||||
|
||||
def start_up(self):
|
||||
"""
|
||||
@@ -187,6 +198,7 @@ class ReplyOnPause(StreamHandler):
|
||||
self.output_frame_size,
|
||||
self.input_sample_rate,
|
||||
self.model,
|
||||
self.needs_args,
|
||||
)
|
||||
|
||||
def determine_pause(
|
||||
@@ -361,19 +373,21 @@ class ReplyOnPause(StreamHandler):
|
||||
else:
|
||||
if not self.generator:
|
||||
self.send_message_sync(create_message("log", "pause_detected"))
|
||||
if self._needs_additional_inputs and not self.args_set.is_set():
|
||||
if not self.phone_mode:
|
||||
self.wait_for_args_sync()
|
||||
else:
|
||||
self.latest_args = [None]
|
||||
self.args_set.set()
|
||||
logger.debug("Creating generator")
|
||||
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
|
||||
if self._needs_additional_inputs:
|
||||
self.latest_args[0] = (self.state.sampling_rate, audio)
|
||||
self.generator = self.fn(*self.latest_args) # type: ignore
|
||||
if self._needs_additional_inputs and not self.phone_mode:
|
||||
self.wait_for_args_sync()
|
||||
else:
|
||||
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
|
||||
self.latest_args = [None]
|
||||
self.args_set.set()
|
||||
logger.debug("Creating generator")
|
||||
if self.state.stream is not None and self.state.stream.size > 0:
|
||||
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
|
||||
else:
|
||||
audio = np.array([[]], dtype=np.int16)
|
||||
if isinstance(self.latest_args[0], WebRTCData):
|
||||
self.latest_args[0].audio = (self.state.sampling_rate, audio)
|
||||
else:
|
||||
self.latest_args[0] = (self.state.sampling_rate, audio)
|
||||
self.generator = self.fn(*self.latest_args) # type: ignore
|
||||
logger.debug("Latest args: %s", self.latest_args)
|
||||
self.state = self.state.new()
|
||||
self.state.responding = True
|
||||
|
||||
Reference in New Issue
Block a user