Add text mode (#321)

* Pretty good spot

* Working draft

* Fix other mode

* Add js to git

* Working

* Add code

* fix

* Fix

* Add code

* Fix submit race condition

* demo

* fix

* Fix

* Fix
This commit is contained in:
Freddy Boulton
2025-06-03 19:24:21 -04:00
committed by GitHub
parent 1179f8ef21
commit 1877720231
69 changed files with 110161 additions and 22889 deletions

View File

@@ -141,6 +141,15 @@ class Stream(WebRTCConnectionMixin):
self.modality = modality
self.rtp_params = rtp_params
self.event_handler = handler
if (
ui_args
and ui_args.get("variant") == "textbox"
and hasattr(handler, "needs_args")
):
self.event_handler.needs_args = True # type: ignore
else:
self.event_handler.needs_args = False # type: ignore
self.concurrency_limit = cast(
(int),
1 if concurrency_limit in ["default", None] else concurrency_limit,
@@ -574,28 +583,58 @@ class Stream(WebRTCConnectionMixin):
</div>
"""
)
with gr.Row():
with gr.Column():
with gr.Group():
image = WebRTC(
label="Stream",
rtc_configuration=self.rtc_configuration,
track_constraints=self.track_constraints,
mode="send",
modality="audio",
icon=ui_args.get("icon"),
icon_button_color=ui_args.get("icon_button_color"),
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = image
for component in additional_input_components:
if component not in same_components:
if ui_args.get("variant", "textbox"):
with gr.Row():
if additional_input_components:
with gr.Column():
for component in additional_input_components:
component.render()
if additional_output_components:
diff_output_components = [
component
for component in additional_output_components
if component not in same_components
]
if diff_output_components:
with gr.Column():
for component in diff_output_components:
component.render()
with gr.Row():
image = WebRTC(
label="Stream",
rtc_configuration=self.rtc_configuration,
track_constraints=self.track_constraints,
mode="send",
modality="audio",
icon=ui_args.get("icon"),
icon_button_color=ui_args.get("icon_button_color"),
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
variant=ui_args.get("variant", "wave"),
)
else:
with gr.Row():
with gr.Column():
for component in additional_output_components:
component.render()
with gr.Group():
image = WebRTC(
label="Stream",
rtc_configuration=self.rtc_configuration,
track_constraints=self.track_constraints,
mode="send",
modality="audio",
icon=ui_args.get("icon"),
icon_button_color=ui_args.get("icon_button_color"),
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
variant=ui_args.get("variant", "wave"),
)
for component in additional_input_components:
if component not in same_components:
component.render()
if additional_output_components:
with gr.Column():
for component in additional_output_components:
component.render()
self.webrtc_component = image
image.stream(
fn=self.event_handler,
inputs=[image] + additional_input_components,
@@ -630,45 +669,89 @@ class Stream(WebRTCConnectionMixin):
</div>
"""
)
with gr.Row():
with gr.Column():
with gr.Group():
image = WebRTC(
label="Stream",
rtc_configuration=self.rtc_configuration,
track_constraints=self.track_constraints,
mode="send-receive",
modality="audio",
icon=ui_args.get("icon"),
icon_button_color=ui_args.get("icon_button_color"),
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = image
for component in additional_input_components:
if component not in same_components:
if ui_args.get("variant", "") == "textbox":
with gr.Row():
if additional_input_components:
with gr.Column():
for component in additional_input_components:
component.render()
if additional_output_components:
with gr.Column():
for component in additional_output_components:
component.render()
image.stream(
fn=self.event_handler,
inputs=[image] + additional_input_components,
outputs=[image],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
image.on_additional_outputs(
self.additional_outputs_handler,
inputs=additional_output_components,
outputs=additional_output_components,
concurrency_limit=self.concurrency_limit_gradio, # type: ignore
diff_output_components = [
component
for component in additional_output_components
if component not in same_components
]
if diff_output_components:
with gr.Column():
for component in diff_output_components:
component.render()
with gr.Row():
image = WebRTC(
label="Stream",
rtc_configuration=self.rtc_configuration,
track_constraints=self.track_constraints,
mode="send-receive",
modality="audio",
icon=ui_args.get("icon"),
icon_button_color=ui_args.get("icon_button_color"),
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
variant=ui_args.get("variant", "wave"),
)
else:
if additional_output_components:
with gr.Row():
with gr.Column():
image = WebRTC(
label="Stream",
rtc_configuration=self.rtc_configuration,
track_constraints=self.track_constraints,
mode="send-receive",
modality="audio",
icon=ui_args.get("icon"),
icon_button_color=ui_args.get("icon_button_color"),
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
for component in additional_input_components:
if component not in same_components:
component.render()
with gr.Column():
for component in additional_output_components:
component.render()
else:
with gr.Row():
with gr.Column():
image = WebRTC(
label="Stream",
rtc_configuration=self.rtc_configuration,
track_constraints=self.track_constraints,
mode="send-receive",
modality="audio",
icon=ui_args.get("icon"),
icon_button_color=ui_args.get("icon_button_color"),
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
for component in additional_input_components:
if component not in same_components:
component.render()
self.webrtc_component = image
image.stream(
fn=self.event_handler,
inputs=[image] + additional_input_components,
outputs=[image],
time_limit=self.time_limit,
concurrency_limit=self.concurrency_limit, # type: ignore
send_input_on=ui_args.get("send_input_on", "change"),
)
if additional_output_components:
assert self.additional_outputs_handler
image.on_additional_outputs(
self.additional_outputs_handler,
inputs=additional_output_components,
outputs=additional_output_components,
concurrency_limit=self.concurrency_limit_gradio, # type: ignore
)
elif self.modality == "audio-video" and self.mode == "send-receive":
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""