mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 09:59:22 +08:00
working prototype
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .webrtc import WebRTC
|
||||
from .webrtc import WebRTC, StreamHandler
|
||||
|
||||
__all__ = ["WebRTC"]
|
||||
__all__ = ["StreamHandler", "WebRTC"]
|
||||
|
||||
@@ -15,8 +15,7 @@ AUDIO_PTIME = 0.020
|
||||
|
||||
def player_worker_decode(
|
||||
loop,
|
||||
callable: Callable,
|
||||
stream,
|
||||
next: Callable,
|
||||
queue: asyncio.Queue,
|
||||
throttle_playback: bool,
|
||||
thread_quit: threading.Event,
|
||||
@@ -33,22 +32,10 @@ def player_worker_decode(
|
||||
|
||||
frame_time = None
|
||||
start_time = time.time()
|
||||
generator = None
|
||||
|
||||
while not thread_quit.is_set():
|
||||
if stream.latest_args == "not_set":
|
||||
continue
|
||||
if generator is None:
|
||||
generator = callable(*stream.latest_args)
|
||||
try:
|
||||
frame = next(generator)
|
||||
except Exception as exc:
|
||||
if isinstance(exc, StopIteration):
|
||||
logger.debug("Stopping audio stream")
|
||||
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
||||
thread_quit.set()
|
||||
break
|
||||
|
||||
frame = next()
|
||||
logger.debug("emitted %s", frame)
|
||||
# read up to 1 second ahead
|
||||
if throttle_playback:
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -56,7 +43,7 @@ def player_worker_decode(
|
||||
time.sleep(0.1)
|
||||
sample_rate, audio_array = frame
|
||||
format = "s16" if audio_array.dtype == "int16" else "fltp"
|
||||
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout="mono")
|
||||
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout="stereo")
|
||||
frame.sample_rate = sample_rate
|
||||
for frame in audio_resampler.resample(frame):
|
||||
# fix timestamps
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
@@ -10,14 +11,16 @@ import traceback
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
|
||||
|
||||
import anyio.to_thread
|
||||
import numpy as np
|
||||
from aiortc import (
|
||||
AudioStreamTrack,
|
||||
RTCPeerConnection,
|
||||
RTCSessionDescription,
|
||||
VideoStreamTrack,
|
||||
MediaStreamTrack,
|
||||
)
|
||||
from aiortc.contrib.media import MediaRelay, VideoFrame # type: ignore
|
||||
from aiortc.contrib.media import MediaRelay, AudioFrame, VideoFrame # type: ignore
|
||||
from aiortc.mediastreams import MediaStreamError
|
||||
from gradio import wasm_utils
|
||||
from gradio.components.base import Component, server
|
||||
@@ -47,7 +50,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
track,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: Callable,
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
@@ -72,7 +75,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
async def recv(self):
|
||||
try:
|
||||
try:
|
||||
frame = await self.track.recv()
|
||||
frame = cast(VideoFrame, await self.track.recv())
|
||||
except MediaStreamError:
|
||||
return
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
@@ -100,6 +103,100 @@ class VideoCallback(VideoStreamTrack):
|
||||
logger.debug(exec)
|
||||
|
||||
|
||||
class StreamHandler(ABC):
|
||||
@abstractmethod
|
||||
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def emit(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AudioCallback(AudioStreamTrack):
|
||||
kind = "audio"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: StreamHandler,
|
||||
) -> None:
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
self.current_timestamp = 0
|
||||
self.latest_args = "not_set"
|
||||
self.queue = asyncio.Queue()
|
||||
self.thread_quit = threading.Event()
|
||||
self.__thread = None
|
||||
self._start: float | None = None
|
||||
self.has_started = False
|
||||
super().__init__()
|
||||
|
||||
async def process_input_frames(self) -> None:
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
frame = cast(AudioFrame, await self.track.recv())
|
||||
numpy_array = frame.to_ndarray()
|
||||
logger.debug("numpy array shape %s", numpy_array.shape)
|
||||
await anyio.to_thread.run_sync(
|
||||
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
||||
)
|
||||
except MediaStreamError:
|
||||
break
|
||||
|
||||
def start(self):
|
||||
if not self.has_started:
|
||||
asyncio.create_task(self.process_input_frames())
|
||||
self.__thread = threading.Thread(
|
||||
name="audio-output-decoders",
|
||||
target=player_worker_decode,
|
||||
args=(
|
||||
asyncio.get_event_loop(),
|
||||
self.event_handler.emit,
|
||||
self.queue,
|
||||
True,
|
||||
self.thread_quit,
|
||||
),
|
||||
)
|
||||
self.__thread.start()
|
||||
self.has_started = True
|
||||
|
||||
async def recv(self):
|
||||
try:
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
self.start()
|
||||
data = await self.queue.get()
|
||||
logger.debug("data %s", data)
|
||||
if data is None:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
data_time = data.time
|
||||
|
||||
# control playback rate
|
||||
if data_time is not None:
|
||||
if self._start is None:
|
||||
self._start = time.time() - data_time
|
||||
else:
|
||||
wait = self._start + data_time - time.time()
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
exec = traceback.format_exc()
|
||||
logger.debug(exec)
|
||||
|
||||
def stop(self):
|
||||
self.thread_quit.set()
|
||||
if self.__thread is not None:
|
||||
self.__thread.join()
|
||||
self.__thread = None
|
||||
super().stop()
|
||||
|
||||
|
||||
class ServerToClientVideo(VideoStreamTrack):
|
||||
"""
|
||||
This works for streaming input and output
|
||||
@@ -116,17 +213,6 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
|
||||
def add_frame_to_payload(
|
||||
self, args: list[Any], frame: np.ndarray | None
|
||||
) -> list[Any]:
|
||||
new_args = []
|
||||
for val in args:
|
||||
if isinstance(val, str) and val == "__webrtc_value__":
|
||||
new_args.append(frame)
|
||||
else:
|
||||
new_args.append(val)
|
||||
return new_args
|
||||
|
||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||
|
||||
@@ -176,6 +262,25 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
self._start: float | None = None
|
||||
super().__init__()
|
||||
|
||||
def next(self) -> tuple[int, np.ndarray] | None:
|
||||
import anyio
|
||||
|
||||
if self.latest_args == "not_set":
|
||||
return
|
||||
if self.generator is None:
|
||||
self.generator = self.event_handler(*self.latest_args)
|
||||
if self.generator is not None:
|
||||
try:
|
||||
frame = next(self.generator)
|
||||
return frame
|
||||
except Exception as exc:
|
||||
if isinstance(exc, StopIteration):
|
||||
logger.debug("Stopping audio stream")
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.queue.put(None), asyncio.get_event_loop()
|
||||
)
|
||||
self.thread_quit.set()
|
||||
|
||||
def start(self):
|
||||
if self.__thread is None:
|
||||
self.__thread = threading.Thread(
|
||||
@@ -183,8 +288,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
target=player_worker_decode,
|
||||
args=(
|
||||
asyncio.get_event_loop(),
|
||||
self.event_handler,
|
||||
self,
|
||||
self.next,
|
||||
self.queue,
|
||||
False,
|
||||
self.thread_quit,
|
||||
@@ -241,7 +345,7 @@ class WebRTC(Component):
|
||||
pcs: set[RTCPeerConnection] = set([])
|
||||
relay = MediaRelay()
|
||||
connections: dict[
|
||||
str, VideoCallback | ServerToClientVideo | ServerToClientAudio
|
||||
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
|
||||
] = {}
|
||||
|
||||
EVENTS = ["tick"]
|
||||
@@ -300,9 +404,6 @@ class WebRTC(Component):
|
||||
streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding.
|
||||
watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png.
|
||||
"""
|
||||
if modality == "audio" and mode == "send-receive":
|
||||
raise ValueError("Audio modality is not supported in send-receive mode")
|
||||
|
||||
self.time_limit = time_limit
|
||||
self.height = height
|
||||
self.width = width
|
||||
@@ -358,7 +459,7 @@ class WebRTC(Component):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
fn: Callable[..., Any] | None = None,
|
||||
fn: Callable[..., Any] | StreamHandler | None = None,
|
||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
js: str | None = None,
|
||||
@@ -384,6 +485,15 @@ class WebRTC(Component):
|
||||
self.event_handler = fn
|
||||
self.time_limit = time_limit
|
||||
|
||||
if (
|
||||
self.mode == "send-receive"
|
||||
and self.modality == "audio"
|
||||
and not isinstance(self.event_handler, StreamHandler)
|
||||
):
|
||||
raise ValueError(
|
||||
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
|
||||
)
|
||||
|
||||
if self.mode == "send-receive":
|
||||
if cast(list[Block], inputs)[0] != self:
|
||||
raise ValueError(
|
||||
@@ -439,7 +549,7 @@ class WebRTC(Component):
|
||||
@server
|
||||
async def offer(self, body):
|
||||
logger.debug("Starting to handle offer")
|
||||
logger.debug("Offer body", body)
|
||||
logger.debug("Offer body %s", body)
|
||||
if len(self.connections) >= cast(int, self.concurrency_limit):
|
||||
return {"status": "failed"}
|
||||
|
||||
@@ -450,7 +560,7 @@ class WebRTC(Component):
|
||||
|
||||
@pc.on("iceconnectionstatechange")
|
||||
async def on_iceconnectionstatechange():
|
||||
logger.debug("ICE connection state change", pc.iceConnectionState)
|
||||
logger.debug("ICE connection state change %s", pc.iceConnectionState)
|
||||
if pc.iceConnectionState == "failed":
|
||||
await pc.close()
|
||||
self.connections.pop(body["webrtc_id"], None)
|
||||
@@ -468,12 +578,19 @@ class WebRTC(Component):
|
||||
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
cb = VideoCallback(
|
||||
self.relay.subscribe(track),
|
||||
event_handler=cast(Callable, self.event_handler),
|
||||
)
|
||||
relay = MediaRelay()
|
||||
if self.modality == "video":
|
||||
cb = VideoCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(Callable, self.event_handler),
|
||||
)
|
||||
elif self.modality == "audio":
|
||||
cb = AudioCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(StreamHandler, self.event_handler),
|
||||
)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
logger.debug("Adding track to peer connection", cb)
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
|
||||
if self.mode == "receive":
|
||||
@@ -482,7 +599,7 @@ class WebRTC(Component):
|
||||
elif self.modality == "audio":
|
||||
cb = ServerToClientAudio(cast(Callable, self.event_handler))
|
||||
|
||||
logger.debug("Adding track to peer connection", cb)
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
|
||||
|
||||
Reference in New Issue
Block a user