mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
@@ -6,12 +6,22 @@ from .credentials import (
|
|||||||
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions
|
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions
|
||||||
from .reply_on_stopwords import ReplyOnStopWords
|
from .reply_on_stopwords import ReplyOnStopWords
|
||||||
from .speech_to_text import stt, stt_for_chunks
|
from .speech_to_text import stt, stt_for_chunks
|
||||||
from .utils import AdditionalOutputs, audio_to_bytes, audio_to_file, audio_to_float32
|
from .utils import (
|
||||||
from .webrtc import StreamHandler, WebRTC
|
AdditionalOutputs,
|
||||||
|
aggregate_bytes_to_16bit,
|
||||||
|
async_aggregate_bytes_to_16bit,
|
||||||
|
audio_to_bytes,
|
||||||
|
audio_to_file,
|
||||||
|
audio_to_float32,
|
||||||
|
)
|
||||||
|
from .webrtc import AsyncStreamHandler, StreamHandler, WebRTC
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AsyncStreamHandler",
|
||||||
"AlgoOptions",
|
"AlgoOptions",
|
||||||
"AdditionalOutputs",
|
"AdditionalOutputs",
|
||||||
|
"aggregate_bytes_to_16bit",
|
||||||
|
"async_aggregate_bytes_to_16bit",
|
||||||
"audio_to_bytes",
|
"audio_to_bytes",
|
||||||
"audio_to_file",
|
"audio_to_file",
|
||||||
"audio_to_float32",
|
"audio_to_float32",
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import numpy as np
|
|||||||
|
|
||||||
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
|
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
|
||||||
from gradio_webrtc.utils import AdditionalOutputs
|
from gradio_webrtc.utils import AdditionalOutputs
|
||||||
from gradio_webrtc.webrtc import StreamHandler
|
from gradio_webrtc.webrtc import EmitType, StreamHandler
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
@@ -47,25 +47,11 @@ ReplyFnGenerator = Union[
|
|||||||
# For two arguments
|
# For two arguments
|
||||||
Callable[
|
Callable[
|
||||||
[tuple[int, np.ndarray], list[dict[Any, Any]]],
|
[tuple[int, np.ndarray], list[dict[Any, Any]]],
|
||||||
Generator[
|
Generator[EmitType, None, None],
|
||||||
tuple[int, np.ndarray]
|
|
||||||
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
|
|
||||||
| AdditionalOutputs
|
|
||||||
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
],
|
|
||||||
],
|
],
|
||||||
Callable[
|
Callable[
|
||||||
[tuple[int, np.ndarray]],
|
[tuple[int, np.ndarray]],
|
||||||
Generator[
|
Generator[EmitType, None, None],
|
||||||
tuple[int, np.ndarray]
|
|
||||||
| tuple[int, np.ndarray, Literal["mono", "stereo"]]
|
|
||||||
| AdditionalOutputs
|
|
||||||
| tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
],
|
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -99,11 +85,9 @@ class ReplyOnPause(StreamHandler):
|
|||||||
self.is_async = inspect.isasyncgenfunction(fn)
|
self.is_async = inspect.isasyncgenfunction(fn)
|
||||||
self.event = Event()
|
self.event = Event()
|
||||||
self.state = AppState()
|
self.state = AppState()
|
||||||
self.generator = None
|
self.generator: Generator[EmitType, None, None] | None = None
|
||||||
self.model_options = model_options
|
self.model_options = model_options
|
||||||
self.algo_options = algo_options or AlgoOptions()
|
self.algo_options = algo_options or AlgoOptions()
|
||||||
self.latest_args: list[Any] = []
|
|
||||||
self.args_set = Event()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _needs_additional_inputs(self) -> bool:
|
def _needs_additional_inputs(self) -> bool:
|
||||||
@@ -168,23 +152,12 @@ class ReplyOnPause(StreamHandler):
|
|||||||
self.event.set()
|
self.event.set()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.args_set.clear()
|
super().reset()
|
||||||
self.generator = None
|
self.generator = None
|
||||||
self.event.clear()
|
self.event.clear()
|
||||||
self.state = AppState()
|
self.state = AppState()
|
||||||
|
|
||||||
def set_args(self, args: list[Any]):
|
async def async_iterate(self, generator) -> EmitType:
|
||||||
super().set_args(args)
|
|
||||||
self.args_set.set()
|
|
||||||
|
|
||||||
async def fetch_args(
|
|
||||||
self,
|
|
||||||
):
|
|
||||||
if self.channel:
|
|
||||||
self.channel.send("tick")
|
|
||||||
logger.debug("Sent tick")
|
|
||||||
|
|
||||||
async def async_iterate(self, generator) -> Any:
|
|
||||||
return await anext(generator)
|
return await anext(generator)
|
||||||
|
|
||||||
def emit(self):
|
def emit(self):
|
||||||
@@ -193,8 +166,9 @@ class ReplyOnPause(StreamHandler):
|
|||||||
else:
|
else:
|
||||||
if not self.generator:
|
if not self.generator:
|
||||||
if self._needs_additional_inputs and not self.args_set.is_set():
|
if self._needs_additional_inputs and not self.args_set.is_set():
|
||||||
asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop)
|
asyncio.run_coroutine_threadsafe(
|
||||||
self.args_set.wait()
|
self.wait_for_args(), self.loop
|
||||||
|
).result()
|
||||||
logger.debug("Creating generator")
|
logger.debug("Creating generator")
|
||||||
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
|
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
|
||||||
if self._needs_additional_inputs:
|
if self._needs_additional_inputs:
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
def send_stopword(self):
|
def send_stopword(self):
|
||||||
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
|
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
|
||||||
|
|
||||||
def determine_pause(
|
def determine_pause( # type: ignore
|
||||||
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
|
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Take in the stream, determine if a pause happened"""
|
"""Take in the stream, determine if a pause happened"""
|
||||||
@@ -128,7 +128,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.args_set.clear()
|
super().reset()
|
||||||
self.generator = None
|
self.generator = None
|
||||||
self.event.clear()
|
self.event.clear()
|
||||||
self.state = ReplyOnStopWordsState()
|
self.state = ReplyOnStopWordsState()
|
||||||
|
|||||||
@@ -218,3 +218,43 @@ def audio_to_float32(audio: tuple[int, np.ndarray]) -> np.ndarray:
|
|||||||
>>> audio_float32 = audio_to_float32(audio_tuple)
|
>>> audio_float32 = audio_to_float32(audio_tuple)
|
||||||
"""
|
"""
|
||||||
return audio[1].astype(np.float32) / 32768.0
|
return audio[1].astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_bytes_to_16bit(chunks_iterator):
|
||||||
|
leftover = b"" # Store incomplete bytes between chunks
|
||||||
|
|
||||||
|
for chunk in chunks_iterator:
|
||||||
|
# Combine with any leftover bytes from previous chunk
|
||||||
|
current_bytes = leftover + chunk
|
||||||
|
|
||||||
|
# Calculate complete samples
|
||||||
|
n_complete_samples = len(current_bytes) // 2 # int16 = 2 bytes
|
||||||
|
bytes_to_process = n_complete_samples * 2
|
||||||
|
|
||||||
|
# Split into complete samples and leftover
|
||||||
|
to_process = current_bytes[:bytes_to_process]
|
||||||
|
leftover = current_bytes[bytes_to_process:]
|
||||||
|
|
||||||
|
if to_process: # Only yield if we have complete samples
|
||||||
|
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
|
||||||
|
yield audio_array
|
||||||
|
|
||||||
|
|
||||||
|
async def async_aggregate_bytes_to_16bit(chunks_iterator):
|
||||||
|
leftover = b"" # Store incomplete bytes between chunks
|
||||||
|
|
||||||
|
async for chunk in chunks_iterator:
|
||||||
|
# Combine with any leftover bytes from previous chunk
|
||||||
|
current_bytes = leftover + chunk
|
||||||
|
|
||||||
|
# Calculate complete samples
|
||||||
|
n_complete_samples = len(current_bytes) // 2 # int16 = 2 bytes
|
||||||
|
bytes_to_process = n_complete_samples * 2
|
||||||
|
|
||||||
|
# Split into complete samples and leftover
|
||||||
|
to_process = current_bytes[:bytes_to_process]
|
||||||
|
leftover = current_bytes[bytes_to_process:]
|
||||||
|
|
||||||
|
if to_process: # Only yield if we have complete samples
|
||||||
|
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
|
||||||
|
yield audio_array
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ from typing import (
|
|||||||
Literal,
|
Literal,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
TypeAlias,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -161,7 +163,7 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
logger.debug("traceback %s", exec)
|
logger.debug("traceback %s", exec)
|
||||||
|
|
||||||
|
|
||||||
class StreamHandler(ABC):
|
class StreamHandlerBase(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||||
@@ -173,10 +175,11 @@ class StreamHandler(ABC):
|
|||||||
self.output_sample_rate = output_sample_rate
|
self.output_sample_rate = output_sample_rate
|
||||||
self.output_frame_size = output_frame_size
|
self.output_frame_size = output_frame_size
|
||||||
self.input_sample_rate = input_sample_rate
|
self.input_sample_rate = input_sample_rate
|
||||||
self.latest_args: str | list[Any] = "not_set"
|
self.latest_args: list[Any] = []
|
||||||
self._resampler = None
|
self._resampler = None
|
||||||
self._channel: DataChannel | None = None
|
self._channel: DataChannel | None = None
|
||||||
self._loop: asyncio.AbstractEventLoop
|
self._loop: asyncio.AbstractEventLoop
|
||||||
|
self.args_set = asyncio.Event()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loop(self) -> asyncio.AbstractEventLoop:
|
def loop(self) -> asyncio.AbstractEventLoop:
|
||||||
@@ -189,15 +192,30 @@ class StreamHandler(ABC):
|
|||||||
def set_channel(self, channel: DataChannel):
|
def set_channel(self, channel: DataChannel):
|
||||||
self._channel = channel
|
self._channel = channel
|
||||||
|
|
||||||
|
async def fetch_args(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
if self.channel:
|
||||||
|
self.channel.send("tick")
|
||||||
|
logger.debug("Sent tick")
|
||||||
|
|
||||||
|
async def wait_for_args(self):
|
||||||
|
await self.fetch_args()
|
||||||
|
await self.args_set.wait()
|
||||||
|
|
||||||
def set_args(self, args: list[Any]):
|
def set_args(self, args: list[Any]):
|
||||||
logger.debug("setting args in audio callback %s", args)
|
logger.debug("setting args in audio callback %s", args)
|
||||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||||
|
self.args_set.set()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.args_set.clear()
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def copy(self) -> "StreamHandler":
|
def copy(self) -> "StreamHandlerBase":
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
|
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
|
||||||
@@ -210,6 +228,17 @@ class StreamHandler(ABC):
|
|||||||
)
|
)
|
||||||
yield from self._resampler.resample(frame)
|
yield from self._resampler.resample(frame)
|
||||||
|
|
||||||
|
|
||||||
|
EmitType: TypeAlias = Union[
|
||||||
|
tuple[int, np.ndarray],
|
||||||
|
tuple[int, np.ndarray, Literal["mono", "stereo"]],
|
||||||
|
AdditionalOutputs,
|
||||||
|
tuple[tuple[int, np.ndarray], AdditionalOutputs],
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class StreamHandler(StreamHandlerBase):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||||
pass
|
pass
|
||||||
@@ -217,22 +246,32 @@ class StreamHandler(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def emit(
|
def emit(
|
||||||
self,
|
self,
|
||||||
) -> (
|
) -> EmitType:
|
||||||
tuple[int, np.ndarray]
|
|
||||||
| AdditionalOutputs
|
|
||||||
| None
|
|
||||||
| tuple[tuple[int, np.ndarray], AdditionalOutputs]
|
|
||||||
):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncStreamHandler(StreamHandlerBase):
|
||||||
|
@abstractmethod
|
||||||
|
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def emit(
|
||||||
|
self,
|
||||||
|
) -> EmitType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
StreamHandlerImpl = Union[StreamHandler, AsyncStreamHandler]
|
||||||
|
|
||||||
|
|
||||||
class AudioCallback(AudioStreamTrack):
|
class AudioCallback(AudioStreamTrack):
|
||||||
kind = "audio"
|
kind = "audio"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
track: MediaStreamTrack,
|
track: MediaStreamTrack,
|
||||||
event_handler: StreamHandler,
|
event_handler: StreamHandlerImpl,
|
||||||
channel: DataChannel | None = None,
|
channel: DataChannel | None = None,
|
||||||
set_additional_outputs: Callable | None = None,
|
set_additional_outputs: Callable | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -262,9 +301,14 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
frame = cast(AudioFrame, await self.track.recv())
|
frame = cast(AudioFrame, await self.track.recv())
|
||||||
for frame in self.event_handler.resample(frame):
|
for frame in self.event_handler.resample(frame):
|
||||||
numpy_array = frame.to_ndarray()
|
numpy_array = frame.to_ndarray()
|
||||||
await anyio.to_thread.run_sync(
|
if isinstance(self.event_handler, AsyncStreamHandler):
|
||||||
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
await self.event_handler.receive(
|
||||||
)
|
(frame.sample_rate, numpy_array)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await anyio.to_thread.run_sync(
|
||||||
|
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
||||||
|
)
|
||||||
except MediaStreamError:
|
except MediaStreamError:
|
||||||
logger.debug("MediaStreamError in process_input_frames")
|
logger.debug("MediaStreamError in process_input_frames")
|
||||||
break
|
break
|
||||||
@@ -272,9 +316,12 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
def start(self):
|
def start(self):
|
||||||
if not self.has_started:
|
if not self.has_started:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
callable = functools.partial(
|
if isinstance(self.event_handler, AsyncStreamHandler):
|
||||||
loop.run_in_executor, None, self.event_handler.emit
|
callable = self.event_handler.emit
|
||||||
)
|
else:
|
||||||
|
callable = functools.partial(
|
||||||
|
loop.run_in_executor, None, self.event_handler.emit
|
||||||
|
)
|
||||||
asyncio.create_task(self.process_input_frames())
|
asyncio.create_task(self.process_input_frames())
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
player_worker_decode(
|
player_worker_decode(
|
||||||
@@ -692,7 +739,7 @@ class WebRTC(Component):
|
|||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
fn: Callable[..., Any] | StreamHandler | None = None,
|
fn: Callable[..., Any] | StreamHandler | AsyncStreamHandler | None = None,
|
||||||
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
inputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||||
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
outputs: Block | Sequence[Block] | set[Block] | None = None,
|
||||||
js: str | None = None,
|
js: str | None = None,
|
||||||
@@ -721,7 +768,7 @@ class WebRTC(Component):
|
|||||||
if (
|
if (
|
||||||
self.mode == "send-receive"
|
self.mode == "send-receive"
|
||||||
and self.modality == "audio"
|
and self.modality == "audio"
|
||||||
and not isinstance(self.event_handler, StreamHandler)
|
and not isinstance(self.event_handler, (AsyncStreamHandler, StreamHandler))
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
|
"In the send-receive mode for audio, the event handler must be an instance of StreamHandler."
|
||||||
@@ -840,6 +887,8 @@ class WebRTC(Component):
|
|||||||
event_handler=handler,
|
event_handler=handler,
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Modality must be either video or audio")
|
||||||
self.connections[body["webrtc_id"]] = cb
|
self.connections[body["webrtc_id"]] = cb
|
||||||
if body["webrtc_id"] in self.data_channels:
|
if body["webrtc_id"] in self.data_channels:
|
||||||
self.connections[body["webrtc_id"]].set_channel(
|
self.connections[body["webrtc_id"]].set_channel(
|
||||||
@@ -862,6 +911,8 @@ class WebRTC(Component):
|
|||||||
cast(Callable, self.event_handler),
|
cast(Callable, self.event_handler),
|
||||||
set_additional_outputs=set_outputs,
|
set_additional_outputs=set_outputs,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Modality must be either video or audio")
|
||||||
|
|
||||||
logger.debug("Adding track to peer connection %s", cb)
|
logger.debug("Adding track to peer connection %s", cb)
|
||||||
pc.addTrack(cb)
|
pc.addTrack(cb)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "gradio_webrtc"
|
name = "gradio_webrtc"
|
||||||
version = "0.0.23"
|
version = "0.0.24"
|
||||||
description = "Stream images in realtime with webrtc"
|
description = "Stream images in realtime with webrtc"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "apache-2.0"
|
license = "apache-2.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user