Lots of bugs

This commit is contained in:
freddyaboulton
2024-10-11 13:01:26 -07:00
parent d681dbfd7e
commit 23b1e20c05
13 changed files with 214 additions and 241 deletions

View File

@@ -1,10 +1,15 @@
import time
import fractions
import av
import asyncio
import fractions
import threading
import time
from typing import Callable
import av
import logging
logger = logging.getLogger(__name__)
AUDIO_PTIME = 0.020
@@ -39,7 +44,7 @@ def player_worker_decode(
frame = next(generator)
except Exception as exc:
if isinstance(exc, StopIteration):
print("Not iterating")
logger.debug("Stopping audio stream")
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
thread_quit.set()
break

View File

@@ -1,31 +1,33 @@
"""gr.Video() component."""
"""gr.WebRTC() component."""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, Literal, cast, Generator
import fractions
import logging
import threading
import time
from gradio_webrtc.utils import player_worker_decode
import traceback
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, Generator, Sequence, cast
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay
from aiortc import VideoStreamTrack, AudioStreamTrack
from aiortc.mediastreams import MediaStreamError
from aiortc.contrib.media import AudioFrame, VideoFrame # type: ignore
from gradio_client import handle_file
import numpy as np
from aiortc import (
AudioStreamTrack,
RTCPeerConnection,
RTCSessionDescription,
VideoStreamTrack,
)
from aiortc.contrib.media import AudioFrame, MediaRelay, VideoFrame # type: ignore
from aiortc.mediastreams import MediaStreamError
from gradio import wasm_utils
from gradio.components.base import Component, server
from gradio_client import handle_file
from gradio_webrtc.utils import player_worker_decode
if TYPE_CHECKING:
from gradio.components import Timer
from gradio.blocks import Block
from gradio.components import Timer
from gradio.events import Dependency
@@ -33,6 +35,9 @@ if wasm_utils.IS_WASM:
raise ValueError("Not supported in gradio-lite!")
logger = logging.getLogger(__name__)
class VideoCallback(VideoStreamTrack):
"""
This works for streaming input and output
@@ -90,10 +95,9 @@ class VideoCallback(VideoStreamTrack):
return new_frame
except Exception as e:
print(e)
import traceback
traceback.print_exc()
logger.debug(e)
exec = traceback.format_exc()
logger.debug(exec)
class ServerToClientVideo(VideoStreamTrack):
@@ -150,10 +154,9 @@ class ServerToClientVideo(VideoStreamTrack):
next_frame.time_base = time_base
return next_frame
except Exception as e:
print(e)
import traceback
traceback.print_exc()
logger.debug(e)
exec = traceback.format_exc()
logger.debug(exec)
class ServerToClientAudio(AudioStreamTrack):
@@ -173,26 +176,6 @@ class ServerToClientAudio(AudioStreamTrack):
self._start: float | None = None
super().__init__()
def array_to_frame(self, array: tuple[int, np.ndarray]) -> AudioFrame:
frame = AudioFrame.from_ndarray(array[1], format="s16", layout="mono")
frame.sample_rate = array[0]
frame.time_base = fractions.Fraction(1, array[0])
self.current_timestamp += array[1].shape[1]
frame.pts = self.current_timestamp
return frame
async def empty_frame(self) -> AudioFrame:
sample_rate = 22050
samples = 100
frame = AudioFrame(format="s16", layout="mono", samples=samples)
for p in frame.planes:
p.update(bytes(p.buffer_size))
frame.sample_rate = sample_rate
frame.time_base = fractions.Fraction(1, sample_rate)
self.current_timestamp += samples
frame.pts = self.current_timestamp
return frame
def start(self):
if self.__thread is None:
self.__thread = threading.Thread(
@@ -232,10 +215,9 @@ class ServerToClientAudio(AudioStreamTrack):
return data
except Exception as e:
print(e)
import traceback
traceback.print_exc()
logger.debug(e)
exec = traceback.format_exc()
logger.debug(exec)
def stop(self):
self.thread_quit.set()
@@ -244,39 +226,6 @@ class ServerToClientAudio(AudioStreamTrack):
self.__thread = None
super().stop()
# next_frame = await super().recv()
# print("next frame", next_frame)
# return next_frame
# try:
# if self.latest_args == "not_set":
# frame = await self.empty_frame()
# # await self.modify_frame(frame)
# await asyncio.sleep(100 / 22050)
# print("next_frame not set", frame)
# return frame
# if self.generator is None:
# self.generator = cast(
# Generator[Any, None, Any], self.event_handler(*self.latest_args)
# )
# try:
# next_array = next(self.generator)
# print("iteration")
# except StopIteration:
# print("exception")
# self.stop() # type: ignore
# return
# next_frame = self.array_to_frame(next_array)
# # await self.modify_frame(next_frame)
# print("next frame", next_frame)
# return next_frame
# except Exception as e:
# print(e)
# import traceback
# traceback.print_exc()
class WebRTC(Component):
"""
@@ -485,7 +434,8 @@ class WebRTC(Component):
@server
async def offer(self, body):
print("starting")
logger.debug("Starting to handle offer")
logger.debug("Offer body", body)
if len(self.connections) >= cast(int, self.concurrency_limit):
return {"status": "failed"}
@@ -496,7 +446,7 @@ class WebRTC(Component):
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
print(pc.iceConnectionState)
logger.debug("ICE connection state change", pc.iceConnectionState)
if pc.iceConnectionState == "failed":
await pc.close()
self.connections.pop(body["webrtc_id"], None)
@@ -519,32 +469,27 @@ class WebRTC(Component):
event_handler=cast(Callable, self.event_handler),
)
self.connections[body["webrtc_id"]] = cb
logger.debug("Adding track to peer connection", cb)
pc.addTrack(cb)
if self.mode == "receive" and self.modality == "video":
cb = ServerToClientVideo(cast(Callable, self.event_handler))
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
if self.mode == "receive" and self.modality == "audio":
print("adding")
cb = ServerToClientAudio(cast(Callable, self.event_handler))
print("cb.recv", cb.recv)
# from aiortc.contrib.media import MediaPlayer
# player = MediaPlayer("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
# pc.addTrack(player.audio)
if self.mode == "receive":
if self.modality == "video":
cb = ServerToClientVideo(cast(Callable, self.event_handler))
elif self.modality == "audio":
cb = ServerToClientAudio(cast(Callable, self.event_handler))
logger.debug("Adding track to peer connection", cb)
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
print("here")
# handle offer
await pc.setRemoteDescription(offer)
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer) # type: ignore
print("done")
logger.debug("done handling offer about to return")
return {
"sdp": pc.localDescription.sdp,