diff --git a/backend/fastrtc/reply_on_pause.py b/backend/fastrtc/reply_on_pause.py index 1fd22b1..1e3215d 100644 --- a/backend/fastrtc/reply_on_pause.py +++ b/backend/fastrtc/reply_on_pause.py @@ -68,6 +68,7 @@ class ReplyOnPause(StreamHandler): def __init__( self, fn: ReplyFnGenerator, + startup_fn: Callable | None = None, algo_options: AlgoOptions | None = None, model_options: ModelOptions | None = None, can_interrupt: bool = True, @@ -97,14 +98,26 @@ class ReplyOnPause(StreamHandler): ) = None self.model_options = model_options self.algo_options = algo_options or AlgoOptions() + self.startup_fn = startup_fn @property def _needs_additional_inputs(self) -> bool: return len(inspect.signature(self.fn).parameters) > 1 + def start_up(self): + if self.startup_fn: + if self._needs_additional_inputs: + self.wait_for_args_sync() + args = self.latest_args[1:] + else: + args = () + self.generator = self.startup_fn(*args) + self.event.set() + def copy(self): return ReplyOnPause( self.fn, + self.startup_fn, self.algo_options, self.model_options, self.can_interrupt, diff --git a/backend/fastrtc/reply_on_stopwords.py b/backend/fastrtc/reply_on_stopwords.py index fead1f5..7d082d7 100644 --- a/backend/fastrtc/reply_on_stopwords.py +++ b/backend/fastrtc/reply_on_stopwords.py @@ -1,7 +1,7 @@ import asyncio import logging import re -from typing import Literal +from typing import Callable, Literal import numpy as np @@ -33,6 +33,7 @@ class ReplyOnStopWords(ReplyOnPause): self, fn: ReplyFnGenerator, stop_words: list[str], + startup_fn: Callable | None = None, algo_options: AlgoOptions | None = None, model_options: ModelOptions | None = None, can_interrupt: bool = True, @@ -45,6 +46,7 @@ class ReplyOnStopWords(ReplyOnPause): super().__init__( fn, algo_options=algo_options, + startup_fn=startup_fn, model_options=model_options, can_interrupt=can_interrupt, expected_layout=expected_layout, @@ -149,6 +151,7 @@ class ReplyOnStopWords(ReplyOnPause): return ReplyOnStopWords( self.fn, self.stop_words, + self.startup_fn, self.algo_options, self.model_options, self.can_interrupt, diff --git a/backend/fastrtc/tracks.py b/backend/fastrtc/tracks.py index 5a18cce..0a9c5aa 100644 --- a/backend/fastrtc/tracks.py +++ b/backend/fastrtc/tracks.py @@ -111,7 +111,7 @@ class VideoCallback(VideoStreamTrack): except TimeoutError: continue - def start( + async def start( self, ): asyncio.create_task(self.process_frames()) @@ -371,13 +371,13 @@ class VideoStreamHandler(VideoCallback): except MediaStreamError: self.stop() - def start(self): + async def start(self): if not self.has_started: asyncio.create_task(self.process_frames()) self.has_started = True async def recv(self): # type: ignore - self.start() + await self.start() try: handler = cast(VideoStreamHandlerImpl, self.event_handler) if inspect.iscoroutinefunction(handler.video_emit): @@ -444,6 +444,12 @@ class AudioCallback(AudioStreamTrack): logger.debug("popped %d items from queue", i) self._start = None + async def wait_for_channel(self): + if not self.event_handler.channel_set.is_set(): + await self.event_handler.channel_set.wait() + if current_channel.get() != self.event_handler.channel: + current_channel.set(self.event_handler.channel) + def set_channel(self, channel: DataChannel): self.channel = channel self.event_handler.set_channel(channel) @@ -477,9 +483,10 @@ class AudioCallback(AudioStreamTrack): logger.debug("MediaStreamError in process_input_frames") break - def start(self): + async def start(self): if not self.has_started: loop = asyncio.get_running_loop() + await self.wait_for_channel() if isinstance(self.event_handler, AsyncHandler): callable = self.event_handler.emit start_up = self.event_handler.start_up() @@ -525,7 +532,7 @@ class AudioCallback(AudioStreamTrack): await self.event_handler.channel_set.wait() if current_channel.get() != self.event_handler.channel: current_channel.set(self.event_handler.channel) - self.start() + await self.start() frame = await self.queue.get() logger.debug("frame %s", frame) @@ -671,7 +678,7 @@ class ServerToClientAudio(AudioStreamTrack): except StopIteration: self.thread_quit.set() - def start(self): + async def start(self): if not self.has_started: loop = asyncio.get_running_loop() callable = functools.partial(loop.run_in_executor, None, self.next) @@ -692,7 +699,7 @@ class ServerToClientAudio(AudioStreamTrack): if self.readyState != "live": raise MediaStreamError - self.start() + await self.start() data = await self.queue.get() if data is None: self.stop() diff --git a/backend/fastrtc/utils.py b/backend/fastrtc/utils.py index dafc882..bd1772c 100644 --- a/backend/fastrtc/utils.py +++ b/backend/fastrtc/utils.py @@ -156,6 +156,11 @@ async def player_worker_decode( break continue + if not isinstance(frame, tuple) and not isinstance(frame[1], np.ndarray): + raise WebRTCError( + "The frame must be a tuple containing a sample rate and a numpy array." + ) + if len(frame) == 2: sample_rate, audio_array = frame layout = "mono" @@ -199,7 +204,10 @@ async def player_worker_decode( exec = traceback.format_exc() print("traceback %s", exec) print("Error processing frame: %s", str(e)) - continue + if isinstance(e, WebRTCError): + raise e + else: + continue def audio_to_bytes(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> bytes: diff --git a/backend/fastrtc/webrtc_connection_mixin.py b/backend/fastrtc/webrtc_connection_mixin.py index 05026e7..b7e5733 100644 --- a/backend/fastrtc/webrtc_connection_mixin.py +++ b/backend/fastrtc/webrtc_connection_mixin.py @@ -241,7 +241,7 @@ class WebRTCConnectionMixin: logger.debug("Adding track to peer connection %s", cb) pc.addTrack(cb) elif self.mode == "send": - cast(AudioCallback | VideoCallback, cb).start() + asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start()) if self.mode == "receive": if self.modality == "video": diff --git a/docs/userguide/audio.md b/docs/userguide/audio.md index e809dd8..81cf657 100644 --- a/docs/userguide/audio.md +++ b/docs/userguide/audio.md @@ -76,6 +76,37 @@ stream = Stream( !!! tip "Muting Response Audio" You can directly talk over the output audio and the interruption will still work. However, in these cases, the audio transcription may be incorrect. To prevent this, it's best practice to mute the output audio before talking over it. +### Startup Function + +You can pass in a `startup_fn` to the `ReplyOnPause` class. This function will be called when the connection is first established. It is helpful for generating intial responses. + +```python +from fastrtc import get_tts_model, Stream, ReplyOnPause + +tts_client = get_tts_model() + + +def detection(audio: tuple[int, np.ndarray]): + # Implement any iterator that yields audio + # See "LLM Voice Chat" for a more complete example + yield audio + + +def startup(): + for chunk in tts_client.stream_tts_sync("Welcome to the echo audio demo!"): + yield chunk + + +stream = Stream( + handler=ReplyOnPause(detection, startup_fn=startup), + modality="audio", + mode="send-receive", + ui_args={"title": "Echo Audio"}, +) +``` + + + ## Reply On Stopwords You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.