mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Make sure channel is always set, be able to raise UI errors with WebRTCError (#45)
* Code * test * code * user guide
This commit is contained in:
@@ -8,6 +8,8 @@ from .reply_on_stopwords import ReplyOnStopWords
|
||||
from .speech_to_text import stt, stt_for_chunks
|
||||
from .utils import (
|
||||
AdditionalOutputs,
|
||||
Warning,
|
||||
WebRTCError,
|
||||
aggregate_bytes_to_16bit,
|
||||
async_aggregate_bytes_to_16bit,
|
||||
audio_to_bytes,
|
||||
@@ -35,4 +37,6 @@ __all__ = [
|
||||
"stt_for_chunks",
|
||||
"StreamHandler",
|
||||
"WebRTC",
|
||||
"WebRTCError",
|
||||
"Warning",
|
||||
]
|
||||
|
||||
@@ -75,6 +75,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
||||
) -> bool:
|
||||
"""Take in the stream, determine if a pause happened"""
|
||||
import librosa
|
||||
|
||||
duration = len(audio) / sampling_rate
|
||||
|
||||
if duration >= self.algo_options.audio_chunk_duration:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import fractions
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Protocol, TypedDict, cast
|
||||
|
||||
import av
|
||||
@@ -29,6 +31,55 @@ class DataChannel(Protocol):
|
||||
def send(self, message: str) -> None: ...
|
||||
|
||||
|
||||
current_channel: ContextVar[DataChannel | None] = ContextVar(
|
||||
"current_channel", default=None
|
||||
)
|
||||
|
||||
|
||||
def _send_log(message: str, type: str) -> None:
|
||||
async def _send(channel: DataChannel) -> None:
|
||||
channel.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": type,
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if channel := current_channel.get():
|
||||
print("channel", channel)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
asyncio.run_coroutine_threadsafe(_send(channel), loop)
|
||||
except RuntimeError:
|
||||
asyncio.run(_send(channel))
|
||||
|
||||
|
||||
def Warning( # noqa: N802
|
||||
message: str = "Warning issued.",
|
||||
):
|
||||
"""
|
||||
Send a warning message that is deplayed in the UI of the application.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : str
|
||||
The warning message to send
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
_send_log(message, "warning")
|
||||
|
||||
|
||||
class WebRTCError(Exception):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
_send_log(message, "error")
|
||||
|
||||
|
||||
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||
if isinstance(data, AdditionalOutputs):
|
||||
return None, data
|
||||
|
||||
@@ -44,6 +44,7 @@ from gradio_client import handle_file
|
||||
from gradio_webrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
DataChannel,
|
||||
current_channel,
|
||||
player_worker_decode,
|
||||
split_output,
|
||||
)
|
||||
@@ -84,9 +85,12 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.thread_quit = asyncio.Event()
|
||||
self.mode = mode
|
||||
self.channel_set = asyncio.Event()
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
current_channel.set(channel)
|
||||
self.channel_set.set()
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||
@@ -122,6 +126,12 @@ class VideoCallback(VideoStreamTrack):
|
||||
logger.debug("video callback stop")
|
||||
self.thread_quit.set()
|
||||
|
||||
async def wait_for_channel(self):
|
||||
if not self.channel_set.is_set():
|
||||
await self.channel_set.wait()
|
||||
if current_channel.get() != self.channel:
|
||||
current_channel.set(self.channel)
|
||||
|
||||
async def recv(self):
|
||||
try:
|
||||
try:
|
||||
@@ -129,6 +139,8 @@ class VideoCallback(VideoStreamTrack):
|
||||
except MediaStreamError:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
await self.wait_for_channel()
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
|
||||
if self.latest_args == "not_set":
|
||||
@@ -180,6 +192,7 @@ class StreamHandlerBase(ABC):
|
||||
self._channel: DataChannel | None = None
|
||||
self._loop: asyncio.AbstractEventLoop
|
||||
self.args_set = asyncio.Event()
|
||||
self.channel_set = asyncio.Event()
|
||||
|
||||
@property
|
||||
def loop(self) -> asyncio.AbstractEventLoop:
|
||||
@@ -191,6 +204,7 @@ class StreamHandlerBase(ABC):
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self._channel = channel
|
||||
self.channel_set.set()
|
||||
|
||||
async def fetch_args(
|
||||
self,
|
||||
@@ -203,6 +217,9 @@ class StreamHandlerBase(ABC):
|
||||
await self.fetch_args()
|
||||
await self.args_set.wait()
|
||||
|
||||
def wait_for_args_sync(self):
|
||||
asyncio.run_coroutine_threadsafe(self.wait_for_args(), self.loop).result()
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
logger.debug("setting args in audio callback %s", args)
|
||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||
@@ -275,6 +292,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
self.current_timestamp = 0
|
||||
@@ -286,7 +304,6 @@ class AudioCallback(AudioStreamTrack):
|
||||
self.last_timestamp = 0
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
super().__init__()
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
@@ -295,6 +312,10 @@ class AudioCallback(AudioStreamTrack):
|
||||
def set_args(self, args: list[Any]):
|
||||
self.event_handler.set_args(args)
|
||||
|
||||
def event_handler_receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
return cast(Callable, self.event_handler.receive)(frame)
|
||||
|
||||
async def process_input_frames(self) -> None:
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
@@ -307,7 +328,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
)
|
||||
else:
|
||||
await anyio.to_thread.run_sync(
|
||||
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
||||
self.event_handler_receive, (frame.sample_rate, numpy_array)
|
||||
)
|
||||
except MediaStreamError:
|
||||
logger.debug("MediaStreamError in process_input_frames")
|
||||
@@ -342,7 +363,13 @@ class AudioCallback(AudioStreamTrack):
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
if not self.event_handler.channel_set.is_set():
|
||||
await self.event_handler.channel_set.wait()
|
||||
if current_channel.get() != self.event_handler.channel:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
|
||||
self.start()
|
||||
|
||||
frame = await self.queue.get()
|
||||
logger.debug("frame %s", frame)
|
||||
|
||||
@@ -415,7 +442,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.generator = cast(
|
||||
Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
||||
)
|
||||
|
||||
current_channel.set(self.channel)
|
||||
try:
|
||||
next_array, outputs = split_output(next(self.generator))
|
||||
if (
|
||||
@@ -470,6 +497,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
|
||||
def next(self) -> tuple[int, np.ndarray] | None:
|
||||
self.args_set.wait()
|
||||
current_channel.set(self.channel)
|
||||
if self.generator is None:
|
||||
self.generator = self.event_handler(*self.latest_args)
|
||||
if self.generator is not None:
|
||||
@@ -946,6 +974,7 @@ class WebRTC(Component):
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer) # type: ignore
|
||||
logger.debug("done handling offer about to return")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return {
|
||||
"sdp": pc.localDescription.sdp,
|
||||
|
||||
Reference in New Issue
Block a user