From 7692ffad005e31305494e9b50150c58d86835c8d Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Tue, 25 Mar 2025 14:42:46 -0400 Subject: [PATCH] Add code (#211) --- backend/fastrtc/stream.py | 12 +++++++++++- backend/fastrtc/webrtc.py | 1 + backend/fastrtc/webrtc_connection_mixin.py | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py index 1742fa2..329ba04 100644 --- a/backend/fastrtc/stream.py +++ b/backend/fastrtc/stream.py @@ -77,9 +77,12 @@ class Stream(WebRTCConnectionMixin): self.rtp_params = rtp_params self.event_handler = handler self.concurrency_limit = cast( - (int | float), + (int), 1 if concurrency_limit in ["default", None] else concurrency_limit, ) + self.concurrency_limit_gradio = cast( + int | Literal["default"] | None, concurrency_limit + ) self.time_limit = time_limit self.additional_output_components = additional_outputs self.additional_input_components = additional_inputs @@ -242,6 +245,7 @@ class Stream(WebRTCConnectionMixin): assert self.additional_outputs_handler output_video.on_additional_outputs( self.additional_outputs_handler, + concurrency_limit=self.concurrency_limit_gradio, # type: ignore inputs=additional_output_components, outputs=additional_output_components, ) @@ -289,6 +293,7 @@ class Stream(WebRTCConnectionMixin): assert self.additional_outputs_handler output_video.on_additional_outputs( self.additional_outputs_handler, + concurrency_limit=self.concurrency_limit_gradio, # type: ignore inputs=additional_output_components, outputs=additional_output_components, ) @@ -342,6 +347,7 @@ class Stream(WebRTCConnectionMixin): self.additional_outputs_handler, inputs=additional_output_components, outputs=additional_output_components, + concurrency_limit=self.concurrency_limit_gradio, # type: ignore ) elif self.modality == "audio" and self.mode == "receive": with gr.Blocks() as demo: @@ -395,6 +401,7 @@ class Stream(WebRTCConnectionMixin): self.additional_outputs_handler, inputs=additional_output_components, outputs=additional_output_components, + concurrency_limit=self.concurrency_limit_gradio, # type: ignore ) elif self.modality == "audio" and self.mode == "send": with gr.Blocks() as demo: @@ -447,6 +454,7 @@ class Stream(WebRTCConnectionMixin): self.additional_outputs_handler, inputs=additional_output_components, outputs=additional_output_components, + concurrency_limit=self.concurrency_limit_gradio, # type: ignore ) elif self.modality == "audio" and self.mode == "send-receive": with gr.Blocks() as demo: @@ -500,6 +508,7 @@ class Stream(WebRTCConnectionMixin): self.additional_outputs_handler, inputs=additional_output_components, outputs=additional_output_components, + concurrency_limit=self.concurrency_limit_gradio, # type: ignore ) elif self.modality == "audio-video" and self.mode == "send-receive": css = """.my-group {max-width: 600px !important; max-height: 600 !important;} @@ -555,6 +564,7 @@ class Stream(WebRTCConnectionMixin): self.additional_outputs_handler, inputs=additional_output_components, outputs=additional_output_components, + concurrency_limit=self.concurrency_limit_gradio, # type: ignore ) else: raise ValueError(f"Invalid modality: {self.modality} and mode: {self.mode}") diff --git a/backend/fastrtc/webrtc.py b/backend/fastrtc/webrtc.py index 6363fcd..2883133 100644 --- a/backend/fastrtc/webrtc.py +++ b/backend/fastrtc/webrtc.py @@ -233,6 +233,7 @@ class WebRTC(Component, WebRTCConnectionMixin): inputs = list(inputs) async def handler(webrtc_id: str, *args): + print("webrtc_id", webrtc_id) async for next_outputs in self.output_stream(webrtc_id): yield fn(*args, *next_outputs.args) # type: ignore diff --git a/backend/fastrtc/webrtc_connection_mixin.py b/backend/fastrtc/webrtc_connection_mixin.py index 70762ad..3b744dc 100644 --- a/backend/fastrtc/webrtc_connection_mixin.py +++ b/backend/fastrtc/webrtc_connection_mixin.py @@ -75,7 +75,7 @@ class WebRTCConnectionMixin: self.handlers = {} self.connection_timeouts = defaultdict(asyncio.Event) # These attributes should be set by subclasses: - self.concurrency_limit: int | float | None + self.concurrency_limit: int | None self.event_handler: HandlerType | None self.time_limit: float | None self.modality: Literal["video", "audio", "audio-video"]