feat: Add optional startup function to ReplyOnPause (#170)

* feat: Add optional startup function to ReplyOnPause

* feat: Implement startup_fn in ReplyOnStopWords

* refactor: Remove redundant startup_fn implementation in ReplyOnStopWords

* tweaks

* revert

---------

Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
Sofian Mejjoute
2025-03-12 00:11:29 +01:00
committed by GitHub
parent 514310691d
commit 66f0a81b76
6 changed files with 72 additions and 10 deletions

View File

@@ -68,6 +68,7 @@ class ReplyOnPause(StreamHandler):
def __init__(
self,
fn: ReplyFnGenerator,
startup_fn: Callable | None = None,
algo_options: AlgoOptions | None = None,
model_options: ModelOptions | None = None,
can_interrupt: bool = True,
@@ -97,14 +98,26 @@ class ReplyOnPause(StreamHandler):
) = None
self.model_options = model_options
self.algo_options = algo_options or AlgoOptions()
self.startup_fn = startup_fn
@property
def _needs_additional_inputs(self) -> bool:
return len(inspect.signature(self.fn).parameters) > 1
def start_up(self):
if self.startup_fn:
if self._needs_additional_inputs:
self.wait_for_args_sync()
args = self.latest_args[1:]
else:
args = ()
self.generator = self.startup_fn(*args)
self.event.set()
def copy(self):
return ReplyOnPause(
self.fn,
self.startup_fn,
self.algo_options,
self.model_options,
self.can_interrupt,

View File

@@ -1,7 +1,7 @@
import asyncio
import logging
import re
from typing import Literal
from typing import Callable, Literal
import numpy as np
@@ -33,6 +33,7 @@ class ReplyOnStopWords(ReplyOnPause):
self,
fn: ReplyFnGenerator,
stop_words: list[str],
startup_fn: Callable | None = None,
algo_options: AlgoOptions | None = None,
model_options: ModelOptions | None = None,
can_interrupt: bool = True,
@@ -45,6 +46,7 @@ class ReplyOnStopWords(ReplyOnPause):
super().__init__(
fn,
algo_options=algo_options,
startup_fn=startup_fn,
model_options=model_options,
can_interrupt=can_interrupt,
expected_layout=expected_layout,
@@ -149,6 +151,7 @@ class ReplyOnStopWords(ReplyOnPause):
return ReplyOnStopWords(
self.fn,
self.stop_words,
self.startup_fn,
self.algo_options,
self.model_options,
self.can_interrupt,

View File

@@ -111,7 +111,7 @@ class VideoCallback(VideoStreamTrack):
except TimeoutError:
continue
def start(
async def start(
self,
):
asyncio.create_task(self.process_frames())
@@ -371,13 +371,13 @@ class VideoStreamHandler(VideoCallback):
except MediaStreamError:
self.stop()
def start(self):
async def start(self):
if not self.has_started:
asyncio.create_task(self.process_frames())
self.has_started = True
async def recv(self): # type: ignore
self.start()
await self.start()
try:
handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_emit):
@@ -444,6 +444,12 @@ class AudioCallback(AudioStreamTrack):
logger.debug("popped %d items from queue", i)
self._start = None
async def wait_for_channel(self):
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
def set_channel(self, channel: DataChannel):
self.channel = channel
self.event_handler.set_channel(channel)
@@ -477,9 +483,10 @@ class AudioCallback(AudioStreamTrack):
logger.debug("MediaStreamError in process_input_frames")
break
def start(self):
async def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
await self.wait_for_channel()
if isinstance(self.event_handler, AsyncHandler):
callable = self.event_handler.emit
start_up = self.event_handler.start_up()
@@ -525,7 +532,7 @@ class AudioCallback(AudioStreamTrack):
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
self.start()
await self.start()
frame = await self.queue.get()
logger.debug("frame %s", frame)
@@ -671,7 +678,7 @@ class ServerToClientAudio(AudioStreamTrack):
except StopIteration:
self.thread_quit.set()
def start(self):
async def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
callable = functools.partial(loop.run_in_executor, None, self.next)
@@ -692,7 +699,7 @@ class ServerToClientAudio(AudioStreamTrack):
if self.readyState != "live":
raise MediaStreamError
self.start()
await self.start()
data = await self.queue.get()
if data is None:
self.stop()

View File

@@ -156,6 +156,11 @@ async def player_worker_decode(
break
continue
if not isinstance(frame, tuple) and not isinstance(frame[1], np.ndarray):
raise WebRTCError(
"The frame must be a tuple containing a sample rate and a numpy array."
)
if len(frame) == 2:
sample_rate, audio_array = frame
layout = "mono"
@@ -199,7 +204,10 @@ async def player_worker_decode(
exec = traceback.format_exc()
print("traceback %s", exec)
print("Error processing frame: %s", str(e))
continue
if isinstance(e, WebRTCError):
raise e
else:
continue
def audio_to_bytes(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> bytes:

View File

@@ -241,7 +241,7 @@ class WebRTCConnectionMixin:
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
elif self.mode == "send":
cast(AudioCallback | VideoCallback, cb).start()
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
if self.mode == "receive":
if self.modality == "video":

View File

@@ -76,6 +76,37 @@ stream = Stream(
!!! tip "Muting Response Audio"
You can directly talk over the output audio and the interruption will still work. However, in these cases, the audio transcription may be incorrect. To prevent this, it's best practice to mute the output audio before talking over it.
### Startup Function
You can pass in a `startup_fn` to the `ReplyOnPause` class. This function will be called when the connection is first established. It is helpful for generating intial responses.
```python
from fastrtc import get_tts_model, Stream, ReplyOnPause
tts_client = get_tts_model()
def detection(audio: tuple[int, np.ndarray]):
# Implement any iterator that yields audio
# See "LLM Voice Chat" for a more complete example
yield audio
def startup():
for chunk in tts_client.stream_tts_sync("Welcome to the echo audio demo!"):
yield chunk
stream = Stream(
handler=ReplyOnPause(detection, startup_fn=startup),
modality="audio",
mode="send-receive",
ui_args={"title": "Echo Audio"},
)
```
<video width=98% src="https://github.com/user-attachments/assets/c6b1cb51-5790-4522-80c3-e24e58ef9f11" controls style="text-align: center"></video>
## Reply On Stopwords
You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.