Merge remote-tracking branch 'origin/main' into open-avatar-chat-0.4.0

This commit is contained in:
bingochaos
2025-06-17 20:39:40 +08:00
142 changed files with 117010 additions and 814 deletions

View File

@@ -12,15 +12,12 @@ import time
import traceback
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Generator
from dataclasses import dataclass
from typing import (
Any,
Generator,
Literal,
Tuple,
TypeAlias,
Union,
cast,
)
@@ -51,11 +48,11 @@ from fastrtc.utils import (
logger = logging.getLogger(__name__)
VideoNDArray: TypeAlias = Union[
np.ndarray[Any, np.dtype[np.uint8]],
np.ndarray[Any, np.dtype[np.uint16]],
np.ndarray[Any, np.dtype[np.float32]],
]
VideoNDArray: TypeAlias = (
np.ndarray[Any, np.dtype[np.uint8]]
| np.ndarray[Any, np.dtype[np.uint16]]
| np.ndarray[Any, np.dtype[np.float32]]
)
VideoEmitType = (
VideoNDArray
@@ -117,14 +114,14 @@ class VideoCallback(VideoStreamTrack):
self.channel_set.set()
def set_args(self, args: list[Any]):
self.latest_args = ["__webrtc_value__"] + list(args)
self.latest_args = list(args)
def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None
) -> list[Any]:
new_args = []
for val in args:
if isinstance(val, str) and val == "__webrtc_value__":
for i, val in enumerate(args):
if i == 0:
new_args.append(frame)
else:
new_args.append(val)
@@ -219,7 +216,7 @@ class VideoCallback(VideoStreamTrack):
else:
raise WebRTCError(str(e)) from e
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
async def next_timestamp(self) -> tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError
@@ -236,14 +233,48 @@ class VideoCallback(VideoStreamTrack):
class StreamHandlerBase(ABC):
"""
Base class for handling media streams in FastRTC.
Provides common attributes and methods for managing stream state,
communication channels, and basic configuration. This class is intended
to be subclassed by concrete stream handlers like `StreamHandler` or
`AsyncStreamHandler`.
Attributes:
expected_layout (Literal["mono", "stereo"]): The expected channel layout
of the input audio ('mono' or 'stereo').
output_sample_rate (int): The target sample rate for the output audio.
output_frame_size (int): The desired number of samples per output audio frame.
input_sample_rate (int): The expected sample rate of the input audio.
channel (DataChannel | None): The WebRTC data channel for communication.
channel_set (asyncio.Event): Event indicating if the data channel is set.
args_set (asyncio.Event): Event indicating if additional arguments are set.
latest_args (str | list[Any]): Stores the latest arguments received.
loop (asyncio.AbstractEventLoop): The asyncio event loop.
_resampler (av.AudioResampler | None): Internal audio resampler instance.
_clear_queue (Callable | None): Callback to clear the processing queue.
phone_mode (bool): Flag indicating if operating in telephone mode.
"""
def __init__(
self,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int | None = None,
output_frame_size: int | None = None, # Deprecated
input_sample_rate: int = 48000,
fps: int = 30,
) -> None:
"""
Initializes the StreamHandlerBase.
Args:
expected_layout: Expected input audio layout ('mono' or 'stereo').
output_sample_rate: Target output audio sample rate.
output_frame_size: Deprecated. Frame size is now derived from sample rate.
input_sample_rate: Expected input audio sample rate.
fps: The desired frame rate for the output audio.
"""
self.expected_layout = expected_layout
self.output_sample_rate = output_sample_rate
self.input_sample_rate = input_sample_rate
@@ -302,6 +333,12 @@ class StreamHandlerBase(ABC):
self._phone_mode = value
def set_channel(self, channel: DataChannel):
"""
Sets the data channel for communication and signals readiness.
Args:
channel: The WebRTC DataChannel instance.
"""
self._channel = channel
self.channel_set.set()
@@ -328,11 +365,25 @@ class StreamHandlerBase(ABC):
traceback.print_exc()
async def send_message(self, msg: str):
"""
Asynchronously sends a message over the data channel.
Args:
msg: The string message to send.
"""
if self.channel:
self.channel.send(msg)
logger.debug("Sent msg %s", msg)
def send_message_sync(self, msg: str):
"""
Synchronously sends a message over the data channel.
Runs the async `send_message` in the event loop and waits for completion.
Args:
msg: The string message to send.
"""
try:
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
logger.debug("Sent msg %s", msg)
@@ -340,17 +391,36 @@ class StreamHandlerBase(ABC):
logger.debug("Exception sending msg %s", e)
def set_args(self, args: list[Any]):
"""
Sets additional arguments received (e.g., from UI components).
Args:
args: A list of arguments.
"""
logger.debug("setting args in audio callback %s", args)
self.latest_args = ["__webrtc_value__"] + list(args)
self.latest_args = list(args)
self.args_set.set()
def reset(self):
"""Resets the argument set event."""
self.args_set.clear()
def shutdown(self):
"""Placeholder for shutdown logic. Subclasses can override."""
pass
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
"""
Resamples an incoming audio frame to the target format and sample rate.
Initializes the resampler on the first call.
Args:
frame: The input AudioFrame.
Yields:
Resampled AudioFrame(s).
"""
if self._resampler is None:
self._resampler = av.AudioResampler( # type: ignore
format="s16",
@@ -370,42 +440,109 @@ EmitType: TypeAlias = (
| tuple[int, npt.NDArray[np.int16 | np.float32], Literal["mono", "stereo"]]
| AdditionalOutputs
| tuple[tuple[int, npt.NDArray[np.int16 | np.float32]], AdditionalOutputs]
| CloseStream
| None
)
AudioEmitType = EmitType
class StreamHandler(StreamHandlerBase):
"""
Abstract base class for synchronous stream handlers.
Inherits from `StreamHandlerBase` and defines the core synchronous interface
for processing audio streams. Subclasses must implement `receive`, `emit`,
and `copy`.
"""
@abstractmethod
def receive(self, frame: tuple[int, npt.NDArray[np.int16]]) -> None:
"""
Process an incoming audio frame synchronously.
Args:
frame: A tuple containing the sample rate (int) and the audio data
as a numpy array (int16).
"""
pass
@abstractmethod
def emit(self) -> EmitType:
"""
Produce the next output chunk synchronously.
This method is called to generate the output to be sent back over the stream.
Returns:
An output item conforming to `EmitType`, which could be audio data,
additional outputs, control signals (like `CloseStream`), or None.
"""
pass
@abstractmethod
def copy(self, **kwargs) -> StreamHandler:
"""
Create a copy of this stream handler instance.
Used to create a new handler for each connection.
Returns:
A new instance of the concrete StreamHandler subclass.
"""
pass
def start_up(self):
"""Optional synchronous startup logic. Can be overridden by subclasses."""
pass
class AsyncStreamHandler(StreamHandlerBase):
"""
Abstract base class for asynchronous stream handlers.
Inherits from `StreamHandlerBase` and defines the core asynchronous interface
for processing audio streams using `async`/`await`. Subclasses must implement
`receive`, `emit`, and `copy`. The `start_up` method must also be a coroutine.
"""
@abstractmethod
async def receive(self, frame: tuple[int, npt.NDArray[np.int16]]) -> None:
"""
Process an incoming audio frame asynchronously.
Args:
frame: A tuple containing the sample rate (int) and the audio data
as a numpy array (int16).
"""
pass
@abstractmethod
async def emit(self) -> EmitType:
"""
Produce the next output chunk asynchronously.
This coroutine is called to generate the output to be sent back over the stream.
Returns:
An output item conforming to `EmitType`, which could be audio data,
additional outputs, control signals (like `CloseStream`), or None.
"""
pass
@abstractmethod
def copy(self, **kwargs) -> AsyncStreamHandler:
"""
Create a copy of this asynchronous stream handler instance.
Used to create a new handler for each connection.
Returns:
A new instance of the concrete AsyncStreamHandler subclass.
"""
pass
async def start_up(self):
"""Optional asynchronous startup logic. Must be a coroutine (async def)."""
pass
async def on_chat_datachannel(self, message: dict,channel):
@@ -416,30 +553,88 @@ StreamHandlerImpl = StreamHandler | AsyncStreamHandler
class AudioVideoStreamHandler(StreamHandler):
"""
Abstract base class for synchronous handlers processing both audio and video.
Inherits from `StreamHandler` (synchronous audio) and adds abstract methods
for handling video frames synchronously. Subclasses must implement the audio
methods (`receive`, `emit`) and the video methods (`video_receive`, `video_emit`),
as well as `copy`.
"""
@abstractmethod
def video_receive(self, frame: VideoFrame) -> None:
"""
Process an incoming video frame synchronously.
Args:
frame: The incoming aiortc `VideoFrame`.
"""
pass
@abstractmethod
def video_emit(self) -> VideoEmitType:
"""
Produce the next output video frame synchronously.
Returns:
An output item conforming to `VideoEmitType`, typically a numpy array
representing the video frame, or None.
"""
pass
@abstractmethod
def copy(self, **kwargs) -> AudioVideoStreamHandler:
"""
Create a copy of this audio-video stream handler instance.
Returns:
A new instance of the concrete AudioVideoStreamHandler subclass.
"""
pass
class AsyncAudioVideoStreamHandler(AsyncStreamHandler):
"""
Abstract base class for asynchronous handlers processing both audio and video.
Inherits from `AsyncStreamHandler` (asynchronous audio) and adds abstract
coroutines for handling video frames asynchronously. Subclasses must implement
the async audio methods (`receive`, `emit`, `start_up`) and the async video
methods (`video_receive`, `video_emit`), as well as `copy`.
"""
@abstractmethod
async def video_receive(self, frame: npt.NDArray[np.float32]) -> None:
"""
Process an incoming video frame asynchronously.
Args:
frame: The video frame data as a numpy array (float32).
Note: The type hint differs from the synchronous version.
Consider standardizing if possible.
"""
pass
@abstractmethod
async def video_emit(self) -> VideoEmitType:
"""
Produce the next output video frame asynchronously.
Returns:
An output item conforming to `VideoEmitType`, typically a numpy array
representing the video frame, or None.
"""
pass
@abstractmethod
def copy(self, **kwargs) -> AsyncAudioVideoStreamHandler:
"""
Create a copy of this asynchronous audio-video stream handler instance.
Returns:
A new instance of the concrete AsyncAudioVideoStreamHandler subclass.
"""
pass
@@ -716,7 +911,7 @@ class ServerToClientVideo(VideoStreamTrack):
self.latest_args = list(args)
self.args_set.set()
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
async def next_timestamp(self) -> tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError
@@ -739,7 +934,7 @@ class ServerToClientVideo(VideoStreamTrack):
current_context.set(self.context)
if self.generator is None:
self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args)
Generator[Any, None, Any], self.event_handler(*self.latest_args[1:])
)
try:
next_array, outputs = split_output(next(self.generator))
@@ -786,7 +981,7 @@ class ServerToClientAudio(AudioStreamTrack):
) -> None:
self.generator: Generator[Any, None, Any] | None = None
self.event_handler = event_handler
self.event_handler._clear_queue = self.clear_queue
self.event_handler._clear_queue = self.clear_queue # pyright: ignore
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.args_set = threading.Event()
@@ -816,7 +1011,7 @@ class ServerToClientAudio(AudioStreamTrack):
self.args_set.wait()
current_channel.set(self.channel)
if self.generator is None:
self.generator = self.event_handler(*self.latest_args)
self.generator = self.event_handler(*self.latest_args[1:])
if self.generator is not None:
try:
frame = next(self.generator)