mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
feat: Add optional startup function to ReplyOnPause (#170)
* feat: Add optional startup function to ReplyOnPause * feat: Implement startup_fn in ReplyOnStopWords * refactor: Remove redundant startup_fn implementation in ReplyOnStopWords * tweaks * revert --------- Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
@@ -68,6 +68,7 @@ class ReplyOnPause(StreamHandler):
|
||||
def __init__(
|
||||
self,
|
||||
fn: ReplyFnGenerator,
|
||||
startup_fn: Callable | None = None,
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: ModelOptions | None = None,
|
||||
can_interrupt: bool = True,
|
||||
@@ -97,14 +98,26 @@ class ReplyOnPause(StreamHandler):
|
||||
) = None
|
||||
self.model_options = model_options
|
||||
self.algo_options = algo_options or AlgoOptions()
|
||||
self.startup_fn = startup_fn
|
||||
|
||||
@property
|
||||
def _needs_additional_inputs(self) -> bool:
|
||||
return len(inspect.signature(self.fn).parameters) > 1
|
||||
|
||||
def start_up(self):
|
||||
if self.startup_fn:
|
||||
if self._needs_additional_inputs:
|
||||
self.wait_for_args_sync()
|
||||
args = self.latest_args[1:]
|
||||
else:
|
||||
args = ()
|
||||
self.generator = self.startup_fn(*args)
|
||||
self.event.set()
|
||||
|
||||
def copy(self):
|
||||
return ReplyOnPause(
|
||||
self.fn,
|
||||
self.startup_fn,
|
||||
self.algo_options,
|
||||
self.model_options,
|
||||
self.can_interrupt,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
from typing import Callable, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -33,6 +33,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
self,
|
||||
fn: ReplyFnGenerator,
|
||||
stop_words: list[str],
|
||||
startup_fn: Callable | None = None,
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: ModelOptions | None = None,
|
||||
can_interrupt: bool = True,
|
||||
@@ -45,6 +46,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
super().__init__(
|
||||
fn,
|
||||
algo_options=algo_options,
|
||||
startup_fn=startup_fn,
|
||||
model_options=model_options,
|
||||
can_interrupt=can_interrupt,
|
||||
expected_layout=expected_layout,
|
||||
@@ -149,6 +151,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
return ReplyOnStopWords(
|
||||
self.fn,
|
||||
self.stop_words,
|
||||
self.startup_fn,
|
||||
self.algo_options,
|
||||
self.model_options,
|
||||
self.can_interrupt,
|
||||
|
||||
@@ -111,7 +111,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
def start(
|
||||
async def start(
|
||||
self,
|
||||
):
|
||||
asyncio.create_task(self.process_frames())
|
||||
@@ -371,13 +371,13 @@ class VideoStreamHandler(VideoCallback):
|
||||
except MediaStreamError:
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
async def start(self):
|
||||
if not self.has_started:
|
||||
asyncio.create_task(self.process_frames())
|
||||
self.has_started = True
|
||||
|
||||
async def recv(self): # type: ignore
|
||||
self.start()
|
||||
await self.start()
|
||||
try:
|
||||
handler = cast(VideoStreamHandlerImpl, self.event_handler)
|
||||
if inspect.iscoroutinefunction(handler.video_emit):
|
||||
@@ -444,6 +444,12 @@ class AudioCallback(AudioStreamTrack):
|
||||
logger.debug("popped %d items from queue", i)
|
||||
self._start = None
|
||||
|
||||
async def wait_for_channel(self):
|
||||
if not self.event_handler.channel_set.is_set():
|
||||
await self.event_handler.channel_set.wait()
|
||||
if current_channel.get() != self.event_handler.channel:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
self.event_handler.set_channel(channel)
|
||||
@@ -477,9 +483,10 @@ class AudioCallback(AudioStreamTrack):
|
||||
logger.debug("MediaStreamError in process_input_frames")
|
||||
break
|
||||
|
||||
def start(self):
|
||||
async def start(self):
|
||||
if not self.has_started:
|
||||
loop = asyncio.get_running_loop()
|
||||
await self.wait_for_channel()
|
||||
if isinstance(self.event_handler, AsyncHandler):
|
||||
callable = self.event_handler.emit
|
||||
start_up = self.event_handler.start_up()
|
||||
@@ -525,7 +532,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
await self.event_handler.channel_set.wait()
|
||||
if current_channel.get() != self.event_handler.channel:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
self.start()
|
||||
await self.start()
|
||||
|
||||
frame = await self.queue.get()
|
||||
logger.debug("frame %s", frame)
|
||||
@@ -671,7 +678,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
except StopIteration:
|
||||
self.thread_quit.set()
|
||||
|
||||
def start(self):
|
||||
async def start(self):
|
||||
if not self.has_started:
|
||||
loop = asyncio.get_running_loop()
|
||||
callable = functools.partial(loop.run_in_executor, None, self.next)
|
||||
@@ -692,7 +699,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
self.start()
|
||||
await self.start()
|
||||
data = await self.queue.get()
|
||||
if data is None:
|
||||
self.stop()
|
||||
|
||||
@@ -156,6 +156,11 @@ async def player_worker_decode(
|
||||
break
|
||||
continue
|
||||
|
||||
if not isinstance(frame, tuple) and not isinstance(frame[1], np.ndarray):
|
||||
raise WebRTCError(
|
||||
"The frame must be a tuple containing a sample rate and a numpy array."
|
||||
)
|
||||
|
||||
if len(frame) == 2:
|
||||
sample_rate, audio_array = frame
|
||||
layout = "mono"
|
||||
@@ -199,7 +204,10 @@ async def player_worker_decode(
|
||||
exec = traceback.format_exc()
|
||||
print("traceback %s", exec)
|
||||
print("Error processing frame: %s", str(e))
|
||||
continue
|
||||
if isinstance(e, WebRTCError):
|
||||
raise e
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
def audio_to_bytes(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> bytes:
|
||||
|
||||
@@ -241,7 +241,7 @@ class WebRTCConnectionMixin:
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
elif self.mode == "send":
|
||||
cast(AudioCallback | VideoCallback, cb).start()
|
||||
asyncio.create_task(cast(AudioCallback | VideoCallback, cb).start())
|
||||
|
||||
if self.mode == "receive":
|
||||
if self.modality == "video":
|
||||
|
||||
Reference in New Issue
Block a user