feat: Add optional startup function to ReplyOnPause (#170)

* feat: Add optional startup function to ReplyOnPause

* feat: Implement startup_fn in ReplyOnStopWords

* refactor: Remove redundant startup_fn implementation in ReplyOnStopWords

* tweaks

* revert

---------

Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
Sofian Mejjoute
2025-03-12 00:11:29 +01:00
committed by GitHub
parent 514310691d
commit 66f0a81b76
6 changed files with 72 additions and 10 deletions

View File

@@ -68,6 +68,7 @@ class ReplyOnPause(StreamHandler):
def __init__( def __init__(
self, self,
fn: ReplyFnGenerator, fn: ReplyFnGenerator,
startup_fn: Callable | None = None,
algo_options: AlgoOptions | None = None, algo_options: AlgoOptions | None = None,
model_options: ModelOptions | None = None, model_options: ModelOptions | None = None,
can_interrupt: bool = True, can_interrupt: bool = True,
@@ -97,14 +98,26 @@ class ReplyOnPause(StreamHandler):
) = None ) = None
self.model_options = model_options self.model_options = model_options
self.algo_options = algo_options or AlgoOptions() self.algo_options = algo_options or AlgoOptions()
self.startup_fn = startup_fn
@property @property
def _needs_additional_inputs(self) -> bool: def _needs_additional_inputs(self) -> bool:
return len(inspect.signature(self.fn).parameters) > 1 return len(inspect.signature(self.fn).parameters) > 1
def start_up(self):
if self.startup_fn:
if self._needs_additional_inputs:
self.wait_for_args_sync()
args = self.latest_args[1:]
else:
args = ()
self.generator = self.startup_fn(*args)
self.event.set()
def copy(self): def copy(self):
return ReplyOnPause( return ReplyOnPause(
self.fn, self.fn,
self.startup_fn,
self.algo_options, self.algo_options,
self.model_options, self.model_options,
self.can_interrupt, self.can_interrupt,

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
import re import re
from typing import Literal from typing import Callable, Literal
import numpy as np import numpy as np
@@ -33,6 +33,7 @@ class ReplyOnStopWords(ReplyOnPause):
self, self,
fn: ReplyFnGenerator, fn: ReplyFnGenerator,
stop_words: list[str], stop_words: list[str],
startup_fn: Callable | None = None,
algo_options: AlgoOptions | None = None, algo_options: AlgoOptions | None = None,
model_options: ModelOptions | None = None, model_options: ModelOptions | None = None,
can_interrupt: bool = True, can_interrupt: bool = True,
@@ -45,6 +46,7 @@ class ReplyOnStopWords(ReplyOnPause):
super().__init__( super().__init__(
fn, fn,
algo_options=algo_options, algo_options=algo_options,
startup_fn=startup_fn,
model_options=model_options, model_options=model_options,
can_interrupt=can_interrupt, can_interrupt=can_interrupt,
expected_layout=expected_layout, expected_layout=expected_layout,
@@ -149,6 +151,7 @@ class ReplyOnStopWords(ReplyOnPause):
return ReplyOnStopWords( return ReplyOnStopWords(
self.fn, self.fn,
self.stop_words, self.stop_words,
self.startup_fn,
self.algo_options, self.algo_options,
self.model_options, self.model_options,
self.can_interrupt, self.can_interrupt,

View File

@@ -111,7 +111,7 @@ class VideoCallback(VideoStreamTrack):
except TimeoutError: except TimeoutError:
continue continue
def start( async def start(
self, self,
): ):
asyncio.create_task(self.process_frames()) asyncio.create_task(self.process_frames())
@@ -371,13 +371,13 @@ class VideoStreamHandler(VideoCallback):
except MediaStreamError: except MediaStreamError:
self.stop() self.stop()
def start(self): async def start(self):
if not self.has_started: if not self.has_started:
asyncio.create_task(self.process_frames()) asyncio.create_task(self.process_frames())
self.has_started = True self.has_started = True
async def recv(self): # type: ignore async def recv(self): # type: ignore
self.start() await self.start()
try: try:
handler = cast(VideoStreamHandlerImpl, self.event_handler) handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_emit): if inspect.iscoroutinefunction(handler.video_emit):
@@ -444,6 +444,12 @@ class AudioCallback(AudioStreamTrack):
logger.debug("popped %d items from queue", i) logger.debug("popped %d items from queue", i)
self._start = None self._start = None
async def wait_for_channel(self):
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
def set_channel(self, channel: DataChannel): def set_channel(self, channel: DataChannel):
self.channel = channel self.channel = channel
self.event_handler.set_channel(channel) self.event_handler.set_channel(channel)
@@ -477,9 +483,10 @@ class AudioCallback(AudioStreamTrack):
logger.debug("MediaStreamError in process_input_frames") logger.debug("MediaStreamError in process_input_frames")
break break
def start(self): async def start(self):
if not self.has_started: if not self.has_started:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await self.wait_for_channel()
if isinstance(self.event_handler, AsyncHandler): if isinstance(self.event_handler, AsyncHandler):
callable = self.event_handler.emit callable = self.event_handler.emit
start_up = self.event_handler.start_up() start_up = self.event_handler.start_up()
@@ -525,7 +532,7 @@ class AudioCallback(AudioStreamTrack):
await self.event_handler.channel_set.wait() await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel: if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel) current_channel.set(self.event_handler.channel)
self.start() await self.start()
frame = await self.queue.get() frame = await self.queue.get()
logger.debug("frame %s", frame) logger.debug("frame %s", frame)
@@ -671,7 +678,7 @@ class ServerToClientAudio(AudioStreamTrack):
except StopIteration: except StopIteration:
self.thread_quit.set() self.thread_quit.set()
def start(self): async def start(self):
if not self.has_started: if not self.has_started:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
callable = functools.partial(loop.run_in_executor, None, self.next) callable = functools.partial(loop.run_in_executor, None, self.next)
@@ -692,7 +699,7 @@ class ServerToClientAudio(AudioStreamTrack):
if self.readyState != "live": if self.readyState != "live":
raise MediaStreamError raise MediaStreamError
self.start() await self.start()
data = await self.queue.get() data = await self.queue.get()
if data is None: if data is None:
self.stop() self.stop()

View File

@@ -156,6 +156,11 @@ async def player_worker_decode(
break break
continue continue
if not isinstance(frame, tuple) and not isinstance(frame[1], np.ndarray):
raise WebRTCError(
"The frame must be a tuple containing a sample rate and a numpy array."
)
if len(frame) == 2: if len(frame) == 2:
sample_rate, audio_array = frame sample_rate, audio_array = frame
layout = "mono" layout = "mono"
@@ -199,7 +204,10 @@ async def player_worker_decode(
exec = traceback.format_exc() exec = traceback.format_exc()
print("traceback %s", exec) print("traceback %s", exec)
print("Error processing frame: %s", str(e)) print("Error processing frame: %s", str(e))
continue if isinstance(e, WebRTCError):
raise e
else:
continue
def audio_to_bytes(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> bytes: def audio_to_bytes(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> bytes:

View File

@@ -241,7 +241,7 @@ class WebRTCConnectionMixin:
logger.debug("Adding track to peer connection %s", cb) logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb) pc.addTrack(cb)
elif self.mode == "send": elif self.mode == "send":
cast(AudioCallback | VideoCallback, cb).start() asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
if self.mode == "receive": if self.mode == "receive":
if self.modality == "video": if self.modality == "video":

View File

@@ -76,6 +76,37 @@ stream = Stream(
!!! tip "Muting Response Audio" !!! tip "Muting Response Audio"
You can directly talk over the output audio and the interruption will still work. However, in these cases, the audio transcription may be incorrect. To prevent this, it's best practice to mute the output audio before talking over it. You can directly talk over the output audio and the interruption will still work. However, in these cases, the audio transcription may be incorrect. To prevent this, it's best practice to mute the output audio before talking over it.
### Startup Function
You can pass in a `startup_fn` to the `ReplyOnPause` class. This function will be called when the connection is first established. It is helpful for generating intial responses.
```python
from fastrtc import get_tts_model, Stream, ReplyOnPause
tts_client = get_tts_model()
def detection(audio: tuple[int, np.ndarray]):
# Implement any iterator that yields audio
# See "LLM Voice Chat" for a more complete example
yield audio
def startup():
for chunk in tts_client.stream_tts_sync("Welcome to the echo audio demo!"):
yield chunk
stream = Stream(
handler=ReplyOnPause(detection, startup_fn=startup),
modality="audio",
mode="send-receive",
ui_args={"title": "Echo Audio"},
)
```
<video width=98% src="https://github.com/user-attachments/assets/c6b1cb51-5790-4522-80c3-e24e58ef9f11" controls style="text-align: center"></video>
## Reply On Stopwords ## Reply On Stopwords
You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class. You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.