This commit is contained in:
Freddy Boulton
2025-03-25 14:42:46 -04:00
committed by GitHub
parent e231f793e8
commit 7692ffad00
3 changed files with 13 additions and 2 deletions

View File

@@ -77,9 +77,12 @@ class Stream(WebRTCConnectionMixin):
self.rtp_params = rtp_params self.rtp_params = rtp_params
self.event_handler = handler self.event_handler = handler
self.concurrency_limit = cast( self.concurrency_limit = cast(
(int | float), (int),
1 if concurrency_limit in ["default", None] else concurrency_limit, 1 if concurrency_limit in ["default", None] else concurrency_limit,
) )
self.concurrency_limit_gradio = cast(
int | Literal["default"] | None, concurrency_limit
)
self.time_limit = time_limit self.time_limit = time_limit
self.additional_output_components = additional_outputs self.additional_output_components = additional_outputs
self.additional_input_components = additional_inputs self.additional_input_components = additional_inputs
@@ -242,6 +245,7 @@ class Stream(WebRTCConnectionMixin):
assert self.additional_outputs_handler assert self.additional_outputs_handler
output_video.on_additional_outputs( output_video.on_additional_outputs(
self.additional_outputs_handler, self.additional_outputs_handler,
concurrency_limit=self.concurrency_limit_gradio, # type: ignore
inputs=additional_output_components, inputs=additional_output_components,
outputs=additional_output_components, outputs=additional_output_components,
) )
@@ -289,6 +293,7 @@ class Stream(WebRTCConnectionMixin):
assert self.additional_outputs_handler assert self.additional_outputs_handler
output_video.on_additional_outputs( output_video.on_additional_outputs(
self.additional_outputs_handler, self.additional_outputs_handler,
concurrency_limit=self.concurrency_limit_gradio, # type: ignore
inputs=additional_output_components, inputs=additional_output_components,
outputs=additional_output_components, outputs=additional_output_components,
) )
@@ -342,6 +347,7 @@ class Stream(WebRTCConnectionMixin):
self.additional_outputs_handler, self.additional_outputs_handler,
inputs=additional_output_components, inputs=additional_output_components,
outputs=additional_output_components, outputs=additional_output_components,
concurrency_limit=self.concurrency_limit_gradio, # type: ignore
) )
elif self.modality == "audio" and self.mode == "receive": elif self.modality == "audio" and self.mode == "receive":
with gr.Blocks() as demo: with gr.Blocks() as demo:
@@ -395,6 +401,7 @@ class Stream(WebRTCConnectionMixin):
self.additional_outputs_handler, self.additional_outputs_handler,
inputs=additional_output_components, inputs=additional_output_components,
outputs=additional_output_components, outputs=additional_output_components,
concurrency_limit=self.concurrency_limit_gradio, # type: ignore
) )
elif self.modality == "audio" and self.mode == "send": elif self.modality == "audio" and self.mode == "send":
with gr.Blocks() as demo: with gr.Blocks() as demo:
@@ -447,6 +454,7 @@ class Stream(WebRTCConnectionMixin):
self.additional_outputs_handler, self.additional_outputs_handler,
inputs=additional_output_components, inputs=additional_output_components,
outputs=additional_output_components, outputs=additional_output_components,
concurrency_limit=self.concurrency_limit_gradio, # type: ignore
) )
elif self.modality == "audio" and self.mode == "send-receive": elif self.modality == "audio" and self.mode == "send-receive":
with gr.Blocks() as demo: with gr.Blocks() as demo:
@@ -500,6 +508,7 @@ class Stream(WebRTCConnectionMixin):
self.additional_outputs_handler, self.additional_outputs_handler,
inputs=additional_output_components, inputs=additional_output_components,
outputs=additional_output_components, outputs=additional_output_components,
concurrency_limit=self.concurrency_limit_gradio, # type: ignore
) )
elif self.modality == "audio-video" and self.mode == "send-receive": elif self.modality == "audio-video" and self.mode == "send-receive":
css = """.my-group {max-width: 600px !important; max-height: 600 !important;} css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
@@ -555,6 +564,7 @@ class Stream(WebRTCConnectionMixin):
self.additional_outputs_handler, self.additional_outputs_handler,
inputs=additional_output_components, inputs=additional_output_components,
outputs=additional_output_components, outputs=additional_output_components,
concurrency_limit=self.concurrency_limit_gradio, # type: ignore
) )
else: else:
raise ValueError(f"Invalid modality: {self.modality} and mode: {self.mode}") raise ValueError(f"Invalid modality: {self.modality} and mode: {self.mode}")

View File

@@ -233,6 +233,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
inputs = list(inputs) inputs = list(inputs)
async def handler(webrtc_id: str, *args): async def handler(webrtc_id: str, *args):
print("webrtc_id", webrtc_id)
async for next_outputs in self.output_stream(webrtc_id): async for next_outputs in self.output_stream(webrtc_id):
yield fn(*args, *next_outputs.args) # type: ignore yield fn(*args, *next_outputs.args) # type: ignore

View File

@@ -75,7 +75,7 @@ class WebRTCConnectionMixin:
self.handlers = {} self.handlers = {}
self.connection_timeouts = defaultdict(asyncio.Event) self.connection_timeouts = defaultdict(asyncio.Event)
# These attributes should be set by subclasses: # These attributes should be set by subclasses:
self.concurrency_limit: int | float | None self.concurrency_limit: int | None
self.event_handler: HandlerType | None self.event_handler: HandlerType | None
self.time_limit: float | None self.time_limit: float | None
self.modality: Literal["video", "audio", "audio-video"] self.modality: Literal["video", "audio", "audio-video"]