Raise errors automatically (#69)

* Add auto errors

* change code

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
Freddy Boulton
2025-02-24 20:21:25 -05:00
committed by GitHub
parent c36fb8da50
commit 5a4693ee0b
20 changed files with 297 additions and 344 deletions

View File

@@ -25,6 +25,7 @@ from .utils import (
audio_to_bytes,
audio_to_file,
audio_to_float32,
wait_for_item,
)
from .webrtc import (
WebRTC,
@@ -58,4 +59,5 @@ __all__ = [
"Warning",
"get_tts_model",
"KokoroTTSOptions",
"wait_for_item",
]

View File

@@ -233,3 +233,4 @@ class ReplyOnPause(StreamHandler):
traceback.print_exc()
logger.debug("Error in ReplyOnPause: %s", e)
self.reset()
raise e

View File

@@ -1,4 +1,4 @@
"""gr.WebRTC() component."""
"""WebRTC tracks."""
from __future__ import annotations
@@ -277,9 +277,7 @@ class StreamHandler(StreamHandlerBase):
pass
@abstractmethod
def emit(
self,
) -> EmitType:
def emit(self) -> EmitType:
pass
@abstractmethod
@@ -296,9 +294,7 @@ class AsyncStreamHandler(StreamHandlerBase):
pass
@abstractmethod
async def emit(
self,
) -> EmitType:
async def emit(self) -> EmitType:
pass
@abstractmethod
@@ -312,15 +308,13 @@ class AsyncStreamHandler(StreamHandlerBase):
StreamHandlerImpl = StreamHandler | AsyncStreamHandler
class AudioVideoStreamHandler(StreamHandlerBase):
class AudioVideoStreamHandler(StreamHandler):
@abstractmethod
def video_receive(self, frame: VideoFrame) -> None:
pass
@abstractmethod
def video_emit(
self,
) -> VideoEmitType:
def video_emit(self) -> VideoEmitType:
pass
@abstractmethod
@@ -328,15 +322,13 @@ class AudioVideoStreamHandler(StreamHandlerBase):
pass
class AsyncAudioVideoStreamHandler(StreamHandlerBase):
class AsyncAudioVideoStreamHandler(AsyncStreamHandler):
@abstractmethod
async def video_receive(self, frame: npt.NDArray[np.float32]) -> None:
pass
@abstractmethod
async def video_emit(
self,
) -> VideoEmitType:
async def video_emit(self) -> VideoEmitType:
pass
@abstractmethod

View File

@@ -6,7 +6,9 @@ import logging
import tempfile
from contextvars import ContextVar
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
import functools
import traceback
import inspect
import av
import numpy as np
from numpy.typing import NDArray
@@ -353,3 +355,47 @@ async def async_aggregate_bytes_to_16bit(chunks_iterator):
if to_process:
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
yield audio_array
def webrtc_error_handler(func):
"""Decorator to catch exceptions and raise WebRTCError with stacktrace."""
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
traceback.print_exc()
if isinstance(e, WebRTCError):
raise e
else:
raise WebRTCError(str(e)) from e
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
traceback.print_exc()
if isinstance(e, WebRTCError):
raise e
else:
raise WebRTCError(str(e)) from e
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper
async def wait_for_item(queue: asyncio.Queue, timeout: float = 0.1) -> Any:
"""
Wait for an item from an asyncio.Queue with a timeout.
This function attempts to retrieve an item from the queue using asyncio.wait_for.
If the timeout is reached, it returns None.
This is useful to avoid blocking `emit` when the queue is empty.
"""
try:
return await asyncio.wait_for(queue.get(), timeout=timeout)
except (TimeoutError, asyncio.TimeoutError):
return None

View File

@@ -37,6 +37,7 @@ from fastrtc.utils import (
AdditionalOutputs,
DataChannel,
create_message,
webrtc_error_handler,
)
Track = (
@@ -148,8 +149,16 @@ class WebRTCConnectionMixin:
if isinstance(self.event_handler, StreamHandlerBase):
handler = self.event_handler.copy()
handler.emit = webrtc_error_handler(handler.emit)
handler.receive = webrtc_error_handler(handler.receive)
handler.start_up = webrtc_error_handler(handler.start_up)
handler.shutdown = webrtc_error_handler(handler.shutdown)
if hasattr(handler, "video_receive"):
handler.video_receive = webrtc_error_handler(handler.video_receive)
if hasattr(handler, "video_emit"):
handler.video_emit = webrtc_error_handler(handler.video_emit)
else:
handler = cast(Callable, self.event_handler)
handler = webrtc_error_handler(cast(Callable, self.event_handler))
self.handlers[body["webrtc_id"]] = handler