working prototype

This commit is contained in:
freddyaboulton
2024-10-17 15:34:57 -07:00
parent 35c2e313d2
commit cff6073df0
18 changed files with 1240 additions and 496 deletions

View File

@@ -1,3 +1,3 @@
from .webrtc import WebRTC
from .webrtc import WebRTC, StreamHandler
__all__ = ["WebRTC"]
__all__ = ["StreamHandler", "WebRTC"]

View File

@@ -15,8 +15,7 @@ AUDIO_PTIME = 0.020
def player_worker_decode(
loop,
callable: Callable,
stream,
next: Callable,
queue: asyncio.Queue,
throttle_playback: bool,
thread_quit: threading.Event,
@@ -33,22 +32,10 @@ def player_worker_decode(
frame_time = None
start_time = time.time()
generator = None
while not thread_quit.is_set():
if stream.latest_args == "not_set":
continue
if generator is None:
generator = callable(*stream.latest_args)
try:
frame = next(generator)
except Exception as exc:
if isinstance(exc, StopIteration):
logger.debug("Stopping audio stream")
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
thread_quit.set()
break
frame = next()
logger.debug("emitted %s", frame)
# read up to 1 second ahead
if throttle_playback:
elapsed_time = time.time() - start_time
@@ -56,7 +43,7 @@ def player_worker_decode(
time.sleep(0.1)
sample_rate, audio_array = frame
format = "s16" if audio_array.dtype == "int16" else "fltp"
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout="mono")
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout="stereo")
frame.sample_rate = sample_rate
for frame in audio_resampler.resample(frame):
# fix timestamps

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
import logging
import threading
import time
@@ -10,14 +11,16 @@ import traceback
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
import anyio.to_thread
import numpy as np
from aiortc import (
AudioStreamTrack,
RTCPeerConnection,
RTCSessionDescription,
VideoStreamTrack,
MediaStreamTrack,
)
from aiortc.contrib.media import MediaRelay, VideoFrame # type: ignore
from aiortc.contrib.media import MediaRelay, AudioFrame, VideoFrame # type: ignore
from aiortc.mediastreams import MediaStreamError
from gradio import wasm_utils
from gradio.components.base import Component, server
@@ -47,7 +50,7 @@ class VideoCallback(VideoStreamTrack):
def __init__(
self,
track,
track: MediaStreamTrack,
event_handler: Callable,
) -> None:
super().__init__() # don't forget this!
@@ -72,7 +75,7 @@ class VideoCallback(VideoStreamTrack):
async def recv(self):
try:
try:
frame = await self.track.recv()
frame = cast(VideoFrame, await self.track.recv())
except MediaStreamError:
return
frame_array = frame.to_ndarray(format="bgr24")
@@ -100,6 +103,100 @@ class VideoCallback(VideoStreamTrack):
logger.debug(exec)
class StreamHandler(ABC):
@abstractmethod
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
pass
@abstractmethod
def emit(self) -> None:
pass
class AudioCallback(AudioStreamTrack):
kind = "audio"
def __init__(
self,
track: MediaStreamTrack,
event_handler: StreamHandler,
) -> None:
self.track = track
self.event_handler = event_handler
self.current_timestamp = 0
self.latest_args = "not_set"
self.queue = asyncio.Queue()
self.thread_quit = threading.Event()
self.__thread = None
self._start: float | None = None
self.has_started = False
super().__init__()
async def process_input_frames(self) -> None:
while not self.thread_quit.is_set():
try:
frame = cast(AudioFrame, await self.track.recv())
numpy_array = frame.to_ndarray()
logger.debug("numpy array shape %s", numpy_array.shape)
await anyio.to_thread.run_sync(
self.event_handler.receive, (frame.sample_rate, numpy_array)
)
except MediaStreamError:
break
def start(self):
if not self.has_started:
asyncio.create_task(self.process_input_frames())
self.__thread = threading.Thread(
name="audio-output-decoders",
target=player_worker_decode,
args=(
asyncio.get_event_loop(),
self.event_handler.emit,
self.queue,
True,
self.thread_quit,
),
)
self.__thread.start()
self.has_started = True
async def recv(self):
try:
if self.readyState != "live":
raise MediaStreamError
self.start()
data = await self.queue.get()
logger.debug("data %s", data)
if data is None:
self.stop()
return
data_time = data.time
# control playback rate
if data_time is not None:
if self._start is None:
self._start = time.time() - data_time
else:
wait = self._start + data_time - time.time()
await asyncio.sleep(wait)
return data
except Exception as e:
logger.debug(e)
exec = traceback.format_exc()
logger.debug(exec)
def stop(self):
self.thread_quit.set()
if self.__thread is not None:
self.__thread.join()
self.__thread = None
super().stop()
class ServerToClientVideo(VideoStreamTrack):
"""
This works for streaming input and output
@@ -116,17 +213,6 @@ class ServerToClientVideo(VideoStreamTrack):
self.latest_args: str | list[Any] = "not_set"
self.generator: Generator[Any, None, Any] | None = None
def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None
) -> list[Any]:
new_args = []
for val in args:
if isinstance(val, str) and val == "__webrtc_value__":
new_args.append(frame)
else:
new_args.append(val)
return new_args
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
@@ -176,6 +262,25 @@ class ServerToClientAudio(AudioStreamTrack):
self._start: float | None = None
super().__init__()
def next(self) -> tuple[int, np.ndarray] | None:
import anyio
if self.latest_args == "not_set":
return
if self.generator is None:
self.generator = self.event_handler(*self.latest_args)
if self.generator is not None:
try:
frame = next(self.generator)
return frame
except Exception as exc:
if isinstance(exc, StopIteration):
logger.debug("Stopping audio stream")
asyncio.run_coroutine_threadsafe(
self.queue.put(None), asyncio.get_event_loop()
)
self.thread_quit.set()
def start(self):
if self.__thread is None:
self.__thread = threading.Thread(
@@ -183,8 +288,7 @@ class ServerToClientAudio(AudioStreamTrack):
target=player_worker_decode,
args=(
asyncio.get_event_loop(),
self.event_handler,
self,
self.next,
self.queue,
False,
self.thread_quit,
@@ -241,7 +345,7 @@ class WebRTC(Component):
pcs: set[RTCPeerConnection] = set([])
relay = MediaRelay()
connections: dict[
str, VideoCallback | ServerToClientVideo | ServerToClientAudio
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
] = {}
EVENTS = ["tick"]
@@ -300,9 +404,6 @@ class WebRTC(Component):
streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding.
watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png.
"""
if modality == "audio" and mode == "send-receive":
raise ValueError("Audio modality is not supported in send-receive mode")
self.time_limit = time_limit
self.height = height
self.width = width
@@ -358,7 +459,7 @@ class WebRTC(Component):
def stream(
self,
fn: Callable[..., Any] | None = None,
fn: Callable[..., Any] | StreamHandler | None = None,
inputs: Block | Sequence[Block] | set[Block] | None = None,
outputs: Block | Sequence[Block] | set[Block] | None = None,
js: str | None = None,
@@ -384,6 +485,15 @@ class WebRTC(Component):
self.event_handler = fn
self.time_limit = time_limit
if (
self.mode == "send-receive"
and self.modality == "audio"
and not isinstance(self.event_handler, StreamHandler)
):
raise ValueError(
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
)
if self.mode == "send-receive":
if cast(list[Block], inputs)[0] != self:
raise ValueError(
@@ -439,7 +549,7 @@ class WebRTC(Component):
@server
async def offer(self, body):
logger.debug("Starting to handle offer")
logger.debug("Offer body", body)
logger.debug("Offer body %s", body)
if len(self.connections) >= cast(int, self.concurrency_limit):
return {"status": "failed"}
@@ -450,7 +560,7 @@ class WebRTC(Component):
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
logger.debug("ICE connection state change", pc.iceConnectionState)
logger.debug("ICE connection state change %s", pc.iceConnectionState)
if pc.iceConnectionState == "failed":
await pc.close()
self.connections.pop(body["webrtc_id"], None)
@@ -468,12 +578,19 @@ class WebRTC(Component):
@pc.on("track")
def on_track(track):
cb = VideoCallback(
self.relay.subscribe(track),
event_handler=cast(Callable, self.event_handler),
)
relay = MediaRelay()
if self.modality == "video":
cb = VideoCallback(
relay.subscribe(track),
event_handler=cast(Callable, self.event_handler),
)
elif self.modality == "audio":
cb = AudioCallback(
relay.subscribe(track),
event_handler=cast(StreamHandler, self.event_handler),
)
self.connections[body["webrtc_id"]] = cb
logger.debug("Adding track to peer connection", cb)
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
if self.mode == "receive":
@@ -482,7 +599,7 @@ class WebRTC(Component):
elif self.modality == "audio":
cb = ServerToClientAudio(cast(Callable, self.event_handler))
logger.debug("Adding track to peer connection", cb)
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))

View File

@@ -1,72 +1,63 @@
import logging
# Configure the root logger to WARNING to suppress debug messages from other libraries
logging.basicConfig(level=logging.WARNING)
# Create a console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
# Create a formatter
formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
# Configure the logger for your specific library
logger = logging.getLogger("gradio_webrtc")
logger.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
import gradio as gr
import cv2
from huggingface_hub import hf_hub_download
from gradio_webrtc import WebRTC
from twilio.rest import Client
import os
from inference import YOLOv10
model_file = hf_hub_download(
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
)
model = YOLOv10(model_file)
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
import numpy as np
from gradio_webrtc import WebRTC, StreamHandler
from queue import Queue
import time
def detection(image, conf_threshold=0.3):
image = cv2.resize(image, (model.input_width, model.input_height))
new_image = model.detect_objects(image, conf_threshold)
return cv2.resize(new_image, (500, 500))
class EchoHandler(StreamHandler):
def __init__(self) -> None:
self.queue = Queue()
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
self.queue.put(frame)
def emit(self) -> None:
return self.queue.get()
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks(css=css) as demo:
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️)
Audio Streaming (Powered by WebRTC ⚡️)
</h1>
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
</h3>
"""
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
image = WebRTC(label="Stream", rtc_configuration=rtc_configuration)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
audio = WebRTC(
label="Stream",
rtc_configuration=None,
mode="send-receive",
modality="audio",
)
image.stream(
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10
)
audio.stream(fn=EchoHandler(), inputs=[audio], outputs=[audio], time_limit=15)
if __name__ == "__main__":
demo.launch()

View File

@@ -6,7 +6,6 @@ import os
from pydub import AudioSegment
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
@@ -24,10 +23,16 @@ else:
import time
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1))
segment = AudioSegment.from_file(
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
)
yield (
segment.frame_rate,
np.array(segment.get_array_of_samples()).reshape(1, -1),
)
time.sleep(3.5)
@@ -45,8 +50,12 @@ with gr.Blocks() as demo:
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
audio = WebRTC(label="Stream", rtc_configuration=rtc_configuration,
mode="receive", modality="audio")
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="audio",
)
num_steps = gr.Slider(
label="Number of Steps",
minimum=1,
@@ -57,8 +66,7 @@ with gr.Blocks() as demo:
button = gr.Button("Generate")
audio.stream(
fn=generation, inputs=[num_steps], outputs=[audio],
trigger=button.click
fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click
)

View File

@@ -6,7 +6,6 @@ import os
from pydub import AudioSegment
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
@@ -24,10 +23,16 @@ else:
import time
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1))
segment = AudioSegment.from_file(
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
)
yield (
segment.frame_rate,
np.array(segment.get_array_of_samples()).reshape(1, -1),
)
time.sleep(3.5)
@@ -48,9 +53,12 @@ with gr.Blocks() as demo:
gr.Slider()
with gr.Column():
# audio = gr.Audio(interactive=False)
audio = WebRTC(label="Stream", rtc_configuration=rtc_configuration,
mode="receive", modality="audio")
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="audio",
)
if __name__ == "__main__":

View File

@@ -1,26 +1,91 @@
import gradio as gr
import os
_docs = {'WebRTC':
{'description': 'Stream audio/video with WebRTC',
'members': {'__init__':
{
'rtc_configuration': {'type': 'dict[str, Any] | None', 'default': 'None', 'description': "The configration dictionary to pass to the RTCPeerConnection constructor. If None, the default configuration is used."},
'height': {'type': 'int | str | None', 'default': 'None', 'description': 'The height of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.'},
'width': {'type': 'int | str | None', 'default': 'None', 'description': 'The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.'},
'label': {'type': 'str | None', 'default': 'None', 'description': 'the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.'},
'show_label': {'type': 'bool | None', 'default': 'None', 'description': 'if True, will display label.'}, 'container': {'type': 'bool', 'default': 'True', 'description': 'if True, will place the component in a container - providing some extra padding around the border.'},
'scale': {'type': 'int | None', 'default': 'None', 'description': 'relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.'},
'min_width': {'type': 'int', 'default': '160', 'description': 'minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.'},
'interactive': {'type': 'bool | None', 'default': 'None', 'description': 'if True, will allow users to upload a video; if False, can only be used to display videos. If not provided, this is inferred based on whether the component is used as an input or output.'}, 'visible': {'type': 'bool', 'default': 'True', 'description': 'if False, component will be hidden.'},
'elem_id': {'type': 'str | None', 'default': 'None', 'description': 'an optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.'},
'elem_classes': {'type': 'list[str] | str | None', 'default': 'None', 'description': 'an optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.'},
'render': {'type': 'bool', 'default': 'True', 'description': 'if False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.'},
'key': {'type': 'int | str | None', 'default': 'None', 'description': 'if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.'},
'mirror_webcam': {'type': 'bool', 'default': 'True', 'description': 'if True webcam will be mirrored. Default is True.'},
},
'events': {'tick': {'type': None, 'default': None, 'description': ''}}}, '__meta__': {'additional_interfaces': {}, 'user_fn_refs': {'WebRTC': []}}}
_docs = {
"WebRTC": {
"description": "Stream audio/video with WebRTC",
"members": {
"__init__": {
"rtc_configuration": {
"type": "dict[str, Any] | None",
"default": "None",
"description": "The configration dictionary to pass to the RTCPeerConnection constructor. If None, the default configuration is used.",
},
"height": {
"type": "int | str | None",
"default": "None",
"description": "The height of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
},
"width": {
"type": "int | str | None",
"default": "None",
"description": "The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
},
"label": {
"type": "str | None",
"default": "None",
"description": "the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.",
},
"show_label": {
"type": "bool | None",
"default": "None",
"description": "if True, will display label.",
},
"container": {
"type": "bool",
"default": "True",
"description": "if True, will place the component in a container - providing some extra padding around the border.",
},
"scale": {
"type": "int | None",
"default": "None",
"description": "relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.",
},
"min_width": {
"type": "int",
"default": "160",
"description": "minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.",
},
"interactive": {
"type": "bool | None",
"default": "None",
"description": "if True, will allow users to upload a video; if False, can only be used to display videos. If not provided, this is inferred based on whether the component is used as an input or output.",
},
"visible": {
"type": "bool",
"default": "True",
"description": "if False, component will be hidden.",
},
"elem_id": {
"type": "str | None",
"default": "None",
"description": "an optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.",
},
"elem_classes": {
"type": "list[str] | str | None",
"default": "None",
"description": "an optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.",
},
"render": {
"type": "bool",
"default": "True",
"description": "if False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.",
},
"key": {
"type": "int | str | None",
"default": "None",
"description": "if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.",
},
"mirror_webcam": {
"type": "bool",
"default": "True",
"description": "if True webcam will be mirrored. Default is True.",
},
},
"events": {"tick": {"type": None, "default": None, "description": ""}},
},
"__meta__": {"additional_interfaces": {}, "user_fn_refs": {"WebRTC": []}},
}
}
@@ -36,16 +101,19 @@ with gr.Blocks(
),
) as demo:
gr.Markdown(
"""
"""
<h1 style='text-align: center; margin-bottom: 1rem'> Gradio WebRTC ⚡️ </h1>
<div style="display: flex; flex-direction: row; justify-content: center">
<img style="display: block; padding-right: 5px; height: 20px;" alt="Static Badge" src="https://img.shields.io/badge/version%20-%200.0.5%20-%20orange">
<a href="https://github.com/freddyaboulton/gradio-webrtc" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/github-white?logo=github&logoColor=black"></a>
</div>
""", elem_classes=["md-custom"], header_links=True)
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.Markdown(
"""
"""
## Installation
```bash
@@ -195,17 +263,24 @@ with gr.Blocks() as demo:
rtc = WebRTC(rtc_configuration=rtc_configuration, ...)
...
```
""", elem_classes=["md-custom"], header_links=True)
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.Markdown("""
gr.Markdown(
"""
##
""", elem_classes=["md-custom"], header_links=True)
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.ParamViewer(value=_docs["WebRTC"]["members"]["__init__"], linkify=[])
demo.load(None, js=r"""function() {
demo.load(
None,
js=r"""function() {
const refs = {};
const user_fn_refs = {
WebRTC: [], };
@@ -239,6 +314,7 @@ with gr.Blocks() as demo:
})
}
""")
""",
)
demo.launch()

View File

@@ -24,7 +24,6 @@ else:
def generation(input_video):
cap = cv2.VideoCapture(input_video)
iterating = True
while iterating:
@@ -35,6 +34,7 @@ def generation(input_video):
display_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
yield display_frame
with gr.Blocks() as demo:
gr.HTML(
"""
@@ -47,11 +47,17 @@ with gr.Blocks() as demo:
with gr.Column():
input_video = gr.Video(sources="upload")
with gr.Column():
output_video = WebRTC(label="Video Stream", rtc_configuration=rtc_configuration,
mode="receive", modality="video")
output_video = WebRTC(
label="Video Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="video",
)
output_video.stream(
fn=generation, inputs=[input_video], outputs=[output_video],
trigger=input_video.upload
fn=generation,
inputs=[input_video],
outputs=[output_video],
trigger=input_video.upload,
)

View File

@@ -30,7 +30,6 @@ def generation():
yield frame
with gr.Blocks() as demo:
gr.HTML(
"""
@@ -39,12 +38,15 @@ with gr.Blocks() as demo:
</h1>
"""
)
output_video = WebRTC(label="Video Stream", rtc_configuration=rtc_configuration,
mode="receive", modality="video")
output_video = WebRTC(
label="Video Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="video",
)
button = gr.Button("Start", variant="primary")
output_video.stream(
fn=generation, inputs=None, outputs=[output_video],
trigger=button.click
fn=generation, inputs=None, outputs=[output_video], trigger=button.click
)

View File

@@ -7,6 +7,7 @@
import type { LoadingStatus } from "@gradio/statustracker";
import StaticVideo from "./shared/StaticVideo.svelte";
import StaticAudio from "./shared/StaticAudio.svelte";
import InteractiveAudio from "./shared/InteractiveAudio.svelte";
export let elem_id = "";
export let elem_classes: string[] = [];
@@ -37,8 +38,7 @@
$: console.log("value", value);
</script>
{#if mode == "receive" && modality === "video"}
<Block
<Block
{visible}
variant={"solid"}
border_mode={dragging ? "focus" : "base"}
@@ -59,6 +59,7 @@
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
/>
{#if mode == "receive" && modality === "video"}
<StaticVideo
bind:value={value}
{label}
@@ -68,27 +69,7 @@
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
</Block>
{:else if mode == "receive" && modality === "audio"}
<Block
variant={"solid"}
border_mode={dragging ? "focus" : "base"}
padding={false}
allow_overflow={false}
{elem_id}
{elem_classes}
{visible}
{container}
{scale}
{min_width}
>
<StatusTracker
autoscroll={gradio.autoscroll}
i18n={gradio.i18n}
{...loading_status}
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
/>
{:else if mode == "receive" && modality === "audio"}
<StaticAudio
bind:value={value}
{label}
@@ -99,28 +80,7 @@
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
</Block>
{:else if mode === "send-receive" && modality === "video"}
<Block
{visible}
variant={"solid"}
border_mode={dragging ? "focus" : "base"}
padding={false}
{elem_id}
{elem_classes}
{height}
{width}
{container}
{scale}
{min_width}
allow_overflow={false}
>
<StatusTracker
autoscroll={gradio.autoscroll}
i18n={gradio.i18n}
{...loading_status}
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
/>
{:else if mode === "send-receive" && modality === "video"}
<Video
bind:value={value}
{label}
@@ -145,5 +105,17 @@
>
<UploadText i18n={gradio.i18n} type="video" />
</Video>
</Block>
{/if}
{:else if mode === "send-receive" && modality === "audio"}
<InteractiveAudio
bind:value={value}
{label}
{show_label}
{server}
{rtc_configuration}
{time_limit}
i18n={gradio.i18n}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
{/if}
</Block>

File diff suppressed because it is too large Load Diff

View File

@@ -9,15 +9,15 @@
"dependencies": {
"@ffmpeg/ffmpeg": "^0.12.10",
"@ffmpeg/util": "^0.12.1",
"@gradio/atoms": "0.9.0",
"@gradio/client": "1.6.0",
"@gradio/atoms": "0.9.2",
"@gradio/client": "1.7.0",
"@gradio/icons": "0.8.0",
"@gradio/image": "0.16.0",
"@gradio/markdown": "^0.10.0",
"@gradio/statustracker": "0.8.0",
"@gradio/upload": "0.13.0",
"@gradio/image": "0.16.4",
"@gradio/markdown": "^0.10.3",
"@gradio/statustracker": "0.9.1",
"@gradio/upload": "0.13.3",
"@gradio/utils": "0.7.0",
"@gradio/wasm": "0.14.0",
"@gradio/wasm": "0.14.2",
"hls.js": "^1.5.16",
"mrmime": "^2.0.0"
},

View File

@@ -2,7 +2,7 @@
import { onMount, onDestroy } from 'svelte';
export let numBars = 16;
export let stream_state: "open" | "closed" = "closed";
export let stream_state: "open" | "closed" | "waiting" = "closed";
export let audio_source: HTMLAudioElement;
let audioContext: AudioContext;
@@ -69,17 +69,12 @@
{#each Array(numBars) as _}
<div class="box"></div>
{/each}
</div>
<button class="muteButton" on:click={toggleMute}>
{is_muted ? '🔈' : '🔊'}
</div>
</div>
<style>
.waveContainer {
position: relative;
display: flex;
flex-direction: column;
display: flex;
}
@@ -98,15 +93,4 @@
border-radius: 8px;
transition: transform 0.05s ease;
}
.muteButton {
margin-top: 10px;
padding: 10px 20px;
font-size: 24px;
cursor: pointer;
background: none;
border: none;
border-radius: 5px;
color: var(--color-accent);
}

View File

@@ -0,0 +1,244 @@
<script lang="ts">
import {
BlockLabel,
} from "@gradio/atoms";
import type { I18nFormatter } from "@gradio/utils";
import { createEventDispatcher } from "svelte";
import { onMount } from "svelte";
import { StreamingBar } from "@gradio/statustracker";
import {
Circle,
Square,
Spinner,
Music
} from "@gradio/icons";
import { start, stop } from "./webrtc_utils";
import AudioWave from "./AudioWave.svelte";
export let value: string | null = null;
export let label: string | undefined = undefined;
export let show_label = true;
export let rtc_configuration: Object | null = null;
export let i18n: I18nFormatter;
export let time_limit: number | null = null;
let _time_limit: number | null = null;
$: console.log("time_limit", time_limit);
export let server: {
offer: (body: any) => Promise<any>;
};
let stream_state: "open" | "closed" | "waiting" = "closed";
let audio_player: HTMLAudioElement;
let pc: RTCPeerConnection;
let _webrtc_id = Math.random().toString(36).substring(2);
const dispatch = createEventDispatcher<{
tick: undefined;
error: string
play: undefined;
stop: undefined;
}>();
onMount(() => {
window.setInterval(() => {
if (stream_state == "open") {
dispatch("tick");
}
}, 1000);
}
)
async function start_stream(): Promise<void> {
if( stream_state === "open"){
stop(pc);
stream_state = "closed";
_time_limit = null;
return;
}
value = _webrtc_id;
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
console.info("connected");
stream_state = "open";
_time_limit = time_limit;
break;
case "disconnected":
console.info("closed");
stream_state = "closed";
_time_limit = null;
stop(pc);
break;
default:
break;
}
}
)
stream_state = "waiting"
let stream = null
try {
stream = await navigator.mediaDevices.getUserMedia({ audio: {
echoCancellation: true,
noiseSuppression: {exact: true},
autoGainControl: {exact: true},
sampleRate: {ideal: 48000},
sampleSize: {ideal: 16},
channelCount: 2,
} });
} catch (err) {
if (!navigator.mediaDevices) {
dispatch("error", i18n("audio.no_device_support"));
return;
}
if (err instanceof DOMException && err.name == "NotAllowedError") {
dispatch("error", i18n("audio.allow_recording_access"));
return;
}
throw err;
}
if (stream == null) return;
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => {
pc = connection;
}).catch(() => {
console.info("catching")
dispatch("error", "Too many concurrent users. Come back later!");
});
}
</script>
<BlockLabel
{show_label}
Icon={Music}
float={false}
label={label || i18n("audio.audio")}
/>
<div class="audio-container">
<audio
class="standard-player"
class:hidden={value === "__webrtc_value__"}
on:load
bind:this={audio_player}
on:ended={() => dispatch("stop")}
on:play={() => dispatch("play")}
/>
<AudioWave audio_source={audio_player} {stream_state}/>
<StreamingBar time_limit={_time_limit} />
<div class="button-wrap">
<button
on:click={start_stream}
aria-label={"start stream"}
>
{#if stream_state === "waiting"}
<div class="icon-with-text" style="width:var(--size-24);">
<div class="icon color-primary" title="spinner">
<Spinner />
</div>
{i18n("audio.waiting")}
</div>
{:else if stream_state === "open"}
<div class="icon-with-text">
<div class="icon color-primary" title="stop recording">
<Square />
</div>
{i18n("audio.stop")}
</div>
{:else}
<div class="icon-with-text">
<div class="icon color-primary" title="start recording">
<Circle />
</div>
{i18n("audio.record")}
</div>
{/if}
</button>
</div>
</div>
<style>
.audio-container {
display: flex;
height: 100%;
flex-direction: column;
justify-content: center;
align-items: center;
}
:global(::part(wrapper)) {
margin-bottom: var(--size-2);
}
.standard-player {
width: 100%;
padding: var(--size-2);
}
.hidden {
display: none;
}
.button-wrap {
margin-top: var(--size-2);
margin-bottom: var(--size-2);
background-color: var(--block-background-fill);
border: 1px solid var(--border-color-primary);
border-radius: var(--radius-xl);
padding: var(--size-1-5);
display: flex;
bottom: var(--size-2);
box-shadow: var(--shadow-drop-lg);
border-radius: var(--radius-xl);
line-height: var(--size-3);
color: var(--button-secondary-text-color);
}
.icon-with-text {
width: var(--size-20);
align-items: center;
margin: 0 var(--spacing-xl);
display: flex;
justify-content: space-evenly;
}
@media (--screen-md) {
button {
bottom: var(--size-4);
}
}
@media (--screen-xl) {
button {
bottom: var(--size-8);
}
}
.icon {
width: 18px;
height: 18px;
display: flex;
justify-content: space-between;
align-items: center;
}
.color-primary {
fill: var(--primary-600);
stroke: var(--primary-600);
color: var(--primary-600);
}
</style>

View File

@@ -62,22 +62,6 @@
</div>
<style>
.file-name {
padding: var(--size-6);
font-size: var(--text-xxl);
word-break: break-all;
}
.file-size {
padding: var(--size-2);
font-size: var(--text-xl);
}
.upload-container {
height: 100%;
width: 100%;
}
.video-container {
display: flex;
height: 100%;

View File

@@ -17,13 +17,12 @@
export let show_label = true;
export let rtc_configuration: Object | null = null;
export let i18n: I18nFormatter;
export let autoplay: boolean = true;
export let server: {
offer: (body: any) => Promise<any>;
};
let stream_state = "closed";
let stream_state: "open" | "closed" | "connecting" = "closed";
let audio_player: HTMLAudioElement;
let pc: RTCPeerConnection;
let _webrtc_id = Math.random().toString(36).substring(2);
@@ -46,33 +45,38 @@
}
)
$: if( value === "start_webrtc_stream") {
stream_state = "connecting";
value = _webrtc_id;
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
console.info("connected");
stream_state = "open";
break;
case "disconnected":
console.info("closed");
stop(pc);
break;
default:
break;
async function start_stream(value: string): Promise<void> {
if( value === "start_webrtc_stream") {
stream_state = "connecting";
value = _webrtc_id;
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
console.info("connected");
stream_state = "open";
break;
case "disconnected":
console.info("closed");
stop(pc);
break;
default:
break;
}
}
}
)
start(null, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => {
pc = connection;
}).catch(() => {
console.info("catching")
dispatch("error", "Too many concurrent users. Come back later!");
});
}
)
let stream = null;
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => {
pc = connection;
}).catch(() => {
console.info("catching")
dispatch("error", "Too many concurrent users. Come back later!");
});
}
}
$: start_stream(value);

View File

@@ -21,8 +21,8 @@ export async function get_video_stream(
device_id?: string
): Promise<MediaStream> {
const size = {
width: { ideal: 1920 },
height: { ideal: 1440 }
width: { ideal: 500 },
height: { ideal: 500 }
};
const constraints = {

View File

@@ -35,7 +35,6 @@ export function createPeerConnection(pc, node) {
node.volume = 1.0; // Ensure volume is up
node.muted = false;
node.autoplay = true;
// Attempt to play (needed for some browsers)
node.play().catch(e => console.debug("Autoplay failed:", e));
}
@@ -49,8 +48,8 @@ export async function start(stream, pc: RTCPeerConnection, node, server_fn, webr
pc = createPeerConnection(pc, node);
if (stream) {
stream.getTracks().forEach((track) => {
track.applyConstraints({ frameRate: { max: 30 } });
if(modality == "video") track.applyConstraints({ frameRate: { max: 30 } });
else if(modality == "audio") track.applyConstraints({ sampleRate: 48000, channelCount: 1 });
console.debug("Track stream callback", track);
pc.addTrack(track, stream);
});