mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 09:59:22 +08:00
code (#48)
This commit is contained in:
@@ -16,10 +16,21 @@ from .utils import (
|
||||
audio_to_file,
|
||||
audio_to_float32,
|
||||
)
|
||||
from .webrtc import AsyncStreamHandler, StreamHandler, WebRTC
|
||||
from .webrtc import (
|
||||
AsyncAudioVideoStreamHandler,
|
||||
AsyncStreamHandler,
|
||||
AudioVideoStreamHandler,
|
||||
StreamHandler,
|
||||
WebRTC,
|
||||
VideoEmitType,
|
||||
AudioEmitType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AsyncStreamHandler",
|
||||
"AudioVideoStreamHandler",
|
||||
"AudioEmitType",
|
||||
"AsyncAudioVideoStreamHandler",
|
||||
"AlgoOptions",
|
||||
"AdditionalOutputs",
|
||||
"aggregate_bytes_to_16bit",
|
||||
@@ -36,6 +47,7 @@ __all__ = [
|
||||
"stt",
|
||||
"stt_for_chunks",
|
||||
"StreamHandler",
|
||||
"VideoEmitType",
|
||||
"WebRTC",
|
||||
"WebRTCError",
|
||||
"Warning",
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any, Callable, Generator, Literal, Union, cast
|
||||
import numpy as np
|
||||
|
||||
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
|
||||
from gradio_webrtc.utils import AdditionalOutputs
|
||||
from gradio_webrtc.webrtc import EmitType, StreamHandler
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@@ -147,16 +147,18 @@ async def player_worker_decode(
|
||||
|
||||
logger.debug(
|
||||
"received array with shape %s sample rate %s layout %s",
|
||||
audio_array.shape,
|
||||
audio_array.shape, # type: ignore
|
||||
sample_rate,
|
||||
layout,
|
||||
layout, # type: ignore
|
||||
)
|
||||
format = "s16" if audio_array.dtype == "int16" else "fltp"
|
||||
format = "s16" if audio_array.dtype == "int16" else "fltp" # type: ignore
|
||||
|
||||
# Convert to audio frame and resample
|
||||
# This runs in the same timeout context
|
||||
frame = av.AudioFrame.from_ndarray( # type: ignore
|
||||
audio_array, format=format, layout=layout
|
||||
audio_array, # type: ignore
|
||||
format=format,
|
||||
layout=layout, # type: ignore
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
|
||||
|
||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -40,6 +42,7 @@ from aiortc.mediastreams import MediaStreamError
|
||||
from gradio import wasm_utils
|
||||
from gradio.components.base import Component, server
|
||||
from gradio_client import handle_file
|
||||
from numpy import typing as npt
|
||||
|
||||
from gradio_webrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
@@ -61,6 +64,11 @@ if wasm_utils.IS_WASM:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VideoEmitType = Union[
|
||||
AdditionalOutputs, tuple[npt.ArrayLike, AdditionalOutputs], npt.ArrayLike, None
|
||||
]
|
||||
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
|
||||
|
||||
|
||||
class VideoCallback(VideoStreamTrack):
|
||||
"""
|
||||
@@ -72,7 +80,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
def __init__(
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: Callable,
|
||||
event_handler: VideoEventHandler,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
mode: Literal["send-receive", "send"] = "send-receive",
|
||||
@@ -86,6 +94,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.thread_quit = asyncio.Event()
|
||||
self.mode = mode
|
||||
self.channel_set = asyncio.Event()
|
||||
self.has_started = False
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
@@ -132,7 +141,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
if current_channel.get() != self.channel:
|
||||
current_channel.set(self.channel)
|
||||
|
||||
async def recv(self):
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
try:
|
||||
frame = cast(VideoFrame, await self.track.recv())
|
||||
@@ -142,7 +151,6 @@ class VideoCallback(VideoStreamTrack):
|
||||
|
||||
await self.wait_for_channel()
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
|
||||
if self.latest_args == "not_set":
|
||||
return frame
|
||||
|
||||
@@ -253,6 +261,7 @@ EmitType: TypeAlias = Union[
|
||||
tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
||||
None,
|
||||
]
|
||||
AudioEmitType = EmitType
|
||||
|
||||
|
||||
class StreamHandler(StreamHandlerBase):
|
||||
@@ -282,19 +291,104 @@ class AsyncStreamHandler(StreamHandlerBase):
|
||||
StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler]
|
||||
|
||||
|
||||
class AudioVideoStreamHandler(StreamHandlerBase):
|
||||
@abstractmethod
|
||||
def video_receive(self, frame: npt.NDArray) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def video_emit(
|
||||
self,
|
||||
) -> VideoEmitType:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncAudioVideoStreamHandler(StreamHandlerBase):
|
||||
@abstractmethod
|
||||
async def video_receive(self, frame: npt.NDArray) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def video_emit(
|
||||
self,
|
||||
) -> VideoEmitType:
|
||||
pass
|
||||
|
||||
|
||||
VideoStreamHandlerImpl = Union[AudioVideoStreamHandler, AsyncAudioVideoStreamHandler]
|
||||
AudioVideoStreamHandlerImpl = Union[
|
||||
AudioVideoStreamHandler, AsyncAudioVideoStreamHandler
|
||||
]
|
||||
AsyncHandler = Union[AsyncStreamHandler, AsyncAudioVideoStreamHandler]
|
||||
|
||||
|
||||
class VideoStreamHander(VideoCallback):
|
||||
async def process_frames(self):
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
await self.channel_set.wait()
|
||||
frame = cast(VideoFrame, await self.track.recv())
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
handler = cast(VideoStreamHandlerImpl, self.event_handler)
|
||||
if inspect.iscoroutinefunction(handler.video_receive):
|
||||
await handler.video_receive(frame_array)
|
||||
else:
|
||||
handler.video_receive(frame_array)
|
||||
except MediaStreamError:
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
if not self.has_started:
|
||||
asyncio.create_task(self.process_frames())
|
||||
self.has_started = True
|
||||
|
||||
async def recv(self): # type: ignore
|
||||
self.start()
|
||||
try:
|
||||
handler = cast(VideoStreamHandlerImpl, self.event_handler)
|
||||
if inspect.iscoroutinefunction(handler.video_emit):
|
||||
outputs = await handler.video_emit()
|
||||
else:
|
||||
outputs = handler.video_emit()
|
||||
|
||||
array, outputs = split_output(outputs)
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
and self.set_additional_outputs
|
||||
and self.channel
|
||||
):
|
||||
self.set_additional_outputs(outputs)
|
||||
self.channel.send("change")
|
||||
if array is None and self.mode == "send":
|
||||
return
|
||||
|
||||
new_frame = self.array_to_frame(array)
|
||||
|
||||
# Will probably have to give developer ability to set pts and time_base
|
||||
pts, time_base = await self.next_timestamp()
|
||||
new_frame.pts = pts
|
||||
new_frame.time_base = time_base
|
||||
|
||||
return new_frame
|
||||
except Exception as e:
|
||||
logger.debug("exception %s", e)
|
||||
exec = traceback.format_exc()
|
||||
logger.debug("traceback %s", exec)
|
||||
|
||||
|
||||
class AudioCallback(AudioStreamTrack):
|
||||
kind = "audio"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: StreamHandlerImpl,
|
||||
event_handler: StreamHandlerBase,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
self.event_handler = cast(StreamHandlerImpl, event_handler)
|
||||
self.current_timestamp = 0
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.queue = asyncio.Queue()
|
||||
@@ -322,7 +416,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
frame = cast(AudioFrame, await self.track.recv())
|
||||
for frame in self.event_handler.resample(frame):
|
||||
numpy_array = frame.to_ndarray()
|
||||
if isinstance(self.event_handler, AsyncStreamHandler):
|
||||
if isinstance(self.event_handler, AsyncHandler):
|
||||
await self.event_handler.receive(
|
||||
(frame.sample_rate, numpy_array)
|
||||
)
|
||||
@@ -337,7 +431,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
def start(self):
|
||||
if not self.has_started:
|
||||
loop = asyncio.get_running_loop()
|
||||
if isinstance(self.event_handler, AsyncStreamHandler):
|
||||
if isinstance(self.event_handler, AsyncHandler):
|
||||
callable = self.event_handler.emit
|
||||
else:
|
||||
callable = functools.partial(
|
||||
@@ -358,7 +452,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
)
|
||||
self.has_started = True
|
||||
|
||||
async def recv(self):
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
@@ -383,7 +477,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
|
||||
# control playback rate
|
||||
if self._start is None:
|
||||
self._start = time.time() - data_time
|
||||
self._start = time.time() - data_time # type: ignore
|
||||
else:
|
||||
wait = self._start + data_time - time.time()
|
||||
await asyncio.sleep(wait)
|
||||
@@ -434,7 +528,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.latest_args = list(args)
|
||||
self.args_set.set()
|
||||
|
||||
async def recv(self):
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
pts, time_base = await self.next_timestamp()
|
||||
await self.args_set.wait()
|
||||
@@ -523,7 +617,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
)
|
||||
self.has_started = True
|
||||
|
||||
async def recv(self):
|
||||
async def recv(self): # type: ignore
|
||||
try:
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
@@ -539,7 +633,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
# control playback rate
|
||||
if data_time is not None:
|
||||
if self._start is None:
|
||||
self._start = time.time() - data_time
|
||||
self._start = time.time() - data_time # type: ignore
|
||||
else:
|
||||
wait = self._start + data_time - time.time()
|
||||
await asyncio.sleep(wait)
|
||||
@@ -576,10 +670,12 @@ class WebRTC(Component):
|
||||
pcs: set[RTCPeerConnection] = set([])
|
||||
relay = MediaRelay()
|
||||
connections: dict[
|
||||
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
|
||||
] = {}
|
||||
str,
|
||||
list[VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback],
|
||||
] = defaultdict(list)
|
||||
data_channels: dict[str, DataChannel] = {}
|
||||
additional_outputs: dict[str, list[AdditionalOutputs]] = {}
|
||||
handlers: dict[str, StreamHandlerBase | Callable] = {}
|
||||
|
||||
EVENTS = ["tick", "state_change"]
|
||||
|
||||
@@ -606,7 +702,7 @@ class WebRTC(Component):
|
||||
track_constraints: dict[str, Any] | None = None,
|
||||
time_limit: float | None = None,
|
||||
mode: Literal["send-receive", "receive", "send"] = "send-receive",
|
||||
modality: Literal["video", "audio"] = "video",
|
||||
modality: Literal["video", "audio", "audio-video"] = "video",
|
||||
rtp_params: dict[str, Any] | None = None,
|
||||
icon: str | None = None,
|
||||
icon_button_color: str | None = None,
|
||||
@@ -669,6 +765,23 @@ class WebRTC(Component):
|
||||
"height": {"ideal": 500},
|
||||
"frameRate": {"ideal": 30},
|
||||
}
|
||||
if track_constraints is None and modality == "audio-video":
|
||||
track_constraints = {
|
||||
"video": {
|
||||
"facingMode": "user",
|
||||
"width": {"ideal": 500},
|
||||
"height": {"ideal": 500},
|
||||
"frameRate": {"ideal": 30},
|
||||
},
|
||||
"audio": {
|
||||
"echoCancellation": True,
|
||||
"noiseSuppression": {"exact": True},
|
||||
"autoGainControl": {"exact": True},
|
||||
"sampleRate": {"ideal": 24000},
|
||||
"sampleSize": {"ideal": 16},
|
||||
"channelCount": {"exact": 1},
|
||||
},
|
||||
}
|
||||
self.track_constraints = track_constraints
|
||||
self.event_handler: Callable | StreamHandler | None = None
|
||||
super().__init__(
|
||||
@@ -722,7 +835,8 @@ class WebRTC(Component):
|
||||
|
||||
def set_input(self, webrtc_id: str, *args):
|
||||
if webrtc_id in self.connections:
|
||||
self.connections[webrtc_id].set_args(list(args))
|
||||
for conn in self.connections[webrtc_id]:
|
||||
conn.set_args(list(args))
|
||||
|
||||
def on_additional_outputs(
|
||||
self,
|
||||
@@ -767,7 +881,10 @@ class WebRTC(Component):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
fn: Callable[..., Any] | StreamHandler | AsyncStreamHandler | None = None,
|
||||
fn: Callable[..., Any]
|
||||
| StreamHandlerImpl
|
||||
| AudioVideoStreamHandlerImpl
|
||||
| None = None,
|
||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
js: str | None = None,
|
||||
@@ -790,16 +907,16 @@ class WebRTC(Component):
|
||||
self.concurrency_limit = (
|
||||
1 if concurrency_limit in ["default", None] else concurrency_limit
|
||||
)
|
||||
self.event_handler = fn
|
||||
self.event_handler = fn # type: ignore
|
||||
self.time_limit = time_limit
|
||||
|
||||
if (
|
||||
self.mode == "send-receive"
|
||||
and self.modality == "audio"
|
||||
and not isinstance(self.event_handler, (AsyncStreamHandler, StreamHandler))
|
||||
and self.modality in ["audio", "audio-video"]
|
||||
and not isinstance(self.event_handler, StreamHandlerBase)
|
||||
):
|
||||
raise ValueError(
|
||||
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
|
||||
"In the send-receive mode for audio, the event handler must be an instance of StreamHandlerBase."
|
||||
)
|
||||
|
||||
if self.mode == "send-receive" or self.mode == "send":
|
||||
@@ -815,13 +932,23 @@ class WebRTC(Component):
|
||||
raise ValueError(
|
||||
"In the webrtc stream event, the only output component must be the WebRTC component."
|
||||
)
|
||||
for input_component in inputs[1:]: # type: ignore
|
||||
if hasattr(input_component, "change"):
|
||||
input_component.change( # type: ignore
|
||||
self.set_input,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
concurrency_limit=None,
|
||||
time_limit=None,
|
||||
js=js,
|
||||
)
|
||||
return self.tick( # type: ignore
|
||||
self.set_input,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
concurrency_limit=None,
|
||||
stream_every=0.5,
|
||||
time_limit=None,
|
||||
js=js,
|
||||
)
|
||||
@@ -855,9 +982,11 @@ class WebRTC(Component):
|
||||
await pc.close()
|
||||
|
||||
def clean_up(self, webrtc_id: str):
|
||||
connection = self.connections.pop(webrtc_id, None)
|
||||
if isinstance(connection, AudioCallback):
|
||||
connection.event_handler.shutdown()
|
||||
self.handlers.pop(webrtc_id, None)
|
||||
connection = self.connections.pop(webrtc_id, [])
|
||||
for conn in connection:
|
||||
if isinstance(conn, AudioCallback):
|
||||
conn.event_handler.shutdown()
|
||||
self.additional_outputs.pop(webrtc_id, None)
|
||||
self.data_channels.pop(webrtc_id, None)
|
||||
return connection
|
||||
@@ -874,6 +1003,13 @@ class WebRTC(Component):
|
||||
pc = RTCPeerConnection()
|
||||
self.pcs.add(pc)
|
||||
|
||||
if isinstance(self.event_handler, StreamHandlerBase):
|
||||
handler = self.event_handler.copy()
|
||||
else:
|
||||
handler = cast(Callable, self.event_handler)
|
||||
|
||||
self.handlers[body["webrtc_id"]] = handler
|
||||
|
||||
set_outputs = self.set_additional_outputs(body["webrtc_id"])
|
||||
|
||||
@pc.on("iceconnectionstatechange")
|
||||
@@ -891,7 +1027,8 @@ class WebRTC(Component):
|
||||
await pc.close()
|
||||
connection = self.clean_up(body["webrtc_id"])
|
||||
if connection:
|
||||
connection.stop()
|
||||
for conn in connection:
|
||||
conn.stop()
|
||||
self.pcs.discard(pc)
|
||||
if pc.connectionState == "connected":
|
||||
if self.time_limit is not None:
|
||||
@@ -900,28 +1037,38 @@ class WebRTC(Component):
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
relay = MediaRelay()
|
||||
if self.modality == "video":
|
||||
handler = self.handlers[body["webrtc_id"]]
|
||||
|
||||
if self.modality == "video" and track.kind == "video":
|
||||
cb = VideoCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(Callable, self.event_handler),
|
||||
event_handler=cast(VideoEventHandler, handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
mode=cast(Literal["send", "send-receive"], self.mode),
|
||||
)
|
||||
elif self.modality == "audio":
|
||||
handler = cast(StreamHandler, self.event_handler).copy()
|
||||
handler._loop = asyncio.get_running_loop()
|
||||
elif self.modality == "audio-video" and track.kind == "video":
|
||||
cb = VideoStreamHander(
|
||||
relay.subscribe(track),
|
||||
event_handler=handler, # type: ignore
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
|
||||
eh = cast(StreamHandlerImpl, handler)
|
||||
eh._loop = asyncio.get_running_loop()
|
||||
cb = AudioCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=handler,
|
||||
event_handler=eh,
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Modality must be either video or audio")
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
raise ValueError("Modality must be either video, audio, or audio-video")
|
||||
if body["webrtc_id"] not in self.connections:
|
||||
self.connections[body["webrtc_id"]] = []
|
||||
|
||||
self.connections[body["webrtc_id"]].append(cb)
|
||||
if body["webrtc_id"] in self.data_channels:
|
||||
self.connections[body["webrtc_id"]].set_channel(
|
||||
self.data_channels[body["webrtc_id"]]
|
||||
)
|
||||
for conn in self.connections[body["webrtc_id"]]:
|
||||
conn.set_channel(self.data_channels[body["webrtc_id"]])
|
||||
if self.mode == "send-receive":
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
@@ -944,7 +1091,7 @@ class WebRTC(Component):
|
||||
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
self.connections[body["webrtc_id"]].append(cb)
|
||||
cb.on("ended", lambda: self.clean_up(body["webrtc_id"]))
|
||||
|
||||
@pc.on("datachannel")
|
||||
@@ -957,7 +1104,8 @@ class WebRTC(Component):
|
||||
while not self.connections.get(webrtc_id):
|
||||
await asyncio.sleep(0.05)
|
||||
logger.debug("setting channel for webrtc id %s", webrtc_id)
|
||||
self.connections[webrtc_id].set_channel(channel)
|
||||
for conn in self.connections[webrtc_id]:
|
||||
conn.set_channel(channel)
|
||||
|
||||
asyncio.create_task(set_channel(body["webrtc_id"]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user