Files
gradio-webrtc/backend/fastrtc/tracks.py
Freddy Boulton 5a4693ee0b Raise errors automatically (#69)
* Add auto errors

* change code

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
2025-02-24 20:21:25 -05:00

682 lines
21 KiB
Python

"""WebRTC tracks."""
from __future__ import annotations
import asyncio
import functools
import inspect
import logging
import threading
import time
import traceback
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import (
Any,
Generator,
Literal,
TypeAlias,
Union,
cast,
)
import anyio.to_thread
import av
import numpy as np
from aiortc import (
AudioStreamTrack,
MediaStreamTrack,
VideoStreamTrack,
)
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
from aiortc.mediastreams import MediaStreamError
from numpy import typing as npt
from fastrtc.utils import (
AdditionalOutputs,
DataChannel,
create_message,
current_channel,
player_worker_decode,
split_output,
)
logger = logging.getLogger(__name__)
VideoNDArray: TypeAlias = Union[
np.ndarray[Any, np.dtype[np.uint8]],
np.ndarray[Any, np.dtype[np.uint16]],
np.ndarray[Any, np.dtype[np.float32]],
]
VideoEmitType = (
VideoNDArray | tuple[VideoNDArray, AdditionalOutputs] | AdditionalOutputs
)
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
class VideoCallback(VideoStreamTrack):
"""
This works for streaming input and output
"""
kind = "video"
def __init__(
self,
track: MediaStreamTrack,
event_handler: VideoEventHandler,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive",
) -> None:
super().__init__() # don't forget this!
self.track = track
self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set"
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.thread_quit = asyncio.Event()
self.mode = mode
self.channel_set = asyncio.Event()
self.has_started = False
def set_channel(self, channel: DataChannel):
self.channel = channel
current_channel.set(channel)
self.channel_set.set()
def set_args(self, args: list[Any]):
self.latest_args = ["__webrtc_value__"] + list(args)
def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None
) -> list[Any]:
new_args = []
for val in args:
if isinstance(val, str) and val == "__webrtc_value__":
new_args.append(frame)
else:
new_args.append(val)
return new_args
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
async def process_frames(self):
while not self.thread_quit.is_set():
try:
await self.recv()
except TimeoutError:
continue
def start(
self,
):
asyncio.create_task(self.process_frames())
def stop(self):
super().stop()
logger.debug("video callback stop")
self.thread_quit.set()
async def wait_for_channel(self):
if not self.channel_set.is_set():
await self.channel_set.wait()
if current_channel.get() != self.channel:
current_channel.set(self.channel)
async def recv(self): # type: ignore
try:
try:
frame = cast(VideoFrame, await self.track.recv())
except MediaStreamError:
self.stop()
return
await self.wait_for_channel()
frame_array = frame.to_ndarray(format="bgr24")
if self.latest_args == "not_set":
return frame
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array, outputs = split_output(self.event_handler(*args))
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send(create_message("fetch_output", []))
if array is None and self.mode == "send":
return
new_frame = self.array_to_frame(array)
if frame:
new_frame.pts = frame.pts
new_frame.time_base = frame.time_base
else:
pts, time_base = await self.next_timestamp()
new_frame.pts = pts
new_frame.time_base = time_base
return new_frame
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
class StreamHandlerBase(ABC):
def __init__(
self,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 960,
input_sample_rate: int = 48000,
) -> None:
self.expected_layout = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.input_sample_rate = input_sample_rate
self.latest_args: list[Any] = []
self._resampler = None
self._channel: DataChannel | None = None
self._loop: asyncio.AbstractEventLoop
self.args_set = asyncio.Event()
self.channel_set = asyncio.Event()
self._phone_mode = False
@property
def loop(self) -> asyncio.AbstractEventLoop:
return cast(asyncio.AbstractEventLoop, self._loop)
@property
def channel(self) -> DataChannel | None:
return self._channel
@property
def phone_mode(self) -> bool:
return self._phone_mode
@phone_mode.setter
def phone_mode(self, value: bool):
self._phone_mode = value
def set_channel(self, channel: DataChannel):
self._channel = channel
self.channel_set.set()
async def fetch_args(
self,
):
if self.channel:
self.channel.send(create_message("send_input", []))
logger.debug("Sent send_input")
async def wait_for_args(self):
if not self.phone_mode:
await self.fetch_args()
await self.args_set.wait()
else:
self.args_set.set()
def wait_for_args_sync(self):
try:
asyncio.run_coroutine_threadsafe(self.wait_for_args(), self.loop).result()
except Exception:
import traceback
traceback.print_exc()
async def send_message(self, msg: str):
if self.channel:
self.channel.send(msg)
logger.debug("Sent msg %s", msg)
def send_message_sync(self, msg: str):
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
logger.debug("Sent msg %s", msg)
def set_args(self, args: list[Any]):
logger.debug("setting args in audio callback %s", args)
self.latest_args = ["__webrtc_value__"] + list(args)
self.args_set.set()
def reset(self):
self.args_set.clear()
def shutdown(self):
pass
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
if self._resampler is None:
self._resampler = av.AudioResampler( # type: ignore
format="s16",
layout=self.expected_layout,
rate=self.input_sample_rate,
frame_size=frame.samples,
)
yield from self._resampler.resample(frame)
EmitType: TypeAlias = (
tuple[int, npt.NDArray[np.int16 | np.float32]]
| tuple[int, npt.NDArray[np.int16 | np.float32], Literal["mono", "stereo"]]
| AdditionalOutputs
| tuple[tuple[int, npt.NDArray[np.int16 | np.float32]], AdditionalOutputs]
| None
)
AudioEmitType = EmitType
class StreamHandler(StreamHandlerBase):
@abstractmethod
def receive(self, frame: tuple[int, npt.NDArray[np.int16]]) -> None:
pass
@abstractmethod
def emit(self) -> EmitType:
pass
@abstractmethod
def copy(self) -> StreamHandler:
pass
def start_up(self):
pass
class AsyncStreamHandler(StreamHandlerBase):
@abstractmethod
async def receive(self, frame: tuple[int, npt.NDArray[np.int16]]) -> None:
pass
@abstractmethod
async def emit(self) -> EmitType:
pass
@abstractmethod
def copy(self) -> AsyncStreamHandler:
pass
async def start_up(self):
pass
StreamHandlerImpl = StreamHandler | AsyncStreamHandler
class AudioVideoStreamHandler(StreamHandler):
@abstractmethod
def video_receive(self, frame: VideoFrame) -> None:
pass
@abstractmethod
def video_emit(self) -> VideoEmitType:
pass
@abstractmethod
def copy(self) -> AudioVideoStreamHandler:
pass
class AsyncAudioVideoStreamHandler(AsyncStreamHandler):
@abstractmethod
async def video_receive(self, frame: npt.NDArray[np.float32]) -> None:
pass
@abstractmethod
async def video_emit(self) -> VideoEmitType:
pass
@abstractmethod
def copy(self) -> AsyncAudioVideoStreamHandler:
pass
VideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler
HandlerType = StreamHandlerImpl | VideoStreamHandlerImpl | VideoEventHandler | Callable
class VideoStreamHandler(VideoCallback):
async def process_frames(self):
while not self.thread_quit.is_set():
try:
await self.channel_set.wait()
frame = cast(VideoFrame, await self.track.recv())
frame_array = frame.to_ndarray(format="bgr24")
handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_receive):
await handler.video_receive(frame_array)
else:
handler.video_receive(frame_array) # type: ignore
except MediaStreamError:
self.stop()
def start(self):
if not self.has_started:
asyncio.create_task(self.process_frames())
self.has_started = True
async def recv(self): # type: ignore
self.start()
try:
handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_emit):
outputs = await handler.video_emit()
else:
outputs = handler.video_emit()
array, outputs = split_output(outputs)
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send(create_message("fetch_output", []))
if array is None and self.mode == "send":
return
new_frame = self.array_to_frame(array)
# Will probably have to give developer ability to set pts and time_base
pts, time_base = await self.next_timestamp()
new_frame.pts = pts
new_frame.time_base = time_base
return new_frame
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
class AudioCallback(AudioStreamTrack):
kind = "audio"
def __init__(
self,
track: MediaStreamTrack,
event_handler: StreamHandlerBase,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
super().__init__()
self.track = track
self.event_handler = cast(StreamHandlerImpl, event_handler)
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.queue = asyncio.Queue()
self.thread_quit = asyncio.Event()
self._start: float | None = None
self.has_started = False
self.last_timestamp = 0
self.channel = channel
self.set_additional_outputs = set_additional_outputs
def set_channel(self, channel: DataChannel):
self.channel = channel
self.event_handler.set_channel(channel)
def set_args(self, args: list[Any]):
self.event_handler.set_args(args)
def event_handler_receive(self, frame: tuple[int, np.ndarray]) -> None:
current_channel.set(self.event_handler.channel)
return cast(Callable, self.event_handler.receive)(frame)
def event_handler_emit(self) -> EmitType:
current_channel.set(self.event_handler.channel)
return cast(Callable, self.event_handler.emit)()
async def process_input_frames(self) -> None:
while not self.thread_quit.is_set():
try:
frame = cast(AudioFrame, await self.track.recv())
for frame in self.event_handler.resample(frame):
numpy_array = frame.to_ndarray()
if isinstance(self.event_handler, AsyncHandler):
await self.event_handler.receive(
(frame.sample_rate, numpy_array) # type: ignore
)
else:
await anyio.to_thread.run_sync(
self.event_handler_receive, (frame.sample_rate, numpy_array)
)
except MediaStreamError:
logger.debug("MediaStreamError in process_input_frames")
break
def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
if isinstance(self.event_handler, AsyncHandler):
callable = self.event_handler.emit
start_up = self.event_handler.start_up()
else:
callable = functools.partial(
loop.run_in_executor, None, self.event_handler_emit
)
start_up = anyio.to_thread.run_sync(self.event_handler.start_up)
self.process_input_task = asyncio.create_task(self.process_input_frames())
self.process_input_task.add_done_callback(
lambda _: logger.debug("process_input_done")
)
self.start_up_task = asyncio.create_task(start_up)
self.start_up_task.add_done_callback(
lambda _: logger.debug("start_up_done")
)
self.decode_task = asyncio.create_task(
player_worker_decode(
callable,
self.queue,
self.thread_quit,
lambda: self.channel,
self.set_additional_outputs,
False,
self.event_handler.output_sample_rate,
self.event_handler.output_frame_size,
)
)
self.decode_task.add_done_callback(lambda _: logger.debug("decode_done"))
self.has_started = True
async def recv(self): # type: ignore
try:
if self.readyState != "live":
raise MediaStreamError
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
self.start()
frame = await self.queue.get()
logger.debug("frame %s", frame)
data_time = frame.time
if time.time() - self.last_timestamp > 10 * (
self.event_handler.output_frame_size
/ self.event_handler.output_sample_rate
):
self._start = None
# control playback rate
if self._start is None:
self._start = time.time() - data_time # type: ignore
else:
wait = self._start + data_time - time.time()
await asyncio.sleep(wait)
self.last_timestamp = time.time()
return frame
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
def stop(self):
logger.debug("audio callback stop")
self.thread_quit.set()
super().stop()
class ServerToClientVideo(VideoStreamTrack):
"""
This works for streaming input and output
"""
kind = "video"
def __init__(
self,
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
super().__init__() # don't forget this!
self.event_handler = event_handler
self.args_set = asyncio.Event()
self.latest_args: str | list[Any] = "not_set"
self.generator: Generator[Any, None, Any] | None = None
self.channel = channel
self.set_additional_outputs = set_additional_outputs
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
def set_channel(self, channel: DataChannel):
self.channel = channel
def set_args(self, args: list[Any]):
self.latest_args = list(args)
self.args_set.set()
async def recv(self): # type: ignore
try:
pts, time_base = await self.next_timestamp()
await self.args_set.wait()
if self.generator is None:
self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args)
)
current_channel.set(self.channel)
try:
next_array, outputs = split_output(next(self.generator))
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send(create_message("fetch_output", []))
except StopIteration:
self.stop()
return
next_frame = self.array_to_frame(next_array)
next_frame.pts = pts
next_frame.time_base = time_base
return next_frame
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s ", exec)
class ServerToClientAudio(AudioStreamTrack):
kind = "audio"
def __init__(
self,
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
self.generator: Generator[Any, None, Any] | None = None
self.event_handler = event_handler
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.args_set = threading.Event()
self.queue = asyncio.Queue()
self.thread_quit = asyncio.Event()
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.has_started = False
self._start: float | None = None
super().__init__()
def set_channel(self, channel: DataChannel):
self.channel = channel
def set_args(self, args: list[Any]):
self.latest_args = list(args)
self.args_set.set()
def next(self) -> tuple[int, np.ndarray] | None:
self.args_set.wait()
current_channel.set(self.channel)
if self.generator is None:
self.generator = self.event_handler(*self.latest_args)
if self.generator is not None:
try:
frame = next(self.generator)
return frame
except StopIteration:
self.thread_quit.set()
def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
callable = functools.partial(loop.run_in_executor, None, self.next)
asyncio.create_task(
player_worker_decode(
callable,
self.queue,
self.thread_quit,
lambda: self.channel,
self.set_additional_outputs,
True,
)
)
self.has_started = True
async def recv(self): # type: ignore
try:
if self.readyState != "live":
raise MediaStreamError
self.start()
data = await self.queue.get()
if data is None:
self.stop()
return
data_time = data.time
# control playback rate
if data_time is not None:
if self._start is None:
self._start = time.time() - data_time # type: ignore
else:
wait = self._start + data_time - time.time()
await asyncio.sleep(wait)
return data
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
def stop(self):
logger.debug("audio-to-client stop callback")
self.thread_quit.set()
super().stop()