mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
feat: Add optional startup function to ReplyOnPause (#170)
* feat: Add optional startup function to ReplyOnPause * feat: Implement startup_fn in ReplyOnStopWords * refactor: Remove redundant startup_fn implementation in ReplyOnStopWords * tweaks * revert --------- Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
@@ -111,7 +111,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
def start(
|
||||
async def start(
|
||||
self,
|
||||
):
|
||||
asyncio.create_task(self.process_frames())
|
||||
@@ -371,13 +371,13 @@ class VideoStreamHandler(VideoCallback):
|
||||
except MediaStreamError:
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
async def start(self):
|
||||
if not self.has_started:
|
||||
asyncio.create_task(self.process_frames())
|
||||
self.has_started = True
|
||||
|
||||
async def recv(self): # type: ignore
|
||||
self.start()
|
||||
await self.start()
|
||||
try:
|
||||
handler = cast(VideoStreamHandlerImpl, self.event_handler)
|
||||
if inspect.iscoroutinefunction(handler.video_emit):
|
||||
@@ -444,6 +444,12 @@ class AudioCallback(AudioStreamTrack):
|
||||
logger.debug("popped %d items from queue", i)
|
||||
self._start = None
|
||||
|
||||
async def wait_for_channel(self):
|
||||
if not self.event_handler.channel_set.is_set():
|
||||
await self.event_handler.channel_set.wait()
|
||||
if current_channel.get() != self.event_handler.channel:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
self.event_handler.set_channel(channel)
|
||||
@@ -477,9 +483,10 @@ class AudioCallback(AudioStreamTrack):
|
||||
logger.debug("MediaStreamError in process_input_frames")
|
||||
break
|
||||
|
||||
def start(self):
|
||||
async def start(self):
|
||||
if not self.has_started:
|
||||
loop = asyncio.get_running_loop()
|
||||
await self.wait_for_channel()
|
||||
if isinstance(self.event_handler, AsyncHandler):
|
||||
callable = self.event_handler.emit
|
||||
start_up = self.event_handler.start_up()
|
||||
@@ -525,7 +532,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
await self.event_handler.channel_set.wait()
|
||||
if current_channel.get() != self.event_handler.channel:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
self.start()
|
||||
await self.start()
|
||||
|
||||
frame = await self.queue.get()
|
||||
logger.debug("frame %s", frame)
|
||||
@@ -671,7 +678,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
except StopIteration:
|
||||
self.thread_quit.set()
|
||||
|
||||
def start(self):
|
||||
async def start(self):
|
||||
if not self.has_started:
|
||||
loop = asyncio.get_running_loop()
|
||||
callable = functools.partial(loop.run_in_executor, None, self.next)
|
||||
@@ -692,7 +699,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
self.start()
|
||||
await self.start()
|
||||
data = await self.queue.get()
|
||||
if data is None:
|
||||
self.stop()
|
||||
|
||||
Reference in New Issue
Block a user