diff --git a/README.md b/README.md index 76943fa..f6a58ad 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ app_file: space.py --- # `gradio_webrtc` -Static Badge +PyPI - Version Stream images in realtime with webrtc @@ -358,6 +358,32 @@ float | None None None + + +mode + + +```python +Literal["send-receive", "receive"] +``` + + +"send-receive" +None + + + +modality + + +```python +Literal["video", "audio"] +``` + + +"video" +None + diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py new file mode 100644 index 0000000..5553b23 --- /dev/null +++ b/backend/gradio_webrtc/utils.py @@ -0,0 +1,63 @@ +import time +import fractions +import av +import asyncio +import threading +from typing import Callable + +AUDIO_PTIME = 0.020 + + +def player_worker_decode( + loop, + callable: Callable, + stream, + queue: asyncio.Queue, + throttle_playback: bool, + thread_quit: threading.Event, +): + audio_sample_rate = 48000 + audio_samples = 0 + audio_time_base = fractions.Fraction(1, audio_sample_rate) + audio_resampler = av.AudioResampler( + format="s16", + layout="stereo", + rate=audio_sample_rate, + frame_size=int(audio_sample_rate * AUDIO_PTIME), + ) + + frame_time = None + start_time = time.time() + generator = None + + while not thread_quit.is_set(): + print("stream.latest_args", stream.latest_args) + if stream.latest_args == "not_set": + continue + if generator is None: + generator = callable(*stream.latest_args) + try: + frame = next(generator) + except Exception as exc: + if isinstance(exc, StopIteration): + print("Not iterating") + asyncio.run_coroutine_threadsafe(queue.put(frame), loop) + thread_quit.set() + break + + # read up to 1 second ahead + if throttle_playback: + elapsed_time = time.time() - start_time + if frame_time and frame_time > elapsed_time + 1: + time.sleep(0.1) + sample_rate, audio_array = frame + frame = av.AudioFrame.from_ndarray(audio_array, format="s16", layout="mono") + frame.sample_rate = sample_rate + for frame in audio_resampler.resample(frame): + # fix timestamps + frame.pts = audio_samples + frame.time_base = audio_time_base + audio_samples += frame.samples + + frame_time = frame.time + asyncio.run_coroutine_threadsafe(queue.put(frame), loop) diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index 9cfbd18..c3f16f0 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -2,16 +2,20 @@ from __future__ import annotations +from abc import ABC, abstractmethod import asyncio from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any, Literal, cast, Generator - +import fractions +import threading +import time +from gradio_webrtc.utils import player_worker_decode from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc.contrib.media import MediaRelay -from aiortc import VideoStreamTrack +from aiortc import VideoStreamTrack, AudioStreamTrack from aiortc.mediastreams import MediaStreamError -from aiortc.contrib.media import VideoFrame # type: ignore +from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore from gradio_client import handle_file import numpy as np @@ -124,7 +128,6 @@ class ServerToClientVideo(VideoStreamTrack): async def recv(self): try: - pts, time_base = await self.next_timestamp() if self.latest_args == "not_set": frame = self.array_to_frame(np.zeros((480, 640, 3), dtype=np.uint8)) @@ -137,12 +140,9 @@ class ServerToClientVideo(VideoStreamTrack): try: next_array = next(self.generator) except StopIteration: - print("exception") self.stop() return - print("pts", pts) - print("time_base", time_base) next_frame = self.array_to_frame(next_array) next_frame.pts = pts next_frame.time_base = time_base @@ -153,6 +153,131 @@ class ServerToClientVideo(VideoStreamTrack): traceback.print_exc() +class ServerToClientAudio(AudioStreamTrack): + kind = "audio" + + def __init__( + self, + event_handler: Callable, + ) -> None: + self.generator: Generator[Any, None, Any] | None = None + self.event_handler = event_handler + self.current_timestamp = 0 + self.latest_args = "not_set" + self.queue = asyncio.Queue() + self.thread_quit = threading.Event() + self.__thread = None + self._start: float | None = None + super().__init__() + + def array_to_frame(self, array: tuple[int, np.ndarray]) -> AudioFrame: + frame = AudioFrame.from_ndarray(array[1], format="s16", layout="mono") + frame.sample_rate = array[0] + frame.time_base = fractions.Fraction(1, array[0]) + self.current_timestamp += array[1].shape[1] + frame.pts = self.current_timestamp + return frame + + async def empty_frame(self) -> AudioFrame: + sample_rate = 22050 + samples = 100 + frame = AudioFrame(format="s16", layout="mono", samples=samples) + for p in frame.planes: + p.update(bytes(p.buffer_size)) + frame.sample_rate = sample_rate + frame.time_base = fractions.Fraction(1, sample_rate) + self.current_timestamp += samples + frame.pts = self.current_timestamp + return frame + + def start(self): + if self.__thread is None: + self.__thread = threading.Thread( + name="generator-runner", + target=player_worker_decode, + args=( + asyncio.get_event_loop(), + self.event_handler, + self, + self.queue, + False, + self.thread_quit + ), + ) + self.__thread.start() + + + async def recv(self): + try: + if self.readyState != "live": + raise MediaStreamError + + self.start() + data = await self.queue.get() + if data is None: + self.stop() + raise MediaStreamError + + data_time = data.time + + # control playback rate + if ( + data_time is not None + ): + if self._start is None: + self._start = time.time() - data_time + else: + wait = self._start + data_time - time.time() + await asyncio.sleep(wait) + + return data + except Exception as e: + print(e) + import traceback + + traceback.print_exc() + + def stop(self): + super().stop() + self.thread_quit.set() + if self.__thread is not None: + self.__thread.join() + self.__thread = None + + # next_frame = await super().recv() + # print("next frame", next_frame) + # return next_frame + #try: + # if self.latest_args == "not_set": + # frame = await self.empty_frame() + + # # await self.modify_frame(frame) + # await asyncio.sleep(100 / 22050) + # print("next_frame not set", frame) + # return frame + # if self.generator is None: + # self.generator = cast( + # Generator[Any, None, Any], self.event_handler(*self.latest_args) + # ) + + # try: + # next_array = next(self.generator) + # print("iteration") + # except StopIteration: + # print("exception") + # self.stop() # type: ignore + # return + # next_frame = self.array_to_frame(next_array) + # # await self.modify_frame(next_frame) + # print("next frame", next_frame) + # return next_frame + # except Exception as e: + # print(e) + # import traceback + + # traceback.print_exc() + + class WebRTC(Component): """ Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output). @@ -166,7 +291,9 @@ class WebRTC(Component): pcs: set[RTCPeerConnection] = set([]) relay = MediaRelay() - connections: dict[str, VideoCallback | ServerToClientVideo] = {} + connections: dict[ + str, VideoCallback | ServerToClientVideo | ServerToClientAudio + ] = {} EVENTS = ["tick"] @@ -191,7 +318,8 @@ class WebRTC(Component): mirror_webcam: bool = True, rtc_configuration: dict[str, Any] | None = None, time_limit: float | None = None, - mode: Literal["video-in-out", "video-out"] = "video-in-out", + mode: Literal["send-receive", "receive"] = "send-receive", + modality: Literal["video", "audio"] = "video", ): """ Parameters: @@ -223,6 +351,9 @@ class WebRTC(Component): streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding. watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png. """ + if modality == "audio" and mode == "send-receive": + raise ValueError("Audio modality is not supported in send-receive mode") + self.time_limit = time_limit self.height = height self.width = width @@ -230,6 +361,7 @@ class WebRTC(Component): self.concurrency_limit = 1 self.rtc_configuration = rtc_configuration self.mode = mode + self.modality = modality self.event_handler: Callable | None = None super().__init__( label=label, @@ -268,9 +400,11 @@ class WebRTC(Component): def set_output(self, webrtc_id: str, *args): if webrtc_id in self.connections: - if self.mode == "video-in-out": - self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(args) - elif self.mode == "video-out": + if self.mode == "send-receive": + self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list( + args + ) + elif self.mode == "receive": self.connections[webrtc_id].latest_args = list(args) def stream( @@ -296,9 +430,8 @@ class WebRTC(Component): ) self.event_handler = fn self.time_limit = time_limit - - if self.mode == "video-in-out": + if self.mode == "send-receive": if cast(list[Block], inputs)[0] != self: raise ValueError( "In the webrtc stream event, the first input component must be the WebRTC component." @@ -321,27 +454,29 @@ class WebRTC(Component): time_limit=None, js=js, ) - elif self.mode == "video-out": + elif self.mode == "receive": if self in cast(list[Block], inputs): raise ValueError( - "In the video-out stream event, the WebRTC component cannot be an input." + "In the receive mode stream event, the WebRTC component cannot be an input." ) if ( len(cast(list[Block], outputs)) != 1 and cast(list[Block], outputs)[0] != self ): raise ValueError( - "In the video-out stream, the only output component must be the WebRTC component." + "In the receive mode stream, the only output component must be the WebRTC component." ) if trigger is None: raise ValueError( - "In the video-out stream event, the trigger parameter must be provided" + "In the receive mode stream event, the trigger parameter must be provided" ) trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self) self.tick( - self.set_output, inputs=[self] + inputs, outputs=None, concurrency_id=concurrency_id + self.set_output, + inputs=[self] + inputs, + outputs=None, + concurrency_id=concurrency_id, ) - @staticmethod async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float): @@ -350,6 +485,7 @@ class WebRTC(Component): @server async def offer(self, body): + print("starting") if len(self.connections) >= cast(int, self.concurrency_limit): return {"status": "failed"} @@ -384,19 +520,29 @@ class WebRTC(Component): ) self.connections[body["webrtc_id"]] = cb pc.addTrack(cb) - - if self.mode == "video-out": + + if self.mode == "receive" and self.modality == "video": cb = ServerToClientVideo(cast(Callable, self.event_handler)) pc.addTrack(cb) self.connections[body["webrtc_id"]] = cb + if self.mode == "receive" and self.modality == "audio": + print("adding") + cb = ServerToClientAudio(cast(Callable, self.event_handler)) + print("cb.recv", cb.recv) + # from aiortc.contrib.media import MediaPlayer + # player = MediaPlayer("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav") + # pc.addTrack(player.audio) + pc.addTrack(cb) + self.connections[body["webrtc_id"]] = cb - + print("here") # handle offer await pc.setRemoteDescription(offer) # send answer answer = await pc.createAnswer() await pc.setLocalDescription(answer) # type: ignore + print("done") return { "sdp": pc.localDescription.sdp, diff --git a/demo/audio_out.py b/demo/audio_out.py new file mode 100644 index 0000000..716c1ab --- /dev/null +++ b/demo/audio_out.py @@ -0,0 +1,64 @@ +import gradio as gr +import numpy as np +from gradio_webrtc import WebRTC +from twilio.rest import Client +import os +from pydub import AudioSegment + + + +account_sid = os.environ.get("TWILIO_ACCOUNT_SID") +auth_token = os.environ.get("TWILIO_AUTH_TOKEN") + +if account_sid and auth_token: + client = Client(account_sid, auth_token) + + token = client.tokens.create() + + rtc_configuration = { + "iceServers": token.ice_servers, + "iceTransportPolicy": "relay", + } +else: + rtc_configuration = None + + +def generation(num_steps): + for _ in range(num_steps): + segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav") + yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1)) + + +css = """.my-group {max-width: 600px !important; max-height: 600 !important;} + .my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" + + +with gr.Blocks(css=css) as demo: + gr.HTML( + """ +

+ Audio Streaming (Powered by WebRTC ⚡️) +

+ """ + ) + with gr.Column(elem_classes=["my-column"]): + with gr.Group(elem_classes=["my-group"]): + audio = WebRTC(label="Stream", rtc_configuration=rtc_configuration, + mode="receive", modality="audio") + num_steps = gr.Slider( + label="Number of Steps", + minimum=1, + maximum=10, + step=1, + value=5, + ) + button = gr.Button("Generate") + + audio.stream( + fn=generation, inputs=[num_steps], outputs=[audio], + trigger=button.click + ) + + +if __name__ == "__main__": + demo.launch() diff --git a/demo/video_out.py b/demo/video_out.py new file mode 100644 index 0000000..696d3fe --- /dev/null +++ b/demo/video_out.py @@ -0,0 +1,59 @@ +import gradio as gr +from gradio_webrtc import WebRTC +from twilio.rest import Client +import os +import cv2 + + +account_sid = os.environ.get("TWILIO_ACCOUNT_SID") +auth_token = os.environ.get("TWILIO_AUTH_TOKEN") + +if account_sid and auth_token: + client = Client(account_sid, auth_token) + + token = client.tokens.create() + + rtc_configuration = { + "iceServers": token.ice_servers, + "iceTransportPolicy": "relay", + } +else: + rtc_configuration = None + + +def generation(input_video): + cap = cv2.VideoCapture(input_video) + + + iterating = True + + while iterating: + iterating, frame = cap.read() + + # flip frame vertically + frame = cv2.flip(frame, 0) + display_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + yield display_frame + +with gr.Blocks() as demo: + gr.HTML( + """ +

+ Video Streaming (Powered by WebRTC ⚡️) +

+ """ + ) + with gr.Row(): + with gr.Column(): + input_video = gr.Video(sources="upload") + with gr.Column(): + output_video = WebRTC(label="Video Stream", rtc_configuration=rtc_configuration, + mode="receive", modality="video") + output_video.stream( + fn=generation, inputs=[input_video], outputs=[output_video], + trigger=input_video.upload + ) + + +if __name__ == "__main__": + demo.launch() diff --git a/frontend/Index.svelte b/frontend/Index.svelte index 0dcbeaf..5f022fc 100644 --- a/frontend/Index.svelte +++ b/frontend/Index.svelte @@ -6,6 +6,7 @@ import { StatusTracker } from "@gradio/statustracker"; import type { LoadingStatus } from "@gradio/statustracker"; import StaticVideo from "./shared/StaticVideo.svelte"; + import StaticAudio from "./shared/StaticAudio.svelte"; export let elem_id = ""; export let elem_classes: string[] = []; @@ -28,7 +29,8 @@ export let gradio; export let rtc_configuration: Object; export let time_limit: number | null = null; - export let mode: "video-in-out" | "video-out" = "video-in-out"; + export let modality: "video" | "audio" = "video"; + export let mode: "send-receive" | "receive" = "send-receive"; let dragging = false; @@ -57,7 +59,7 @@ on:clear_status={() => gradio.dispatch("clear_status", loading_status)} /> - {#if mode === "video-out"} + {#if mode === "receive" && modality === "video"} gradio.dispatch("tick")} on:error={({ detail }) => gradio.dispatch("error", detail)} /> - {:else} + {:else if mode == "receive" && modality === "audio"} + gradio.dispatch("tick")} + on:error={({ detail }) => gradio.dispatch("error", detail)} + /> + {:else if mode === "send-receive" && modality === "video"}