feat: Add optional startup function to ReplyOnPause (#170)

* feat: Add optional startup function to ReplyOnPause

* feat: Implement startup_fn in ReplyOnStopWords

* refactor: Remove redundant startup_fn implementation in ReplyOnStopWords

* tweaks

* revert

---------

Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
Sofian Mejjoute
2025-03-12 00:11:29 +01:00
committed by GitHub
parent 514310691d
commit 66f0a81b76
6 changed files with 72 additions and 10 deletions

View File

@@ -68,6 +68,7 @@ class ReplyOnPause(StreamHandler):
def __init__(
self,
fn: ReplyFnGenerator,
startup_fn: Callable | None = None,
algo_options: AlgoOptions | None = None,
model_options: ModelOptions | None = None,
can_interrupt: bool = True,
@@ -97,14 +98,26 @@ class ReplyOnPause(StreamHandler):
) = None
self.model_options = model_options
self.algo_options = algo_options or AlgoOptions()
self.startup_fn = startup_fn
@property
def _needs_additional_inputs(self) -> bool:
return len(inspect.signature(self.fn).parameters) > 1
def start_up(self):
if self.startup_fn:
if self._needs_additional_inputs:
self.wait_for_args_sync()
args = self.latest_args[1:]
else:
args = ()
self.generator = self.startup_fn(*args)
self.event.set()
def copy(self):
return ReplyOnPause(
self.fn,
self.startup_fn,
self.algo_options,
self.model_options,
self.can_interrupt,

View File

@@ -1,7 +1,7 @@
import asyncio
import logging
import re
from typing import Literal
from typing import Callable, Literal
import numpy as np
@@ -33,6 +33,7 @@ class ReplyOnStopWords(ReplyOnPause):
self,
fn: ReplyFnGenerator,
stop_words: list[str],
startup_fn: Callable | None = None,
algo_options: AlgoOptions | None = None,
model_options: ModelOptions | None = None,
can_interrupt: bool = True,
@@ -45,6 +46,7 @@ class ReplyOnStopWords(ReplyOnPause):
super().__init__(
fn,
algo_options=algo_options,
startup_fn=startup_fn,
model_options=model_options,
can_interrupt=can_interrupt,
expected_layout=expected_layout,
@@ -149,6 +151,7 @@ class ReplyOnStopWords(ReplyOnPause):
return ReplyOnStopWords(
self.fn,
self.stop_words,
self.startup_fn,
self.algo_options,
self.model_options,
self.can_interrupt,

View File

@@ -111,7 +111,7 @@ class VideoCallback(VideoStreamTrack):
except TimeoutError:
continue
def start(
async def start(
self,
):
asyncio.create_task(self.process_frames())
@@ -371,13 +371,13 @@ class VideoStreamHandler(VideoCallback):
except MediaStreamError:
self.stop()
def start(self):
async def start(self):
if not self.has_started:
asyncio.create_task(self.process_frames())
self.has_started = True
async def recv(self): # type: ignore
self.start()
await self.start()
try:
handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_emit):
@@ -444,6 +444,12 @@ class AudioCallback(AudioStreamTrack):
logger.debug("popped %d items from queue", i)
self._start = None
async def wait_for_channel(self):
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
def set_channel(self, channel: DataChannel):
self.channel = channel
self.event_handler.set_channel(channel)
@@ -477,9 +483,10 @@ class AudioCallback(AudioStreamTrack):
logger.debug("MediaStreamError in process_input_frames")
break
def start(self):
async def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
await self.wait_for_channel()
if isinstance(self.event_handler, AsyncHandler):
callable = self.event_handler.emit
start_up = self.event_handler.start_up()
@@ -525,7 +532,7 @@ class AudioCallback(AudioStreamTrack):
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
self.start()
await self.start()
frame = await self.queue.get()
logger.debug("frame %s", frame)
@@ -671,7 +678,7 @@ class ServerToClientAudio(AudioStreamTrack):
except StopIteration:
self.thread_quit.set()
def start(self):
async def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
callable = functools.partial(loop.run_in_executor, None, self.next)
@@ -692,7 +699,7 @@ class ServerToClientAudio(AudioStreamTrack):
if self.readyState != "live":
raise MediaStreamError
self.start()
await self.start()
data = await self.queue.get()
if data is None:
self.stop()

View File

@@ -156,6 +156,11 @@ async def player_worker_decode(
break
continue
if not isinstance(frame, tuple) and not isinstance(frame[1], np.ndarray):
raise WebRTCError(
"The frame must be a tuple containing a sample rate and a numpy array."
)
if len(frame) == 2:
sample_rate, audio_array = frame
layout = "mono"
@@ -199,7 +204,10 @@ async def player_worker_decode(
exec = traceback.format_exc()
print("traceback %s", exec)
print("Error processing frame: %s", str(e))
continue
if isinstance(e, WebRTCError):
raise e
else:
continue
def audio_to_bytes(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> bytes:

View File

@@ -241,7 +241,7 @@ class WebRTCConnectionMixin:
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
elif self.mode == "send":
cast(AudioCallback | VideoCallback, cb).start()
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
if self.mode == "receive":
if self.modality == "video":