Video Bugfix + generator (#96)

* Code

* Fix demo

* move to init

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
Freddy Boulton
2025-02-27 12:30:19 -05:00
committed by GitHub
parent 43e42c1b22
commit 9cc0278985
13 changed files with 73 additions and 41 deletions

View File

@@ -25,6 +25,7 @@ from .utils import (
audio_to_bytes,
audio_to_file,
audio_to_float32,
audio_to_int16,
wait_for_item,
)
from .webrtc import (
@@ -43,6 +44,7 @@ __all__ = [
"audio_to_bytes",
"audio_to_file",
"audio_to_float32",
"audio_to_int16",
"get_hf_turn_credentials",
"get_twilio_turn_credentials",
"get_turn_credentials",

View File

@@ -63,7 +63,7 @@ class AppState:
ReplyFnGenerator = (
Callable[
[tuple[int, NDArray[np.int16]], list[dict[Any, Any]]],
[tuple[int, NDArray[np.int16]], Any],
Generator[EmitType, None, None],
]
| Callable[
@@ -75,7 +75,7 @@ ReplyFnGenerator = (
AsyncGenerator[EmitType, None],
]
| Callable[
[tuple[int, NDArray[np.int16]], list[dict[Any, Any]]],
[tuple[int, NDArray[np.int16]], Any],
AsyncGenerator[EmitType, None],
]
)

View File

@@ -62,6 +62,7 @@ class Stream(WebRTCConnectionMixin):
additional_outputs: list[Component] | None = None,
ui_args: UIArgs | None = None,
):
WebRTCConnectionMixin.__init__(self)
self.mode = mode
self.modality = modality
self.rtp_params = rtp_params

View File

@@ -294,6 +294,40 @@ def audio_to_float32(
return audio[1].astype(np.float32) / 32768.0
def audio_to_int16(
audio: tuple[int, NDArray[np.int16 | np.float32]],
) -> NDArray[np.int16]:
"""
Convert an audio tuple containing sample rate and numpy array data to int16.
Parameters
----------
audio : tuple[int, np.ndarray]
A tuple containing:
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
Returns
-------
np.ndarray
The audio data as a numpy array with dtype int16
Example
-------
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data)
>>> audio_int16 = audio_to_int16(audio_tuple)
"""
if audio[1].dtype == np.int16:
return audio[1]
elif audio[1].dtype == np.float32:
# Convert float32 to int16 by scaling to the int16 range
return (audio[1] * 32767.0).astype(np.int16)
else:
raise TypeError(f"Unsupported audio data type: {audio[1].dtype}")
def aggregate_bytes_to_16bit(chunks_iterator):
"""
Aggregate bytes to 16-bit audio samples.

View File

@@ -122,6 +122,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
button_labels: Text to display on the audio or video start, stop, waiting buttons. Dict with keys "start", "stop", "waiting" mapping to the text to display on the buttons.
icon_radius: Border radius of the icon button expressed as a percentage of the button size. Default is 50%
"""
WebRTCConnectionMixin.__init__(self)
self.time_limit = time_limit
self.height = height
self.width = width
@@ -230,15 +231,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
inputs = [inputs]
inputs = list(inputs)
def handler(webrtc_id: str, *args):
if self.additional_outputs[webrtc_id].queue.qsize() > 0:
next_outputs = self.additional_outputs[webrtc_id].queue.get_nowait()
return fn(*args, *next_outputs.args) # type: ignore
return (
tuple([None for _ in range(len(outputs))])
if isinstance(outputs, Iterable)
else None
)
async def handler(webrtc_id: str, *args):
async for next_outputs in self.output_stream(webrtc_id):
yield fn(*args, *next_outputs.args) # type: ignore
return self.state_change( # type: ignore
fn=handler,
@@ -247,9 +242,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
js=js,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
show_progress=show_progress,
show_progress="minimal",
queue=queue,
trigger_mode="multiple",
trigger_mode="once",
)
def stream(

View File

@@ -35,7 +35,6 @@ from fastrtc.tracks import (
)
from fastrtc.utils import (
AdditionalOutputs,
DataChannel,
create_message,
webrtc_error_handler,
)
@@ -64,18 +63,20 @@ class OutputQueue:
class WebRTCConnectionMixin:
pcs: set[RTCPeerConnection] = set([])
relay = MediaRelay()
connections: dict[str, list[Track]] = defaultdict(list)
data_channels: dict[str, DataChannel] = {}
additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue)
handlers: dict[str, HandlerType | Callable] = {}
connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event)
concurrency_limit: int | float
event_handler: HandlerType
time_limit: float | int | None
modality: Literal["video", "audio", "audio-video"]
mode: Literal["send", "receive", "send-receive"]
def __init__(self):
self.pcs = set([])
self.relay = MediaRelay()
self.connections = defaultdict(list)
self.data_channels = {}
self.additional_outputs = defaultdict(OutputQueue)
self.handlers = {}
self.connection_timeouts = defaultdict(asyncio.Event)
# These attributes should be set by subclasses:
self.concurrency_limit: int | float | None
self.event_handler: HandlerType | None
self.time_limit: float | None
self.modality: Literal["video", "audio", "audio-video"]
self.mode: Literal["send", "receive", "send-receive"]
@staticmethod
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):

View File

@@ -1,7 +1,3 @@
import subprocess
subprocess.run(["pip", "install", "fastrtc==0.0.4.post1"])
import asyncio
import base64
import os
@@ -80,12 +76,11 @@ class PhonicHandler(AsyncStreamHandler):
return super().shutdown()
def add_to_chatbot(state, chatbot, message):
state.append(message)
return state, gr.skip()
def add_to_chatbot(chatbot, message):
chatbot.append(message)
return chatbot
state = gr.State(value=[])
chatbot = gr.Chatbot(type="messages", value=[])
stream = Stream(
handler=PhonicHandler(),
@@ -99,7 +94,7 @@ stream = Stream(
info="Select a voice from the dropdown",
)
],
additional_outputs=[state, chatbot],
additional_outputs=[chatbot],
additional_outputs_handler=add_to_chatbot,
ui_args={
"title": "Phonic Chat (Powered by FastRTC ⚡️)",
@@ -109,8 +104,8 @@ stream = Stream(
time_limit=90 if get_space() else None,
)
with stream.ui:
state.change(lambda s: s, inputs=state, outputs=chatbot)
# with stream.ui:
# state.change(lambda s: s, inputs=state, outputs=chatbot)
if __name__ == "__main__":
if (mode := os.getenv("MODE")) == "UI":

View File

@@ -38,6 +38,7 @@ def response(
):
gradio_chatbot = gradio_chatbot or []
conversation_state = conversation_state or []
print("chatbot", gradio_chatbot)
text = stt_model.stt(audio)
sample_rate, array = audio
@@ -47,7 +48,6 @@ def response(
yield AdditionalOutputs(gradio_chatbot, conversation_state)
conversation_state.append({"role": "user", "content": text})
request = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
messages=conversation_state, # type: ignore

View File

@@ -10,7 +10,7 @@ from fastrtc import Stream
import gradio as gr
import numpy as np
def detection(image):
def detection(image, slider):
return np.flip(image, axis=0)
stream = Stream(

View File

@@ -148,6 +148,7 @@
{icon}
{icon_button_color}
{pulse_color}
{icon_radius}
{button_labels}
on:clear={() => gradio.dispatch("clear")}
on:play={() => gradio.dispatch("play")}

View File

@@ -30,6 +30,7 @@
export let icon: string | undefined | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
export let icon_radius: number = 50;
const dispatch = createEventDispatcher<{
change: FileData | null;
@@ -62,8 +63,8 @@
{icon}
{icon_button_color}
{pulse_color}
{button_labels}
{icon_radius}
{button_labels}
on:error
on:start_recording
on:stop_recording

View File

@@ -33,6 +33,7 @@
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
export let icon: string | undefined | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let icon_radius: number = 50;
export let pulse_color: string = "var(--color-accent)";
export let button_labels: { start: string; stop: string; waiting: string };
@@ -242,6 +243,7 @@
icon={icon || Mic}
{icon_button_color}
{pulse_color}
{icon_radius}
/>
</div>
{/if}

View File

@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
[project]
name = "fastrtc"
version = "0.0.8post1"
version = "0.0.9"
description = "The realtime communication library for Python"
readme = "README.md"
license = "apache-2.0"