mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
ReplyOnPause and ReplyOnStopWords can be interrupted (#119)
* Add all this code * add code * Fix demo --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from logging import getLogger
|
||||
from threading import Event
|
||||
@@ -59,6 +59,10 @@ class AppState:
|
||||
stopped: bool = False
|
||||
buffer: np.ndarray | None = None
|
||||
responded_audio: bool = False
|
||||
interrupted: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
|
||||
def new(self):
|
||||
return AppState()
|
||||
|
||||
|
||||
ReplyFnGenerator = (
|
||||
@@ -91,6 +95,7 @@ class ReplyOnPause(StreamHandler):
|
||||
fn: ReplyFnGenerator,
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: SileroVadOptions | None = None,
|
||||
can_interrupt: bool = True,
|
||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 480,
|
||||
@@ -102,6 +107,7 @@ class ReplyOnPause(StreamHandler):
|
||||
output_frame_size,
|
||||
input_sample_rate=input_sample_rate,
|
||||
)
|
||||
self.can_interrupt = can_interrupt
|
||||
self.expected_layout: Literal["mono", "stereo"] = expected_layout
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_frame_size = output_frame_size
|
||||
@@ -123,6 +129,7 @@ class ReplyOnPause(StreamHandler):
|
||||
self.fn,
|
||||
self.algo_options,
|
||||
self.model_options,
|
||||
self.can_interrupt,
|
||||
self.expected_layout,
|
||||
self.output_sample_rate,
|
||||
self.output_frame_size,
|
||||
@@ -170,11 +177,14 @@ class ReplyOnPause(StreamHandler):
|
||||
state.pause_detected = pause_detected
|
||||
|
||||
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
if self.state.responding:
|
||||
if self.state.responding and not self.can_interrupt:
|
||||
return
|
||||
self.process_audio(frame, self.state)
|
||||
if self.state.pause_detected:
|
||||
self.event.set()
|
||||
if self.can_interrupt:
|
||||
self.clear_queue()
|
||||
self.generator = None
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
@@ -207,6 +217,7 @@ class ReplyOnPause(StreamHandler):
|
||||
else:
|
||||
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
|
||||
logger.debug("Latest args: %s", self.latest_args)
|
||||
self.state = self.state.new()
|
||||
self.state.responding = True
|
||||
try:
|
||||
if self.is_async:
|
||||
|
||||
@@ -23,6 +23,9 @@ class ReplyOnStopWordsState(AppState):
|
||||
post_stop_word_buffer: np.ndarray | None = None
|
||||
started_talking_pre_stop_word: bool = False
|
||||
|
||||
def new(self):
|
||||
return ReplyOnStopWordsState()
|
||||
|
||||
|
||||
class ReplyOnStopWords(ReplyOnPause):
|
||||
def __init__(
|
||||
@@ -31,6 +34,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
stop_words: list[str],
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: SileroVadOptions | None = None,
|
||||
can_interrupt: bool = True,
|
||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 480,
|
||||
@@ -40,6 +44,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
fn,
|
||||
algo_options=algo_options,
|
||||
model_options=model_options,
|
||||
can_interrupt=can_interrupt,
|
||||
expected_layout=expected_layout,
|
||||
output_sample_rate=output_sample_rate,
|
||||
output_frame_size=output_frame_size,
|
||||
@@ -144,6 +149,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
self.stop_words,
|
||||
self.algo_options,
|
||||
self.model_options,
|
||||
self.can_interrupt,
|
||||
self.expected_layout,
|
||||
self.output_sample_rate,
|
||||
self.output_frame_size,
|
||||
|
||||
@@ -360,7 +360,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
image = WebRTC(
|
||||
label="Stream",
|
||||
rtc_configuration=self.rtc_configuration,
|
||||
mode="send-receive",
|
||||
mode="send",
|
||||
modality="audio",
|
||||
icon=ui_args.get("icon"),
|
||||
icon_button_color=ui_args.get("icon_button_color"),
|
||||
@@ -505,7 +505,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
return HTMLResponse(content=str(response), media_type="application/xml")
|
||||
|
||||
async def telephone_handler(self, websocket: WebSocket):
|
||||
handler = cast(StreamHandlerImpl, self.event_handler.copy())
|
||||
handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore
|
||||
handler.phone_mode = True
|
||||
|
||||
async def set_handler(s: str, a: WebSocketHandler):
|
||||
@@ -528,7 +528,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
await ws.handle_websocket(websocket)
|
||||
|
||||
async def websocket_offer(self, websocket: WebSocket):
|
||||
handler = cast(StreamHandlerImpl, self.event_handler.copy())
|
||||
handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore
|
||||
handler.phone_mode = False
|
||||
|
||||
async def set_handler(s: str, a: WebSocketHandler):
|
||||
|
||||
@@ -188,6 +188,11 @@ class StreamHandlerBase(ABC):
|
||||
self.args_set = asyncio.Event()
|
||||
self.channel_set = asyncio.Event()
|
||||
self._phone_mode = False
|
||||
self._clear_queue: Callable | None = None
|
||||
|
||||
@property
|
||||
def clear_queue(self) -> Callable:
|
||||
return cast(Callable, self._clear_queue)
|
||||
|
||||
@property
|
||||
def loop(self) -> asyncio.AbstractEventLoop:
|
||||
@@ -237,8 +242,11 @@ class StreamHandlerBase(ABC):
|
||||
logger.debug("Sent msg %s", msg)
|
||||
|
||||
def send_message_sync(self, msg: str):
|
||||
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
|
||||
logger.debug("Sent msg %s", msg)
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
|
||||
logger.debug("Sent msg %s", msg)
|
||||
except Exception as e:
|
||||
logger.debug("Exception sending msg %s", e)
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
logger.debug("setting args in audio callback %s", args)
|
||||
@@ -411,6 +419,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.event_handler = cast(StreamHandlerImpl, event_handler)
|
||||
self.event_handler._clear_queue = self.clear_queue
|
||||
self.current_timestamp = 0
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.queue = asyncio.Queue()
|
||||
@@ -421,6 +430,12 @@ class AudioCallback(AudioStreamTrack):
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
|
||||
def clear_queue(self):
|
||||
if self.queue:
|
||||
while not self.queue.empty():
|
||||
self.queue.get_nowait()
|
||||
self._start = None
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
self.event_handler.set_channel(channel)
|
||||
@@ -608,6 +623,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
) -> None:
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
self.event_handler = event_handler
|
||||
self.event_handler._clear_queue = self.clear_queue
|
||||
self.current_timestamp = 0
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.args_set = threading.Event()
|
||||
@@ -619,6 +635,11 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
self._start: float | None = None
|
||||
super().__init__()
|
||||
|
||||
def clear_queue(self):
|
||||
while not self.queue.empty():
|
||||
self.queue.get_nowait()
|
||||
self._start = None
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
|
||||
|
||||
@@ -320,7 +320,7 @@ def audio_to_int16(
|
||||
>>> audio_int16 = audio_to_int16(audio_tuple)
|
||||
"""
|
||||
if audio[1].dtype == np.int16:
|
||||
return audio[1]
|
||||
return audio[1] # type: ignore
|
||||
elif audio[1].dtype == np.float32:
|
||||
# Convert float32 to int16 by scaling to the int16 range
|
||||
return (audio[1] * 32767.0).astype(np.int16)
|
||||
|
||||
@@ -55,6 +55,7 @@ class WebSocketHandler:
|
||||
],
|
||||
):
|
||||
self.stream_handler = stream_handler
|
||||
self.stream_handler._clear_queue = lambda: None
|
||||
self.websocket: Optional[WebSocket] = None
|
||||
self._emit_task: Optional[asyncio.Task] = None
|
||||
self.stream_id: Optional[str] = None
|
||||
|
||||
Reference in New Issue
Block a user