ReplyOnPause and ReplyOnStopWords can be interrupted (#119)

* Add all this code

* add code

* Fix demo

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
Freddy Boulton
2025-03-03 21:47:16 -05:00
committed by GitHub
parent 87954a62aa
commit 6ea54777af
13 changed files with 155 additions and 40 deletions

View File

@@ -1,6 +1,6 @@
import asyncio
import inspect
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import lru_cache
from logging import getLogger
from threading import Event
@@ -59,6 +59,10 @@ class AppState:
stopped: bool = False
buffer: np.ndarray | None = None
responded_audio: bool = False
interrupted: asyncio.Event = field(default_factory=asyncio.Event)
def new(self):
return AppState()
ReplyFnGenerator = (
@@ -91,6 +95,7 @@ class ReplyOnPause(StreamHandler):
fn: ReplyFnGenerator,
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
can_interrupt: bool = True,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
@@ -102,6 +107,7 @@ class ReplyOnPause(StreamHandler):
output_frame_size,
input_sample_rate=input_sample_rate,
)
self.can_interrupt = can_interrupt
self.expected_layout: Literal["mono", "stereo"] = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
@@ -123,6 +129,7 @@ class ReplyOnPause(StreamHandler):
self.fn,
self.algo_options,
self.model_options,
self.can_interrupt,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
@@ -170,11 +177,14 @@ class ReplyOnPause(StreamHandler):
state.pause_detected = pause_detected
def receive(self, frame: tuple[int, np.ndarray]) -> None:
if self.state.responding:
if self.state.responding and not self.can_interrupt:
return
self.process_audio(frame, self.state)
if self.state.pause_detected:
self.event.set()
if self.can_interrupt:
self.clear_queue()
self.generator = None
def reset(self):
super().reset()
@@ -207,6 +217,7 @@ class ReplyOnPause(StreamHandler):
else:
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
logger.debug("Latest args: %s", self.latest_args)
self.state = self.state.new()
self.state.responding = True
try:
if self.is_async: