This commit is contained in:
Freddy Boulton
2025-04-02 18:31:08 -04:00
committed by GitHub
parent cfde58fce6
commit d07bb41a9e
2 changed files with 63 additions and 4 deletions

View File

@@ -10,7 +10,7 @@ from numpy.typing import NDArray
from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model
from .tracks import EmitType, StreamHandler
from .utils import create_message, split_output
from .utils import AdditionalOutputs, create_message, split_output
logger = getLogger(__name__)
@@ -243,7 +243,7 @@ class ReplyOnPause(StreamHandler):
self.send_message_sync(create_message("log", "response_starting"))
self.state.responded_audio = True
if self.phone_mode:
if additional_outputs:
if isinstance(additional_outputs, AdditionalOutputs):
self.latest_args = [None] + list(additional_outputs.args)
return output
except (StopIteration, StopAsyncIteration):

View File

@@ -11,7 +11,7 @@ from fastapi import WebSocket
from fastapi.websockets import WebSocketDisconnect, WebSocketState
from .tracks import AsyncStreamHandler, StreamHandlerImpl
from .utils import AdditionalOutputs, DataChannel, split_output
from .utils import AdditionalOutputs, CloseStream, DataChannel, split_output
class WebSocketDataChannel(DataChannel):
@@ -66,6 +66,9 @@ class WebSocketHandler:
self.quit = asyncio.Event()
self.clean_up = clean_up
self.queue = asyncio.Queue()
self.playing_durations = [] # Track durations of frames being played
self._frame_cleanup_task: Optional[asyncio.Task] = None
self._graceful_shutdown_task: Optional[asyncio.Task] = None
def _clear_queue(self):
old_queue = self.queue
@@ -93,6 +96,7 @@ class WebSocketHandler:
self.stream_handler.set_channel(self.data_channel)
self._emit_task = asyncio.create_task(self._emit_loop())
self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue())
self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop())
if isinstance(self.stream_handler, AsyncStreamHandler):
start_up = self.stream_handler.start_up()
else:
@@ -165,6 +169,10 @@ class WebSocketHandler:
self._emit_task.cancel()
if self._emit_to_queue_task:
self._emit_to_queue_task.cancel()
if self._frame_cleanup_task:
self._frame_cleanup_task.cancel()
if self._graceful_shutdown_task:
self._graceful_shutdown_task.cancel()
if self.start_up_task:
self.start_up_task.cancel()
@@ -189,6 +197,45 @@ class WebSocketHandler:
traceback.print_exc()
logger.debug("Error in emit loop: %s", e)
async def _cleanup_frames_loop(self):
"""Background task that removes frames from playing_durations after they've finished playing."""
try:
while not self.quit.is_set():
if self.playing_durations:
duration = self.playing_durations[0]
await asyncio.sleep(duration)
if self.playing_durations:
self.playing_durations.pop(0)
else:
await asyncio.sleep(0.1)
except asyncio.CancelledError:
logger.debug("Frame cleanup loop cancelled")
except Exception as e:
logger.debug(f"Error in frame cleanup loop: {e}")
async def _wait_for_audio_completion(self):
"""Wait for all queued audio to finish playing before shutting down."""
try:
if not self.playing_durations:
self.quit.set()
return
# Calculate total remaining playback time
total_wait = sum(self.playing_durations)
if total_wait > 0:
logger.debug(
f"Waiting {total_wait:.2f}s for audio to complete before closing"
)
await asyncio.sleep(total_wait)
self.quit.set()
except asyncio.CancelledError:
logger.debug("Graceful shutdown cancelled")
except Exception as e:
logger.debug(f"Error in graceful shutdown: {e}")
self.quit.set()
async def _emit_loop(self):
try:
while not self.quit.is_set():
@@ -196,10 +243,17 @@ class WebSocketHandler:
if output is not None:
frame, output = split_output(output)
if output is not None:
if isinstance(output, AdditionalOutputs):
self.set_additional_outputs(output)
elif isinstance(output, CloseStream):
self._graceful_shutdown_task = asyncio.create_task(
self._wait_for_audio_completion()
)
continue
if not isinstance(frame, tuple):
continue
target_rate = (
self.stream_handler.output_sample_rate
if not self.stream_handler.phone_mode
@@ -211,6 +265,11 @@ class WebSocketHandler:
audio_payload = base64.b64encode(mulaw_audio).decode("utf-8")
if self.websocket and self.stream_id:
sample_rate, audio_array = frame
duration = len(audio_array) / sample_rate
self.playing_durations.append(duration)
payload = {
"event": "media",
"media": {"payload": audio_payload},