mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add text mode (#321)
* Pretty good spot * Working draft * Fix other mode * Add js to git * Working * Add code * fix * Fix * Add code * Fix submit race condition * demo * fix * Fix * Fix
This commit is contained in:
@@ -26,7 +26,7 @@ from .tracks import (
|
||||
VideoEventHandler,
|
||||
VideoStreamHandler,
|
||||
)
|
||||
from .utils import RTCConfigurationCallable
|
||||
from .utils import RTCConfigurationCallable, WebRTCData, WebRTCModel
|
||||
from .webrtc_connection_mixin import WebRTCConnectionMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -57,7 +57,8 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
Demos: video_identity_2
|
||||
"""
|
||||
|
||||
EVENTS = ["tick", "state_change"]
|
||||
EVENTS = ["tick", "state_change", "submit"]
|
||||
data_model = WebRTCModel
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -91,6 +92,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
pulse_color: str | None = None,
|
||||
icon_radius: int | None = None,
|
||||
button_labels: dict | None = None,
|
||||
variant: Literal["textbox", "wave"] = "wave",
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@@ -128,6 +130,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
icon_radius: Border radius of the icon button expressed as a percentage of the button size. Default is 50%
|
||||
"""
|
||||
WebRTCConnectionMixin.__init__(self)
|
||||
self.variant = variant
|
||||
self.time_limit = time_limit
|
||||
self.height = height
|
||||
self.width = width
|
||||
@@ -206,14 +209,21 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
icon if not icon else cast(dict, self.serve_static_file(icon)).get("url")
|
||||
)
|
||||
|
||||
def preprocess(self, payload: str) -> str:
|
||||
def preprocess(self, payload: WebRTCModel) -> WebRTCData | str:
|
||||
"""
|
||||
Parameters:
|
||||
payload: An instance of VideoData containing the video and subtitle files.
|
||||
Returns:
|
||||
Passes the uploaded video as a `str` filepath or URL whose extension can be modified by `format`.
|
||||
"""
|
||||
return payload
|
||||
if self.variant == "textbox":
|
||||
return payload.root
|
||||
else:
|
||||
return (
|
||||
payload.root
|
||||
if isinstance(payload.root, str)
|
||||
else payload.root.webrtc_id
|
||||
)
|
||||
|
||||
def postprocess(self, value: Any) -> str:
|
||||
"""
|
||||
@@ -240,7 +250,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
inputs = [inputs]
|
||||
inputs = list(inputs)
|
||||
|
||||
async def handler(webrtc_id: str, *args):
|
||||
async def handler(webrtc_id: str | WebRTCData, *args):
|
||||
if isinstance(webrtc_id, WebRTCData):
|
||||
webrtc_id = webrtc_id.webrtc_id
|
||||
async for next_outputs in self.output_stream(webrtc_id):
|
||||
yield fn(*args, *next_outputs.args) # type: ignore
|
||||
|
||||
@@ -291,7 +303,8 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
)
|
||||
self.event_handler = fn # type: ignore
|
||||
self.time_limit = time_limit
|
||||
|
||||
if self.variant == "textbox":
|
||||
self.event_handler.needs_args = True # type: ignore
|
||||
if (
|
||||
self.mode == "send-receive"
|
||||
and self.modality in ["audio", "audio-video"]
|
||||
@@ -317,7 +330,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
for input_component in inputs[1:]: # type: ignore
|
||||
if hasattr(input_component, "change") and send_input_on == "change":
|
||||
input_component.change( # type: ignore
|
||||
self.set_input,
|
||||
self.set_input_gradio,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
@@ -327,13 +340,19 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
)
|
||||
if hasattr(input_component, "submit") and send_input_on == "submit":
|
||||
input_component.submit( # type: ignore
|
||||
self.set_input,
|
||||
self.set_input_gradio,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
)
|
||||
self.submit( # type: ignore
|
||||
self.set_input_on_submit,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
)
|
||||
return self.tick( # type: ignore
|
||||
self.set_input,
|
||||
self.set_input_gradio,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
@@ -359,7 +378,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
)
|
||||
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
|
||||
self.tick( # type: ignore
|
||||
self.set_input,
|
||||
self.set_input_gradio,
|
||||
inputs=[self] + list(inputs),
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
@@ -378,6 +397,12 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
body, self.set_additional_outputs(body["webrtc_id"])
|
||||
)
|
||||
|
||||
@server
|
||||
async def quit_output_stream(self, body):
|
||||
if body["webrtc_id"] in self.additional_outputs:
|
||||
self.additional_outputs[body["webrtc_id"]].quit.set()
|
||||
return {"success": True}
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return {
|
||||
"video": handle_file(
|
||||
|
||||
Reference in New Issue
Block a user