mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Audio in only (#15)
* Audio + Video / test Audio * Add code * Fix demo * support additional inputs * Add code * Add code
This commit is contained in:
@@ -2,11 +2,14 @@ from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from logging import getLogger
|
||||
from threading import Event
|
||||
from typing import Callable, Generator, Literal, cast
|
||||
import inspect
|
||||
from typing import Any, Callable, Generator, Literal, Union, cast
|
||||
import asyncio
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
|
||||
from gradio_webrtc.utils import AdditionalOutputs
|
||||
from gradio_webrtc.webrtc import StreamHandler
|
||||
|
||||
logger = getLogger(__name__)
|
||||
@@ -40,12 +43,29 @@ class AppState:
|
||||
buffer: np.ndarray | None = None
|
||||
|
||||
|
||||
ReplyFnGenerator = Callable[
|
||||
[tuple[int, np.ndarray]],
|
||||
Generator[
|
||||
tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]],
|
||||
None,
|
||||
None,
|
||||
ReplyFnGenerator = Union[
|
||||
# For two arguments
|
||||
Callable[
|
||||
[tuple[int, np.ndarray], list[dict[Any, Any]]],
|
||||
Generator[
|
||||
tuple[int, np.ndarray]
|
||||
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
|
||||
| AdditionalOutputs
|
||||
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
||||
None,
|
||||
None,
|
||||
],
|
||||
],
|
||||
Callable[
|
||||
[tuple[int, np.ndarray]],
|
||||
Generator[
|
||||
tuple[int, np.ndarray]
|
||||
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
|
||||
| AdditionalOutputs
|
||||
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
||||
None,
|
||||
None,
|
||||
],
|
||||
],
|
||||
]
|
||||
|
||||
@@ -71,6 +91,12 @@ class ReplyOnPause(StreamHandler):
|
||||
self.generator = None
|
||||
self.model_options = model_options
|
||||
self.algo_options = algo_options or AlgoOptions()
|
||||
self.latest_args: list[Any] = []
|
||||
self.args_set = Event()
|
||||
|
||||
@property
|
||||
def _needs_additional_inputs(self) -> bool:
|
||||
return len(inspect.signature(self.fn).parameters) > 1
|
||||
|
||||
def copy(self):
|
||||
return ReplyOnPause(
|
||||
@@ -130,17 +156,38 @@ class ReplyOnPause(StreamHandler):
|
||||
self.event.set()
|
||||
|
||||
def reset(self):
|
||||
self.args_set.clear()
|
||||
self.generator = None
|
||||
self.event.clear()
|
||||
self.state = AppState()
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
super().set_args(args)
|
||||
self.args_set.set()
|
||||
|
||||
async def fetch_args(
|
||||
self,
|
||||
):
|
||||
if self.channel:
|
||||
self.channel.send("tick")
|
||||
logger.debug("Sent tick")
|
||||
|
||||
def emit(self):
|
||||
if not self.event.is_set():
|
||||
return None
|
||||
else:
|
||||
if not self.generator:
|
||||
if self._needs_additional_inputs and not self.args_set.is_set():
|
||||
asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop)
|
||||
self.args_set.wait()
|
||||
logger.debug("Creating generator")
|
||||
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
|
||||
self.generator = self.fn((self.state.sampling_rate, audio))
|
||||
if self._needs_additional_inputs:
|
||||
self.latest_args[0] = (self.state.sampling_rate, audio)
|
||||
self.generator = self.fn(*self.latest_args)
|
||||
else:
|
||||
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
|
||||
logger.debug("Latest args: %s", self.latest_args)
|
||||
self.state.responding = True
|
||||
try:
|
||||
return next(self.generator)
|
||||
|
||||
@@ -22,6 +22,8 @@ class DataChannel(Protocol):
|
||||
|
||||
|
||||
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||
if isinstance(data, AdditionalOutputs):
|
||||
return None, data
|
||||
if isinstance(data, tuple):
|
||||
# handle the bare audio case
|
||||
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
|
||||
|
||||
@@ -72,6 +72,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
event_handler: Callable,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
mode: Literal["send-receive", "send"] = "send-receive",
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
self.track = track
|
||||
@@ -79,6 +80,14 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.thread_quit = asyncio.Event()
|
||||
self.mode = mode
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||
|
||||
def add_frame_to_payload(
|
||||
self, args: list[Any], frame: np.ndarray | None
|
||||
@@ -94,11 +103,29 @@ class VideoCallback(VideoStreamTrack):
|
||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||
|
||||
async def process_frames(self):
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
await self.recv()
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
def start(
|
||||
self,
|
||||
):
|
||||
asyncio.create_task(self.process_frames())
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
logger.debug("video callback stop")
|
||||
self.thread_quit.set()
|
||||
|
||||
async def recv(self):
|
||||
try:
|
||||
try:
|
||||
frame = cast(VideoFrame, await self.track.recv())
|
||||
except MediaStreamError:
|
||||
self.stop()
|
||||
return
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
|
||||
@@ -115,6 +142,8 @@ class VideoCallback(VideoStreamTrack):
|
||||
):
|
||||
self.set_additional_outputs(outputs)
|
||||
self.channel.send("change")
|
||||
if array is None and self.mode == "send":
|
||||
return
|
||||
|
||||
new_frame = self.array_to_frame(array)
|
||||
if frame:
|
||||
@@ -142,7 +171,25 @@ class StreamHandler(ABC):
|
||||
self.expected_layout = expected_layout
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_frame_size = output_frame_size
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self._resampler = None
|
||||
self._channel: DataChannel | None = None
|
||||
self._loop: asyncio.AbstractEventLoop
|
||||
|
||||
@property
|
||||
def loop(self) -> asyncio.AbstractEventLoop:
|
||||
return cast(asyncio.AbstractEventLoop, self._loop)
|
||||
|
||||
@property
|
||||
def channel(self) -> DataChannel | None:
|
||||
return self._channel
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self._channel = channel
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
logger.debug("setting args in audio callback %s", args)
|
||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||
|
||||
@abstractmethod
|
||||
def copy(self) -> "StreamHandler":
|
||||
@@ -190,6 +237,13 @@ class AudioCallback(AudioStreamTrack):
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
super().__init__()
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
self.event_handler.set_channel(channel)
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
self.event_handler.set_args(args)
|
||||
|
||||
async def process_input_frames(self) -> None:
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
@@ -284,6 +338,13 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
def array_to_frame(self, array: np.ndarray) -> VideoFrame:
|
||||
return VideoFrame.from_ndarray(array, format="bgr24")
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
self.latest_args = list(args)
|
||||
self.args_set.set()
|
||||
|
||||
async def recv(self):
|
||||
try:
|
||||
pts, time_base = await self.next_timestamp()
|
||||
@@ -338,6 +399,13 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
self._start: float | None = None
|
||||
super().__init__()
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
self.latest_args = list(args)
|
||||
self.args_set.set()
|
||||
|
||||
def next(self) -> tuple[int, np.ndarray] | None:
|
||||
self.args_set.wait()
|
||||
if self.generator is None:
|
||||
@@ -447,7 +515,7 @@ class WebRTC(Component):
|
||||
rtc_configuration: dict[str, Any] | None = None,
|
||||
track_constraints: dict[str, Any] | None = None,
|
||||
time_limit: float | None = None,
|
||||
mode: Literal["send-receive", "receive"] = "send-receive",
|
||||
mode: Literal["send-receive", "receive", "send"] = "send-receive",
|
||||
modality: Literal["video", "audio"] = "video",
|
||||
):
|
||||
"""
|
||||
@@ -549,17 +617,11 @@ class WebRTC(Component):
|
||||
"""
|
||||
return value
|
||||
|
||||
def set_output(self, webrtc_id: str, *args):
|
||||
def set_input(self, webrtc_id: str, *args):
|
||||
if webrtc_id in self.connections:
|
||||
if self.mode == "send-receive":
|
||||
self.connections[webrtc_id].latest_args = ["__webrtc_value__"] + list(
|
||||
args
|
||||
)
|
||||
elif self.mode == "receive":
|
||||
self.connections[webrtc_id].latest_args = list(args)
|
||||
self.connections[webrtc_id].args_set.set() # type: ignore
|
||||
self.connections[webrtc_id].set_args(list(args))
|
||||
|
||||
def change(
|
||||
def on_additional_outputs(
|
||||
self,
|
||||
fn: Callable[Concatenate[P], R],
|
||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
@@ -628,7 +690,7 @@ class WebRTC(Component):
|
||||
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
|
||||
)
|
||||
|
||||
if self.mode == "send-receive":
|
||||
if self.mode == "send-receive" or self.mode == "send":
|
||||
if cast(list[Block], inputs)[0] != self:
|
||||
raise ValueError(
|
||||
"In the webrtc stream event, the first input component must be the WebRTC component."
|
||||
@@ -642,7 +704,7 @@ class WebRTC(Component):
|
||||
"In the webrtc stream event, the only output component must be the WebRTC component."
|
||||
)
|
||||
return self.tick( # type: ignore
|
||||
self.set_output,
|
||||
self.set_input,
|
||||
inputs=inputs,
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
@@ -669,7 +731,7 @@ class WebRTC(Component):
|
||||
)
|
||||
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
|
||||
self.tick( # type: ignore
|
||||
self.set_output,
|
||||
self.set_input,
|
||||
inputs=[self] + list(inputs),
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
@@ -680,6 +742,12 @@ class WebRTC(Component):
|
||||
await asyncio.sleep(time_limit)
|
||||
await pc.close()
|
||||
|
||||
def clean_up(self, webrtc_id: str):
|
||||
connection = self.connections.pop(webrtc_id, None)
|
||||
self.additional_outputs.pop(webrtc_id, None)
|
||||
self.data_channels.pop(webrtc_id, None)
|
||||
return connection
|
||||
|
||||
@server
|
||||
async def offer(self, body):
|
||||
logger.debug("Starting to handle offer")
|
||||
@@ -707,7 +775,7 @@ class WebRTC(Component):
|
||||
logger.debug("pc.connectionState %s", pc.connectionState)
|
||||
if pc.connectionState in ["failed", "closed"]:
|
||||
await pc.close()
|
||||
connection = self.connections.pop(body["webrtc_id"], None)
|
||||
connection = self.clean_up(body["webrtc_id"])
|
||||
if connection:
|
||||
connection.stop()
|
||||
self.pcs.discard(pc)
|
||||
@@ -723,20 +791,26 @@ class WebRTC(Component):
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
mode=cast(Literal["send", "send-receive"], self.mode),
|
||||
)
|
||||
elif self.modality == "audio":
|
||||
handler = cast(StreamHandler, self.event_handler).copy()
|
||||
handler._loop = asyncio.get_running_loop()
|
||||
cb = AudioCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(StreamHandler, self.event_handler).copy(),
|
||||
event_handler=handler,
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
if body["webrtc_id"] in self.data_channels:
|
||||
self.connections[body["webrtc_id"]].channel = self.data_channels[
|
||||
body["webrtc_id"]
|
||||
]
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
self.connections[body["webrtc_id"]].set_channel(
|
||||
self.data_channels[body["webrtc_id"]]
|
||||
)
|
||||
if self.mode == "send-receive":
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
elif self.mode == "send":
|
||||
cast(AudioCallback | VideoCallback, cb).start()
|
||||
|
||||
if self.mode == "receive":
|
||||
if self.modality == "video":
|
||||
@@ -753,21 +827,19 @@ class WebRTC(Component):
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
|
||||
cb.on("ended", lambda: self.clean_up(body["webrtc_id"]))
|
||||
|
||||
@pc.on("datachannel")
|
||||
def on_datachannel(channel):
|
||||
print("data channel established")
|
||||
logger.debug(f"Data channel established: {channel.label}")
|
||||
|
||||
self.data_channels[body["webrtc_id"]] = channel
|
||||
|
||||
async def set_channel(webrtc_id: str):
|
||||
print("webrtc_id", webrtc_id)
|
||||
while not self.connections.get(webrtc_id):
|
||||
await asyncio.sleep(0.05)
|
||||
print("setting channel")
|
||||
self.connections[webrtc_id].channel = channel
|
||||
logger.debug("setting channel for webrtc id %s", webrtc_id)
|
||||
self.connections[webrtc_id].set_channel(channel)
|
||||
|
||||
asyncio.create_task(set_channel(body["webrtc_id"]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user