diff --git a/README.md b/README.md
index 76943fa..f6a58ad 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@ app_file: space.py
---
# `gradio_webrtc`
-
+
Stream images in realtime with webrtc
@@ -358,6 +358,32 @@ float | None
None |
None |
+
+
+mode |
+
+
+```python
+Literal["send-receive", "receive"]
+```
+
+ |
+"send-receive" |
+None |
+
+
+
+modality |
+
+
+```python
+Literal["video", "audio"]
+```
+
+ |
+"video" |
+None |
+
diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py
new file mode 100644
index 0000000..dedd014
--- /dev/null
+++ b/backend/gradio_webrtc/utils.py
@@ -0,0 +1,63 @@
+import time
+import fractions
+import av
+import asyncio
+import threading
+from typing import Callable
+
+AUDIO_PTIME = 0.020
+
+
+def player_worker_decode(
+ loop,
+ callable: Callable,
+ stream,
+ queue: asyncio.Queue,
+ throttle_playback: bool,
+ thread_quit: threading.Event,
+):
+ audio_sample_rate = 48000
+ audio_samples = 0
+ audio_time_base = fractions.Fraction(1, audio_sample_rate)
+ audio_resampler = av.AudioResampler(
+ format="s16",
+ layout="stereo",
+ rate=audio_sample_rate,
+ frame_size=int(audio_sample_rate * AUDIO_PTIME),
+ )
+
+ frame_time = None
+ start_time = time.time()
+ generator = None
+
+ while not thread_quit.is_set():
+ if stream.latest_args == "not_set":
+ continue
+ if generator is None:
+ generator = callable(*stream.latest_args)
+ try:
+ frame = next(generator)
+ except Exception as exc:
+ if isinstance(exc, StopIteration):
+ print("Not iterating")
+ asyncio.run_coroutine_threadsafe(queue.put(None), loop)
+ thread_quit.set()
+ break
+
+ # read up to 1 second ahead
+ if throttle_playback:
+ elapsed_time = time.time() - start_time
+ if frame_time and frame_time > elapsed_time + 1:
+ time.sleep(0.1)
+ sample_rate, audio_array = frame
+ format = "s16" if audio_array.dtype == "int16" else "fltp"
+ frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout="mono")
+ frame.sample_rate = sample_rate
+ for frame in audio_resampler.resample(frame):
+ # fix timestamps
+ frame.pts = audio_samples
+ frame.time_base = audio_time_base
+ audio_samples += frame.samples
+
+ frame_time = frame.time
+ asyncio.run_coroutine_threadsafe(queue.put(frame), loop)
diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py
index 9cfbd18..c5a3a13 100644
--- a/backend/gradio_webrtc/webrtc.py
+++ b/backend/gradio_webrtc/webrtc.py
@@ -2,16 +2,20 @@
from __future__ import annotations
+from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, Literal, cast, Generator
-
+import fractions
+import threading
+import time
+from gradio_webrtc.utils import player_worker_decode
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay
-from aiortc import VideoStreamTrack
+from aiortc import VideoStreamTrack, AudioStreamTrack
from aiortc.mediastreams import MediaStreamError
-from aiortc.contrib.media import VideoFrame # type: ignore
+from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
from gradio_client import handle_file
import numpy as np
@@ -124,7 +128,6 @@ class ServerToClientVideo(VideoStreamTrack):
async def recv(self):
try:
-
pts, time_base = await self.next_timestamp()
if self.latest_args == "not_set":
frame = self.array_to_frame(np.zeros((480, 640, 3), dtype=np.uint8))
@@ -132,17 +135,16 @@ class ServerToClientVideo(VideoStreamTrack):
frame.time_base = time_base
return frame
elif self.generator is None:
- self.generator = cast(Generator[Any, None, Any], self.event_handler(*self.latest_args))
+ self.generator = cast(
+ Generator[Any, None, Any], self.event_handler(*self.latest_args)
+ )
try:
next_array = next(self.generator)
except StopIteration:
- print("exception")
self.stop()
return
- print("pts", pts)
- print("time_base", time_base)
next_frame = self.array_to_frame(next_array)
next_frame.pts = pts
next_frame.time_base = time_base
@@ -150,9 +152,132 @@ class ServerToClientVideo(VideoStreamTrack):
except Exception as e:
print(e)
import traceback
+
traceback.print_exc()
+class ServerToClientAudio(AudioStreamTrack):
+ kind = "audio"
+
+ def __init__(
+ self,
+ event_handler: Callable,
+ ) -> None:
+ self.generator: Generator[Any, None, Any] | None = None
+ self.event_handler = event_handler
+ self.current_timestamp = 0
+ self.latest_args = "not_set"
+ self.queue = asyncio.Queue()
+ self.thread_quit = threading.Event()
+ self.__thread = None
+ self._start: float | None = None
+ super().__init__()
+
+ def array_to_frame(self, array: tuple[int, np.ndarray]) -> AudioFrame:
+ frame = AudioFrame.from_ndarray(array[1], format="s16", layout="mono")
+ frame.sample_rate = array[0]
+ frame.time_base = fractions.Fraction(1, array[0])
+ self.current_timestamp += array[1].shape[1]
+ frame.pts = self.current_timestamp
+ return frame
+
+ async def empty_frame(self) -> AudioFrame:
+ sample_rate = 22050
+ samples = 100
+ frame = AudioFrame(format="s16", layout="mono", samples=samples)
+ for p in frame.planes:
+ p.update(bytes(p.buffer_size))
+ frame.sample_rate = sample_rate
+ frame.time_base = fractions.Fraction(1, sample_rate)
+ self.current_timestamp += samples
+ frame.pts = self.current_timestamp
+ return frame
+
+ def start(self):
+ if self.__thread is None:
+ self.__thread = threading.Thread(
+ name="generator-runner",
+ target=player_worker_decode,
+ args=(
+ asyncio.get_event_loop(),
+ self.event_handler,
+ self,
+ self.queue,
+ False,
+ self.thread_quit,
+ ),
+ )
+ self.__thread.start()
+
+ async def recv(self):
+ try:
+ if self.readyState != "live":
+ raise MediaStreamError
+
+ self.start()
+ data = await self.queue.get()
+ if data is None:
+ self.stop()
+ return
+
+ data_time = data.time
+
+ # control playback rate
+ if data_time is not None:
+ if self._start is None:
+ self._start = time.time() - data_time
+ else:
+ wait = self._start + data_time - time.time()
+ await asyncio.sleep(wait)
+
+ return data
+ except Exception as e:
+ print(e)
+ import traceback
+
+ traceback.print_exc()
+
+ def stop(self):
+ self.thread_quit.set()
+ if self.__thread is not None:
+ self.__thread.join()
+ self.__thread = None
+ super().stop()
+
+ # next_frame = await super().recv()
+ # print("next frame", next_frame)
+ # return next_frame
+ # try:
+ # if self.latest_args == "not_set":
+ # frame = await self.empty_frame()
+
+ # # await self.modify_frame(frame)
+ # await asyncio.sleep(100 / 22050)
+ # print("next_frame not set", frame)
+ # return frame
+ # if self.generator is None:
+ # self.generator = cast(
+ # Generator[Any, None, Any], self.event_handler(*self.latest_args)
+ # )
+
+ # try:
+ # next_array = next(self.generator)
+ # print("iteration")
+ # except StopIteration:
+ # print("exception")
+ # self.stop() # type: ignore
+ # return
+ # next_frame = self.array_to_frame(next_array)
+ # # await self.modify_frame(next_frame)
+ # print("next frame", next_frame)
+ # return next_frame
+ # except Exception as e:
+ # print(e)
+ # import traceback
+
+ # traceback.print_exc()
+
+
class WebRTC(Component):
"""
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
@@ -166,7 +291,9 @@ class WebRTC(Component):
pcs: set[RTCPeerConnection] = set([])
relay = MediaRelay()
- connections: dict[str, VideoCallback | ServerToClientVideo] = {}
+ connections: dict[
+ str, VideoCallback | ServerToClientVideo | ServerToClientAudio
+ ] = {}
EVENTS = ["tick"]
@@ -191,7 +318,8 @@ class WebRTC(Component):
mirror_webcam: bool = True,
rtc_configuration: dict[str, Any] | None = None,
time_limit: float | None = None,
- mode: Literal["video-in-out", "video-out"] = "video-in-out",
+ mode: Literal["send-receive", "receive"] = "send-receive",
+ modality: Literal["video", "audio"] = "video",
):
"""
Parameters:
@@ -223,6 +351,9 @@ class WebRTC(Component):
streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding.
watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png.
"""
+ if modality == "audio" and mode == "send-receive":
+ raise ValueError("Audio modality is not supported in send-receive mode")
+
self.time_limit = time_limit
self.height = height
self.width = width
@@ -230,6 +361,7 @@ class WebRTC(Component):
self.concurrency_limit = 1
self.rtc_configuration = rtc_configuration
self.mode = mode
+ self.modality = modality
self.event_handler: Callable | None = None
super().__init__(
label=label,
@@ -268,9 +400,11 @@ class WebRTC(Component):
def set_output(self, webrtc_id: str, *args):
if webrtc_id in self.connections:
- if self.mode == "video-in-out":
- self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args)
- elif self.mode == "video-out":
+ if self.mode == "send-receive":
+ self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(
+ args
+ )
+ elif self.mode == "receive":
self.connections[webrtc_id].latest_args = list(args)
def stream(
@@ -296,9 +430,8 @@ class WebRTC(Component):
)
self.event_handler = fn
self.time_limit = time_limit
-
- if self.mode == "video-in-out":
+ if self.mode == "send-receive":
if cast(list[Block], inputs)[0] != self:
raise ValueError(
"In the webrtc stream event, the first input component must be the WebRTC component."
@@ -321,27 +454,29 @@ class WebRTC(Component):
time_limit=None,
js=js,
)
- elif self.mode == "video-out":
+ elif self.mode == "receive":
if self in cast(list[Block], inputs):
raise ValueError(
- "In the video-out stream event, the WebRTC component cannot be an input."
+ "In the receive mode stream event, the WebRTC component cannot be an input."
)
if (
len(cast(list[Block], outputs)) != 1
and cast(list[Block], outputs)[0] != self
):
raise ValueError(
- "In the video-out stream, the only output component must be the WebRTC component."
+ "In the receive mode stream, the only output component must be the WebRTC component."
)
if trigger is None:
raise ValueError(
- "In the video-out stream event, the trigger parameter must be provided"
+ "In the receive mode stream event, the trigger parameter must be provided"
)
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
self.tick(
- self.set_output, inputs=[self] + inputs, outputs=None, concurrency_id=concurrency_id
+ self.set_output,
+ inputs=[self] + inputs,
+ outputs=None,
+ concurrency_id=concurrency_id,
)
-
@staticmethod
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
@@ -350,6 +485,7 @@ class WebRTC(Component):
@server
async def offer(self, body):
+ print("starting")
if len(self.connections) >= cast(int, self.concurrency_limit):
return {"status": "failed"}
@@ -384,19 +520,31 @@ class WebRTC(Component):
)
self.connections[body["webrtc_id"]] = cb
pc.addTrack(cb)
-
- if self.mode == "video-out":
+
+ if self.mode == "receive" and self.modality == "video":
cb = ServerToClientVideo(cast(Callable, self.event_handler))
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
+ cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
+ if self.mode == "receive" and self.modality == "audio":
+ print("adding")
+ cb = ServerToClientAudio(cast(Callable, self.event_handler))
+ print("cb.recv", cb.recv)
+ # from aiortc.contrib.media import MediaPlayer
+ # player = MediaPlayer("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
+ # pc.addTrack(player.audio)
+ pc.addTrack(cb)
+ self.connections[body["webrtc_id"]] = cb
+ cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
-
+ print("here")
# handle offer
await pc.setRemoteDescription(offer)
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer) # type: ignore
+ print("done")
return {
"sdp": pc.localDescription.sdp,
diff --git a/demo/audio_out.py b/demo/audio_out.py
new file mode 100644
index 0000000..716c1ab
--- /dev/null
+++ b/demo/audio_out.py
@@ -0,0 +1,64 @@
+import gradio as gr
+import numpy as np
+from gradio_webrtc import WebRTC
+from twilio.rest import Client
+import os
+from pydub import AudioSegment
+
+
+
+account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
+auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
+
+if account_sid and auth_token:
+ client = Client(account_sid, auth_token)
+
+ token = client.tokens.create()
+
+ rtc_configuration = {
+ "iceServers": token.ice_servers,
+ "iceTransportPolicy": "relay",
+ }
+else:
+ rtc_configuration = None
+
+
+def generation(num_steps):
+ for _ in range(num_steps):
+ segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
+ yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1))
+
+
+css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
+ .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
+
+
+with gr.Blocks(css=css) as demo:
+ gr.HTML(
+ """
+
+ Audio Streaming (Powered by WebRTC ⚡️)
+
+ """
+ )
+ with gr.Column(elem_classes=["my-column"]):
+ with gr.Group(elem_classes=["my-group"]):
+ audio = WebRTC(label="Stream", rtc_configuration=rtc_configuration,
+ mode="receive", modality="audio")
+ num_steps = gr.Slider(
+ label="Number of Steps",
+ minimum=1,
+ maximum=10,
+ step=1,
+ value=5,
+ )
+ button = gr.Button("Generate")
+
+ audio.stream(
+ fn=generation, inputs=[num_steps], outputs=[audio],
+ trigger=button.click
+ )
+
+
+if __name__ == "__main__":
+ demo.launch()
diff --git a/demo/video_out.py b/demo/video_out.py
new file mode 100644
index 0000000..696d3fe
--- /dev/null
+++ b/demo/video_out.py
@@ -0,0 +1,59 @@
+import gradio as gr
+from gradio_webrtc import WebRTC
+from twilio.rest import Client
+import os
+import cv2
+
+
+account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
+auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
+
+if account_sid and auth_token:
+ client = Client(account_sid, auth_token)
+
+ token = client.tokens.create()
+
+ rtc_configuration = {
+ "iceServers": token.ice_servers,
+ "iceTransportPolicy": "relay",
+ }
+else:
+ rtc_configuration = None
+
+
+def generation(input_video):
+ cap = cv2.VideoCapture(input_video)
+
+
+ iterating = True
+
+ while iterating:
+ iterating, frame = cap.read()
+
+ # flip frame vertically
+ frame = cv2.flip(frame, 0)
+ display_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ yield display_frame
+
+with gr.Blocks() as demo:
+ gr.HTML(
+ """
+
+ Video Streaming (Powered by WebRTC ⚡️)
+
+ """
+ )
+ with gr.Row():
+ with gr.Column():
+ input_video = gr.Video(sources="upload")
+ with gr.Column():
+ output_video = WebRTC(label="Video Stream", rtc_configuration=rtc_configuration,
+ mode="receive", modality="video")
+ output_video.stream(
+ fn=generation, inputs=[input_video], outputs=[output_video],
+ trigger=input_video.upload
+ )
+
+
+if __name__ == "__main__":
+ demo.launch()
diff --git a/frontend/Index.svelte b/frontend/Index.svelte
index 0dcbeaf..5f022fc 100644
--- a/frontend/Index.svelte
+++ b/frontend/Index.svelte
@@ -6,6 +6,7 @@
import { StatusTracker } from "@gradio/statustracker";
import type { LoadingStatus } from "@gradio/statustracker";
import StaticVideo from "./shared/StaticVideo.svelte";
+ import StaticAudio from "./shared/StaticAudio.svelte";
export let elem_id = "";
export let elem_classes: string[] = [];
@@ -28,7 +29,8 @@
export let gradio;
export let rtc_configuration: Object;
export let time_limit: number | null = null;
- export let mode: "video-in-out" | "video-out" = "video-in-out";
+ export let modality: "video" | "audio" = "video";
+ export let mode: "send-receive" | "receive" = "send-receive";
let dragging = false;
@@ -57,7 +59,7 @@
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
/>
- {#if mode === "video-out"}
+ {#if mode === "receive" && modality === "video"}
gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
- {:else}
+ {:else if mode == "receive" && modality === "audio"}
+ gradio.dispatch("tick")}
+ on:error={({ detail }) => gradio.dispatch("error", detail)}
+ />
+ {:else if mode === "send-receive" && modality === "video"}