mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
feat: Add optional startup function to ReplyOnPause (#170)
* feat: Add optional startup function to ReplyOnPause * feat: Implement startup_fn in ReplyOnStopWords * refactor: Remove redundant startup_fn implementation in ReplyOnStopWords * tweaks * revert --------- Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
@@ -68,6 +68,7 @@ class ReplyOnPause(StreamHandler):
|
||||
def __init__(
|
||||
self,
|
||||
fn: ReplyFnGenerator,
|
||||
startup_fn: Callable | None = None,
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: ModelOptions | None = None,
|
||||
can_interrupt: bool = True,
|
||||
@@ -97,14 +98,26 @@ class ReplyOnPause(StreamHandler):
|
||||
) = None
|
||||
self.model_options = model_options
|
||||
self.algo_options = algo_options or AlgoOptions()
|
||||
self.startup_fn = startup_fn
|
||||
|
||||
@property
|
||||
def _needs_additional_inputs(self) -> bool:
|
||||
return len(inspect.signature(self.fn).parameters) > 1
|
||||
|
||||
def start_up(self):
|
||||
if self.startup_fn:
|
||||
if self._needs_additional_inputs:
|
||||
self.wait_for_args_sync()
|
||||
args = self.latest_args[1:]
|
||||
else:
|
||||
args = ()
|
||||
self.generator = self.startup_fn(*args)
|
||||
self.event.set()
|
||||
|
||||
def copy(self):
|
||||
return ReplyOnPause(
|
||||
self.fn,
|
||||
self.startup_fn,
|
||||
self.algo_options,
|
||||
self.model_options,
|
||||
self.can_interrupt,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
from typing import Callable, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -33,6 +33,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
self,
|
||||
fn: ReplyFnGenerator,
|
||||
stop_words: list[str],
|
||||
startup_fn: Callable | None = None,
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: ModelOptions | None = None,
|
||||
can_interrupt: bool = True,
|
||||
@@ -45,6 +46,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
super().__init__(
|
||||
fn,
|
||||
algo_options=algo_options,
|
||||
startup_fn=startup_fn,
|
||||
model_options=model_options,
|
||||
can_interrupt=can_interrupt,
|
||||
expected_layout=expected_layout,
|
||||
@@ -149,6 +151,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
return ReplyOnStopWords(
|
||||
self.fn,
|
||||
self.stop_words,
|
||||
self.startup_fn,
|
||||
self.algo_options,
|
||||
self.model_options,
|
||||
self.can_interrupt,
|
||||
|
||||
@@ -111,7 +111,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
def start(
|
||||
async def start(
|
||||
self,
|
||||
):
|
||||
asyncio.create_task(self.process_frames())
|
||||
@@ -371,13 +371,13 @@ class VideoStreamHandler(VideoCallback):
|
||||
except MediaStreamError:
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
async def start(self):
|
||||
if not self.has_started:
|
||||
asyncio.create_task(self.process_frames())
|
||||
self.has_started = True
|
||||
|
||||
async def recv(self): # type: ignore
|
||||
self.start()
|
||||
await self.start()
|
||||
try:
|
||||
handler = cast(VideoStreamHandlerImpl, self.event_handler)
|
||||
if inspect.iscoroutinefunction(handler.video_emit):
|
||||
@@ -444,6 +444,12 @@ class AudioCallback(AudioStreamTrack):
|
||||
logger.debug("popped %d items from queue", i)
|
||||
self._start = None
|
||||
|
||||
async def wait_for_channel(self):
|
||||
if not self.event_handler.channel_set.is_set():
|
||||
await self.event_handler.channel_set.wait()
|
||||
if current_channel.get() != self.event_handler.channel:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
self.event_handler.set_channel(channel)
|
||||
@@ -477,9 +483,10 @@ class AudioCallback(AudioStreamTrack):
|
||||
logger.debug("MediaStreamError in process_input_frames")
|
||||
break
|
||||
|
||||
def start(self):
|
||||
async def start(self):
|
||||
if not self.has_started:
|
||||
loop = asyncio.get_running_loop()
|
||||
await self.wait_for_channel()
|
||||
if isinstance(self.event_handler, AsyncHandler):
|
||||
callable = self.event_handler.emit
|
||||
start_up = self.event_handler.start_up()
|
||||
@@ -525,7 +532,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
await self.event_handler.channel_set.wait()
|
||||
if current_channel.get() != self.event_handler.channel:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
self.start()
|
||||
await self.start()
|
||||
|
||||
frame = await self.queue.get()
|
||||
logger.debug("frame %s", frame)
|
||||
@@ -671,7 +678,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
except StopIteration:
|
||||
self.thread_quit.set()
|
||||
|
||||
def start(self):
|
||||
async def start(self):
|
||||
if not self.has_started:
|
||||
loop = asyncio.get_running_loop()
|
||||
callable = functools.partial(loop.run_in_executor, None, self.next)
|
||||
@@ -692,7 +699,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
self.start()
|
||||
await self.start()
|
||||
data = await self.queue.get()
|
||||
if data is None:
|
||||
self.stop()
|
||||
|
||||
@@ -156,6 +156,11 @@ async def player_worker_decode(
|
||||
break
|
||||
continue
|
||||
|
||||
if not isinstance(frame, tuple) and not isinstance(frame[1], np.ndarray):
|
||||
raise WebRTCError(
|
||||
"The frame must be a tuple containing a sample rate and a numpy array."
|
||||
)
|
||||
|
||||
if len(frame) == 2:
|
||||
sample_rate, audio_array = frame
|
||||
layout = "mono"
|
||||
@@ -199,6 +204,9 @@ async def player_worker_decode(
|
||||
exec = traceback.format_exc()
|
||||
print("traceback %s", exec)
|
||||
print("Error processing frame: %s", str(e))
|
||||
if isinstance(e, WebRTCError):
|
||||
raise e
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ class WebRTCConnectionMixin:
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
elif self.mode == "send":
|
||||
cast(AudioCallback | VideoCallback, cb).start()
|
||||
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
|
||||
|
||||
if self.mode == "receive":
|
||||
if self.modality == "video":
|
||||
|
||||
@@ -76,6 +76,37 @@ stream = Stream(
|
||||
!!! tip "Muting Response Audio"
|
||||
You can directly talk over the output audio and the interruption will still work. However, in these cases, the audio transcription may be incorrect. To prevent this, it's best practice to mute the output audio before talking over it.
|
||||
|
||||
### Startup Function
|
||||
|
||||
You can pass in a `startup_fn` to the `ReplyOnPause` class. This function will be called when the connection is first established. It is helpful for generating intial responses.
|
||||
|
||||
```python
|
||||
from fastrtc import get_tts_model, Stream, ReplyOnPause
|
||||
|
||||
tts_client = get_tts_model()
|
||||
|
||||
|
||||
def detection(audio: tuple[int, np.ndarray]):
|
||||
# Implement any iterator that yields audio
|
||||
# See "LLM Voice Chat" for a more complete example
|
||||
yield audio
|
||||
|
||||
|
||||
def startup():
|
||||
for chunk in tts_client.stream_tts_sync("Welcome to the echo audio demo!"):
|
||||
yield chunk
|
||||
|
||||
|
||||
stream = Stream(
|
||||
handler=ReplyOnPause(detection, startup_fn=startup),
|
||||
modality="audio",
|
||||
mode="send-receive",
|
||||
ui_args={"title": "Echo Audio"},
|
||||
)
|
||||
```
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/c6b1cb51-5790-4522-80c3-e24e58ef9f11" controls style="text-align: center"></video>
|
||||
|
||||
## Reply On Stopwords
|
||||
|
||||
You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.
|
||||
|
||||
Reference in New Issue
Block a user