From 54d07bc3c8f7ed589a9f216dbd9f62e73714203c Mon Sep 17 00:00:00 2001 From: Freddy Boulton <41651716+freddyaboulton@users.noreply.github.com> Date: Mon, 14 Apr 2025 09:57:15 -0400 Subject: [PATCH] Add code (#276) --- backend/fastrtc/templates/component/index.js | 2 +- backend/fastrtc/websocket.py | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/backend/fastrtc/templates/component/index.js b/backend/fastrtc/templates/component/index.js index 01b20e0..a7ea23f 100644 --- a/backend/fastrtc/templates/component/index.js +++ b/backend/fastrtc/templates/component/index.js @@ -22470,7 +22470,7 @@ function Bp(n, e, t) { const oe = (N) => { (N == null ? void 0 : N.type) === "info" || (N == null ? void 0 : N.type) === "warning" || (N == null ? void 0 : N.type) === "error" ? b.dispatch( (N == null ? void 0 : N.type) === "error" ? "error" : "warning", - N.message + N.data ) : (N == null ? void 0 : N.type) === "end_stream" ? b.dispatch("warning", N.data) : (N == null ? void 0 : N.type) === "fetch_output" ? b.dispatch("state_change") : (N == null ? void 0 : N.type) === "send_input" ? b.dispatch("tick") : (N == null ? void 0 : N.type) === "connection_timeout" && b.dispatch("warning", "Taking a while to connect. Are you on a VPN?"), N.type === "state_change" && b.dispatch(N === "change" ? "state_change" : "tick"); }, q = (N) => { var le, qt; diff --git a/backend/fastrtc/websocket.py b/backend/fastrtc/websocket.py index 48f74e4..041779f 100644 --- a/backend/fastrtc/websocket.py +++ b/backend/fastrtc/websocket.py @@ -15,9 +15,11 @@ from .tracks import AsyncStreamHandler, StreamHandlerImpl from .utils import ( AdditionalOutputs, CloseStream, + Context, DataChannel, audio_to_float32, audio_to_int16, + current_context, split_output, ) @@ -141,7 +143,7 @@ class WebSocketHandler: ) else: await run_sync( - self.stream_handler.receive, + self.receive_with_context, (self.stream_handler.input_sample_rate, audio_array), ) except Exception as e: @@ -156,6 +158,7 @@ class WebSocketHandler: self.stream_id = cast(str, message["streamSid"]) else: self.stream_id = cast(str, message["websocket_id"]) + current_context.set(Context(webrtc_id=self.stream_id)) self.set_additional_outputs = self.set_additional_outputs_factory( self.stream_id ) @@ -186,13 +189,21 @@ class WebSocketHandler: self.clean_up(cast(str, self.stream_id)) + def emit_with_context(self): + current_context.set(Context(webrtc_id=cast(str, self.stream_id))) + return self.stream_handler.emit() + + def receive_with_context(self, frame: tuple[int, np.ndarray]): + current_context.set(Context(webrtc_id=cast(str, self.stream_id))) + return self.stream_handler.receive(frame) + async def _emit_to_queue(self): try: while not self.quit.is_set(): if isinstance(self.stream_handler, AsyncStreamHandler): output = await self.stream_handler.emit() else: - output = await run_sync(self.stream_handler.emit) + output = await run_sync(self.emit_with_context) self.queue.put_nowait(output) except asyncio.CancelledError: logger.debug("Emit loop cancelled") @@ -270,7 +281,7 @@ class WebSocketHandler: audio_payload = base64.b64encode(mulaw_audio).decode("utf-8") if self.websocket and self.stream_id: - sample_rate, audio_array = frame + sample_rate, audio_array = frame[:2] duration = len(audio_array) / sample_rate self.playing_durations.append(duration)