diff --git a/backend/fastrtc/reply_on_pause.py b/backend/fastrtc/reply_on_pause.py index 78fedd0..9b6405c 100644 --- a/backend/fastrtc/reply_on_pause.py +++ b/backend/fastrtc/reply_on_pause.py @@ -184,10 +184,11 @@ class ReplyOnPause(StreamHandler): self.process_audio(frame, self.state) if self.state.pause_detected: self.event.set() - if self.can_interrupt: - self.clear_queue() + if self.can_interrupt and self.state.responding: self._close_generator() self.generator = None + if self.can_interrupt: + self.clear_queue() def _close_generator(self): """Properly close the generator to ensure resources are released.""" diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py index c2c6c8e..ab501e7 100644 --- a/backend/fastrtc/stream.py +++ b/backend/fastrtc/stream.py @@ -509,7 +509,7 @@ class Stream(WebRTCConnectionMixin): handler.phone_mode = True async def set_handler(s: str, a: WebSocketHandler): - if len(self.connections) >= self.concurrency_limit: + if len(self.connections) >= self.concurrency_limit: # type: ignore await cast(WebSocket, a.websocket).send_json( { "status": "failed", @@ -532,7 +532,7 @@ class Stream(WebRTCConnectionMixin): handler.phone_mode = False async def set_handler(s: str, a: WebSocketHandler): - if len(self.connections) >= self.concurrency_limit: + if len(self.connections) >= self.concurrency_limit: # type: ignore await cast(WebSocket, a.websocket).send_json( { "status": "failed", diff --git a/backend/fastrtc/tracks.py b/backend/fastrtc/tracks.py index 4ef50bc..9001e75 100644 --- a/backend/fastrtc/tracks.py +++ b/backend/fastrtc/tracks.py @@ -431,10 +431,14 @@ class AudioCallback(AudioStreamTrack): self.set_additional_outputs = set_additional_outputs def clear_queue(self): - if self.queue: - while not self.queue.empty(): - self.queue.get_nowait() - self._start = None + logger.debug("clearing queue") + logger.debug("queue size: %d", self.queue.qsize()) + i = 0 + while not self.queue.empty(): + self.queue.get_nowait() + i += 1 + logger.debug("popped %d items from queue", i) + self._start = None def set_channel(self, channel: DataChannel): self.channel = channel diff --git a/backend/fastrtc/websocket.py b/backend/fastrtc/websocket.py index 738e2c0..4d182b9 100644 --- a/backend/fastrtc/websocket.py +++ b/backend/fastrtc/websocket.py @@ -55,7 +55,7 @@ class WebSocketHandler: ], ): self.stream_handler = stream_handler - self.stream_handler._clear_queue = lambda: None + self.stream_handler._clear_queue = self._clear_queue self.websocket: Optional[WebSocket] = None self._emit_task: Optional[asyncio.Task] = None self.stream_id: Optional[str] = None @@ -64,6 +64,20 @@ class WebSocketHandler: self.set_handler = set_handler self.quit = asyncio.Event() self.clean_up = clean_up + self.queue = asyncio.Queue() + + def _clear_queue(self): + old_queue = self.queue + self.queue = asyncio.Queue() + logger.debug("clearing queue") + i = 0 + while not old_queue.empty(): + try: + old_queue.get_nowait() + i += 1 + except asyncio.QueueEmpty: + break + logger.debug("popped %d items from queue", i) def set_args(self, args: list[Any]): self.stream_handler.set_args(args) @@ -77,6 +91,7 @@ class WebSocketHandler: self.stream_handler._loop = loop self.stream_handler.set_channel(self.data_channel) self._emit_task = asyncio.create_task(self._emit_loop()) + self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue()) if isinstance(self.stream_handler, AsyncStreamHandler): start_up = self.stream_handler.start_up() else: @@ -137,17 +152,32 @@ class WebSocketHandler: finally: if self._emit_task: self._emit_task.cancel() + if self._emit_to_queue_task: + self._emit_to_queue_task.cancel() if self.start_up_task: self.start_up_task.cancel() await websocket.close() - async def _emit_loop(self): + async def _emit_to_queue(self): try: while not self.quit.is_set(): if isinstance(self.stream_handler, AsyncStreamHandler): output = await self.stream_handler.emit() else: output = await run_sync(self.stream_handler.emit) + self.queue.put_nowait(output) + except asyncio.CancelledError: + logger.debug("Emit loop cancelled") + except Exception as e: + import traceback + + traceback.print_exc() + logger.debug("Error in emit loop: %s", e) + + async def _emit_loop(self): + try: + while not self.quit.is_set(): + output = await self.queue.get() if output is not None: frame, output = split_output(output)