Add ability to trigger ReplyOnPause without waiting for pause (#250)

* Add code

* Send text or audio demo
This commit is contained in:
Freddy Boulton
2025-04-03 20:19:50 -04:00
committed by GitHub
parent aed34825e3
commit 8dd17d3216
5 changed files with 726 additions and 2 deletions

View File

@@ -206,6 +206,11 @@ class ReplyOnPause(StreamHandler):
self.event.clear()
self.state = AppState()
def trigger_response(self):
self.event.set()
if self.state.stream is None:
self.state.stream = np.array([], dtype=np.int16)
async def async_iterate(self, generator) -> EmitType:
return await anext(generator)

View File

@@ -91,6 +91,7 @@ class Stream(WebRTCConnectionMixin):
self.additional_input_components = additional_inputs
self.additional_outputs_handler = additional_outputs_handler
self.track_constraints = track_constraints
self.webrtc_component: WebRTC
self.rtc_configuration = rtc_configuration
self._ui = self._generate_default_ui(ui_args)
self._ui.launch = self._wrap_gradio_launch(self._ui.launch)
@@ -234,6 +235,7 @@ class Stream(WebRTCConnectionMixin):
mode="receive",
modality="video",
)
self.webrtc_component = output_video
for component in additional_output_components:
if component not in same_components:
component.render()
@@ -284,6 +286,7 @@ class Stream(WebRTCConnectionMixin):
mode="send",
modality="video",
)
self.webrtc_component = output_video
for component in additional_output_components:
if component not in same_components:
component.render()
@@ -339,7 +342,7 @@ class Stream(WebRTCConnectionMixin):
for component in additional_output_components:
if component not in same_components:
component.render()
self.webrtc_component = image
image.stream(
fn=self.event_handler,
inputs=[image] + additional_input_components,
@@ -391,6 +394,7 @@ class Stream(WebRTCConnectionMixin):
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = output_video
for component in additional_output_components:
if component not in same_components:
component.render()
@@ -442,6 +446,7 @@ class Stream(WebRTCConnectionMixin):
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = image
for component in additional_input_components:
if component not in same_components:
component.render()
@@ -496,6 +501,7 @@ class Stream(WebRTCConnectionMixin):
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = image
for component in additional_input_components:
if component not in same_components:
component.render()
@@ -553,6 +559,7 @@ class Stream(WebRTCConnectionMixin):
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = image
for component in additional_input_components:
if component not in same_components:
component.render()

View File

@@ -73,7 +73,7 @@ class WebRTCConnectionMixin:
self.connections = defaultdict(list)
self.data_channels = {}
self.additional_outputs = defaultdict(OutputQueue)
self.handlers = {}
self.handlers: dict[str, HandlerType] = {}
self.connection_timeouts = defaultdict(asyncio.Event)
# These attributes should be set by subclasses:
self.concurrency_limit: int | None