mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Video Bugfix + generator (#96)
* Code * Fix demo * move to init --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
@@ -25,6 +25,7 @@ from .utils import (
|
||||
audio_to_bytes,
|
||||
audio_to_file,
|
||||
audio_to_float32,
|
||||
audio_to_int16,
|
||||
wait_for_item,
|
||||
)
|
||||
from .webrtc import (
|
||||
@@ -43,6 +44,7 @@ __all__ = [
|
||||
"audio_to_bytes",
|
||||
"audio_to_file",
|
||||
"audio_to_float32",
|
||||
"audio_to_int16",
|
||||
"get_hf_turn_credentials",
|
||||
"get_twilio_turn_credentials",
|
||||
"get_turn_credentials",
|
||||
|
||||
@@ -63,7 +63,7 @@ class AppState:
|
||||
|
||||
ReplyFnGenerator = (
|
||||
Callable[
|
||||
[tuple[int, NDArray[np.int16]], list[dict[Any, Any]]],
|
||||
[tuple[int, NDArray[np.int16]], Any],
|
||||
Generator[EmitType, None, None],
|
||||
]
|
||||
| Callable[
|
||||
@@ -75,7 +75,7 @@ ReplyFnGenerator = (
|
||||
AsyncGenerator[EmitType, None],
|
||||
]
|
||||
| Callable[
|
||||
[tuple[int, NDArray[np.int16]], list[dict[Any, Any]]],
|
||||
[tuple[int, NDArray[np.int16]], Any],
|
||||
AsyncGenerator[EmitType, None],
|
||||
]
|
||||
)
|
||||
|
||||
@@ -62,6 +62,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
additional_outputs: list[Component] | None = None,
|
||||
ui_args: UIArgs | None = None,
|
||||
):
|
||||
WebRTCConnectionMixin.__init__(self)
|
||||
self.mode = mode
|
||||
self.modality = modality
|
||||
self.rtp_params = rtp_params
|
||||
|
||||
@@ -294,6 +294,40 @@ def audio_to_float32(
|
||||
return audio[1].astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def audio_to_int16(
|
||||
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
||||
) -> NDArray[np.int16]:
|
||||
"""
|
||||
Convert an audio tuple containing sample rate and numpy array data to int16.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : tuple[int, np.ndarray]
|
||||
A tuple containing:
|
||||
- sample_rate (int): The audio sample rate in Hz
|
||||
- data (np.ndarray): The audio data as a numpy array
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The audio data as a numpy array with dtype int16
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> sample_rate = 44100
|
||||
>>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples
|
||||
>>> audio_tuple = (sample_rate, audio_data)
|
||||
>>> audio_int16 = audio_to_int16(audio_tuple)
|
||||
"""
|
||||
if audio[1].dtype == np.int16:
|
||||
return audio[1]
|
||||
elif audio[1].dtype == np.float32:
|
||||
# Convert float32 to int16 by scaling to the int16 range
|
||||
return (audio[1] * 32767.0).astype(np.int16)
|
||||
else:
|
||||
raise TypeError(f"Unsupported audio data type: {audio[1].dtype}")
|
||||
|
||||
|
||||
def aggregate_bytes_to_16bit(chunks_iterator):
|
||||
"""
|
||||
Aggregate bytes to 16-bit audio samples.
|
||||
|
||||
@@ -122,6 +122,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
button_labels: Text to display on the audio or video start, stop, waiting buttons. Dict with keys "start", "stop", "waiting" mapping to the text to display on the buttons.
|
||||
icon_radius: Border radius of the icon button expressed as a percentage of the button size. Default is 50%
|
||||
"""
|
||||
WebRTCConnectionMixin.__init__(self)
|
||||
self.time_limit = time_limit
|
||||
self.height = height
|
||||
self.width = width
|
||||
@@ -230,15 +231,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
inputs = [inputs]
|
||||
inputs = list(inputs)
|
||||
|
||||
def handler(webrtc_id: str, *args):
|
||||
if self.additional_outputs[webrtc_id].queue.qsize() > 0:
|
||||
next_outputs = self.additional_outputs[webrtc_id].queue.get_nowait()
|
||||
return fn(*args, *next_outputs.args) # type: ignore
|
||||
return (
|
||||
tuple([None for _ in range(len(outputs))])
|
||||
if isinstance(outputs, Iterable)
|
||||
else None
|
||||
)
|
||||
async def handler(webrtc_id: str, *args):
|
||||
async for next_outputs in self.output_stream(webrtc_id):
|
||||
yield fn(*args, *next_outputs.args) # type: ignore
|
||||
|
||||
return self.state_change( # type: ignore
|
||||
fn=handler,
|
||||
@@ -247,9 +242,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
js=js,
|
||||
concurrency_limit=concurrency_limit,
|
||||
concurrency_id=concurrency_id,
|
||||
show_progress=show_progress,
|
||||
show_progress="minimal",
|
||||
queue=queue,
|
||||
trigger_mode="multiple",
|
||||
trigger_mode="once",
|
||||
)
|
||||
|
||||
def stream(
|
||||
|
||||
@@ -35,7 +35,6 @@ from fastrtc.tracks import (
|
||||
)
|
||||
from fastrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
DataChannel,
|
||||
create_message,
|
||||
webrtc_error_handler,
|
||||
)
|
||||
@@ -64,18 +63,20 @@ class OutputQueue:
|
||||
|
||||
|
||||
class WebRTCConnectionMixin:
|
||||
pcs: set[RTCPeerConnection] = set([])
|
||||
relay = MediaRelay()
|
||||
connections: dict[str, list[Track]] = defaultdict(list)
|
||||
data_channels: dict[str, DataChannel] = {}
|
||||
additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue)
|
||||
handlers: dict[str, HandlerType | Callable] = {}
|
||||
connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event)
|
||||
concurrency_limit: int | float
|
||||
event_handler: HandlerType
|
||||
time_limit: float | int | None
|
||||
modality: Literal["video", "audio", "audio-video"]
|
||||
mode: Literal["send", "receive", "send-receive"]
|
||||
def __init__(self):
|
||||
self.pcs = set([])
|
||||
self.relay = MediaRelay()
|
||||
self.connections = defaultdict(list)
|
||||
self.data_channels = {}
|
||||
self.additional_outputs = defaultdict(OutputQueue)
|
||||
self.handlers = {}
|
||||
self.connection_timeouts = defaultdict(asyncio.Event)
|
||||
# These attributes should be set by subclasses:
|
||||
self.concurrency_limit: int | float | None
|
||||
self.event_handler: HandlerType | None
|
||||
self.time_limit: float | None
|
||||
self.modality: Literal["video", "audio", "audio-video"]
|
||||
self.mode: Literal["send", "receive", "send-receive"]
|
||||
|
||||
@staticmethod
|
||||
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
||||
|
||||
Reference in New Issue
Block a user