feat: Add optional startup function to ReplyOnPause (#170)

* feat: Add optional startup function to ReplyOnPause

* feat: Implement startup_fn in ReplyOnStopWords

* refactor: Remove redundant startup_fn implementation in ReplyOnStopWords

* tweaks

* revert

---------

Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
Sofian Mejjoute
2025-03-12 00:11:29 +01:00
committed by GitHub
parent 514310691d
commit 66f0a81b76
6 changed files with 72 additions and 10 deletions

View File

@@ -111,7 +111,7 @@ class VideoCallback(VideoStreamTrack):
except TimeoutError:
continue
def start(
async def start(
self,
):
asyncio.create_task(self.process_frames())
@@ -371,13 +371,13 @@ class VideoStreamHandler(VideoCallback):
except MediaStreamError:
self.stop()
def start(self):
async def start(self):
if not self.has_started:
asyncio.create_task(self.process_frames())
self.has_started = True
async def recv(self): # type: ignore
self.start()
await self.start()
try:
handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_emit):
@@ -444,6 +444,12 @@ class AudioCallback(AudioStreamTrack):
logger.debug("popped %d items from queue", i)
self._start = None
async def wait_for_channel(self):
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
def set_channel(self, channel: DataChannel):
self.channel = channel
self.event_handler.set_channel(channel)
@@ -477,9 +483,10 @@ class AudioCallback(AudioStreamTrack):
logger.debug("MediaStreamError in process_input_frames")
break
def start(self):
async def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
await self.wait_for_channel()
if isinstance(self.event_handler, AsyncHandler):
callable = self.event_handler.emit
start_up = self.event_handler.start_up()
@@ -525,7 +532,7 @@ class AudioCallback(AudioStreamTrack):
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
self.start()
await self.start()
frame = await self.queue.get()
logger.debug("frame %s", frame)
@@ -671,7 +678,7 @@ class ServerToClientAudio(AudioStreamTrack):
except StopIteration:
self.thread_quit.set()
def start(self):
async def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
callable = functools.partial(loop.run_in_executor, None, self.next)
@@ -692,7 +699,7 @@ class ServerToClientAudio(AudioStreamTrack):
if self.readyState != "live":
raise MediaStreamError
self.start()
await self.start()
data = await self.queue.get()
if data is None:
self.stop()