mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Merge pull request #5 from freddyaboulton/server-to-client-audio
Server to client audio
This commit is contained in:
28
README.md
28
README.md
@@ -10,7 +10,7 @@ app_file: space.py
|
|||||||
---
|
---
|
||||||
|
|
||||||
# `gradio_webrtc`
|
# `gradio_webrtc`
|
||||||
<img alt="Static Badge" src="https://img.shields.io/badge/version%20-%200.0.1%20-%20orange">
|
<a href="https://pypi.org/project/gradio_webrtc/" target="_blank"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/gradio_webrtc"></a>
|
||||||
|
|
||||||
Stream images in realtime with webrtc
|
Stream images in realtime with webrtc
|
||||||
|
|
||||||
@@ -358,6 +358,32 @@ float | None
|
|||||||
<td align="left"><code>None</code></td>
|
<td align="left"><code>None</code></td>
|
||||||
<td align="left">None</td>
|
<td align="left">None</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td align="left"><code>mode</code></td>
|
||||||
|
<td align="left" style="width: 25%;">
|
||||||
|
|
||||||
|
```python
|
||||||
|
Literal["send-receive", "receive"]
|
||||||
|
```
|
||||||
|
|
||||||
|
</td>
|
||||||
|
<td align="left"><code>"send-receive"</code></td>
|
||||||
|
<td align="left">None</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td align="left"><code>modality</code></td>
|
||||||
|
<td align="left" style="width: 25%;">
|
||||||
|
|
||||||
|
```python
|
||||||
|
Literal["video", "audio"]
|
||||||
|
```
|
||||||
|
|
||||||
|
</td>
|
||||||
|
<td align="left"><code>"video"</code></td>
|
||||||
|
<td align="left">None</td>
|
||||||
|
</tr>
|
||||||
</tbody></table>
|
</tbody></table>
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
63
backend/gradio_webrtc/utils.py
Normal file
63
backend/gradio_webrtc/utils.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import time
|
||||||
|
import fractions
|
||||||
|
import av
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
AUDIO_PTIME = 0.020
|
||||||
|
|
||||||
|
|
||||||
|
def player_worker_decode(
|
||||||
|
loop,
|
||||||
|
callable: Callable,
|
||||||
|
stream,
|
||||||
|
queue: asyncio.Queue,
|
||||||
|
throttle_playback: bool,
|
||||||
|
thread_quit: threading.Event,
|
||||||
|
):
|
||||||
|
audio_sample_rate = 48000
|
||||||
|
audio_samples = 0
|
||||||
|
audio_time_base = fractions.Fraction(1, audio_sample_rate)
|
||||||
|
audio_resampler = av.AudioResampler(
|
||||||
|
format="s16",
|
||||||
|
layout="stereo",
|
||||||
|
rate=audio_sample_rate,
|
||||||
|
frame_size=int(audio_sample_rate * AUDIO_PTIME),
|
||||||
|
)
|
||||||
|
|
||||||
|
frame_time = None
|
||||||
|
start_time = time.time()
|
||||||
|
generator = None
|
||||||
|
|
||||||
|
while not thread_quit.is_set():
|
||||||
|
if stream.latest_args == "not_set":
|
||||||
|
continue
|
||||||
|
if generator is None:
|
||||||
|
generator = callable(*stream.latest_args)
|
||||||
|
try:
|
||||||
|
frame = next(generator)
|
||||||
|
except Exception as exc:
|
||||||
|
if isinstance(exc, StopIteration):
|
||||||
|
print("Not iterating")
|
||||||
|
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
||||||
|
thread_quit.set()
|
||||||
|
break
|
||||||
|
|
||||||
|
# read up to 1 second ahead
|
||||||
|
if throttle_playback:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
if frame_time and frame_time > elapsed_time + 1:
|
||||||
|
time.sleep(0.1)
|
||||||
|
sample_rate, audio_array = frame
|
||||||
|
format = "s16" if audio_array.dtype == "int16" else "fltp"
|
||||||
|
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout="mono")
|
||||||
|
frame.sample_rate = sample_rate
|
||||||
|
for frame in audio_resampler.resample(frame):
|
||||||
|
# fix timestamps
|
||||||
|
frame.pts = audio_samples
|
||||||
|
frame.time_base = audio_time_base
|
||||||
|
audio_samples += frame.samples
|
||||||
|
|
||||||
|
frame_time = frame.time
|
||||||
|
asyncio.run_coroutine_threadsafe(queue.put(frame), loop)
|
||||||
@@ -2,16 +2,20 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast, Generator
|
from typing import TYPE_CHECKING, Any, Literal, cast, Generator
|
||||||
|
import fractions
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from gradio_webrtc.utils import player_worker_decode
|
||||||
|
|
||||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||||
from aiortc.contrib.media import MediaRelay
|
from aiortc.contrib.media import MediaRelay
|
||||||
from aiortc import VideoStreamTrack
|
from aiortc import VideoStreamTrack, AudioStreamTrack
|
||||||
from aiortc.mediastreams import MediaStreamError
|
from aiortc.mediastreams import MediaStreamError
|
||||||
from aiortc.contrib.media import VideoFrame # type: ignore
|
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
|
||||||
from gradio_client import handle_file
|
from gradio_client import handle_file
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -124,7 +128,6 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
pts, time_base = await self.next_timestamp()
|
pts, time_base = await self.next_timestamp()
|
||||||
if self.latest_args == "not_set":
|
if self.latest_args == "not_set":
|
||||||
frame = self.array_to_frame(np.zeros((480, 640, 3), dtype=np.uint8))
|
frame = self.array_to_frame(np.zeros((480, 640, 3), dtype=np.uint8))
|
||||||
@@ -132,17 +135,16 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
frame.time_base = time_base
|
frame.time_base = time_base
|
||||||
return frame
|
return frame
|
||||||
elif self.generator is None:
|
elif self.generator is None:
|
||||||
self.generator = cast(Generator[Any, None, Any], self.event_handler(*self.latest_args))
|
self.generator = cast(
|
||||||
|
Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
next_array = next(self.generator)
|
next_array = next(self.generator)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
print("exception")
|
|
||||||
self.stop()
|
self.stop()
|
||||||
return
|
return
|
||||||
|
|
||||||
print("pts", pts)
|
|
||||||
print("time_base", time_base)
|
|
||||||
next_frame = self.array_to_frame(next_array)
|
next_frame = self.array_to_frame(next_array)
|
||||||
next_frame.pts = pts
|
next_frame.pts = pts
|
||||||
next_frame.time_base = time_base
|
next_frame.time_base = time_base
|
||||||
@@ -150,9 +152,132 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
class ServerToClientAudio(AudioStreamTrack):
|
||||||
|
kind = "audio"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
event_handler: Callable,
|
||||||
|
) -> None:
|
||||||
|
self.generator: Generator[Any, None, Any] | None = None
|
||||||
|
self.event_handler = event_handler
|
||||||
|
self.current_timestamp = 0
|
||||||
|
self.latest_args = "not_set"
|
||||||
|
self.queue = asyncio.Queue()
|
||||||
|
self.thread_quit = threading.Event()
|
||||||
|
self.__thread = None
|
||||||
|
self._start: float | None = None
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def array_to_frame(self, array: tuple[int, np.ndarray]) -> AudioFrame:
|
||||||
|
frame = AudioFrame.from_ndarray(array[1], format="s16", layout="mono")
|
||||||
|
frame.sample_rate = array[0]
|
||||||
|
frame.time_base = fractions.Fraction(1, array[0])
|
||||||
|
self.current_timestamp += array[1].shape[1]
|
||||||
|
frame.pts = self.current_timestamp
|
||||||
|
return frame
|
||||||
|
|
||||||
|
async def empty_frame(self) -> AudioFrame:
|
||||||
|
sample_rate = 22050
|
||||||
|
samples = 100
|
||||||
|
frame = AudioFrame(format="s16", layout="mono", samples=samples)
|
||||||
|
for p in frame.planes:
|
||||||
|
p.update(bytes(p.buffer_size))
|
||||||
|
frame.sample_rate = sample_rate
|
||||||
|
frame.time_base = fractions.Fraction(1, sample_rate)
|
||||||
|
self.current_timestamp += samples
|
||||||
|
frame.pts = self.current_timestamp
|
||||||
|
return frame
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
if self.__thread is None:
|
||||||
|
self.__thread = threading.Thread(
|
||||||
|
name="generator-runner",
|
||||||
|
target=player_worker_decode,
|
||||||
|
args=(
|
||||||
|
asyncio.get_event_loop(),
|
||||||
|
self.event_handler,
|
||||||
|
self,
|
||||||
|
self.queue,
|
||||||
|
False,
|
||||||
|
self.thread_quit,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.__thread.start()
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
try:
|
||||||
|
if self.readyState != "live":
|
||||||
|
raise MediaStreamError
|
||||||
|
|
||||||
|
self.start()
|
||||||
|
data = await self.queue.get()
|
||||||
|
if data is None:
|
||||||
|
self.stop()
|
||||||
|
return
|
||||||
|
|
||||||
|
data_time = data.time
|
||||||
|
|
||||||
|
# control playback rate
|
||||||
|
if data_time is not None:
|
||||||
|
if self._start is None:
|
||||||
|
self._start = time.time() - data_time
|
||||||
|
else:
|
||||||
|
wait = self._start + data_time - time.time()
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.thread_quit.set()
|
||||||
|
if self.__thread is not None:
|
||||||
|
self.__thread.join()
|
||||||
|
self.__thread = None
|
||||||
|
super().stop()
|
||||||
|
|
||||||
|
# next_frame = await super().recv()
|
||||||
|
# print("next frame", next_frame)
|
||||||
|
# return next_frame
|
||||||
|
# try:
|
||||||
|
# if self.latest_args == "not_set":
|
||||||
|
# frame = await self.empty_frame()
|
||||||
|
|
||||||
|
# # await self.modify_frame(frame)
|
||||||
|
# await asyncio.sleep(100 / 22050)
|
||||||
|
# print("next_frame not set", frame)
|
||||||
|
# return frame
|
||||||
|
# if self.generator is None:
|
||||||
|
# self.generator = cast(
|
||||||
|
# Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# next_array = next(self.generator)
|
||||||
|
# print("iteration")
|
||||||
|
# except StopIteration:
|
||||||
|
# print("exception")
|
||||||
|
# self.stop() # type: ignore
|
||||||
|
# return
|
||||||
|
# next_frame = self.array_to_frame(next_array)
|
||||||
|
# # await self.modify_frame(next_frame)
|
||||||
|
# print("next frame", next_frame)
|
||||||
|
# return next_frame
|
||||||
|
# except Exception as e:
|
||||||
|
# print(e)
|
||||||
|
# import traceback
|
||||||
|
|
||||||
|
# traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
class WebRTC(Component):
|
class WebRTC(Component):
|
||||||
"""
|
"""
|
||||||
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
||||||
@@ -166,7 +291,9 @@ class WebRTC(Component):
|
|||||||
|
|
||||||
pcs: set[RTCPeerConnection] = set([])
|
pcs: set[RTCPeerConnection] = set([])
|
||||||
relay = MediaRelay()
|
relay = MediaRelay()
|
||||||
connections: dict[str, VideoCallback | ServerToClientVideo] = {}
|
connections: dict[
|
||||||
|
str, VideoCallback | ServerToClientVideo | ServerToClientAudio
|
||||||
|
] = {}
|
||||||
|
|
||||||
EVENTS = ["tick"]
|
EVENTS = ["tick"]
|
||||||
|
|
||||||
@@ -191,7 +318,8 @@ class WebRTC(Component):
|
|||||||
mirror_webcam: bool = True,
|
mirror_webcam: bool = True,
|
||||||
rtc_configuration: dict[str, Any] | None = None,
|
rtc_configuration: dict[str, Any] | None = None,
|
||||||
time_limit: float | None = None,
|
time_limit: float | None = None,
|
||||||
mode: Literal["video-in-out", "video-out"] = "video-in-out",
|
mode: Literal["send-receive", "receive"] = "send-receive",
|
||||||
|
modality: Literal["video", "audio"] = "video",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
@@ -223,6 +351,9 @@ class WebRTC(Component):
|
|||||||
streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding.
|
streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding.
|
||||||
watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png.
|
watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png.
|
||||||
"""
|
"""
|
||||||
|
if modality == "audio" and mode == "send-receive":
|
||||||
|
raise ValueError("Audio modality is not supported in send-receive mode")
|
||||||
|
|
||||||
self.time_limit = time_limit
|
self.time_limit = time_limit
|
||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
@@ -230,6 +361,7 @@ class WebRTC(Component):
|
|||||||
self.concurrency_limit = 1
|
self.concurrency_limit = 1
|
||||||
self.rtc_configuration = rtc_configuration
|
self.rtc_configuration = rtc_configuration
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
self.modality = modality
|
||||||
self.event_handler: Callable | None = None
|
self.event_handler: Callable | None = None
|
||||||
super().__init__(
|
super().__init__(
|
||||||
label=label,
|
label=label,
|
||||||
@@ -268,9 +400,11 @@ class WebRTC(Component):
|
|||||||
|
|
||||||
def set_output(self, webrtc_id: str, *args):
|
def set_output(self, webrtc_id: str, *args):
|
||||||
if webrtc_id in self.connections:
|
if webrtc_id in self.connections:
|
||||||
if self.mode == "video-in-out":
|
if self.mode == "send-receive":
|
||||||
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
|
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(
|
||||||
elif self.mode == "video-out":
|
args
|
||||||
|
)
|
||||||
|
elif self.mode == "receive":
|
||||||
self.connections[webrtc_id].latest_args = list(args)
|
self.connections[webrtc_id].latest_args = list(args)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
@@ -297,8 +431,7 @@ class WebRTC(Component):
|
|||||||
self.event_handler = fn
|
self.event_handler = fn
|
||||||
self.time_limit = time_limit
|
self.time_limit = time_limit
|
||||||
|
|
||||||
if self.mode == "video-in-out":
|
if self.mode == "send-receive":
|
||||||
|
|
||||||
if cast(list[Block], inputs)[0] != self:
|
if cast(list[Block], inputs)[0] != self:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"In the webrtc stream event, the first input component must be the WebRTC component."
|
"In the webrtc stream event, the first input component must be the WebRTC component."
|
||||||
@@ -321,28 +454,30 @@ class WebRTC(Component):
|
|||||||
time_limit=None,
|
time_limit=None,
|
||||||
js=js,
|
js=js,
|
||||||
)
|
)
|
||||||
elif self.mode == "video-out":
|
elif self.mode == "receive":
|
||||||
if self in cast(list[Block], inputs):
|
if self in cast(list[Block], inputs):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"In the video-out stream event, the WebRTC component cannot be an input."
|
"In the receive mode stream event, the WebRTC component cannot be an input."
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
len(cast(list[Block], outputs)) != 1
|
len(cast(list[Block], outputs)) != 1
|
||||||
and cast(list[Block], outputs)[0] != self
|
and cast(list[Block], outputs)[0] != self
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"In the video-out stream, the only output component must be the WebRTC component."
|
"In the receive mode stream, the only output component must be the WebRTC component."
|
||||||
)
|
)
|
||||||
if trigger is None:
|
if trigger is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"In the video-out stream event, the trigger parameter must be provided"
|
"In the receive mode stream event, the trigger parameter must be provided"
|
||||||
)
|
)
|
||||||
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
|
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
|
||||||
self.tick(
|
self.tick(
|
||||||
self.set_output, inputs=[self] + inputs, outputs=None, concurrency_id=concurrency_id
|
self.set_output,
|
||||||
|
inputs=[self] + inputs,
|
||||||
|
outputs=None,
|
||||||
|
concurrency_id=concurrency_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
||||||
await asyncio.sleep(time_limit)
|
await asyncio.sleep(time_limit)
|
||||||
@@ -350,6 +485,7 @@ class WebRTC(Component):
|
|||||||
|
|
||||||
@server
|
@server
|
||||||
async def offer(self, body):
|
async def offer(self, body):
|
||||||
|
print("starting")
|
||||||
if len(self.connections) >= cast(int, self.concurrency_limit):
|
if len(self.connections) >= cast(int, self.concurrency_limit):
|
||||||
return {"status": "failed"}
|
return {"status": "failed"}
|
||||||
|
|
||||||
@@ -385,18 +521,30 @@ class WebRTC(Component):
|
|||||||
self.connections[body["webrtc_id"]] = cb
|
self.connections[body["webrtc_id"]] = cb
|
||||||
pc.addTrack(cb)
|
pc.addTrack(cb)
|
||||||
|
|
||||||
if self.mode == "video-out":
|
if self.mode == "receive" and self.modality == "video":
|
||||||
cb = ServerToClientVideo(cast(Callable, self.event_handler))
|
cb = ServerToClientVideo(cast(Callable, self.event_handler))
|
||||||
pc.addTrack(cb)
|
pc.addTrack(cb)
|
||||||
self.connections[body["webrtc_id"]] = cb
|
self.connections[body["webrtc_id"]] = cb
|
||||||
|
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
|
||||||
|
if self.mode == "receive" and self.modality == "audio":
|
||||||
|
print("adding")
|
||||||
|
cb = ServerToClientAudio(cast(Callable, self.event_handler))
|
||||||
|
print("cb.recv", cb.recv)
|
||||||
|
# from aiortc.contrib.media import MediaPlayer
|
||||||
|
# player = MediaPlayer("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
|
||||||
|
# pc.addTrack(player.audio)
|
||||||
|
pc.addTrack(cb)
|
||||||
|
self.connections[body["webrtc_id"]] = cb
|
||||||
|
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
|
||||||
|
|
||||||
|
print("here")
|
||||||
# handle offer
|
# handle offer
|
||||||
await pc.setRemoteDescription(offer)
|
await pc.setRemoteDescription(offer)
|
||||||
|
|
||||||
# send answer
|
# send answer
|
||||||
answer = await pc.createAnswer()
|
answer = await pc.createAnswer()
|
||||||
await pc.setLocalDescription(answer) # type: ignore
|
await pc.setLocalDescription(answer) # type: ignore
|
||||||
|
print("done")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"sdp": pc.localDescription.sdp,
|
"sdp": pc.localDescription.sdp,
|
||||||
|
|||||||
64
demo/audio_out.py
Normal file
64
demo/audio_out.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import gradio as gr
|
||||||
|
import numpy as np
|
||||||
|
from gradio_webrtc import WebRTC
|
||||||
|
from twilio.rest import Client
|
||||||
|
import os
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||||
|
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||||
|
|
||||||
|
if account_sid and auth_token:
|
||||||
|
client = Client(account_sid, auth_token)
|
||||||
|
|
||||||
|
token = client.tokens.create()
|
||||||
|
|
||||||
|
rtc_configuration = {
|
||||||
|
"iceServers": token.ice_servers,
|
||||||
|
"iceTransportPolicy": "relay",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
rtc_configuration = None
|
||||||
|
|
||||||
|
|
||||||
|
def generation(num_steps):
|
||||||
|
for _ in range(num_steps):
|
||||||
|
segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
|
||||||
|
yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1))
|
||||||
|
|
||||||
|
|
||||||
|
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
|
||||||
|
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks(css=css) as demo:
|
||||||
|
gr.HTML(
|
||||||
|
"""
|
||||||
|
<h1 style='text-align: center'>
|
||||||
|
Audio Streaming (Powered by WebRTC ⚡️)
|
||||||
|
</h1>
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with gr.Column(elem_classes=["my-column"]):
|
||||||
|
with gr.Group(elem_classes=["my-group"]):
|
||||||
|
audio = WebRTC(label="Stream", rtc_configuration=rtc_configuration,
|
||||||
|
mode="receive", modality="audio")
|
||||||
|
num_steps = gr.Slider(
|
||||||
|
label="Number of Steps",
|
||||||
|
minimum=1,
|
||||||
|
maximum=10,
|
||||||
|
step=1,
|
||||||
|
value=5,
|
||||||
|
)
|
||||||
|
button = gr.Button("Generate")
|
||||||
|
|
||||||
|
audio.stream(
|
||||||
|
fn=generation, inputs=[num_steps], outputs=[audio],
|
||||||
|
trigger=button.click
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo.launch()
|
||||||
59
demo/video_out.py
Normal file
59
demo/video_out.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import gradio as gr
|
||||||
|
from gradio_webrtc import WebRTC
|
||||||
|
from twilio.rest import Client
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||||
|
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||||
|
|
||||||
|
if account_sid and auth_token:
|
||||||
|
client = Client(account_sid, auth_token)
|
||||||
|
|
||||||
|
token = client.tokens.create()
|
||||||
|
|
||||||
|
rtc_configuration = {
|
||||||
|
"iceServers": token.ice_servers,
|
||||||
|
"iceTransportPolicy": "relay",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
rtc_configuration = None
|
||||||
|
|
||||||
|
|
||||||
|
def generation(input_video):
|
||||||
|
cap = cv2.VideoCapture(input_video)
|
||||||
|
|
||||||
|
|
||||||
|
iterating = True
|
||||||
|
|
||||||
|
while iterating:
|
||||||
|
iterating, frame = cap.read()
|
||||||
|
|
||||||
|
# flip frame vertically
|
||||||
|
frame = cv2.flip(frame, 0)
|
||||||
|
display_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
yield display_frame
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.HTML(
|
||||||
|
"""
|
||||||
|
<h1 style='text-align: center'>
|
||||||
|
Video Streaming (Powered by WebRTC ⚡️)
|
||||||
|
</h1>
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
input_video = gr.Video(sources="upload")
|
||||||
|
with gr.Column():
|
||||||
|
output_video = WebRTC(label="Video Stream", rtc_configuration=rtc_configuration,
|
||||||
|
mode="receive", modality="video")
|
||||||
|
output_video.stream(
|
||||||
|
fn=generation, inputs=[input_video], outputs=[output_video],
|
||||||
|
trigger=input_video.upload
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo.launch()
|
||||||
@@ -6,6 +6,7 @@
|
|||||||
import { StatusTracker } from "@gradio/statustracker";
|
import { StatusTracker } from "@gradio/statustracker";
|
||||||
import type { LoadingStatus } from "@gradio/statustracker";
|
import type { LoadingStatus } from "@gradio/statustracker";
|
||||||
import StaticVideo from "./shared/StaticVideo.svelte";
|
import StaticVideo from "./shared/StaticVideo.svelte";
|
||||||
|
import StaticAudio from "./shared/StaticAudio.svelte";
|
||||||
|
|
||||||
export let elem_id = "";
|
export let elem_id = "";
|
||||||
export let elem_classes: string[] = [];
|
export let elem_classes: string[] = [];
|
||||||
@@ -28,7 +29,8 @@
|
|||||||
export let gradio;
|
export let gradio;
|
||||||
export let rtc_configuration: Object;
|
export let rtc_configuration: Object;
|
||||||
export let time_limit: number | null = null;
|
export let time_limit: number | null = null;
|
||||||
export let mode: "video-in-out" | "video-out" = "video-in-out";
|
export let modality: "video" | "audio" = "video";
|
||||||
|
export let mode: "send-receive" | "receive" = "send-receive";
|
||||||
|
|
||||||
let dragging = false;
|
let dragging = false;
|
||||||
|
|
||||||
@@ -57,7 +59,7 @@
|
|||||||
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
|
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{#if mode === "video-out"}
|
{#if mode === "receive" && modality === "video"}
|
||||||
<StaticVideo
|
<StaticVideo
|
||||||
bind:value={value}
|
bind:value={value}
|
||||||
{label}
|
{label}
|
||||||
@@ -67,7 +69,18 @@
|
|||||||
on:tick={() => gradio.dispatch("tick")}
|
on:tick={() => gradio.dispatch("tick")}
|
||||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||||
/>
|
/>
|
||||||
{:else}
|
{:else if mode == "receive" && modality === "audio"}
|
||||||
|
<StaticAudio
|
||||||
|
bind:value={value}
|
||||||
|
{label}
|
||||||
|
{show_label}
|
||||||
|
{server}
|
||||||
|
{rtc_configuration}
|
||||||
|
i18n={gradio.i18n}
|
||||||
|
on:tick={() => gradio.dispatch("tick")}
|
||||||
|
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||||
|
/>
|
||||||
|
{:else if mode === "send-receive" && modality === "video"}
|
||||||
<Video
|
<Video
|
||||||
bind:value={value}
|
bind:value={value}
|
||||||
{label}
|
{label}
|
||||||
|
|||||||
123
frontend/shared/AudioWave.svelte
Normal file
123
frontend/shared/AudioWave.svelte
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { onMount, onDestroy } from 'svelte';
|
||||||
|
|
||||||
|
export let numBars = 16;
|
||||||
|
export let stream_state: "open" | "closed" = "closed";
|
||||||
|
export let audio_source: HTMLAudioElement;
|
||||||
|
|
||||||
|
let audioContext: AudioContext;
|
||||||
|
let analyser: AnalyserNode;
|
||||||
|
let dataArray: Uint8Array;
|
||||||
|
let animationId: number;
|
||||||
|
let is_muted = false;
|
||||||
|
|
||||||
|
$: containerWidth = `calc((var(--boxSize) + var(--gutter)) * ${numBars})`;
|
||||||
|
|
||||||
|
$: if(stream_state === "open") setupAudioContext()
|
||||||
|
|
||||||
|
onDestroy(() => {
|
||||||
|
if (animationId) {
|
||||||
|
cancelAnimationFrame(animationId);
|
||||||
|
}
|
||||||
|
if (audioContext) {
|
||||||
|
audioContext.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
function setupAudioContext() {
|
||||||
|
console.log("set up")
|
||||||
|
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
|
analyser = audioContext.createAnalyser();
|
||||||
|
console.log("audio_source", audio_source.srcObject);
|
||||||
|
const source = audioContext.createMediaStreamSource(audio_source.srcObject);
|
||||||
|
source.connect(analyser);
|
||||||
|
analyser.connect(audioContext.destination);
|
||||||
|
|
||||||
|
analyser.fftSize = 64;
|
||||||
|
dataArray = new Uint8Array(analyser.frequencyBinCount);
|
||||||
|
|
||||||
|
updateBars();
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateBars() {
|
||||||
|
analyser.getByteFrequencyData(dataArray);
|
||||||
|
|
||||||
|
const bars = document.querySelectorAll('.box');
|
||||||
|
for (let i = 0; i < bars.length; i++) {
|
||||||
|
const barHeight = (dataArray[i] / 255) * 2; // Amplify the effect
|
||||||
|
bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`;
|
||||||
|
}
|
||||||
|
|
||||||
|
animationId = requestAnimationFrame(updateBars);
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleMute() {
|
||||||
|
if (audio_source && audio_source.srcObject) {
|
||||||
|
const audioTracks = (audio_source.srcObject as MediaStream).getAudioTracks();
|
||||||
|
audioTracks.forEach(track => {
|
||||||
|
track.enabled = !track.enabled;
|
||||||
|
});
|
||||||
|
is_muted = !audioTracks[0].enabled;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="waveContainer">
|
||||||
|
<div class="boxContainer" style:width={containerWidth}>
|
||||||
|
{#each Array(numBars) as _}
|
||||||
|
<div class="box"></div>
|
||||||
|
{/each}
|
||||||
|
</div>
|
||||||
|
<button class="muteButton" on:click={toggleMute}>
|
||||||
|
{is_muted ? '🔈' : '🔊'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<style>
|
||||||
|
.waveContainer {
|
||||||
|
position: relative;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.boxContainer {
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
height: 64px;
|
||||||
|
--boxSize: 8px;
|
||||||
|
--gutter: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.box {
|
||||||
|
height: 100%;
|
||||||
|
width: var(--boxSize);
|
||||||
|
background: var(--color-accent);
|
||||||
|
border-radius: 8px;
|
||||||
|
transition: transform 0.05s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.muteButton {
|
||||||
|
margin-top: 10px;
|
||||||
|
padding: 10px 20px;
|
||||||
|
font-size: 24px;
|
||||||
|
cursor: pointer;
|
||||||
|
background: none;
|
||||||
|
border: none;
|
||||||
|
border-radius: 5px;
|
||||||
|
color: var(--color-accent);
|
||||||
|
}
|
||||||
|
|
||||||
|
:global(body) {
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
background: black;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
align-items: center;
|
||||||
|
height: 100vh;
|
||||||
|
color: white;
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
}
|
||||||
126
frontend/shared/StaticAudio.svelte
Normal file
126
frontend/shared/StaticAudio.svelte
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { Empty } from "@gradio/atoms";
|
||||||
|
import {
|
||||||
|
BlockLabel,
|
||||||
|
} from "@gradio/atoms";
|
||||||
|
import { Music } from "@gradio/icons";
|
||||||
|
import type { I18nFormatter } from "@gradio/utils";
|
||||||
|
import { createEventDispatcher } from "svelte";
|
||||||
|
import { onMount } from "svelte";
|
||||||
|
|
||||||
|
import { start, stop } from "./webrtc_utils";
|
||||||
|
import AudioWave from "./AudioWave.svelte";
|
||||||
|
|
||||||
|
|
||||||
|
export let value: string | null = null;
|
||||||
|
export let label: string | undefined = undefined;
|
||||||
|
export let show_label = true;
|
||||||
|
export let rtc_configuration: Object | null = null;
|
||||||
|
export let i18n: I18nFormatter;
|
||||||
|
export let autoplay: boolean = true;
|
||||||
|
|
||||||
|
export let server: {
|
||||||
|
offer: (body: any) => Promise<any>;
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream_state = "closed";
|
||||||
|
let audio_player: HTMLAudioElement;
|
||||||
|
let pc: RTCPeerConnection;
|
||||||
|
let _webrtc_id = Math.random().toString(36).substring(2);
|
||||||
|
|
||||||
|
|
||||||
|
const dispatch = createEventDispatcher<{
|
||||||
|
tick: undefined;
|
||||||
|
error: string
|
||||||
|
play: undefined;
|
||||||
|
stop: undefined;
|
||||||
|
}>();
|
||||||
|
|
||||||
|
|
||||||
|
onMount(() => {
|
||||||
|
window.setInterval(() => {
|
||||||
|
if (stream_state == "open") {
|
||||||
|
dispatch("tick");
|
||||||
|
}
|
||||||
|
}, 1000);
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
$: if( value === "start_webrtc_stream") {
|
||||||
|
stream_state = "connecting";
|
||||||
|
value = _webrtc_id;
|
||||||
|
const fallback_config = {
|
||||||
|
iceServers: [
|
||||||
|
{
|
||||||
|
urls: 'stun:stun.l.google.com:19302'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
};
|
||||||
|
pc = new RTCPeerConnection(rtc_configuration);
|
||||||
|
pc.addEventListener("connectionstatechange",
|
||||||
|
async (event) => {
|
||||||
|
switch(pc.connectionState) {
|
||||||
|
case "connected":
|
||||||
|
console.log("connected");
|
||||||
|
stream_state = "open";
|
||||||
|
break;
|
||||||
|
case "disconnected":
|
||||||
|
console.log("closed");
|
||||||
|
stop(pc);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
start(null, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => {
|
||||||
|
pc = connection;
|
||||||
|
}).catch(() => {
|
||||||
|
console.log("catching")
|
||||||
|
dispatch("error", "Too many concurrent users. Come back later!");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<BlockLabel
|
||||||
|
{show_label}
|
||||||
|
Icon={Music}
|
||||||
|
float={false}
|
||||||
|
label={label || i18n("audio.audio")}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<audio
|
||||||
|
class="standard-player"
|
||||||
|
class:hidden={value === "__webrtc_value__"}
|
||||||
|
on:load
|
||||||
|
bind:this={audio_player}
|
||||||
|
on:ended={() => dispatch("stop")}
|
||||||
|
on:play={() => dispatch("play")}
|
||||||
|
/>
|
||||||
|
{#if value !== "__webrtc_value__"}
|
||||||
|
<AudioWave audio_source={audio_player} {stream_state}/>
|
||||||
|
{/if}
|
||||||
|
{#if value === "__webrtc_value__"}
|
||||||
|
<Empty size="small">
|
||||||
|
<Music />
|
||||||
|
</Empty>
|
||||||
|
{/if}
|
||||||
|
|
||||||
|
|
||||||
|
<style>
|
||||||
|
:global(::part(wrapper)) {
|
||||||
|
margin-bottom: var(--size-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.standard-player {
|
||||||
|
width: 100%;
|
||||||
|
padding: var(--size-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.hidden {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { createEventDispatcher, afterUpdate, tick } from "svelte";
|
import { createEventDispatcher, onMount} from "svelte";
|
||||||
import {
|
import {
|
||||||
BlockLabel,
|
BlockLabel,
|
||||||
Empty
|
Empty
|
||||||
@@ -29,15 +29,16 @@
|
|||||||
}>();
|
}>();
|
||||||
|
|
||||||
let stream_state = "closed";
|
let stream_state = "closed";
|
||||||
window.setInterval(() => {
|
|
||||||
if (stream_state == "open") {
|
|
||||||
dispatch("tick");
|
|
||||||
}
|
|
||||||
}, 1000);
|
|
||||||
|
|
||||||
|
onMount(() => {
|
||||||
|
window.setInterval(() => {
|
||||||
|
if (stream_state == "open") {
|
||||||
|
dispatch("tick");
|
||||||
|
}
|
||||||
|
}, 1000);
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
$: console.log("static video value", value);
|
|
||||||
$: if( value === "start_webrtc_stream") {
|
$: if( value === "start_webrtc_stream") {
|
||||||
value = _webrtc_id;
|
value = _webrtc_id;
|
||||||
const fallback_config = {
|
const fallback_config = {
|
||||||
@@ -48,8 +49,7 @@
|
|||||||
]
|
]
|
||||||
};
|
};
|
||||||
const configuration = rtc_configuration || fallback_config;
|
const configuration = rtc_configuration || fallback_config;
|
||||||
console.log("config", configuration);
|
pc = new RTCPeerConnection(rtc_configuration);
|
||||||
pc = new RTCPeerConnection(configuration);
|
|
||||||
pc.addEventListener("connectionstatechange",
|
pc.addEventListener("connectionstatechange",
|
||||||
async (event) => {
|
async (event) => {
|
||||||
switch(pc.connectionState) {
|
switch(pc.connectionState) {
|
||||||
|
|||||||
@@ -27,17 +27,25 @@ export function createPeerConnection(pc, node) {
|
|||||||
// connect audio / video from server to local
|
// connect audio / video from server to local
|
||||||
pc.addEventListener("track", (evt) => {
|
pc.addEventListener("track", (evt) => {
|
||||||
console.log("track event listener");
|
console.log("track event listener");
|
||||||
if (evt.track.kind == "video") {
|
if (node.srcObject !== evt.streams[0]) {
|
||||||
console.log("streams", evt.streams);
|
console.log("streams", evt.streams);
|
||||||
node.srcObject = evt.streams[0];
|
node.srcObject = evt.streams[0];
|
||||||
console.log("node.srcOject", node.srcObject);
|
console.log("node.srcOject", node.srcObject);
|
||||||
|
if (evt.track.kind === 'audio') {
|
||||||
|
node.volume = 1.0; // Ensure volume is up
|
||||||
|
node.muted = false;
|
||||||
|
node.autoplay = true;
|
||||||
|
|
||||||
|
// Attempt to play (needed for some browsers)
|
||||||
|
node.play().catch(e => console.log("Autoplay failed:", e));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return pc;
|
return pc;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function start(stream, pc, node, server_fn, webrtc_id) {
|
export async function start(stream, pc: RTCPeerConnection, node, server_fn, webrtc_id, modality: "video" | "audio" = "video") {
|
||||||
pc = createPeerConnection(pc, node);
|
pc = createPeerConnection(pc, node);
|
||||||
if (stream) {
|
if (stream) {
|
||||||
stream.getTracks().forEach((track) => {
|
stream.getTracks().forEach((track) => {
|
||||||
@@ -48,7 +56,7 @@ export async function start(stream, pc, node, server_fn, webrtc_id) {
|
|||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
console.log("Creating transceiver!");
|
console.log("Creating transceiver!");
|
||||||
pc.addTransceiver("video", { direction: "recvonly" });
|
pc.addTransceiver(modality, { direction: "recvonly" });
|
||||||
}
|
}
|
||||||
|
|
||||||
await negotiate(pc, server_fn, webrtc_id);
|
await negotiate(pc, server_fn, webrtc_id);
|
||||||
|
|||||||
Reference in New Issue
Block a user