From d07bb41a9ec57b189a316d13a3efcda52095b35d Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Wed, 2 Apr 2025 18:31:08 -0400 Subject: [PATCH] Fix (#242) --- backend/fastrtc/reply_on_pause.py | 4 +- backend/fastrtc/websocket.py | 63 ++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/backend/fastrtc/reply_on_pause.py b/backend/fastrtc/reply_on_pause.py index 64623dc..705e5d1 100644 --- a/backend/fastrtc/reply_on_pause.py +++ b/backend/fastrtc/reply_on_pause.py @@ -10,7 +10,7 @@ from numpy.typing import NDArray from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model from .tracks import EmitType, StreamHandler -from .utils import create_message, split_output +from .utils import AdditionalOutputs, create_message, split_output logger = getLogger(__name__) @@ -243,7 +243,7 @@ class ReplyOnPause(StreamHandler): self.send_message_sync(create_message("log", "response_starting")) self.state.responded_audio = True if self.phone_mode: - if additional_outputs: + if isinstance(additional_outputs, AdditionalOutputs): self.latest_args = [None] + list(additional_outputs.args) return output except (StopIteration, StopAsyncIteration): diff --git a/backend/fastrtc/websocket.py b/backend/fastrtc/websocket.py index 5d015de..2bf60a6 100644 --- a/backend/fastrtc/websocket.py +++ b/backend/fastrtc/websocket.py @@ -11,7 +11,7 @@ from fastapi import WebSocket from fastapi.websockets import WebSocketDisconnect, WebSocketState from .tracks import AsyncStreamHandler, StreamHandlerImpl -from .utils import AdditionalOutputs, DataChannel, split_output +from .utils import AdditionalOutputs, CloseStream, DataChannel, split_output class WebSocketDataChannel(DataChannel): @@ -66,6 +66,9 @@ class WebSocketHandler: self.quit = asyncio.Event() self.clean_up = clean_up self.queue = asyncio.Queue() + self.playing_durations = [] # Track durations of frames being played + self._frame_cleanup_task: Optional[asyncio.Task] = None + self._graceful_shutdown_task: Optional[asyncio.Task] = None def _clear_queue(self): old_queue = self.queue @@ -93,6 +96,7 @@ class WebSocketHandler: self.stream_handler.set_channel(self.data_channel) self._emit_task = asyncio.create_task(self._emit_loop()) self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue()) + self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop()) if isinstance(self.stream_handler, AsyncStreamHandler): start_up = self.stream_handler.start_up() else: @@ -165,6 +169,10 @@ class WebSocketHandler: self._emit_task.cancel() if self._emit_to_queue_task: self._emit_to_queue_task.cancel() + if self._frame_cleanup_task: + self._frame_cleanup_task.cancel() + if self._graceful_shutdown_task: + self._graceful_shutdown_task.cancel() if self.start_up_task: self.start_up_task.cancel() @@ -189,6 +197,45 @@ class WebSocketHandler: traceback.print_exc() logger.debug("Error in emit loop: %s", e) + async def _cleanup_frames_loop(self): + """Background task that removes frames from playing_durations after they've finished playing.""" + try: + while not self.quit.is_set(): + if self.playing_durations: + duration = self.playing_durations[0] + await asyncio.sleep(duration) + if self.playing_durations: + self.playing_durations.pop(0) + else: + await asyncio.sleep(0.1) + except asyncio.CancelledError: + logger.debug("Frame cleanup loop cancelled") + except Exception as e: + logger.debug(f"Error in frame cleanup loop: {e}") + + async def _wait_for_audio_completion(self): + """Wait for all queued audio to finish playing before shutting down.""" + try: + if not self.playing_durations: + self.quit.set() + return + + # Calculate total remaining playback time + total_wait = sum(self.playing_durations) + + if total_wait > 0: + logger.debug( + f"Waiting {total_wait:.2f}s for audio to complete before closing" + ) + await asyncio.sleep(total_wait) + + self.quit.set() + except asyncio.CancelledError: + logger.debug("Graceful shutdown cancelled") + except Exception as e: + logger.debug(f"Error in graceful shutdown: {e}") + self.quit.set() + async def _emit_loop(self): try: while not self.quit.is_set(): @@ -196,10 +243,17 @@ class WebSocketHandler: if output is not None: frame, output = split_output(output) - if output is not None: + if isinstance(output, AdditionalOutputs): self.set_additional_outputs(output) + elif isinstance(output, CloseStream): + self._graceful_shutdown_task = asyncio.create_task( + self._wait_for_audio_completion() + ) + continue + if not isinstance(frame, tuple): continue + target_rate = ( self.stream_handler.output_sample_rate if not self.stream_handler.phone_mode @@ -211,6 +265,11 @@ class WebSocketHandler: audio_payload = base64.b64encode(mulaw_audio).decode("utf-8") if self.websocket and self.stream_id: + sample_rate, audio_array = frame + duration = len(audio_array) / sample_rate + + self.playing_durations.append(duration) + payload = { "event": "media", "media": {"payload": audio_payload},