mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
feat: Add optional startup function to ReplyOnPause (#170)
* feat: Add optional startup function to ReplyOnPause * feat: Implement startup_fn in ReplyOnStopWords * refactor: Remove redundant startup_fn implementation in ReplyOnStopWords * tweaks * revert --------- Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
@@ -68,6 +68,7 @@ class ReplyOnPause(StreamHandler):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fn: ReplyFnGenerator,
|
fn: ReplyFnGenerator,
|
||||||
|
startup_fn: Callable | None = None,
|
||||||
algo_options: AlgoOptions | None = None,
|
algo_options: AlgoOptions | None = None,
|
||||||
model_options: ModelOptions | None = None,
|
model_options: ModelOptions | None = None,
|
||||||
can_interrupt: bool = True,
|
can_interrupt: bool = True,
|
||||||
@@ -97,14 +98,26 @@ class ReplyOnPause(StreamHandler):
|
|||||||
) = None
|
) = None
|
||||||
self.model_options = model_options
|
self.model_options = model_options
|
||||||
self.algo_options = algo_options or AlgoOptions()
|
self.algo_options = algo_options or AlgoOptions()
|
||||||
|
self.startup_fn = startup_fn
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _needs_additional_inputs(self) -> bool:
|
def _needs_additional_inputs(self) -> bool:
|
||||||
return len(inspect.signature(self.fn).parameters) > 1
|
return len(inspect.signature(self.fn).parameters) > 1
|
||||||
|
|
||||||
|
def start_up(self):
|
||||||
|
if self.startup_fn:
|
||||||
|
if self._needs_additional_inputs:
|
||||||
|
self.wait_for_args_sync()
|
||||||
|
args = self.latest_args[1:]
|
||||||
|
else:
|
||||||
|
args = ()
|
||||||
|
self.generator = self.startup_fn(*args)
|
||||||
|
self.event.set()
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return ReplyOnPause(
|
return ReplyOnPause(
|
||||||
self.fn,
|
self.fn,
|
||||||
|
self.startup_fn,
|
||||||
self.algo_options,
|
self.algo_options,
|
||||||
self.model_options,
|
self.model_options,
|
||||||
self.can_interrupt,
|
self.can_interrupt,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Literal
|
from typing import Callable, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -33,6 +33,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
self,
|
self,
|
||||||
fn: ReplyFnGenerator,
|
fn: ReplyFnGenerator,
|
||||||
stop_words: list[str],
|
stop_words: list[str],
|
||||||
|
startup_fn: Callable | None = None,
|
||||||
algo_options: AlgoOptions | None = None,
|
algo_options: AlgoOptions | None = None,
|
||||||
model_options: ModelOptions | None = None,
|
model_options: ModelOptions | None = None,
|
||||||
can_interrupt: bool = True,
|
can_interrupt: bool = True,
|
||||||
@@ -45,6 +46,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
fn,
|
fn,
|
||||||
algo_options=algo_options,
|
algo_options=algo_options,
|
||||||
|
startup_fn=startup_fn,
|
||||||
model_options=model_options,
|
model_options=model_options,
|
||||||
can_interrupt=can_interrupt,
|
can_interrupt=can_interrupt,
|
||||||
expected_layout=expected_layout,
|
expected_layout=expected_layout,
|
||||||
@@ -149,6 +151,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
return ReplyOnStopWords(
|
return ReplyOnStopWords(
|
||||||
self.fn,
|
self.fn,
|
||||||
self.stop_words,
|
self.stop_words,
|
||||||
|
self.startup_fn,
|
||||||
self.algo_options,
|
self.algo_options,
|
||||||
self.model_options,
|
self.model_options,
|
||||||
self.can_interrupt,
|
self.can_interrupt,
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def start(
|
async def start(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
asyncio.create_task(self.process_frames())
|
asyncio.create_task(self.process_frames())
|
||||||
@@ -371,13 +371,13 @@ class VideoStreamHandler(VideoCallback):
|
|||||||
except MediaStreamError:
|
except MediaStreamError:
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
def start(self):
|
async def start(self):
|
||||||
if not self.has_started:
|
if not self.has_started:
|
||||||
asyncio.create_task(self.process_frames())
|
asyncio.create_task(self.process_frames())
|
||||||
self.has_started = True
|
self.has_started = True
|
||||||
|
|
||||||
async def recv(self): # type: ignore
|
async def recv(self): # type: ignore
|
||||||
self.start()
|
await self.start()
|
||||||
try:
|
try:
|
||||||
handler = cast(VideoStreamHandlerImpl, self.event_handler)
|
handler = cast(VideoStreamHandlerImpl, self.event_handler)
|
||||||
if inspect.iscoroutinefunction(handler.video_emit):
|
if inspect.iscoroutinefunction(handler.video_emit):
|
||||||
@@ -444,6 +444,12 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
logger.debug("popped %d items from queue", i)
|
logger.debug("popped %d items from queue", i)
|
||||||
self._start = None
|
self._start = None
|
||||||
|
|
||||||
|
async def wait_for_channel(self):
|
||||||
|
if not self.event_handler.channel_set.is_set():
|
||||||
|
await self.event_handler.channel_set.wait()
|
||||||
|
if current_channel.get() != self.event_handler.channel:
|
||||||
|
current_channel.set(self.event_handler.channel)
|
||||||
|
|
||||||
def set_channel(self, channel: DataChannel):
|
def set_channel(self, channel: DataChannel):
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.event_handler.set_channel(channel)
|
self.event_handler.set_channel(channel)
|
||||||
@@ -477,9 +483,10 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
logger.debug("MediaStreamError in process_input_frames")
|
logger.debug("MediaStreamError in process_input_frames")
|
||||||
break
|
break
|
||||||
|
|
||||||
def start(self):
|
async def start(self):
|
||||||
if not self.has_started:
|
if not self.has_started:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
await self.wait_for_channel()
|
||||||
if isinstance(self.event_handler, AsyncHandler):
|
if isinstance(self.event_handler, AsyncHandler):
|
||||||
callable = self.event_handler.emit
|
callable = self.event_handler.emit
|
||||||
start_up = self.event_handler.start_up()
|
start_up = self.event_handler.start_up()
|
||||||
@@ -525,7 +532,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
await self.event_handler.channel_set.wait()
|
await self.event_handler.channel_set.wait()
|
||||||
if current_channel.get() != self.event_handler.channel:
|
if current_channel.get() != self.event_handler.channel:
|
||||||
current_channel.set(self.event_handler.channel)
|
current_channel.set(self.event_handler.channel)
|
||||||
self.start()
|
await self.start()
|
||||||
|
|
||||||
frame = await self.queue.get()
|
frame = await self.queue.get()
|
||||||
logger.debug("frame %s", frame)
|
logger.debug("frame %s", frame)
|
||||||
@@ -671,7 +678,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.thread_quit.set()
|
self.thread_quit.set()
|
||||||
|
|
||||||
def start(self):
|
async def start(self):
|
||||||
if not self.has_started:
|
if not self.has_started:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
callable = functools.partial(loop.run_in_executor, None, self.next)
|
callable = functools.partial(loop.run_in_executor, None, self.next)
|
||||||
@@ -692,7 +699,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
if self.readyState != "live":
|
if self.readyState != "live":
|
||||||
raise MediaStreamError
|
raise MediaStreamError
|
||||||
|
|
||||||
self.start()
|
await self.start()
|
||||||
data = await self.queue.get()
|
data = await self.queue.get()
|
||||||
if data is None:
|
if data is None:
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|||||||
@@ -156,6 +156,11 @@ async def player_worker_decode(
|
|||||||
break
|
break
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if not isinstance(frame, tuple) and not isinstance(frame[1], np.ndarray):
|
||||||
|
raise WebRTCError(
|
||||||
|
"The frame must be a tuple containing a sample rate and a numpy array."
|
||||||
|
)
|
||||||
|
|
||||||
if len(frame) == 2:
|
if len(frame) == 2:
|
||||||
sample_rate, audio_array = frame
|
sample_rate, audio_array = frame
|
||||||
layout = "mono"
|
layout = "mono"
|
||||||
@@ -199,7 +204,10 @@ async def player_worker_decode(
|
|||||||
exec = traceback.format_exc()
|
exec = traceback.format_exc()
|
||||||
print("traceback %s", exec)
|
print("traceback %s", exec)
|
||||||
print("Error processing frame: %s", str(e))
|
print("Error processing frame: %s", str(e))
|
||||||
continue
|
if isinstance(e, WebRTCError):
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
def audio_to_bytes(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> bytes:
|
def audio_to_bytes(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> bytes:
|
||||||
|
|||||||
@@ -241,7 +241,7 @@ class WebRTCConnectionMixin:
|
|||||||
logger.debug("Adding track to peer connection %s", cb)
|
logger.debug("Adding track to peer connection %s", cb)
|
||||||
pc.addTrack(cb)
|
pc.addTrack(cb)
|
||||||
elif self.mode == "send":
|
elif self.mode == "send":
|
||||||
cast(AudioCallback | VideoCallback, cb).start()
|
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
|
||||||
|
|
||||||
if self.mode == "receive":
|
if self.mode == "receive":
|
||||||
if self.modality == "video":
|
if self.modality == "video":
|
||||||
|
|||||||
@@ -76,6 +76,37 @@ stream = Stream(
|
|||||||
!!! tip "Muting Response Audio"
|
!!! tip "Muting Response Audio"
|
||||||
You can directly talk over the output audio and the interruption will still work. However, in these cases, the audio transcription may be incorrect. To prevent this, it's best practice to mute the output audio before talking over it.
|
You can directly talk over the output audio and the interruption will still work. However, in these cases, the audio transcription may be incorrect. To prevent this, it's best practice to mute the output audio before talking over it.
|
||||||
|
|
||||||
|
### Startup Function
|
||||||
|
|
||||||
|
You can pass in a `startup_fn` to the `ReplyOnPause` class. This function will be called when the connection is first established. It is helpful for generating intial responses.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastrtc import get_tts_model, Stream, ReplyOnPause
|
||||||
|
|
||||||
|
tts_client = get_tts_model()
|
||||||
|
|
||||||
|
|
||||||
|
def detection(audio: tuple[int, np.ndarray]):
|
||||||
|
# Implement any iterator that yields audio
|
||||||
|
# See "LLM Voice Chat" for a more complete example
|
||||||
|
yield audio
|
||||||
|
|
||||||
|
|
||||||
|
def startup():
|
||||||
|
for chunk in tts_client.stream_tts_sync("Welcome to the echo audio demo!"):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
stream = Stream(
|
||||||
|
handler=ReplyOnPause(detection, startup_fn=startup),
|
||||||
|
modality="audio",
|
||||||
|
mode="send-receive",
|
||||||
|
ui_args={"title": "Echo Audio"},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
<video width=98% src="https://github.com/user-attachments/assets/c6b1cb51-5790-4522-80c3-e24e58ef9f11" controls style="text-align: center"></video>
|
||||||
|
|
||||||
## Reply On Stopwords
|
## Reply On Stopwords
|
||||||
|
|
||||||
You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.
|
You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.
|
||||||
|
|||||||
Reference in New Issue
Block a user