Make sure channel is always set, be able to raise UI errors with WebRTCError (#45)

* Code

* test

* code

* user guide
This commit is contained in:
Freddy Boulton
2024-12-23 15:21:10 -05:00
committed by GitHub
parent e057fc1502
commit 5812fd5aeb
11 changed files with 289 additions and 7 deletions

View File

@@ -8,6 +8,8 @@ from .reply_on_stopwords import ReplyOnStopWords
from .speech_to_text import stt, stt_for_chunks
from .utils import (
AdditionalOutputs,
Warning,
WebRTCError,
aggregate_bytes_to_16bit,
async_aggregate_bytes_to_16bit,
audio_to_bytes,
@@ -35,4 +37,6 @@ __all__ = [
"stt_for_chunks",
"StreamHandler",
"WebRTC",
"WebRTCError",
"Warning",
]

View File

@@ -75,6 +75,7 @@ class ReplyOnStopWords(ReplyOnPause):
) -> bool:
"""Take in the stream, determine if a pause happened"""
import librosa
duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration:

View File

@@ -1,8 +1,10 @@
import asyncio
import fractions
import io
import json
import logging
import tempfile
from contextvars import ContextVar
from typing import Any, Callable, Protocol, TypedDict, cast
import av
@@ -29,6 +31,55 @@ class DataChannel(Protocol):
def send(self, message: str) -> None: ...
current_channel: ContextVar[DataChannel | None] = ContextVar(
"current_channel", default=None
)
def _send_log(message: str, type: str) -> None:
async def _send(channel: DataChannel) -> None:
channel.send(
json.dumps(
{
"type": type,
"message": message,
}
)
)
if channel := current_channel.get():
print("channel", channel)
try:
loop = asyncio.get_running_loop()
asyncio.run_coroutine_threadsafe(_send(channel), loop)
except RuntimeError:
asyncio.run(_send(channel))
def Warning( # noqa: N802
message: str = "Warning issued.",
):
"""
Send a warning message that is deplayed in the UI of the application.
Parameters
----------
audio : str
The warning message to send
Returns
-------
None
"""
_send_log(message, "warning")
class WebRTCError(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)
_send_log(message, "error")
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
if isinstance(data, AdditionalOutputs):
return None, data

View File

@@ -44,6 +44,7 @@ from gradio_client import handle_file
from gradio_webrtc.utils import (
AdditionalOutputs,
DataChannel,
current_channel,
player_worker_decode,
split_output,
)
@@ -84,9 +85,12 @@ class VideoCallback(VideoStreamTrack):
self.set_additional_outputs = set_additional_outputs
self.thread_quit = asyncio.Event()
self.mode = mode
self.channel_set = asyncio.Event()
def set_channel(self, channel: DataChannel):
self.channel = channel
current_channel.set(channel)
self.channel_set.set()
def set_args(self, args: list[Any]):
self.latest_args = ["__webrtc_value__"] + list(args)
@@ -122,6 +126,12 @@ class VideoCallback(VideoStreamTrack):
logger.debug("video callback stop")
self.thread_quit.set()
async def wait_for_channel(self):
if not self.channel_set.is_set():
await self.channel_set.wait()
if current_channel.get() != self.channel:
current_channel.set(self.channel)
async def recv(self):
try:
try:
@@ -129,6 +139,8 @@ class VideoCallback(VideoStreamTrack):
except MediaStreamError:
self.stop()
return
await self.wait_for_channel()
frame_array = frame.to_ndarray(format="bgr24")
if self.latest_args == "not_set":
@@ -180,6 +192,7 @@ class StreamHandlerBase(ABC):
self._channel: DataChannel | None = None
self._loop: asyncio.AbstractEventLoop
self.args_set = asyncio.Event()
self.channel_set = asyncio.Event()
@property
def loop(self) -> asyncio.AbstractEventLoop:
@@ -191,6 +204,7 @@ class StreamHandlerBase(ABC):
def set_channel(self, channel: DataChannel):
self._channel = channel
self.channel_set.set()
async def fetch_args(
self,
@@ -203,6 +217,9 @@ class StreamHandlerBase(ABC):
await self.fetch_args()
await self.args_set.wait()
def wait_for_args_sync(self):
asyncio.run_coroutine_threadsafe(self.wait_for_args(), self.loop).result()
def set_args(self, args: list[Any]):
logger.debug("setting args in audio callback %s", args)
self.latest_args = ["__webrtc_value__"] + list(args)
@@ -275,6 +292,7 @@ class AudioCallback(AudioStreamTrack):
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
super().__init__()
self.track = track
self.event_handler = event_handler
self.current_timestamp = 0
@@ -286,7 +304,6 @@ class AudioCallback(AudioStreamTrack):
self.last_timestamp = 0
self.channel = channel
self.set_additional_outputs = set_additional_outputs
super().__init__()
def set_channel(self, channel: DataChannel):
self.channel = channel
@@ -295,6 +312,10 @@ class AudioCallback(AudioStreamTrack):
def set_args(self, args: list[Any]):
self.event_handler.set_args(args)
def event_handler_receive(self, frame: tuple[int, np.ndarray]) -> None:
current_channel.set(self.event_handler.channel)
return cast(Callable, self.event_handler.receive)(frame)
async def process_input_frames(self) -> None:
while not self.thread_quit.is_set():
try:
@@ -307,7 +328,7 @@ class AudioCallback(AudioStreamTrack):
)
else:
await anyio.to_thread.run_sync(
self.event_handler.receive, (frame.sample_rate, numpy_array)
self.event_handler_receive, (frame.sample_rate, numpy_array)
)
except MediaStreamError:
logger.debug("MediaStreamError in process_input_frames")
@@ -342,7 +363,13 @@ class AudioCallback(AudioStreamTrack):
if self.readyState != "live":
raise MediaStreamError
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
self.start()
frame = await self.queue.get()
logger.debug("frame %s", frame)
@@ -415,7 +442,7 @@ class ServerToClientVideo(VideoStreamTrack):
self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args)
)
current_channel.set(self.channel)
try:
next_array, outputs = split_output(next(self.generator))
if (
@@ -470,6 +497,7 @@ class ServerToClientAudio(AudioStreamTrack):
def next(self) -> tuple[int, np.ndarray] | None:
self.args_set.wait()
current_channel.set(self.channel)
if self.generator is None:
self.generator = self.event_handler(*self.latest_args)
if self.generator is not None:
@@ -946,6 +974,7 @@ class WebRTC(Component):
answer = await pc.createAnswer()
await pc.setLocalDescription(answer) # type: ignore
logger.debug("done handling offer about to return")
await asyncio.sleep(0.1)
return {
"sdp": pc.localDescription.sdp,