mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
Raise errors automatically (#69)
* Add auto errors * change code --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
@@ -25,6 +25,7 @@ from .utils import (
|
||||
audio_to_bytes,
|
||||
audio_to_file,
|
||||
audio_to_float32,
|
||||
wait_for_item,
|
||||
)
|
||||
from .webrtc import (
|
||||
WebRTC,
|
||||
@@ -58,4 +59,5 @@ __all__ = [
|
||||
"Warning",
|
||||
"get_tts_model",
|
||||
"KokoroTTSOptions",
|
||||
"wait_for_item",
|
||||
]
|
||||
|
||||
@@ -233,3 +233,4 @@ class ReplyOnPause(StreamHandler):
|
||||
traceback.print_exc()
|
||||
logger.debug("Error in ReplyOnPause: %s", e)
|
||||
self.reset()
|
||||
raise e
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""gr.WebRTC() component."""
|
||||
"""WebRTC tracks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -277,9 +277,7 @@ class StreamHandler(StreamHandlerBase):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def emit(
|
||||
self,
|
||||
) -> EmitType:
|
||||
def emit(self) -> EmitType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -296,9 +294,7 @@ class AsyncStreamHandler(StreamHandlerBase):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def emit(
|
||||
self,
|
||||
) -> EmitType:
|
||||
async def emit(self) -> EmitType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -312,15 +308,13 @@ class AsyncStreamHandler(StreamHandlerBase):
|
||||
StreamHandlerImpl = StreamHandler | AsyncStreamHandler
|
||||
|
||||
|
||||
class AudioVideoStreamHandler(StreamHandlerBase):
|
||||
class AudioVideoStreamHandler(StreamHandler):
|
||||
@abstractmethod
|
||||
def video_receive(self, frame: VideoFrame) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def video_emit(
|
||||
self,
|
||||
) -> VideoEmitType:
|
||||
def video_emit(self) -> VideoEmitType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -328,15 +322,13 @@ class AudioVideoStreamHandler(StreamHandlerBase):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncAudioVideoStreamHandler(StreamHandlerBase):
|
||||
class AsyncAudioVideoStreamHandler(AsyncStreamHandler):
|
||||
@abstractmethod
|
||||
async def video_receive(self, frame: npt.NDArray[np.float32]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def video_emit(
|
||||
self,
|
||||
) -> VideoEmitType:
|
||||
async def video_emit(self) -> VideoEmitType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -6,7 +6,9 @@ import logging
|
||||
import tempfile
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
|
||||
|
||||
import functools
|
||||
import traceback
|
||||
import inspect
|
||||
import av
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
@@ -353,3 +355,47 @@ async def async_aggregate_bytes_to_16bit(chunks_iterator):
|
||||
if to_process:
|
||||
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
|
||||
yield audio_array
|
||||
|
||||
|
||||
def webrtc_error_handler(func):
|
||||
"""Decorator to catch exceptions and raise WebRTCError with stacktrace."""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, WebRTCError):
|
||||
raise e
|
||||
else:
|
||||
raise WebRTCError(str(e)) from e
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, WebRTCError):
|
||||
raise e
|
||||
else:
|
||||
raise WebRTCError(str(e)) from e
|
||||
|
||||
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
async def wait_for_item(queue: asyncio.Queue, timeout: float = 0.1) -> Any:
|
||||
"""
|
||||
Wait for an item from an asyncio.Queue with a timeout.
|
||||
|
||||
This function attempts to retrieve an item from the queue using asyncio.wait_for.
|
||||
If the timeout is reached, it returns None.
|
||||
|
||||
This is useful to avoid blocking `emit` when the queue is empty.
|
||||
"""
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(queue.get(), timeout=timeout)
|
||||
except (TimeoutError, asyncio.TimeoutError):
|
||||
return None
|
||||
|
||||
@@ -37,6 +37,7 @@ from fastrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
DataChannel,
|
||||
create_message,
|
||||
webrtc_error_handler,
|
||||
)
|
||||
|
||||
Track = (
|
||||
@@ -148,8 +149,16 @@ class WebRTCConnectionMixin:
|
||||
|
||||
if isinstance(self.event_handler, StreamHandlerBase):
|
||||
handler = self.event_handler.copy()
|
||||
handler.emit = webrtc_error_handler(handler.emit)
|
||||
handler.receive = webrtc_error_handler(handler.receive)
|
||||
handler.start_up = webrtc_error_handler(handler.start_up)
|
||||
handler.shutdown = webrtc_error_handler(handler.shutdown)
|
||||
if hasattr(handler, "video_receive"):
|
||||
handler.video_receive = webrtc_error_handler(handler.video_receive)
|
||||
if hasattr(handler, "video_emit"):
|
||||
handler.video_emit = webrtc_error_handler(handler.video_emit)
|
||||
else:
|
||||
handler = cast(Callable, self.event_handler)
|
||||
handler = webrtc_error_handler(cast(Callable, self.event_handler))
|
||||
|
||||
self.handlers[body["webrtc_id"]] = handler
|
||||
|
||||
|
||||
Reference in New Issue
Block a user