mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
implementation
This commit is contained in:
63
backend/gradio_webrtc/utils.py
Normal file
63
backend/gradio_webrtc/utils.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import time
|
||||
import fractions
|
||||
import av
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Callable
|
||||
|
||||
AUDIO_PTIME = 0.020
|
||||
|
||||
|
||||
def player_worker_decode(
|
||||
loop,
|
||||
callable: Callable,
|
||||
stream,
|
||||
queue: asyncio.Queue,
|
||||
throttle_playback: bool,
|
||||
thread_quit: threading.Event,
|
||||
):
|
||||
audio_sample_rate = 48000
|
||||
audio_samples = 0
|
||||
audio_time_base = fractions.Fraction(1, audio_sample_rate)
|
||||
audio_resampler = av.AudioResampler(
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
rate=audio_sample_rate,
|
||||
frame_size=int(audio_sample_rate * AUDIO_PTIME),
|
||||
)
|
||||
|
||||
frame_time = None
|
||||
start_time = time.time()
|
||||
generator = None
|
||||
|
||||
while not thread_quit.is_set():
|
||||
print("stream.latest_args", stream.latest_args)
|
||||
if stream.latest_args == "not_set":
|
||||
continue
|
||||
if generator is None:
|
||||
generator = callable(*stream.latest_args)
|
||||
try:
|
||||
frame = next(generator)
|
||||
except Exception as exc:
|
||||
if isinstance(exc, StopIteration):
|
||||
print("Not iterating")
|
||||
asyncio.run_coroutine_threadsafe(queue.put(frame), loop)
|
||||
thread_quit.set()
|
||||
break
|
||||
|
||||
# read up to 1 second ahead
|
||||
if throttle_playback:
|
||||
elapsed_time = time.time() - start_time
|
||||
if frame_time and frame_time > elapsed_time + 1:
|
||||
time.sleep(0.1)
|
||||
sample_rate, audio_array = frame
|
||||
frame = av.AudioFrame.from_ndarray(audio_array, format="s16", layout="mono")
|
||||
frame.sample_rate = sample_rate
|
||||
for frame in audio_resampler.resample(frame):
|
||||
# fix timestamps
|
||||
frame.pts = audio_samples
|
||||
frame.time_base = audio_time_base
|
||||
audio_samples += frame.samples
|
||||
|
||||
frame_time = frame.time
|
||||
asyncio.run_coroutine_threadsafe(queue.put(frame), loop)
|
||||
@@ -2,16 +2,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, Generator
|
||||
|
||||
import fractions
|
||||
import threading
|
||||
import time
|
||||
from gradio_webrtc.utils import player_worker_decode
|
||||
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
from aiortc.contrib.media import MediaRelay
|
||||
from aiortc import VideoStreamTrack
|
||||
from aiortc import VideoStreamTrack, AudioStreamTrack
|
||||
from aiortc.mediastreams import MediaStreamError
|
||||
from aiortc.contrib.media import VideoFrame # type: ignore
|
||||
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
|
||||
from gradio_client import handle_file
|
||||
import numpy as np
|
||||
|
||||
@@ -124,7 +128,6 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
|
||||
async def recv(self):
|
||||
try:
|
||||
|
||||
pts, time_base = await self.next_timestamp()
|
||||
if self.latest_args == "not_set":
|
||||
frame = self.array_to_frame(np.zeros((480, 640, 3), dtype=np.uint8))
|
||||
@@ -137,12 +140,9 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
try:
|
||||
next_array = next(self.generator)
|
||||
except StopIteration:
|
||||
print("exception")
|
||||
self.stop()
|
||||
return
|
||||
|
||||
print("pts", pts)
|
||||
print("time_base", time_base)
|
||||
next_frame = self.array_to_frame(next_array)
|
||||
next_frame.pts = pts
|
||||
next_frame.time_base = time_base
|
||||
@@ -153,6 +153,131 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class ServerToClientAudio(AudioStreamTrack):
|
||||
kind = "audio"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_handler: Callable,
|
||||
) -> None:
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
self.event_handler = event_handler
|
||||
self.current_timestamp = 0
|
||||
self.latest_args = "not_set"
|
||||
self.queue = asyncio.Queue()
|
||||
self.thread_quit = threading.Event()
|
||||
self.__thread = None
|
||||
self._start: float | None = None
|
||||
super().__init__()
|
||||
|
||||
def array_to_frame(self, array: tuple[int, np.ndarray]) -> AudioFrame:
|
||||
frame = AudioFrame.from_ndarray(array[1], format="s16", layout="mono")
|
||||
frame.sample_rate = array[0]
|
||||
frame.time_base = fractions.Fraction(1, array[0])
|
||||
self.current_timestamp += array[1].shape[1]
|
||||
frame.pts = self.current_timestamp
|
||||
return frame
|
||||
|
||||
async def empty_frame(self) -> AudioFrame:
|
||||
sample_rate = 22050
|
||||
samples = 100
|
||||
frame = AudioFrame(format="s16", layout="mono", samples=samples)
|
||||
for p in frame.planes:
|
||||
p.update(bytes(p.buffer_size))
|
||||
frame.sample_rate = sample_rate
|
||||
frame.time_base = fractions.Fraction(1, sample_rate)
|
||||
self.current_timestamp += samples
|
||||
frame.pts = self.current_timestamp
|
||||
return frame
|
||||
|
||||
def start(self):
|
||||
if self.__thread is None:
|
||||
self.__thread = threading.Thread(
|
||||
name="generator-runner",
|
||||
target=player_worker_decode,
|
||||
args=(
|
||||
asyncio.get_event_loop(),
|
||||
self.event_handler,
|
||||
self,
|
||||
self.queue,
|
||||
False,
|
||||
self.thread_quit
|
||||
),
|
||||
)
|
||||
self.__thread.start()
|
||||
|
||||
|
||||
async def recv(self):
|
||||
try:
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
self.start()
|
||||
data = await self.queue.get()
|
||||
if data is None:
|
||||
self.stop()
|
||||
raise MediaStreamError
|
||||
|
||||
data_time = data.time
|
||||
|
||||
# control playback rate
|
||||
if (
|
||||
data_time is not None
|
||||
):
|
||||
if self._start is None:
|
||||
self._start = time.time() - data_time
|
||||
else:
|
||||
wait = self._start + data_time - time.time()
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
self.thread_quit.set()
|
||||
if self.__thread is not None:
|
||||
self.__thread.join()
|
||||
self.__thread = None
|
||||
|
||||
# next_frame = await super().recv()
|
||||
# print("next frame", next_frame)
|
||||
# return next_frame
|
||||
#try:
|
||||
# if self.latest_args == "not_set":
|
||||
# frame = await self.empty_frame()
|
||||
|
||||
# # await self.modify_frame(frame)
|
||||
# await asyncio.sleep(100 / 22050)
|
||||
# print("next_frame not set", frame)
|
||||
# return frame
|
||||
# if self.generator is None:
|
||||
# self.generator = cast(
|
||||
# Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
||||
# )
|
||||
|
||||
# try:
|
||||
# next_array = next(self.generator)
|
||||
# print("iteration")
|
||||
# except StopIteration:
|
||||
# print("exception")
|
||||
# self.stop() # type: ignore
|
||||
# return
|
||||
# next_frame = self.array_to_frame(next_array)
|
||||
# # await self.modify_frame(next_frame)
|
||||
# print("next frame", next_frame)
|
||||
# return next_frame
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# import traceback
|
||||
|
||||
# traceback.print_exc()
|
||||
|
||||
|
||||
class WebRTC(Component):
|
||||
"""
|
||||
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
||||
@@ -166,7 +291,9 @@ class WebRTC(Component):
|
||||
|
||||
pcs: set[RTCPeerConnection] = set([])
|
||||
relay = MediaRelay()
|
||||
connections: dict[str, VideoCallback | ServerToClientVideo] = {}
|
||||
connections: dict[
|
||||
str, VideoCallback | ServerToClientVideo | ServerToClientAudio
|
||||
] = {}
|
||||
|
||||
EVENTS = ["tick"]
|
||||
|
||||
@@ -191,7 +318,8 @@ class WebRTC(Component):
|
||||
mirror_webcam: bool = True,
|
||||
rtc_configuration: dict[str, Any] | None = None,
|
||||
time_limit: float | None = None,
|
||||
mode: Literal["video-in-out", "video-out"] = "video-in-out",
|
||||
mode: Literal["send-receive", "receive"] = "send-receive",
|
||||
modality: Literal["video", "audio"] = "video",
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@@ -223,6 +351,9 @@ class WebRTC(Component):
|
||||
streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding.
|
||||
watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png.
|
||||
"""
|
||||
if modality == "audio" and mode == "send-receive":
|
||||
raise ValueError("Audio modality is not supported in send-receive mode")
|
||||
|
||||
self.time_limit = time_limit
|
||||
self.height = height
|
||||
self.width = width
|
||||
@@ -230,6 +361,7 @@ class WebRTC(Component):
|
||||
self.concurrency_limit = 1
|
||||
self.rtc_configuration = rtc_configuration
|
||||
self.mode = mode
|
||||
self.modality = modality
|
||||
self.event_handler: Callable | None = None
|
||||
super().__init__(
|
||||
label=label,
|
||||
@@ -268,9 +400,11 @@ class WebRTC(Component):
|
||||
|
||||
def set_output(self, webrtc_id: str, *args):
|
||||
if webrtc_id in self.connections:
|
||||
if self.mode == "video-in-out":
|
||||
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
|
||||
elif self.mode == "video-out":
|
||||
if self.mode == "send-receive":
|
||||
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(
|
||||
args
|
||||
)
|
||||
elif self.mode == "receive":
|
||||
self.connections[webrtc_id].latest_args = list(args)
|
||||
|
||||
def stream(
|
||||
@@ -296,9 +430,8 @@ class WebRTC(Component):
|
||||
)
|
||||
self.event_handler = fn
|
||||
self.time_limit = time_limit
|
||||
|
||||
if self.mode == "video-in-out":
|
||||
|
||||
if self.mode == "send-receive":
|
||||
if cast(list[Block], inputs)[0] != self:
|
||||
raise ValueError(
|
||||
"In the webrtc stream event, the first input component must be the WebRTC component."
|
||||
@@ -321,27 +454,29 @@ class WebRTC(Component):
|
||||
time_limit=None,
|
||||
js=js,
|
||||
)
|
||||
elif self.mode == "video-out":
|
||||
elif self.mode == "receive":
|
||||
if self in cast(list[Block], inputs):
|
||||
raise ValueError(
|
||||
"In the video-out stream event, the WebRTC component cannot be an input."
|
||||
"In the receive mode stream event, the WebRTC component cannot be an input."
|
||||
)
|
||||
if (
|
||||
len(cast(list[Block], outputs)) != 1
|
||||
and cast(list[Block], outputs)[0] != self
|
||||
):
|
||||
raise ValueError(
|
||||
"In the video-out stream, the only output component must be the WebRTC component."
|
||||
"In the receive mode stream, the only output component must be the WebRTC component."
|
||||
)
|
||||
if trigger is None:
|
||||
raise ValueError(
|
||||
"In the video-out stream event, the trigger parameter must be provided"
|
||||
"In the receive mode stream event, the trigger parameter must be provided"
|
||||
)
|
||||
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
|
||||
self.tick(
|
||||
self.set_output, inputs=[self] + inputs, outputs=None, concurrency_id=concurrency_id
|
||||
self.set_output,
|
||||
inputs=[self] + inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
||||
@@ -350,6 +485,7 @@ class WebRTC(Component):
|
||||
|
||||
@server
|
||||
async def offer(self, body):
|
||||
print("starting")
|
||||
if len(self.connections) >= cast(int, self.concurrency_limit):
|
||||
return {"status": "failed"}
|
||||
|
||||
@@ -384,19 +520,29 @@ class WebRTC(Component):
|
||||
)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
pc.addTrack(cb)
|
||||
|
||||
if self.mode == "video-out":
|
||||
|
||||
if self.mode == "receive" and self.modality == "video":
|
||||
cb = ServerToClientVideo(cast(Callable, self.event_handler))
|
||||
pc.addTrack(cb)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
if self.mode == "receive" and self.modality == "audio":
|
||||
print("adding")
|
||||
cb = ServerToClientAudio(cast(Callable, self.event_handler))
|
||||
print("cb.recv", cb.recv)
|
||||
# from aiortc.contrib.media import MediaPlayer
|
||||
# player = MediaPlayer("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
|
||||
# pc.addTrack(player.audio)
|
||||
pc.addTrack(cb)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
|
||||
|
||||
print("here")
|
||||
# handle offer
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
# send answer
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer) # type: ignore
|
||||
print("done")
|
||||
|
||||
return {
|
||||
"sdp": pc.localDescription.sdp,
|
||||
|
||||
Reference in New Issue
Block a user