Audio in only (#15)

* Audio + Video / test Audio

* Add code

* Fix demo

* support additional inputs

* Add code

* Add code
This commit is contained in:
Freddy Boulton
2024-10-30 13:08:09 -04:00
committed by GitHub
parent 2068b91854
commit 3bf4a437fb
29 changed files with 1613 additions and 416 deletions

View File

@@ -2,11 +2,14 @@ from dataclasses import dataclass
from functools import lru_cache
from logging import getLogger
from threading import Event
from typing import Callable, Generator, Literal, cast
import inspect
from typing import Any, Callable, Generator, Literal, Union, cast
import asyncio
import numpy as np
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
from gradio_webrtc.utils import AdditionalOutputs
from gradio_webrtc.webrtc import StreamHandler
logger = getLogger(__name__)
@@ -40,12 +43,29 @@ class AppState:
buffer: np.ndarray | None = None
ReplyFnGenerator = Callable[
[tuple[int, np.ndarray]],
Generator[
tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]],
None,
None,
ReplyFnGenerator = Union[
# For two arguments
Callable[
[tuple[int, np.ndarray], list[dict[Any, Any]]],
Generator[
tuple[int, np.ndarray]
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
| AdditionalOutputs
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
None,
None,
],
],
Callable[
[tuple[int, np.ndarray]],
Generator[
tuple[int, np.ndarray]
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
| AdditionalOutputs
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
None,
None,
],
],
]
@@ -71,6 +91,12 @@ class ReplyOnPause(StreamHandler):
self.generator = None
self.model_options = model_options
self.algo_options = algo_options or AlgoOptions()
self.latest_args: list[Any] = []
self.args_set = Event()
@property
def _needs_additional_inputs(self) -> bool:
return len(inspect.signature(self.fn).parameters) > 1
def copy(self):
return ReplyOnPause(
@@ -130,17 +156,38 @@ class ReplyOnPause(StreamHandler):
self.event.set()
def reset(self):
self.args_set.clear()
self.generator = None
self.event.clear()
self.state = AppState()
def set_args(self, args: list[Any]):
super().set_args(args)
self.args_set.set()
async def fetch_args(
self,
):
if self.channel:
self.channel.send("tick")
logger.debug("Sent tick")
def emit(self):
if not self.event.is_set():
return None
else:
if not self.generator:
if self._needs_additional_inputs and not self.args_set.is_set():
asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop)
self.args_set.wait()
logger.debug("Creating generator")
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
self.generator = self.fn((self.state.sampling_rate, audio))
if self._needs_additional_inputs:
self.latest_args[0] = (self.state.sampling_rate, audio)
self.generator = self.fn(*self.latest_args)
else:
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
logger.debug("Latest args: %s", self.latest_args)
self.state.responding = True
try:
return next(self.generator)

View File

@@ -22,6 +22,8 @@ class DataChannel(Protocol):
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
if isinstance(data, AdditionalOutputs):
return None, data
if isinstance(data, tuple):
# handle the bare audio case
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):

View File

@@ -72,6 +72,7 @@ class VideoCallback(VideoStreamTrack):
event_handler: Callable,
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
mode: Literal["send-receive", "send"] = "send-receive",
) -> None:
super().__init__() # don't forget this!
self.track = track
@@ -79,6 +80,14 @@ class VideoCallback(VideoStreamTrack):
self.latest_args: str | list[Any] = "not_set"
self.channel = channel
self.set_additional_outputs = set_additional_outputs
self.thread_quit = asyncio.Event()
self.mode = mode
def set_channel(self, channel: DataChannel):
self.channel = channel
def set_args(self, args: list[Any]):
self.latest_args = ["__webrtc_value__"] + list(args)
def add_frame_to_payload(
self, args: list[Any], frame: np.ndarray | None
@@ -94,11 +103,29 @@ class VideoCallback(VideoStreamTrack):
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
async def process_frames(self):
while not self.thread_quit.is_set():
try:
await self.recv()
except TimeoutError:
continue
def start(
self,
):
asyncio.create_task(self.process_frames())
def stop(self):
super().stop()
logger.debug("video callback stop")
self.thread_quit.set()
async def recv(self):
try:
try:
frame = cast(VideoFrame, await self.track.recv())
except MediaStreamError:
self.stop()
return
frame_array = frame.to_ndarray(format="bgr24")
@@ -115,6 +142,8 @@ class VideoCallback(VideoStreamTrack):
):
self.set_additional_outputs(outputs)
self.channel.send("change")
if array is None and self.mode == "send":
return
new_frame = self.array_to_frame(array)
if frame:
@@ -142,7 +171,25 @@ class StreamHandler(ABC):
self.expected_layout = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.latest_args: str | list[Any] = "not_set"
self._resampler = None
self._channel: DataChannel | None = None
self._loop: asyncio.AbstractEventLoop
@property
def loop(self) -> asyncio.AbstractEventLoop:
return cast(asyncio.AbstractEventLoop, self._loop)
@property
def channel(self) -> DataChannel | None:
return self._channel
def set_channel(self, channel: DataChannel):
self._channel = channel
def set_args(self, args: list[Any]):
logger.debug("setting args in audio callback %s", args)
self.latest_args = ["__webrtc_value__"] + list(args)
@abstractmethod
def copy(self) -> "StreamHandler":
@@ -190,6 +237,13 @@ class AudioCallback(AudioStreamTrack):
self.set_additional_outputs = set_additional_outputs
super().__init__()
def set_channel(self, channel: DataChannel):
self.channel = channel
self.event_handler.set_channel(channel)
def set_args(self, args: list[Any]):
self.event_handler.set_args(args)
async def process_input_frames(self) -> None:
while not self.thread_quit.is_set():
try:
@@ -284,6 +338,13 @@ class ServerToClientVideo(VideoStreamTrack):
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
return VideoFrame.from_ndarray(array, format="bgr24")
def set_channel(self, channel: DataChannel):
self.channel = channel
def set_args(self, args: list[Any]):
self.latest_args = list(args)
self.args_set.set()
async def recv(self):
try:
pts, time_base = await self.next_timestamp()
@@ -338,6 +399,13 @@ class ServerToClientAudio(AudioStreamTrack):
self._start: float | None = None
super().__init__()
def set_channel(self, channel: DataChannel):
self.channel = channel
def set_args(self, args: list[Any]):
self.latest_args = list(args)
self.args_set.set()
def next(self) -> tuple[int, np.ndarray] | None:
self.args_set.wait()
if self.generator is None:
@@ -447,7 +515,7 @@ class WebRTC(Component):
rtc_configuration: dict[str, Any] | None = None,
track_constraints: dict[str, Any] | None = None,
time_limit: float | None = None,
mode: Literal["send-receive", "receive"] = "send-receive",
mode: Literal["send-receive", "receive", "send"] = "send-receive",
modality: Literal["video", "audio"] = "video",
):
"""
@@ -549,17 +617,11 @@ class WebRTC(Component):
"""
return value
def set_output(self, webrtc_id: str, *args):
def set_input(self, webrtc_id: str, *args):
if webrtc_id in self.connections:
if self.mode == "send-receive":
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(
args
)
elif self.mode == "receive":
self.connections[webrtc_id].latest_args = list(args)
self.connections[webrtc_id].args_set.set() # type: ignore
self.connections[webrtc_id].set_args(list(args))
def change(
def on_additional_outputs(
self,
fn: Callable[Concatenate[P], R],
inputs: Block | Sequence[Block] | set[Block] | None = None,
@@ -628,7 +690,7 @@ class WebRTC(Component):
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
)
if self.mode == "send-receive":
if self.mode == "send-receive" or self.mode == "send":
if cast(list[Block], inputs)[0] != self:
raise ValueError(
"In the webrtc stream event, the first input component must be the WebRTC component."
@@ -642,7 +704,7 @@ class WebRTC(Component):
"In the webrtc stream event, the only output component must be the WebRTC component."
)
return self.tick( # type: ignore
self.set_output,
self.set_input,
inputs=inputs,
outputs=None,
concurrency_id=concurrency_id,
@@ -669,7 +731,7 @@ class WebRTC(Component):
)
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
self.tick( # type: ignore
self.set_output,
self.set_input,
inputs=[self] + list(inputs),
outputs=None,
concurrency_id=concurrency_id,
@@ -680,6 +742,12 @@ class WebRTC(Component):
await asyncio.sleep(time_limit)
await pc.close()
def clean_up(self, webrtc_id: str):
connection = self.connections.pop(webrtc_id, None)
self.additional_outputs.pop(webrtc_id, None)
self.data_channels.pop(webrtc_id, None)
return connection
@server
async def offer(self, body):
logger.debug("Starting to handle offer")
@@ -707,7 +775,7 @@ class WebRTC(Component):
logger.debug("pc.connectionState %s", pc.connectionState)
if pc.connectionState in ["failed", "closed"]:
await pc.close()
connection = self.connections.pop(body["webrtc_id"], None)
connection = self.clean_up(body["webrtc_id"])
if connection:
connection.stop()
self.pcs.discard(pc)
@@ -723,20 +791,26 @@ class WebRTC(Component):
relay.subscribe(track),
event_handler=cast(Callable, self.event_handler),
set_additional_outputs=set_outputs,
mode=cast(Literal["send", "send-receive"], self.mode),
)
elif self.modality == "audio":
handler = cast(StreamHandler, self.event_handler).copy()
handler._loop = asyncio.get_running_loop()
cb = AudioCallback(
relay.subscribe(track),
event_handler=cast(StreamHandler, self.event_handler).copy(),
event_handler=handler,
set_additional_outputs=set_outputs,
)
self.connections[body["webrtc_id"]] = cb
if body["webrtc_id"] in self.data_channels:
self.connections[body["webrtc_id"]].channel = self.data_channels[
body["webrtc_id"]
]
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
self.connections[body["webrtc_id"]].set_channel(
self.data_channels[body["webrtc_id"]]
)
if self.mode == "send-receive":
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
elif self.mode == "send":
cast(AudioCallback | VideoCallback, cb).start()
if self.mode == "receive":
if self.modality == "video":
@@ -753,21 +827,19 @@ class WebRTC(Component):
logger.debug("Adding track to peer connection %s", cb)
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
cb.on("ended", lambda: self.clean_up(body["webrtc_id"]))
@pc.on("datachannel")
def on_datachannel(channel):
print("data channel established")
logger.debug(f"Data channel established: {channel.label}")
self.data_channels[body["webrtc_id"]] = channel
async def set_channel(webrtc_id: str):
print("webrtc_id", webrtc_id)
while not self.connections.get(webrtc_id):
await asyncio.sleep(0.05)
print("setting channel")
self.connections[webrtc_id].channel = channel
logger.debug("setting channel for webrtc id %s", webrtc_id)
self.connections[webrtc_id].set_channel(channel)
asyncio.create_task(set_channel(body["webrtc_id"]))