Files
gradio-webrtc/backend/fastrtc/tracks.py
Freddy Boulton e231f793e8 trigger release (#201)
* trigger release

* add code
2025-03-20 21:01:13 -04:00

801 lines
25 KiB
Python

"""WebRTC tracks."""
from __future__ import annotations
import asyncio
import fractions
import functools
import inspect
import logging
import threading
import time
import traceback
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import (
Any,
Generator,
Literal,
Tuple,
TypeAlias,
Union,
cast,
)
import anyio.to_thread
import av
import numpy as np
from aiortc import (
AudioStreamTrack,
MediaStreamTrack,
VideoStreamTrack,
)
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE, MediaStreamError
from numpy import typing as npt
from fastrtc.utils import (
AdditionalOutputs,
DataChannel,
WebRTCError,
create_message,
current_channel,
player_worker_decode,
split_output,
)
logger = logging.getLogger(__name__)
VideoNDArray: TypeAlias = Union[
np.ndarray[Any, np.dtype[np.uint8]],
np.ndarray[Any, np.dtype[np.uint16]],
np.ndarray[Any, np.dtype[np.float32]],
]
VideoEmitType = (
VideoNDArray | tuple[VideoNDArray, AdditionalOutputs] | AdditionalOutputs
)
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
@dataclass
class VideoStreamHandler:
callable: VideoEventHandler
fps: int = 30
skip_frames: bool = False
class VideoCallback(VideoStreamTrack):
"""
This works for streaming input and output
"""
kind = "video"
def __init__(
self,
track: MediaStreamTrack,
event_handler: VideoEventHandler,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive",
fps: int = 30,
skip_frames: bool = False,
) -> None:
super().__init__()
self.track = track
self.event_handler = event_handler
self.latest_args: str | list[Any] = "not_set"
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.thread_quit = asyncio.Event()
self.mode = mode
self.channel_set = asyncio.Event()
self.has_started = False
self.fps = fps
self.frame_ptime = 1.0 / fps
self.skip_frames = skip_frames
self.frame_queue: asyncio.Queue[VideoFrame] = asyncio.Queue()
self.latest_frame = None
def set_channel(self, channel: DataChannel):
self.channel = channel
current_channel.set(channel)
self.channel_set.set()
def set_args(self, args: list[Any]):
self.latest_args = ["__webrtc_value__"] + list(args)
def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None
) -> list[Any]:
new_args = []
for val in args:
if isinstance(val, str) and val == "__webrtc_value__":
new_args.append(frame)
else:
new_args.append(val)
return new_args
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
async def process_frames(self):
while not self.thread_quit.is_set():
try:
await self.recv()
except TimeoutError:
continue
async def start(
self,
):
asyncio.create_task(self.process_frames())
def stop(self):
super().stop()
logger.debug("video callback stop")
self.thread_quit.set()
async def wait_for_channel(self):
if not self.channel_set.is_set():
await self.channel_set.wait()
if current_channel.get() != self.channel:
current_channel.set(self.channel)
async def accept_input(self):
self.has_started = True
while not self.thread_quit.is_set():
try:
frame = cast(VideoFrame, await self.track.recv())
self.latest_frame = frame
self.frame_queue.put_nowait(frame)
except MediaStreamError:
self.stop()
return
def accept_input_in_background(self):
if not self.has_started:
asyncio.create_task(self.accept_input())
async def recv(self): # type: ignore
self.accept_input_in_background()
try:
frame = await self.frame_queue.get()
if self.skip_frames:
frame = self.latest_frame
await self.wait_for_channel()
frame_array = frame.to_ndarray(format="bgr24") # type: ignore
if self.latest_args == "not_set":
return frame
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array, outputs = split_output(self.event_handler(*args))
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send(create_message("fetch_output", []))
if array is None and self.mode == "send":
return
new_frame = self.array_to_frame(array)
if frame:
new_frame.pts = frame.pts
new_frame.time_base = frame.time_base
else:
pts, time_base = await self.next_timestamp()
new_frame.pts = pts
new_frame.time_base = time_base
return new_frame
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
if isinstance(e, WebRTCError):
raise e
else:
raise WebRTCError(str(e)) from e
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError
if hasattr(self, "_timestamp"):
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
if wait > 0:
await asyncio.sleep(wait)
else:
self._start = time.time()
self._timestamp = 0
return self._timestamp, VIDEO_TIME_BASE
class StreamHandlerBase(ABC):
def __init__(
self,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 960,
input_sample_rate: int = 48000,
fps: int = 30,
) -> None:
self.expected_layout = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.input_sample_rate = input_sample_rate
self.fps = fps
self.latest_args: list[Any] = []
self._resampler = None
self._channel: DataChannel | None = None
self._loop: asyncio.AbstractEventLoop
self.args_set = asyncio.Event()
self.channel_set = asyncio.Event()
self._phone_mode = False
self._clear_queue: Callable | None = None
@property
def clear_queue(self) -> Callable:
return cast(Callable, self._clear_queue)
@property
def loop(self) -> asyncio.AbstractEventLoop:
return cast(asyncio.AbstractEventLoop, self._loop)
@property
def channel(self) -> DataChannel | None:
return self._channel
@property
def phone_mode(self) -> bool:
return self._phone_mode
@phone_mode.setter
def phone_mode(self, value: bool):
self._phone_mode = value
def set_channel(self, channel: DataChannel):
self._channel = channel
self.channel_set.set()
async def fetch_args(
self,
):
if self.channel:
self.channel.send(create_message("send_input", []))
logger.debug("Sent send_input")
async def wait_for_args(self):
if not self.phone_mode:
await self.fetch_args()
await self.args_set.wait()
else:
self.args_set.set()
def wait_for_args_sync(self):
try:
asyncio.run_coroutine_threadsafe(self.wait_for_args(), self.loop).result()
except Exception:
import traceback
traceback.print_exc()
async def send_message(self, msg: str):
if self.channel:
self.channel.send(msg)
logger.debug("Sent msg %s", msg)
def send_message_sync(self, msg: str):
try:
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
logger.debug("Sent msg %s", msg)
except Exception as e:
logger.debug("Exception sending msg %s", e)
def set_args(self, args: list[Any]):
logger.debug("setting args in audio callback %s", args)
self.latest_args = ["__webrtc_value__"] + list(args)
self.args_set.set()
def reset(self):
self.args_set.clear()
def shutdown(self):
pass
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
if self._resampler is None:
self._resampler = av.AudioResampler( # type: ignore
format="s16",
layout=self.expected_layout,
rate=self.input_sample_rate,
frame_size=frame.samples,
)
yield from self._resampler.resample(frame)
EmitType: TypeAlias = (
tuple[int, npt.NDArray[np.int16 | np.float32]]
| tuple[int, npt.NDArray[np.int16 | np.float32], Literal["mono", "stereo"]]
| AdditionalOutputs
| tuple[tuple[int, npt.NDArray[np.int16 | np.float32]], AdditionalOutputs]
| None
)
AudioEmitType = EmitType
class StreamHandler(StreamHandlerBase):
@abstractmethod
def receive(self, frame: tuple[int, npt.NDArray[np.int16]]) -> None:
pass
@abstractmethod
def emit(self) -> EmitType:
pass
@abstractmethod
def copy(self) -> StreamHandler:
pass
def start_up(self):
pass
class AsyncStreamHandler(StreamHandlerBase):
@abstractmethod
async def receive(self, frame: tuple[int, npt.NDArray[np.int16]]) -> None:
pass
@abstractmethod
async def emit(self) -> EmitType:
pass
@abstractmethod
def copy(self) -> AsyncStreamHandler:
pass
async def start_up(self):
pass
StreamHandlerImpl = StreamHandler | AsyncStreamHandler
class AudioVideoStreamHandler(StreamHandler):
@abstractmethod
def video_receive(self, frame: VideoFrame) -> None:
pass
@abstractmethod
def video_emit(self) -> VideoEmitType:
pass
@abstractmethod
def copy(self) -> AudioVideoStreamHandler:
pass
class AsyncAudioVideoStreamHandler(AsyncStreamHandler):
@abstractmethod
async def video_receive(self, frame: npt.NDArray[np.float32]) -> None:
pass
@abstractmethod
async def video_emit(self) -> VideoEmitType:
pass
@abstractmethod
def copy(self) -> AsyncAudioVideoStreamHandler:
pass
VideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AudioVideoStreamHandlerImpl = AudioVideoStreamHandler | AsyncAudioVideoStreamHandler
AsyncHandler = AsyncStreamHandler | AsyncAudioVideoStreamHandler
HandlerType = (
StreamHandlerImpl
| VideoStreamHandlerImpl
| VideoEventHandler
| Callable
| VideoStreamHandler
)
class VideoStreamHandler_(VideoCallback):
async def process_frames(self):
while not self.thread_quit.is_set():
try:
await self.channel_set.wait()
frame = cast(VideoFrame, await self.track.recv())
frame_array = frame.to_ndarray(format="bgr24")
handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_receive):
await handler.video_receive(frame_array)
else:
handler.video_receive(frame_array) # type: ignore
except MediaStreamError:
self.stop()
async def start(self):
if not self.has_started:
asyncio.create_task(self.process_frames())
self.has_started = True
async def recv(self): # type: ignore
await self.start()
try:
handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_emit):
outputs = await handler.video_emit()
else:
outputs = handler.video_emit()
array, outputs = split_output(outputs)
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send(create_message("fetch_output", []))
if array is None and self.mode == "send":
return
new_frame = self.array_to_frame(array)
# Will probably have to give developer ability to set pts and time_base
pts, time_base = await self.next_timestamp()
new_frame.pts = pts
new_frame.time_base = time_base
return new_frame
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
class AudioCallback(AudioStreamTrack):
kind = "audio"
def __init__(
self,
track: MediaStreamTrack,
event_handler: StreamHandlerBase,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
super().__init__()
self.track = track
self.event_handler = cast(StreamHandlerImpl, event_handler)
self.event_handler._clear_queue = self.clear_queue
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.queue = asyncio.Queue()
self.thread_quit = asyncio.Event()
self._start: float | None = None
self.has_started = False
self.last_timestamp = 0
self.channel = channel
self.set_additional_outputs = set_additional_outputs
def clear_queue(self):
logger.debug("clearing queue")
logger.debug("queue size: %d", self.queue.qsize())
i = 0
while not self.queue.empty():
self.queue.get_nowait()
i += 1
logger.debug("popped %d items from queue", i)
self._start = None
async def wait_for_channel(self):
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
def set_channel(self, channel: DataChannel):
self.channel = channel
self.event_handler.set_channel(channel)
def set_args(self, args: list[Any]):
self.event_handler.set_args(args)
def event_handler_receive(self, frame: tuple[int, np.ndarray]) -> None:
current_channel.set(self.event_handler.channel)
return cast(Callable, self.event_handler.receive)(frame)
def event_handler_emit(self) -> EmitType:
current_channel.set(self.event_handler.channel)
return cast(Callable, self.event_handler.emit)()
async def process_input_frames(self) -> None:
while not self.thread_quit.is_set():
try:
frame = cast(AudioFrame, await self.track.recv())
for frame in self.event_handler.resample(frame):
numpy_array = frame.to_ndarray()
if isinstance(self.event_handler, AsyncHandler):
await self.event_handler.receive(
(frame.sample_rate, numpy_array) # type: ignore
)
else:
await anyio.to_thread.run_sync(
self.event_handler_receive, (frame.sample_rate, numpy_array)
)
except MediaStreamError:
logger.debug("MediaStreamError in process_input_frames")
break
async def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
await self.wait_for_channel()
if isinstance(self.event_handler, AsyncHandler):
callable = self.event_handler.emit
start_up = self.event_handler.start_up()
if not inspect.isawaitable(start_up):
raise WebRTCError(
"In AsyncStreamHandler, start_up must be a coroutine (async def)"
)
else:
callable = functools.partial(
loop.run_in_executor, None, self.event_handler_emit
)
start_up = anyio.to_thread.run_sync(self.event_handler.start_up)
self.process_input_task = asyncio.create_task(self.process_input_frames())
self.process_input_task.add_done_callback(
lambda _: logger.debug("process_input_done")
)
self.start_up_task = asyncio.create_task(start_up)
self.start_up_task.add_done_callback(
lambda _: logger.debug("start_up_done")
)
self.decode_task = asyncio.create_task(
player_worker_decode(
callable,
self.queue,
self.thread_quit,
lambda: self.channel,
self.set_additional_outputs,
False,
self.event_handler.output_sample_rate,
self.event_handler.output_frame_size,
)
)
self.decode_task.add_done_callback(lambda _: logger.debug("decode_done"))
self.has_started = True
async def recv(self): # type: ignore
try:
if self.readyState != "live":
raise MediaStreamError
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
await self.start()
frame = await self.queue.get()
logger.debug("frame %s", frame)
data_time = frame.time
if time.time() - self.last_timestamp > 10 * (
self.event_handler.output_frame_size
/ self.event_handler.output_sample_rate
):
self._start = None
# control playback rate
if self._start is None:
self._start = time.time() - data_time # type: ignore
else:
wait = self._start + data_time - time.time()
await asyncio.sleep(wait)
self.last_timestamp = time.time()
return frame
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
def stop(self):
logger.debug("audio callback stop")
self.thread_quit.set()
super().stop()
class ServerToClientVideo(VideoStreamTrack):
"""
This works for streaming input and output
"""
kind = "video"
def __init__(
self,
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
fps: int = 30,
) -> None:
super().__init__() # don't forget this!
self.event_handler = event_handler
self.args_set = asyncio.Event()
self.latest_args: str | list[Any] = "not_set"
self.generator: Generator[Any, None, Any] | None = None
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.fps = fps
self.frame_ptime = 1.0 / fps
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
def set_channel(self, channel: DataChannel):
self.channel = channel
def set_args(self, args: list[Any]):
self.latest_args = list(args)
self.args_set.set()
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError
if hasattr(self, "_timestamp"):
self._timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE)
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
if wait > 0:
await asyncio.sleep(wait)
else:
self._start = time.time()
self._timestamp = 0
return self._timestamp, VIDEO_TIME_BASE
async def recv(self): # type: ignore
try:
pts, time_base = await self.next_timestamp()
await self.args_set.wait()
current_channel.set(self.channel)
if self.generator is None:
self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args)
)
try:
next_array, outputs = split_output(next(self.generator))
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send(create_message("fetch_output", []))
except StopIteration:
self.stop()
return
next_frame = self.array_to_frame(next_array)
next_frame.pts = pts
next_frame.time_base = time_base
return next_frame
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s %s", e, exec)
if isinstance(e, WebRTCError):
raise e
else:
raise WebRTCError(str(e)) from e
class ServerToClientAudio(AudioStreamTrack):
kind = "audio"
def __init__(
self,
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
self.generator: Generator[Any, None, Any] | None = None
self.event_handler = event_handler
self.event_handler._clear_queue = self.clear_queue
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.args_set = threading.Event()
self.queue = asyncio.Queue()
self.thread_quit = asyncio.Event()
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.has_started = False
self._start: float | None = None
super().__init__()
def clear_queue(self):
while not self.queue.empty():
self.queue.get_nowait()
self._start = None
def set_channel(self, channel: DataChannel):
self.channel = channel
def set_args(self, args: list[Any]):
self.latest_args = list(args)
self.args_set.set()
def next(self) -> tuple[int, np.ndarray] | None:
self.args_set.wait()
current_channel.set(self.channel)
if self.generator is None:
self.generator = self.event_handler(*self.latest_args)
if self.generator is not None:
try:
frame = next(self.generator)
return frame
except StopIteration:
self.thread_quit.set()
async def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
callable = functools.partial(loop.run_in_executor, None, self.next)
asyncio.create_task(
player_worker_decode(
callable,
self.queue,
self.thread_quit,
lambda: self.channel,
self.set_additional_outputs,
True,
)
)
self.has_started = True
async def recv(self): # type: ignore
try:
if self.readyState != "live":
raise MediaStreamError
await self.start()
data = await self.queue.get()
if data is None:
self.stop()
return
data_time = data.time
# control playback rate
if data_time is not None:
if self._start is None:
self._start = time.time() - data_time # type: ignore
else:
wait = self._start + data_time - time.time()
await asyncio.sleep(wait)
return data
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
if isinstance(e, WebRTCError):
raise e
else:
raise WebRTCError(str(e)) from e
def stop(self):
logger.debug("audio-to-client stop callback")
self.thread_quit.set()
super().stop()