mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
@@ -19,7 +19,9 @@ from typing import (
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Sequence,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -161,7 +163,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
logger.debug("traceback %s", exec)
|
||||
|
||||
|
||||
class StreamHandler(ABC):
|
||||
class StreamHandlerBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||
@@ -173,10 +175,11 @@ class StreamHandler(ABC):
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_frame_size = output_frame_size
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.latest_args: list[Any] = []
|
||||
self._resampler = None
|
||||
self._channel: DataChannel | None = None
|
||||
self._loop: asyncio.AbstractEventLoop
|
||||
self.args_set = asyncio.Event()
|
||||
|
||||
@property
|
||||
def loop(self) -> asyncio.AbstractEventLoop:
|
||||
@@ -189,15 +192,30 @@ class StreamHandler(ABC):
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self._channel = channel
|
||||
|
||||
async def fetch_args(
|
||||
self,
|
||||
):
|
||||
if self.channel:
|
||||
self.channel.send("tick")
|
||||
logger.debug("Sent tick")
|
||||
|
||||
async def wait_for_args(self):
|
||||
await self.fetch_args()
|
||||
await self.args_set.wait()
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
logger.debug("setting args in audio callback %s", args)
|
||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||
self.args_set.set()
|
||||
|
||||
def reset(self):
|
||||
self.args_set.clear()
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def copy(self) -> "StreamHandler":
|
||||
def copy(self) -> "StreamHandlerBase":
|
||||
pass
|
||||
|
||||
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
|
||||
@@ -210,6 +228,17 @@ class StreamHandler(ABC):
|
||||
)
|
||||
yield from self._resampler.resample(frame)
|
||||
|
||||
|
||||
EmitType: TypeAlias = Union[
|
||||
tuple[int, np.ndarray],
|
||||
tuple[int, np.ndarray, Literal["mono", "stereo"]],
|
||||
AdditionalOutputs,
|
||||
tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
||||
None,
|
||||
]
|
||||
|
||||
|
||||
class StreamHandler(StreamHandlerBase):
|
||||
@abstractmethod
|
||||
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
pass
|
||||
@@ -217,22 +246,32 @@ class StreamHandler(ABC):
|
||||
@abstractmethod
|
||||
def emit(
|
||||
self,
|
||||
) -> (
|
||||
tuple[int, np.ndarray]
|
||||
| AdditionalOutputs
|
||||
| None
|
||||
| tuple[tuple[int, np.ndarray], AdditionalOutputs]
|
||||
):
|
||||
) -> EmitType:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncStreamHandler(StreamHandlerBase):
|
||||
@abstractmethod
|
||||
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def emit(
|
||||
self,
|
||||
) -> EmitType:
|
||||
pass
|
||||
|
||||
|
||||
StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler]
|
||||
|
||||
|
||||
class AudioCallback(AudioStreamTrack):
|
||||
kind = "audio"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
track: MediaStreamTrack,
|
||||
event_handler: StreamHandler,
|
||||
event_handler: StreamHandlerImpl,
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
@@ -262,9 +301,14 @@ class AudioCallback(AudioStreamTrack):
|
||||
frame = cast(AudioFrame, await self.track.recv())
|
||||
for frame in self.event_handler.resample(frame):
|
||||
numpy_array = frame.to_ndarray()
|
||||
await anyio.to_thread.run_sync(
|
||||
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
||||
)
|
||||
if isinstance(self.event_handler, AsyncStreamHandler):
|
||||
await self.event_handler.receive(
|
||||
(frame.sample_rate, numpy_array)
|
||||
)
|
||||
else:
|
||||
await anyio.to_thread.run_sync(
|
||||
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
||||
)
|
||||
except MediaStreamError:
|
||||
logger.debug("MediaStreamError in process_input_frames")
|
||||
break
|
||||
@@ -272,9 +316,12 @@ class AudioCallback(AudioStreamTrack):
|
||||
def start(self):
|
||||
if not self.has_started:
|
||||
loop = asyncio.get_running_loop()
|
||||
callable = functools.partial(
|
||||
loop.run_in_executor, None, self.event_handler.emit
|
||||
)
|
||||
if isinstance(self.event_handler, AsyncStreamHandler):
|
||||
callable = self.event_handler.emit
|
||||
else:
|
||||
callable = functools.partial(
|
||||
loop.run_in_executor, None, self.event_handler.emit
|
||||
)
|
||||
asyncio.create_task(self.process_input_frames())
|
||||
asyncio.create_task(
|
||||
player_worker_decode(
|
||||
@@ -692,7 +739,7 @@ class WebRTC(Component):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
fn: Callable[..., Any] | StreamHandler | None = None,
|
||||
fn: Callable[..., Any] | StreamHandler | AsyncStreamHandler | None = None,
|
||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||
js: str | None = None,
|
||||
@@ -721,7 +768,7 @@ class WebRTC(Component):
|
||||
if (
|
||||
self.mode == "send-receive"
|
||||
and self.modality == "audio"
|
||||
and not isinstance(self.event_handler, StreamHandler)
|
||||
and not isinstance(self.event_handler, (AsyncStreamHandler, StreamHandler))
|
||||
):
|
||||
raise ValueError(
|
||||
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
|
||||
@@ -840,6 +887,8 @@ class WebRTC(Component):
|
||||
event_handler=handler,
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Modality must be either video or audio")
|
||||
self.connections[body["webrtc_id"]] = cb
|
||||
if body["webrtc_id"] in self.data_channels:
|
||||
self.connections[body["webrtc_id"]].set_channel(
|
||||
@@ -862,6 +911,8 @@ class WebRTC(Component):
|
||||
cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Modality must be either video or audio")
|
||||
|
||||
logger.debug("Adding track to peer connection %s", cb)
|
||||
pc.addTrack(cb)
|
||||
|
||||
Reference in New Issue
Block a user