import asyncio import audioop import base64 import logging from collections.abc import Awaitable, Callable from typing import Any, cast import anyio import librosa import numpy as np from fastapi import WebSocket from fastapi.websockets import WebSocketDisconnect, WebSocketState from .tracks import AsyncStreamHandler, StreamHandlerImpl from .utils import ( AdditionalOutputs, CloseStream, Context, DataChannel, audio_to_float32, audio_to_int16, current_context, split_output, ) class WebSocketDataChannel(DataChannel): def __init__(self, websocket: WebSocket, loop: asyncio.AbstractEventLoop): self.websocket = websocket self.loop = loop def send(self, message: str) -> None: asyncio.run_coroutine_threadsafe(self.websocket.send_text(message), self.loop) logger = logging.getLogger(__file__) def convert_to_mulaw( audio_data: np.ndarray, original_rate: int, target_rate: int ) -> bytes: """Convert audio data to 8kHz mu-law format""" audio_data = audio_to_float32(audio_data) if original_rate != target_rate: audio_data = librosa.resample( audio_data, orig_sr=original_rate, target_sr=target_rate ) audio_data = audio_to_int16(audio_data) return audioop.lin2ulaw(audio_data, 2) # type: ignore run_sync = anyio.to_thread.run_sync # type: ignore class WebSocketHandler: def __init__( self, stream_handler: StreamHandlerImpl, set_handler: Callable[[str, "WebSocketHandler"], Awaitable[None]], clean_up: Callable[[str], None], additional_outputs_factory: Callable[ [str], Callable[[AdditionalOutputs], None] ], ): self.stream_handler = stream_handler self.stream_handler._clear_queue = self._clear_queue self.websocket: WebSocket | None = None self._emit_task: asyncio.Task | None = None self.stream_id: str | None = None self.set_additional_outputs_factory = additional_outputs_factory self.set_additional_outputs: Callable[[AdditionalOutputs], None] self.set_handler = set_handler self.quit = asyncio.Event() self.clean_up = clean_up self.queue = asyncio.Queue() self.playing_durations = [] # Track durations of frames being played self._frame_cleanup_task: asyncio.Task | None = None self._graceful_shutdown_task: asyncio.Task | None = None def _clear_queue(self): old_queue = self.queue self.queue = asyncio.Queue() logger.debug("clearing queue") i = 0 while not old_queue.empty(): try: old_queue.get_nowait() i += 1 except asyncio.QueueEmpty: break logger.debug("popped %d items from queue", i) def set_args(self, args: list[Any]): self.stream_handler.set_args(args) async def handle_websocket(self, websocket: WebSocket): await websocket.accept() loop = asyncio.get_running_loop() self.loop = loop self.websocket = websocket self.data_channel = WebSocketDataChannel(websocket, loop) self.stream_handler._loop = loop self.stream_handler.set_channel(self.data_channel) self._emit_task = asyncio.create_task(self._emit_loop()) self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue()) self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop()) if isinstance(self.stream_handler, AsyncStreamHandler): start_up = self.stream_handler.start_up() else: start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore was_disconnected = False self.start_up_task = asyncio.create_task(start_up) try: while not self.quit.is_set(): if websocket.application_state != WebSocketState.CONNECTED: was_disconnected = True break message = await websocket.receive_json() if message["event"] == "media": audio_payload = base64.b64decode(message["media"]["payload"]) audio_array = np.frombuffer( audioop.ulaw2lin(audio_payload, 2), dtype=np.int16 ) if ( self.stream_handler.phone_mode and self.stream_handler.input_sample_rate != 8000 ): audio_array = audio_to_float32(audio_array) audio_array = librosa.resample( audio_array, orig_sr=8000, target_sr=self.stream_handler.input_sample_rate, ) audio_array = audio_to_int16(audio_array) try: if isinstance(self.stream_handler, AsyncStreamHandler): await self.stream_handler.receive( (self.stream_handler.input_sample_rate, audio_array) ) else: await run_sync( self.receive_with_context, (self.stream_handler.input_sample_rate, audio_array), ) except Exception as e: print(e) import traceback traceback.print_exc() logger.debug("Error in websocket handler %s", e) elif message["event"] == "start": if self.stream_handler.phone_mode: self.stream_id = cast(str, message["streamSid"]) else: self.stream_id = cast(str, message["websocket_id"]) current_context.set(Context(webrtc_id=self.stream_id)) self.set_additional_outputs = self.set_additional_outputs_factory( self.stream_id ) await self.set_handler(self.stream_id, self) elif message["event"] == "stop": self.quit.set() return # Still runs the `finally` block elif message["event"] == "ping": await websocket.send_json({"event": "pong"}) except WebSocketDisconnect: # Surprisingly, this leaves `websocket.application_state` as CONNECTED # in the `finally` block, so we use this variable was_disconnected = True finally: if self._emit_task: self._emit_task.cancel() if self._emit_to_queue_task: self._emit_to_queue_task.cancel() if self._frame_cleanup_task: self._frame_cleanup_task.cancel() if self._graceful_shutdown_task: self._graceful_shutdown_task.cancel() if self.start_up_task: self.start_up_task.cancel() if not was_disconnected: await websocket.close() self.clean_up(cast(str, self.stream_id)) def emit_with_context(self): current_context.set(Context(webrtc_id=cast(str, self.stream_id))) return self.stream_handler.emit() def receive_with_context(self, frame: tuple[int, np.ndarray]): current_context.set(Context(webrtc_id=cast(str, self.stream_id))) return self.stream_handler.receive(frame) async def _emit_to_queue(self): try: while not self.quit.is_set(): if isinstance(self.stream_handler, AsyncStreamHandler): output = await self.stream_handler.emit() else: output = await run_sync(self.emit_with_context) self.queue.put_nowait(output) except asyncio.CancelledError: logger.debug("Emit loop cancelled") except Exception as e: import traceback traceback.print_exc() logger.debug("Error in emit loop: %s", e) async def _cleanup_frames_loop(self): """Background task that removes frames from playing_durations after they've finished playing.""" try: while not self.quit.is_set(): if self.playing_durations: duration = self.playing_durations[0] await asyncio.sleep(duration) if self.playing_durations: self.playing_durations.pop(0) else: await asyncio.sleep(0.1) except asyncio.CancelledError: logger.debug("Frame cleanup loop cancelled") except Exception as e: logger.debug(f"Error in frame cleanup loop: {e}") async def _wait_for_audio_completion(self): """Wait for all queued audio to finish playing before shutting down.""" try: if not self.playing_durations: self.quit.set() return # Calculate total remaining playback time total_wait = sum(self.playing_durations) if total_wait > 0: logger.debug( f"Waiting {total_wait:.2f}s for audio to complete before closing" ) await asyncio.sleep(total_wait) self.quit.set() except asyncio.CancelledError: logger.debug("Graceful shutdown cancelled") except Exception as e: logger.debug(f"Error in graceful shutdown: {e}") self.quit.set() async def _emit_loop(self): try: while not self.quit.is_set(): output = await self.queue.get() if output is not None: frame, output = split_output(output) if isinstance(output, AdditionalOutputs): self.set_additional_outputs(output) elif isinstance(output, CloseStream): self._graceful_shutdown_task = asyncio.create_task( self._wait_for_audio_completion() ) continue if not isinstance(frame, tuple): continue target_rate = ( 8_000 if self.stream_handler.phone_mode else self.stream_handler.output_sample_rate ) mulaw_audio = convert_to_mulaw( frame[1], frame[0], target_rate=target_rate, ) audio_payload = base64.b64encode(mulaw_audio).decode("utf-8") if self.websocket and self.stream_id: sample_rate, audio_array = frame[:2] duration = len(audio_array) / sample_rate self.playing_durations.append(duration) payload = { "event": "media", "media": {"payload": audio_payload}, } if self.stream_handler.phone_mode: payload["streamSid"] = self.stream_id await self.websocket.send_json(payload) await asyncio.sleep(0.02) except asyncio.CancelledError: logger.debug("Emit loop cancelled") except Exception as e: import traceback traceback.print_exc() logger.debug("Error in emit loop: %s", e)