Lots of bugs

This commit is contained in:
freddyaboulton
2024-10-11 13:01:26 -07:00
parent d681dbfd7e
commit 23b1e20c05
13 changed files with 214 additions and 241 deletions

View File

@@ -54,6 +54,7 @@ else:
def detection(image, conf_threshold=0.3):
print("running detection")
image = cv2.resize(image, (model.input_width, model.input_height))
new_image = model.detect_objects(image, conf_threshold)
return cv2.resize(new_image, (500, 500))

View File

@@ -1,10 +1,15 @@
import time
import fractions
import av
import asyncio
import fractions
import threading
import time
from typing import Callable
import av
import logging
logger = logging.getLogger(__name__)
AUDIO_PTIME = 0.020
@@ -39,7 +44,7 @@ def player_worker_decode(
frame = next(generator)
except Exception as exc:
if isinstance(exc, StopIteration):
print("Not iterating")
logger.debug("Stopping audio stream")
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
thread_quit.set()
break

View File

@@ -1,31 +1,33 @@
"""gr.Video() component."""
"""gr.WebRTC() component."""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, Literal, cast, Generator
import fractions
import logging
import threading
import time
from gradio_webrtc.utils import player_worker_decode
import traceback
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, Generator, Sequence, cast
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay
from aiortc import VideoStreamTrack, AudioStreamTrack
from aiortc.mediastreams import MediaStreamError
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
from gradio_client import handle_file
import numpy as np
from aiortc import (
AudioStreamTrack,
RTCPeerConnection,
RTCSessionDescription,
VideoStreamTrack,
)
from aiortc.contrib.media import AudioFrame, MediaRelay, VideoFrame # type: ignore
from aiortc.mediastreams import MediaStreamError
from gradio import wasm_utils
from gradio.components.base import Component, server
from gradio_client import handle_file
from gradio_webrtc.utils import player_worker_decode
if TYPE_CHECKING:
from gradio.components import Timer
from gradio.blocks import Block
from gradio.components import Timer
from gradio.events import Dependency
@@ -33,6 +35,9 @@ if wasm_utils.IS_WASM:
raise ValueError("Not supported in gradio-lite!")
logger = logging.getLogger(__name__)
class VideoCallback(VideoStreamTrack):
"""
This works for streaming input and output
@@ -90,10 +95,9 @@ class VideoCallback(VideoStreamTrack):
return new_frame
except Exception as e:
print(e)
import traceback
traceback.print_exc()
logger.debug(e)
exec = traceback.format_exc()
logger.debug(exec)
class ServerToClientVideo(VideoStreamTrack):
@@ -150,10 +154,9 @@ class ServerToClientVideo(VideoStreamTrack):
next_frame.time_base = time_base
return next_frame
except Exception as e:
print(e)
import traceback
traceback.print_exc()
logger.debug(e)
exec = traceback.format_exc()
logger.debug(exec)
class ServerToClientAudio(AudioStreamTrack):
@@ -173,26 +176,6 @@ class ServerToClientAudio(AudioStreamTrack):
self._start: float | None = None
super().__init__()
def array_to_frame(self, array: tuple[int, np.ndarray]) -> AudioFrame:
frame = AudioFrame.from_ndarray(array[1], format="s16", layout="mono")
frame.sample_rate = array[0]
frame.time_base = fractions.Fraction(1, array[0])
self.current_timestamp += array[1].shape[1]
frame.pts = self.current_timestamp
return frame
async def empty_frame(self) -> AudioFrame:
sample_rate = 22050
samples = 100
frame = AudioFrame(format="s16", layout="mono", samples=samples)
for p in frame.planes:
p.update(bytes(p.buffer_size))
frame.sample_rate = sample_rate
frame.time_base = fractions.Fraction(1, sample_rate)
self.current_timestamp += samples
frame.pts = self.current_timestamp
return frame
def start(self):
if self.__thread is None:
self.__thread = threading.Thread(
@@ -232,10 +215,9 @@ class ServerToClientAudio(AudioStreamTrack):
return data
except Exception as e:
print(e)
import traceback
traceback.print_exc()
logger.debug(e)
exec = traceback.format_exc()
logger.debug(exec)
def stop(self):
self.thread_quit.set()
@@ -244,39 +226,6 @@ class ServerToClientAudio(AudioStreamTrack):
self.__thread = None
super().stop()
# next_frame = await super().recv()
# print("next frame", next_frame)
# return next_frame
# try:
# if self.latest_args == "not_set":
# frame = await self.empty_frame()
# # await self.modify_frame(frame)
# await asyncio.sleep(100 / 22050)
# print("next_frame not set", frame)
# return frame
# if self.generator is None:
# self.generator = cast(
# Generator[Any, None, Any], self.event_handler(*self.latest_args)
# )
# try:
# next_array = next(self.generator)
# print("iteration")
# except StopIteration:
# print("exception")
# self.stop() # type: ignore
# return
# next_frame = self.array_to_frame(next_array)
# # await self.modify_frame(next_frame)
# print("next frame", next_frame)
# return next_frame
# except Exception as e:
# print(e)
# import traceback
# traceback.print_exc()
class WebRTC(Component):
"""
@@ -485,7 +434,8 @@ class WebRTC(Component):
@server
async def offer(self, body):
print("starting")
logger.debug("Starting to handle offer")
logger.debug("Offer body", body)
if len(self.connections) >= cast(int, self.concurrency_limit):
return {"status": "failed"}
@@ -496,7 +446,7 @@ class WebRTC(Component):
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
print(pc.iceConnectionState)
logger.debug("ICE connection state change", pc.iceConnectionState)
if pc.iceConnectionState == "failed":
await pc.close()
self.connections.pop(body["webrtc_id"], None)
@@ -519,32 +469,27 @@ class WebRTC(Component):
event_handler=cast(Callable, self.event_handler),
)
self.connections[body["webrtc_id"]] = cb
logger.debug("Adding track to peer connection", cb)
pc.addTrack(cb)
if self.mode == "receive" and self.modality == "video":
cb = ServerToClientVideo(cast(Callable, self.event_handler))
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
if self.mode == "receive" and self.modality == "audio":
print("adding")
cb = ServerToClientAudio(cast(Callable, self.event_handler))
print("cb.recv", cb.recv)
# from aiortc.contrib.media import MediaPlayer
# player = MediaPlayer("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
# pc.addTrack(player.audio)
if self.mode == "receive":
if self.modality == "video":
cb = ServerToClientVideo(cast(Callable, self.event_handler))
elif self.modality == "audio":
cb = ServerToClientAudio(cast(Callable, self.event_handler))
logger.debug("Adding track to peer connection", cb)
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
print("here")
# handle offer
await pc.setRemoteDescription(offer)
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer) # type: ignore
print("done")
logger.debug("done handling offer about to return")
return {
"sdp": pc.localDescription.sdp,

View File

@@ -22,18 +22,20 @@ if account_sid and auth_token:
else:
rtc_configuration = None
import time
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1))
time.sleep(3.5)
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks(css=css) as demo:
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>

File diff suppressed because one or more lines are too long

View File

@@ -48,7 +48,7 @@ with gr.Blocks() as demo:
input_video = gr.Video(sources="upload")
with gr.Column():
output_video = WebRTC(label="Video Stream", rtc_configuration=rtc_configuration,
mode="receive", modality="video")
mode="receive", modality="video")
output_video.stream(
fn=generation, inputs=[input_video], outputs=[output_video],
trigger=input_video.upload

View File

@@ -35,10 +35,73 @@
let dragging = false;
$: console.log("value", value);
</script>
<Block
{#if mode == "receive" && modality === "video"}
<Block
{visible}
variant={"solid"}
border_mode={dragging ? "focus" : "base"}
padding={false}
{elem_id}
{elem_classes}
{height}
{width}
{container}
{scale}
{min_width}
allow_overflow={false}
>
<StatusTracker
autoscroll={gradio.autoscroll}
i18n={gradio.i18n}
{...loading_status}
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
/>
<StaticVideo
bind:value={value}
{label}
{show_label}
{server}
{rtc_configuration}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
</Block>
{:else if mode == "receive" && modality === "audio"}
<Block
variant={"solid"}
border_mode={dragging ? "focus" : "base"}
padding={false}
allow_overflow={false}
{elem_id}
{elem_classes}
{visible}
{container}
{scale}
{min_width}
>
<StatusTracker
autoscroll={gradio.autoscroll}
i18n={gradio.i18n}
{...loading_status}
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
/>
<StaticAudio
bind:value={value}
{label}
{show_label}
{server}
{rtc_configuration}
i18n={gradio.i18n}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
</Block>
{:else if mode === "send-receive" && modality === "video"}
<Block
{visible}
variant={"solid"}
border_mode={dragging ? "focus" : "base"}
@@ -51,36 +114,13 @@
{scale}
{min_width}
allow_overflow={false}
>
<StatusTracker
>
<StatusTracker
autoscroll={gradio.autoscroll}
i18n={gradio.i18n}
{...loading_status}
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
/>
{#if mode === "receive" && modality === "video"}
<StaticVideo
bind:value={value}
{label}
{show_label}
{server}
{rtc_configuration}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
{:else if mode == "receive" && modality === "audio"}
<StaticAudio
bind:value={value}
{label}
{show_label}
{server}
{rtc_configuration}
i18n={gradio.i18n}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
{:else if mode === "send-receive" && modality === "video"}
<Video
bind:value={value}
{label}
@@ -105,5 +145,5 @@
>
<UploadText i18n={gradio.i18n} type="video" />
</Video>
{/if}
</Block>
</Block>
{/if}

View File

@@ -38,7 +38,8 @@
let dragging = false;
$: dispatch("drag", dragging);
$: console.log("interactive value", value);
$: console.log("value", value)
</script>
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />

View File

@@ -49,34 +49,27 @@
$: if( value === "start_webrtc_stream") {
stream_state = "connecting";
value = _webrtc_id;
const fallback_config = {
iceServers: [
{
urls: 'stun:stun.l.google.com:19302'
}
]
};
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
console.log("connected");
stream_state = "open";
break;
case "disconnected":
console.log("closed");
stop(pc);
break;
default:
break;
}
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
console.info("connected");
stream_state = "open";
break;
case "disconnected":
console.info("closed");
stop(pc);
break;
default:
break;
}
)
}
)
start(null, pc, audio_player, server.offer, _webrtc_id, "audio").then((connection) => {
pc = connection;
}).catch(() => {
console.log("catching")
console.info("catching")
dispatch("error", "Too many concurrent users. Come back later!");
});
}
@@ -91,7 +84,6 @@
float={false}
label={label || i18n("audio.audio")}
/>
<audio
class="standard-player"
class:hidden={value === "__webrtc_value__"}

View File

@@ -41,66 +41,59 @@
$: if( value === "start_webrtc_stream") {
value = _webrtc_id;
const fallback_config = {
iceServers: [
{
urls: 'stun:stun.l.google.com:19302'
}
]
};
const configuration = rtc_configuration || fallback_config;
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
console.log("connected");
stream_state = "open";
break;
case "disconnected":
console.log("closed");
stop(pc);
break;
default:
break;
}
}
)
start(null, pc, video_element, server.offer, _webrtc_id).then((connection) => {
pc = connection;
}).catch(() => {
console.log("catching")
dispatch("error", "Too many concurrent users. Come back later!");
});
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
console.log("connected");
stream_state = "open";
break;
case "disconnected":
console.log("closed");
stop(pc);
break;
default:
break;
}
}
)
start(null, pc, video_element, server.offer, _webrtc_id).then((connection) => {
pc = connection;
}).catch(() => {
console.log("catching")
dispatch("error", "Too many concurrent users. Come back later!");
});
}
</script>
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />
{#if value === "__webrtc_value__"}
<Empty unpadded_box={true} size="large"><Video /></Empty>
{/if}
<div class="wrap">
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />
{#if value === "__webrtc_value__"}
<Empty unpadded_box={true} size="large"><Video /></Empty>
{/if}
<video
class:hidden={value === "__webrtc_value__"}
bind:this={video_element}
autoplay={true}
on:loadeddata={dispatch.bind(null, "loadeddata")}
on:click={dispatch.bind(null, "click")}
on:play={dispatch.bind(null, "play")}
on:pause={dispatch.bind(null, "pause")}
on:ended={dispatch.bind(null, "ended")}
on:mouseover={dispatch.bind(null, "mouseover")}
on:mouseout={dispatch.bind(null, "mouseout")}
on:focus={dispatch.bind(null, "focus")}
on:blur={dispatch.bind(null, "blur")}
on:load
data-testid={$$props["data-testid"]}
crossorigin="anonymous"
>
<track kind="captions" />
</video>
<video
class:hidden={value === "__webrtc_value__"}
bind:this={video_element}
autoplay={true}
on:loadeddata={dispatch.bind(null, "loadeddata")}
on:click={dispatch.bind(null, "click")}
on:play={dispatch.bind(null, "play")}
on:pause={dispatch.bind(null, "pause")}
on:ended={dispatch.bind(null, "ended")}
on:mouseover={dispatch.bind(null, "mouseover")}
on:mouseout={dispatch.bind(null, "mouseout")}
on:focus={dispatch.bind(null, "focus")}
on:blur={dispatch.bind(null, "blur")}
on:load
data-testid={$$props["data-testid"]}
crossorigin="anonymous"
>
<track kind="captions" />
</video>
</div>

View File

@@ -24,7 +24,7 @@
let _time_limit: number | null = null;
export let time_limit: number | null = null;
let stream_state: "open" | "waiting" | "closed" = "closed";
export const webrtc_id = Math.random().toString(36).substring(2);
const _webrtc_id = Math.random().toString(36).substring(2);
export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
state: "open" | "closed" | "waiting"
@@ -114,19 +114,11 @@
let webcam_accessed = false;
let pc: RTCPeerConnection;
export let webrtc_id;
async function start_webrtc(): Promise<void> {
if (stream_state === 'closed') {
const fallback_config = {
iceServers: [
{
urls: 'stun:stun.l.google.com:19302'
}
]
};
const configuration = rtc_configuration || fallback_config;
console.log("config", configuration);
pc = new RTCPeerConnection(configuration);
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
@@ -136,6 +128,7 @@
break;
case "disconnected":
stream_state = "closed";
_time_limit = null;
await access_webcam();
break;
default:
@@ -144,10 +137,11 @@
}
)
stream_state = "waiting"
webrtc_id = _webrtc_id;
start(stream, pc, video_source, server.offer, webrtc_id).then((connection) => {
pc = connection;
}).catch(() => {
console.log("catching")
console.info("catching")
stream_state = "closed";
dispatch("error", "Too many concurrent users. Come back later!");
});

View File

@@ -3,7 +3,7 @@ export function createPeerConnection(pc, node) {
pc.addEventListener(
"icegatheringstatechange",
() => {
console.log(pc.iceGatheringState);
console.debug(pc.iceGatheringState);
},
false
);
@@ -11,7 +11,7 @@ export function createPeerConnection(pc, node) {
pc.addEventListener(
"iceconnectionstatechange",
() => {
console.log(pc.iceConnectionState);
console.debug(pc.iceConnectionState);
},
false
);
@@ -19,25 +19,25 @@ export function createPeerConnection(pc, node) {
pc.addEventListener(
"signalingstatechange",
() => {
console.log(pc.signalingState);
console.debug(pc.signalingState);
},
false
);
// connect audio / video from server to local
pc.addEventListener("track", (evt) => {
console.log("track event listener");
console.debug("track event listener");
if (node.srcObject !== evt.streams[0]) {
console.log("streams", evt.streams);
console.debug("streams", evt.streams);
node.srcObject = evt.streams[0];
console.log("node.srcOject", node.srcObject);
console.debug("node.srcOject", node.srcObject);
if (evt.track.kind === 'audio') {
node.volume = 1.0; // Ensure volume is up
node.muted = false;
node.autoplay = true;
// Attempt to play (needed for some browsers)
node.play().catch(e => console.log("Autoplay failed:", e));
node.play().catch(e => console.debug("Autoplay failed:", e));
}
}
});
@@ -51,11 +51,11 @@ export async function start(stream, pc: RTCPeerConnection, node, server_fn, webr
stream.getTracks().forEach((track) => {
track.applyConstraints({ frameRate: { max: 30 } });
console.log("Track stream callback", track);
console.debug("Track stream callback", track);
pc.addTrack(track, stream);
});
} else {
console.log("Creating transceiver!");
console.debug("Creating transceiver!");
pc.addTransceiver(modality, { direction: "recvonly" });
}
@@ -66,9 +66,9 @@ export async function start(stream, pc: RTCPeerConnection, node, server_fn, webr
function make_offer(server_fn: any, body): Promise<object> {
return new Promise((resolve, reject) => {
server_fn(body).then((data) => {
console.log("data", data)
console.debug("data", data)
if(data?.status === "failed") {
console.log("rejecting")
console.debug("rejecting")
reject("error")
}
resolve(data);
@@ -89,13 +89,13 @@ async function negotiate(
.then(() => {
// wait for ICE gathering to complete
return new Promise<void>((resolve) => {
console.log("ice gathering state", pc.iceGatheringState);
console.debug("ice gathering state", pc.iceGatheringState);
if (pc.iceGatheringState === "complete") {
resolve();
} else {
const checkState = () => {
if (pc.iceGatheringState === "complete") {
console.log("ice complete");
console.debug("ice complete");
pc.removeEventListener("icegatheringstatechange", checkState);
resolve();
}
@@ -124,8 +124,7 @@ async function negotiate(
}
export function stop(pc: RTCPeerConnection) {
console.log("pc", pc);
console.log("STOPPING");
console.debug("Stopping peer connection");
// close transceivers
if (pc.getTransceivers) {
pc.getTransceivers().forEach((transceiver) => {

View File

@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
[project]
name = "gradio_webrtc"
version = "0.0.1"
version = "0.0.2"
description = "Stream images in realtime with webrtc"
readme = "README.md"
license = "apache-2.0"