Add text mode (#321)

* Pretty good spot

* Working draft

* Fix other mode

* Add js to git

* Working

* Add code

* fix

* Fix

* Add code

* Fix submit race condition

* demo

* fix

* Fix

* Fix
This commit is contained in:
Freddy Boulton
2025-06-03 19:24:21 -04:00
committed by GitHub
parent 1179f8ef21
commit 1877720231
69 changed files with 110161 additions and 22889 deletions

View File

@@ -26,7 +26,7 @@ from .tracks import (
VideoEventHandler,
VideoStreamHandler,
)
from .utils import RTCConfigurationCallable
from .utils import RTCConfigurationCallable, WebRTCData, WebRTCModel
from .webrtc_connection_mixin import WebRTCConnectionMixin
if TYPE_CHECKING:
@@ -57,7 +57,8 @@ class WebRTC(Component, WebRTCConnectionMixin):
Demos: video_identity_2
"""
EVENTS = ["tick", "state_change"]
EVENTS = ["tick", "state_change", "submit"]
data_model = WebRTCModel
def __init__(
self,
@@ -91,6 +92,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
pulse_color: str | None = None,
icon_radius: int | None = None,
button_labels: dict | None = None,
variant: Literal["textbox", "wave"] = "wave",
):
"""
Parameters:
@@ -128,6 +130,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
icon_radius: Border radius of the icon button expressed as a percentage of the button size. Default is 50%
"""
WebRTCConnectionMixin.__init__(self)
self.variant = variant
self.time_limit = time_limit
self.height = height
self.width = width
@@ -206,14 +209,21 @@ class WebRTC(Component, WebRTCConnectionMixin):
icon if not icon else cast(dict, self.serve_static_file(icon)).get("url")
)
def preprocess(self, payload: str) -> str:
def preprocess(self, payload: WebRTCModel) -> WebRTCData | str:
"""
Parameters:
payload: An instance of VideoData containing the video and subtitle files.
Returns:
Passes the uploaded video as a `str` filepath or URL whose extension can be modified by `format`.
"""
return payload
if self.variant == "textbox":
return payload.root
else:
return (
payload.root
if isinstance(payload.root, str)
else payload.root.webrtc_id
)
def postprocess(self, value: Any) -> str:
"""
@@ -240,7 +250,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
inputs = [inputs]
inputs = list(inputs)
async def handler(webrtc_id: str, *args):
async def handler(webrtc_id: str | WebRTCData, *args):
if isinstance(webrtc_id, WebRTCData):
webrtc_id = webrtc_id.webrtc_id
async for next_outputs in self.output_stream(webrtc_id):
yield fn(*args, *next_outputs.args) # type: ignore
@@ -291,7 +303,8 @@ class WebRTC(Component, WebRTCConnectionMixin):
)
self.event_handler = fn # type: ignore
self.time_limit = time_limit
if self.variant == "textbox":
self.event_handler.needs_args = True # type: ignore
if (
self.mode == "send-receive"
and self.modality in ["audio", "audio-video"]
@@ -317,7 +330,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
for input_component in inputs[1:]: # type: ignore
if hasattr(input_component, "change") and send_input_on == "change":
input_component.change( # type: ignore
self.set_input,
self.set_input_gradio,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
@@ -327,13 +340,19 @@ class WebRTC(Component, WebRTCConnectionMixin):
)
if hasattr(input_component, "submit") and send_input_on == "submit":
input_component.submit( # type: ignore
self.set_input,
self.set_input_gradio,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
)
self.submit( # type: ignore
self.set_input_on_submit,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
)
return self.tick( # type: ignore
self.set_input,
self.set_input_gradio,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
@@ -359,7 +378,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
)
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
self.tick( # type: ignore
self.set_input,
self.set_input_gradio,
inputs=[self] + list(inputs),
outputs=None,
concurrency_id=concurrency_id,
@@ -378,6 +397,12 @@ class WebRTC(Component, WebRTCConnectionMixin):
body, self.set_additional_outputs(body["webrtc_id"])
)
@server
async def quit_output_stream(self, body):
if body["webrtc_id"] in self.additional_outputs:
self.additional_outputs[body["webrtc_id"]].quit.set()
return {"success": True}
def example_payload(self) -> Any:
return {
"video": handle_file(