This commit is contained in:
Freddy Boulton
2025-01-10 17:14:47 -05:00
committed by GitHub
parent b64e019323
commit 4d16634307
12 changed files with 431 additions and 103 deletions

View File

@@ -16,10 +16,21 @@ from .utils import (
audio_to_file,
audio_to_float32,
)
from .webrtc import AsyncStreamHandler, StreamHandler, WebRTC
from .webrtc import (
AsyncAudioVideoStreamHandler,
AsyncStreamHandler,
AudioVideoStreamHandler,
StreamHandler,
WebRTC,
VideoEmitType,
AudioEmitType,
)
__all__ = [
"AsyncStreamHandler",
"AudioVideoStreamHandler",
"AudioEmitType",
"AsyncAudioVideoStreamHandler",
"AlgoOptions",
"AdditionalOutputs",
"aggregate_bytes_to_16bit",
@@ -36,6 +47,7 @@ __all__ = [
"stt",
"stt_for_chunks",
"StreamHandler",
"VideoEmitType",
"WebRTC",
"WebRTCError",
"Warning",

View File

@@ -9,7 +9,6 @@ from typing import Any, Callable, Generator, Literal, Union, cast
import numpy as np
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
from gradio_webrtc.utils import AdditionalOutputs
from gradio_webrtc.webrtc import EmitType, StreamHandler
logger = getLogger(__name__)

View File

@@ -147,16 +147,18 @@ async def player_worker_decode(
logger.debug(
"received array with shape %s sample rate %s layout %s",
audio_array.shape,
audio_array.shape, # type: ignore
sample_rate,
layout,
layout, # type: ignore
)
format = "s16" if audio_array.dtype == "int16" else "fltp"
format = "s16" if audio_array.dtype == "int16" else "fltp" # type: ignore
# Convert to audio frame and resample
# This runs in the same timeout context
frame = av.AudioFrame.from_ndarray( # type: ignore
audio_array, format=format, layout=layout
audio_array, # type: ignore
format=format,
layout=layout, # type: ignore
)
frame.sample_rate = sample_rate

View File

@@ -4,11 +4,13 @@ from __future__ import annotations
import asyncio
import functools
import inspect
import logging
import threading
import time
import traceback
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable
from typing import (
TYPE_CHECKING,
@@ -40,6 +42,7 @@ from aiortc.mediastreams import MediaStreamError
from gradio import wasm_utils
from gradio.components.base import Component, server
from gradio_client import handle_file
from numpy import typing as npt
from gradio_webrtc.utils import (
AdditionalOutputs,
@@ -61,6 +64,11 @@ if wasm_utils.IS_WASM:
logger = logging.getLogger(__name__)
VideoEmitType = Union[
AdditionalOutputs, tuple[npt.ArrayLike, AdditionalOutputs], npt.ArrayLike, None
]
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
class VideoCallback(VideoStreamTrack):
"""
@@ -72,7 +80,7 @@ class VideoCallback(VideoStreamTrack):
def __init__(
self,
track: MediaStreamTrack,
event_handler: Callable,
event_handler: VideoEventHandler,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive",
@@ -86,6 +94,7 @@ class VideoCallback(VideoStreamTrack):
self.thread_quit = asyncio.Event()
self.mode = mode
self.channel_set = asyncio.Event()
self.has_started = False
def set_channel(self, channel: DataChannel):
self.channel = channel
@@ -132,7 +141,7 @@ class VideoCallback(VideoStreamTrack):
if current_channel.get() != self.channel:
current_channel.set(self.channel)
async def recv(self):
async def recv(self): # type: ignore
try:
try:
frame = cast(VideoFrame, await self.track.recv())
@@ -142,7 +151,6 @@ class VideoCallback(VideoStreamTrack):
await self.wait_for_channel()
frame_array = frame.to_ndarray(format="bgr24")
if self.latest_args == "not_set":
return frame
@@ -253,6 +261,7 @@ EmitType: TypeAlias = Union[
tuple[tuple[int, np.ndarray], AdditionalOutputs],
None,
]
AudioEmitType = EmitType
class StreamHandler(StreamHandlerBase):
@@ -282,19 +291,104 @@ class AsyncStreamHandler(StreamHandlerBase):
StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler]
class AudioVideoStreamHandler(StreamHandlerBase):
@abstractmethod
def video_receive(self, frame: npt.NDArray) -> None:
pass
@abstractmethod
def video_emit(
self,
) -> VideoEmitType:
pass
class AsyncAudioVideoStreamHandler(StreamHandlerBase):
@abstractmethod
async def video_receive(self, frame: npt.NDArray) -> None:
pass
@abstractmethod
async def video_emit(
self,
) -> VideoEmitType:
pass
VideoStreamHandlerImpl = Union[AudioVideoStreamHandler, AsyncAudioVideoStreamHandler]
AudioVideoStreamHandlerImpl = Union[
AudioVideoStreamHandler, AsyncAudioVideoStreamHandler
]
AsyncHandler = Union[AsyncStreamHandler, AsyncAudioVideoStreamHandler]
class VideoStreamHander(VideoCallback):
async def process_frames(self):
while not self.thread_quit.is_set():
try:
await self.channel_set.wait()
frame = cast(VideoFrame, await self.track.recv())
frame_array = frame.to_ndarray(format="bgr24")
handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_receive):
await handler.video_receive(frame_array)
else:
handler.video_receive(frame_array)
except MediaStreamError:
self.stop()
def start(self):
if not self.has_started:
asyncio.create_task(self.process_frames())
self.has_started = True
async def recv(self): # type: ignore
self.start()
try:
handler = cast(VideoStreamHandlerImpl, self.event_handler)
if inspect.iscoroutinefunction(handler.video_emit):
outputs = await handler.video_emit()
else:
outputs = handler.video_emit()
array, outputs = split_output(outputs)
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
and self.channel
):
self.set_additional_outputs(outputs)
self.channel.send("change")
if array is None and self.mode == "send":
return
new_frame = self.array_to_frame(array)
# Will probably have to give developer ability to set pts and time_base
pts, time_base = await self.next_timestamp()
new_frame.pts = pts
new_frame.time_base = time_base
return new_frame
except Exception as e:
logger.debug("exception %s", e)
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
class AudioCallback(AudioStreamTrack):
kind = "audio"
def __init__(
self,
track: MediaStreamTrack,
event_handler: StreamHandlerImpl,
event_handler: StreamHandlerBase,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
super().__init__()
self.track = track
self.event_handler = event_handler
self.event_handler = cast(StreamHandlerImpl, event_handler)
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.queue = asyncio.Queue()
@@ -322,7 +416,7 @@ class AudioCallback(AudioStreamTrack):
frame = cast(AudioFrame, await self.track.recv())
for frame in self.event_handler.resample(frame):
numpy_array = frame.to_ndarray()
if isinstance(self.event_handler, AsyncStreamHandler):
if isinstance(self.event_handler, AsyncHandler):
await self.event_handler.receive(
(frame.sample_rate, numpy_array)
)
@@ -337,7 +431,7 @@ class AudioCallback(AudioStreamTrack):
def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
if isinstance(self.event_handler, AsyncStreamHandler):
if isinstance(self.event_handler, AsyncHandler):
callable = self.event_handler.emit
else:
callable = functools.partial(
@@ -358,7 +452,7 @@ class AudioCallback(AudioStreamTrack):
)
self.has_started = True
async def recv(self):
async def recv(self): # type: ignore
try:
if self.readyState != "live":
raise MediaStreamError
@@ -383,7 +477,7 @@ class AudioCallback(AudioStreamTrack):
# control playback rate
if self._start is None:
self._start = time.time() - data_time
self._start = time.time() - data_time # type: ignore
else:
wait = self._start + data_time - time.time()
await asyncio.sleep(wait)
@@ -434,7 +528,7 @@ class ServerToClientVideo(VideoStreamTrack):
self.latest_args = list(args)
self.args_set.set()
async def recv(self):
async def recv(self): # type: ignore
try:
pts, time_base = await self.next_timestamp()
await self.args_set.wait()
@@ -523,7 +617,7 @@ class ServerToClientAudio(AudioStreamTrack):
)
self.has_started = True
async def recv(self):
async def recv(self): # type: ignore
try:
if self.readyState != "live":
raise MediaStreamError
@@ -539,7 +633,7 @@ class ServerToClientAudio(AudioStreamTrack):
# control playback rate
if data_time is not None:
if self._start is None:
self._start = time.time() - data_time
self._start = time.time() - data_time # type: ignore
else:
wait = self._start + data_time - time.time()
await asyncio.sleep(wait)
@@ -576,10 +670,12 @@ class WebRTC(Component):
pcs: set[RTCPeerConnection] = set([])
relay = MediaRelay()
connections: dict[
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
] = {}
str,
list[VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback],
] = defaultdict(list)
data_channels: dict[str, DataChannel] = {}
additional_outputs: dict[str, list[AdditionalOutputs]] = {}
handlers: dict[str, StreamHandlerBase | Callable] = {}
EVENTS = ["tick", "state_change"]
@@ -606,7 +702,7 @@ class WebRTC(Component):
track_constraints: dict[str, Any] | None = None,
time_limit: float | None = None,
mode: Literal["send-receive", "receive", "send"] = "send-receive",
modality: Literal["video", "audio"] = "video",
modality: Literal["video", "audio", "audio-video"] = "video",
rtp_params: dict[str, Any] | None = None,
icon: str | None = None,
icon_button_color: str | None = None,
@@ -669,6 +765,23 @@ class WebRTC(Component):
"height": {"ideal": 500},
"frameRate": {"ideal": 30},
}
if track_constraints is None and modality == "audio-video":
track_constraints = {
"video": {
"facingMode": "user",
"width": {"ideal": 500},
"height": {"ideal": 500},
"frameRate": {"ideal": 30},
},
"audio": {
"echoCancellation": True,
"noiseSuppression": {"exact": True},
"autoGainControl": {"exact": True},
"sampleRate": {"ideal": 24000},
"sampleSize": {"ideal": 16},
"channelCount": {"exact": 1},
},
}
self.track_constraints = track_constraints
self.event_handler: Callable | StreamHandler | None = None
super().__init__(
@@ -722,7 +835,8 @@ class WebRTC(Component):
def set_input(self, webrtc_id: str, *args):
if webrtc_id in self.connections:
self.connections[webrtc_id].set_args(list(args))
for conn in self.connections[webrtc_id]:
conn.set_args(list(args))
def on_additional_outputs(
self,
@@ -767,7 +881,10 @@ class WebRTC(Component):
def stream(
self,
fn: Callable[..., Any] | StreamHandler | AsyncStreamHandler | None = None,
fn: Callable[..., Any]
| StreamHandlerImpl
| AudioVideoStreamHandlerImpl
| None = None,
inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | set[Block] | None = None,
js: str | None = None,
@@ -790,16 +907,16 @@ class WebRTC(Component):
self.concurrency_limit = (
1 if concurrency_limit in ["default", None] else concurrency_limit
)
self.event_handler = fn
self.event_handler = fn # type: ignore
self.time_limit = time_limit
if (
self.mode == "send-receive"
and self.modality == "audio"
and not isinstance(self.event_handler, (AsyncStreamHandler, StreamHandler))
and self.modality in ["audio", "audio-video"]
and not isinstance(self.event_handler, StreamHandlerBase)
):
raise ValueError(
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
"In the send-receive mode for audio, the event handler must be an instance of StreamHandlerBase."
)
if self.mode == "send-receive" or self.mode == "send":
@@ -815,13 +932,23 @@ class WebRTC(Component):
raise ValueError(
"In the webrtc stream event, the only output component must be the WebRTC component."
)
for input_component in inputs[1:]: # type: ignore
if hasattr(input_component, "change"):
input_component.change( # type: ignore
self.set_input,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
concurrency_limit=None,
time_limit=None,
js=js,
)
return self.tick( # type: ignore
self.set_input,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
concurrency_limit=None,
stream_every=0.5,
time_limit=None,
js=js,
)
@@ -855,9 +982,11 @@ class WebRTC(Component):
await pc.close()
def clean_up(self, webrtc_id: str):
connection = self.connections.pop(webrtc_id, None)
if isinstance(connection, AudioCallback):
connection.event_handler.shutdown()
self.handlers.pop(webrtc_id, None)
connection = self.connections.pop(webrtc_id, [])
for conn in connection:
if isinstance(conn, AudioCallback):
conn.event_handler.shutdown()
self.additional_outputs.pop(webrtc_id, None)
self.data_channels.pop(webrtc_id, None)
return connection
@@ -874,6 +1003,13 @@ class WebRTC(Component):
pc = RTCPeerConnection()
self.pcs.add(pc)
if isinstance(self.event_handler, StreamHandlerBase):
handler = self.event_handler.copy()
else:
handler = cast(Callable, self.event_handler)
self.handlers[body["webrtc_id"]] = handler
set_outputs = self.set_additional_outputs(body["webrtc_id"])
@pc.on("iceconnectionstatechange")
@@ -891,7 +1027,8 @@ class WebRTC(Component):
await pc.close()
connection = self.clean_up(body["webrtc_id"])
if connection:
connection.stop()
for conn in connection:
conn.stop()
self.pcs.discard(pc)
if pc.connectionState == "connected":
if self.time_limit is not None:
@@ -900,28 +1037,38 @@ class WebRTC(Component):
@pc.on("track")
def on_track(track):
relay = MediaRelay()
if self.modality == "video":
handler = self.handlers[body["webrtc_id"]]
if self.modality == "video" and track.kind == "video":
cb = VideoCallback(
relay.subscribe(track),
event_handler=cast(Callable, self.event_handler),
event_handler=cast(VideoEventHandler, handler),
set_additional_outputs=set_outputs,
mode=cast(Literal["send", "send-receive"], self.mode),
)
elif self.modality == "audio":
handler = cast(StreamHandler, self.event_handler).copy()
handler._loop = asyncio.get_running_loop()
elif self.modality == "audio-video" and track.kind == "video":
cb = VideoStreamHander(
relay.subscribe(track),
event_handler=handler, # type: ignore
set_additional_outputs=set_outputs,
)
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
eh = cast(StreamHandlerImpl, handler)
eh._loop = asyncio.get_running_loop()
cb = AudioCallback(
relay.subscribe(track),
event_handler=handler,
event_handler=eh,
set_additional_outputs=set_outputs,
)
else:
raise ValueError("Modality must be either video or audio")
self.connections[body["webrtc_id"]] = cb
raise ValueError("Modality must be either video, audio, or audio-video")
if body["webrtc_id"] not in self.connections:
self.connections[body["webrtc_id"]] = []
self.connections[body["webrtc_id"]].append(cb)
if body["webrtc_id"] in self.data_channels:
self.connections[body["webrtc_id"]].set_channel(
self.data_channels[body["webrtc_id"]]
)
for conn in self.connections[body["webrtc_id"]]:
conn.set_channel(self.data_channels[body["webrtc_id"]])
if self.mode == "send-receive":
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
@@ -944,7 +1091,7 @@ class WebRTC(Component):
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
self.connections[body["webrtc_id"]].append(cb)
cb.on("ended", lambda: self.clean_up(body["webrtc_id"]))
@pc.on("datachannel")
@@ -957,7 +1104,8 @@ class WebRTC(Component):
while not self.connections.get(webrtc_id):
await asyncio.sleep(0.05)
logger.debug("setting channel for webrtc id %s", webrtc_id)
self.connections[webrtc_id].set_channel(channel)
for conn in self.connections[webrtc_id]:
conn.set_channel(channel)
asyncio.create_task(set_channel(body["webrtc_id"]))