Video Bugfix + generator (#96)

* Code

* Fix demo

* move to init

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
Freddy Boulton
2025-02-27 12:30:19 -05:00
committed by GitHub
parent 43e42c1b22
commit 9cc0278985
13 changed files with 73 additions and 41 deletions

View File

@@ -25,6 +25,7 @@ from .utils import (
audio_to_bytes,
audio_to_file,
audio_to_float32,
audio_to_int16,
wait_for_item,
)
from .webrtc import (
@@ -43,6 +44,7 @@ __all__ = [
"audio_to_bytes",
"audio_to_file",
"audio_to_float32",
"audio_to_int16",
"get_hf_turn_credentials",
"get_twilio_turn_credentials",
"get_turn_credentials",

View File

@@ -63,7 +63,7 @@ class AppState:
ReplyFnGenerator = (
Callable[
[tuple[int, NDArray[np.int16]], list[dict[Any, Any]]],
[tuple[int, NDArray[np.int16]], Any],
Generator[EmitType, None, None],
]
| Callable[
@@ -75,7 +75,7 @@ ReplyFnGenerator = (
AsyncGenerator[EmitType, None],
]
| Callable[
[tuple[int, NDArray[np.int16]], list[dict[Any, Any]]],
[tuple[int, NDArray[np.int16]], Any],
AsyncGenerator[EmitType, None],
]
)

View File

@@ -62,6 +62,7 @@ class Stream(WebRTCConnectionMixin):
additional_outputs: list[Component] | None = None,
ui_args: UIArgs | None = None,
):
WebRTCConnectionMixin.__init__(self)
self.mode = mode
self.modality = modality
self.rtp_params = rtp_params

View File

@@ -294,6 +294,40 @@ def audio_to_float32(
return audio[1].astype(np.float32) / 32768.0
def audio_to_int16(
audio: tuple[int, NDArray[np.int16 | np.float32]],
) -> NDArray[np.int16]:
"""
Convert an audio tuple containing sample rate and numpy array data to int16.
Parameters
----------
audio : tuple[int, np.ndarray]
A tuple containing:
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
Returns
-------
np.ndarray
The audio data as a numpy array with dtype int16
Example
-------
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data)
>>> audio_int16 = audio_to_int16(audio_tuple)
"""
if audio[1].dtype == np.int16:
return audio[1]
elif audio[1].dtype == np.float32:
# Convert float32 to int16 by scaling to the int16 range
return (audio[1] * 32767.0).astype(np.int16)
else:
raise TypeError(f"Unsupported audio data type: {audio[1].dtype}")
def aggregate_bytes_to_16bit(chunks_iterator):
"""
Aggregate bytes to 16-bit audio samples.

View File

@@ -122,6 +122,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
button_labels: Text to display on the audio or video start, stop, waiting buttons. Dict with keys "start", "stop", "waiting" mapping to the text to display on the buttons.
icon_radius: Border radius of the icon button expressed as a percentage of the button size. Default is 50%
"""
WebRTCConnectionMixin.__init__(self)
self.time_limit = time_limit
self.height = height
self.width = width
@@ -230,15 +231,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
inputs = [inputs]
inputs = list(inputs)
def handler(webrtc_id: str, *args):
if self.additional_outputs[webrtc_id].queue.qsize() > 0:
next_outputs = self.additional_outputs[webrtc_id].queue.get_nowait()
return fn(*args, *next_outputs.args) # type: ignore
return (
tuple([None for _ in range(len(outputs))])
if isinstance(outputs, Iterable)
else None
)
async def handler(webrtc_id: str, *args):
async for next_outputs in self.output_stream(webrtc_id):
yield fn(*args, *next_outputs.args) # type: ignore
return self.state_change( # type: ignore
fn=handler,
@@ -247,9 +242,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
js=js,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
show_progress=show_progress,
show_progress="minimal",
queue=queue,
trigger_mode="multiple",
trigger_mode="once",
)
def stream(

View File

@@ -35,7 +35,6 @@ from fastrtc.tracks import (
)
from fastrtc.utils import (
AdditionalOutputs,
DataChannel,
create_message,
webrtc_error_handler,
)
@@ -64,18 +63,20 @@ class OutputQueue:
class WebRTCConnectionMixin:
pcs: set[RTCPeerConnection] = set([])
relay = MediaRelay()
connections: dict[str, list[Track]] = defaultdict(list)
data_channels: dict[str, DataChannel] = {}
additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue)
handlers: dict[str, HandlerType | Callable] = {}
connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event)
concurrency_limit: int | float
event_handler: HandlerType
time_limit: float | int | None
modality: Literal["video", "audio", "audio-video"]
mode: Literal["send", "receive", "send-receive"]
def __init__(self):
self.pcs = set([])
self.relay = MediaRelay()
self.connections = defaultdict(list)
self.data_channels = {}
self.additional_outputs = defaultdict(OutputQueue)
self.handlers = {}
self.connection_timeouts = defaultdict(asyncio.Event)
# These attributes should be set by subclasses:
self.concurrency_limit: int | float | None
self.event_handler: HandlerType | None
self.time_limit: float | None
self.modality: Literal["video", "audio", "audio-video"]
self.mode: Literal["send", "receive", "send-receive"]
@staticmethod
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):