Video Bugfix + generator (#96)

* Code

* Fix demo

* move to init

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
Freddy Boulton
2025-02-27 12:30:19 -05:00
committed by GitHub
parent 43e42c1b22
commit 9cc0278985
13 changed files with 73 additions and 41 deletions

View File

@@ -25,6 +25,7 @@ from .utils import (
audio_to_bytes, audio_to_bytes,
audio_to_file, audio_to_file,
audio_to_float32, audio_to_float32,
audio_to_int16,
wait_for_item, wait_for_item,
) )
from .webrtc import ( from .webrtc import (
@@ -43,6 +44,7 @@ __all__ = [
"audio_to_bytes", "audio_to_bytes",
"audio_to_file", "audio_to_file",
"audio_to_float32", "audio_to_float32",
"audio_to_int16",
"get_hf_turn_credentials", "get_hf_turn_credentials",
"get_twilio_turn_credentials", "get_twilio_turn_credentials",
"get_turn_credentials", "get_turn_credentials",

View File

@@ -63,7 +63,7 @@ class AppState:
ReplyFnGenerator = ( ReplyFnGenerator = (
Callable[ Callable[
[tuple[int, NDArray[np.int16]], list[dict[Any, Any]]], [tuple[int, NDArray[np.int16]], Any],
Generator[EmitType, None, None], Generator[EmitType, None, None],
] ]
| Callable[ | Callable[
@@ -75,7 +75,7 @@ ReplyFnGenerator = (
AsyncGenerator[EmitType, None], AsyncGenerator[EmitType, None],
] ]
| Callable[ | Callable[
[tuple[int, NDArray[np.int16]], list[dict[Any, Any]]], [tuple[int, NDArray[np.int16]], Any],
AsyncGenerator[EmitType, None], AsyncGenerator[EmitType, None],
] ]
) )

View File

@@ -62,6 +62,7 @@ class Stream(WebRTCConnectionMixin):
additional_outputs: list[Component] | None = None, additional_outputs: list[Component] | None = None,
ui_args: UIArgs | None = None, ui_args: UIArgs | None = None,
): ):
WebRTCConnectionMixin.__init__(self)
self.mode = mode self.mode = mode
self.modality = modality self.modality = modality
self.rtp_params = rtp_params self.rtp_params = rtp_params

View File

@@ -294,6 +294,40 @@ def audio_to_float32(
return audio[1].astype(np.float32) / 32768.0 return audio[1].astype(np.float32) / 32768.0
def audio_to_int16(
audio: tuple[int, NDArray[np.int16 | np.float32]],
) -> NDArray[np.int16]:
"""
Convert an audio tuple containing sample rate and numpy array data to int16.
Parameters
----------
audio : tuple[int, np.ndarray]
A tuple containing:
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
Returns
-------
np.ndarray
The audio data as a numpy array with dtype int16
Example
-------
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data)
>>> audio_int16 = audio_to_int16(audio_tuple)
"""
if audio[1].dtype == np.int16:
return audio[1]
elif audio[1].dtype == np.float32:
# Convert float32 to int16 by scaling to the int16 range
return (audio[1] * 32767.0).astype(np.int16)
else:
raise TypeError(f"Unsupported audio data type: {audio[1].dtype}")
def aggregate_bytes_to_16bit(chunks_iterator): def aggregate_bytes_to_16bit(chunks_iterator):
""" """
Aggregate bytes to 16-bit audio samples. Aggregate bytes to 16-bit audio samples.

View File

@@ -122,6 +122,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
button_labels: Text to display on the audio or video start, stop, waiting buttons. Dict with keys "start", "stop", "waiting" mapping to the text to display on the buttons. button_labels: Text to display on the audio or video start, stop, waiting buttons. Dict with keys "start", "stop", "waiting" mapping to the text to display on the buttons.
icon_radius: Border radius of the icon button expressed as a percentage of the button size. Default is 50% icon_radius: Border radius of the icon button expressed as a percentage of the button size. Default is 50%
""" """
WebRTCConnectionMixin.__init__(self)
self.time_limit = time_limit self.time_limit = time_limit
self.height = height self.height = height
self.width = width self.width = width
@@ -230,15 +231,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
inputs = [inputs] inputs = [inputs]
inputs = list(inputs) inputs = list(inputs)
def handler(webrtc_id: str, *args): async def handler(webrtc_id: str, *args):
if self.additional_outputs[webrtc_id].queue.qsize() > 0: async for next_outputs in self.output_stream(webrtc_id):
next_outputs = self.additional_outputs[webrtc_id].queue.get_nowait() yield fn(*args, *next_outputs.args) # type: ignore
return fn(*args, *next_outputs.args) # type: ignore
return (
tuple([None for _ in range(len(outputs))])
if isinstance(outputs, Iterable)
else None
)
return self.state_change( # type: ignore return self.state_change( # type: ignore
fn=handler, fn=handler,
@@ -247,9 +242,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
js=js, js=js,
concurrency_limit=concurrency_limit, concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id, concurrency_id=concurrency_id,
show_progress=show_progress, show_progress="minimal",
queue=queue, queue=queue,
trigger_mode="multiple", trigger_mode="once",
) )
def stream( def stream(

View File

@@ -35,7 +35,6 @@ from fastrtc.tracks import (
) )
from fastrtc.utils import ( from fastrtc.utils import (
AdditionalOutputs, AdditionalOutputs,
DataChannel,
create_message, create_message,
webrtc_error_handler, webrtc_error_handler,
) )
@@ -64,18 +63,20 @@ class OutputQueue:
class WebRTCConnectionMixin: class WebRTCConnectionMixin:
pcs: set[RTCPeerConnection] = set([]) def __init__(self):
relay = MediaRelay() self.pcs = set([])
connections: dict[str, list[Track]] = defaultdict(list) self.relay = MediaRelay()
data_channels: dict[str, DataChannel] = {} self.connections = defaultdict(list)
additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue) self.data_channels = {}
handlers: dict[str, HandlerType | Callable] = {} self.additional_outputs = defaultdict(OutputQueue)
connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event) self.handlers = {}
concurrency_limit: int | float self.connection_timeouts = defaultdict(asyncio.Event)
event_handler: HandlerType # These attributes should be set by subclasses:
time_limit: float | int | None self.concurrency_limit: int | float | None
modality: Literal["video", "audio", "audio-video"] self.event_handler: HandlerType | None
mode: Literal["send", "receive", "send-receive"] self.time_limit: float | None
self.modality: Literal["video", "audio", "audio-video"]
self.mode: Literal["send", "receive", "send-receive"]
@staticmethod @staticmethod
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float): async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):

View File

@@ -1,7 +1,3 @@
import subprocess
subprocess.run(["pip", "install", "fastrtc==0.0.4.post1"])
import asyncio import asyncio
import base64 import base64
import os import os
@@ -80,12 +76,11 @@ class PhonicHandler(AsyncStreamHandler):
return super().shutdown() return super().shutdown()
def add_to_chatbot(state, chatbot, message): def add_to_chatbot(chatbot, message):
state.append(message) chatbot.append(message)
return state, gr.skip() return chatbot
state = gr.State(value=[])
chatbot = gr.Chatbot(type="messages", value=[]) chatbot = gr.Chatbot(type="messages", value=[])
stream = Stream( stream = Stream(
handler=PhonicHandler(), handler=PhonicHandler(),
@@ -99,7 +94,7 @@ stream = Stream(
info="Select a voice from the dropdown", info="Select a voice from the dropdown",
) )
], ],
additional_outputs=[state, chatbot], additional_outputs=[chatbot],
additional_outputs_handler=add_to_chatbot, additional_outputs_handler=add_to_chatbot,
ui_args={ ui_args={
"title": "Phonic Chat (Powered by FastRTC ⚡️)", "title": "Phonic Chat (Powered by FastRTC ⚡️)",
@@ -109,8 +104,8 @@ stream = Stream(
time_limit=90 if get_space() else None, time_limit=90 if get_space() else None,
) )
with stream.ui: # with stream.ui:
state.change(lambda s: s, inputs=state, outputs=chatbot) # state.change(lambda s: s, inputs=state, outputs=chatbot)
if __name__ == "__main__": if __name__ == "__main__":
if (mode := os.getenv("MODE")) == "UI": if (mode := os.getenv("MODE")) == "UI":

View File

@@ -38,6 +38,7 @@ def response(
): ):
gradio_chatbot = gradio_chatbot or [] gradio_chatbot = gradio_chatbot or []
conversation_state = conversation_state or [] conversation_state = conversation_state or []
print("chatbot", gradio_chatbot)
text = stt_model.stt(audio) text = stt_model.stt(audio)
sample_rate, array = audio sample_rate, array = audio
@@ -47,7 +48,6 @@ def response(
yield AdditionalOutputs(gradio_chatbot, conversation_state) yield AdditionalOutputs(gradio_chatbot, conversation_state)
conversation_state.append({"role": "user", "content": text}) conversation_state.append({"role": "user", "content": text})
request = client.chat.completions.create( request = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct", model="meta-llama/Llama-3.2-3B-Instruct",
messages=conversation_state, # type: ignore messages=conversation_state, # type: ignore

View File

@@ -10,7 +10,7 @@ from fastrtc import Stream
import gradio as gr import gradio as gr
import numpy as np import numpy as np
def detection(image): def detection(image, slider):
return np.flip(image, axis=0) return np.flip(image, axis=0)
stream = Stream( stream = Stream(

View File

@@ -148,6 +148,7 @@
{icon} {icon}
{icon_button_color} {icon_button_color}
{pulse_color} {pulse_color}
{icon_radius}
{button_labels} {button_labels}
on:clear={() => gradio.dispatch("clear")} on:clear={() => gradio.dispatch("clear")}
on:play={() => gradio.dispatch("play")} on:play={() => gradio.dispatch("play")}

View File

@@ -30,6 +30,7 @@
export let icon: string | undefined | ComponentType = undefined; export let icon: string | undefined | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)"; export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)"; export let pulse_color: string = "var(--color-accent)";
export let icon_radius: number = 50;
const dispatch = createEventDispatcher<{ const dispatch = createEventDispatcher<{
change: FileData | null; change: FileData | null;
@@ -62,8 +63,8 @@
{icon} {icon}
{icon_button_color} {icon_button_color}
{pulse_color} {pulse_color}
{button_labels}
{icon_radius} {icon_radius}
{button_labels}
on:error on:error
on:start_recording on:start_recording
on:stop_recording on:stop_recording

View File

@@ -33,6 +33,7 @@
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters; export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
export let icon: string | undefined | ComponentType = undefined; export let icon: string | undefined | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)"; export let icon_button_color: string = "var(--color-accent)";
export let icon_radius: number = 50;
export let pulse_color: string = "var(--color-accent)"; export let pulse_color: string = "var(--color-accent)";
export let button_labels: { start: string; stop: string; waiting: string }; export let button_labels: { start: string; stop: string; waiting: string };
@@ -242,6 +243,7 @@
icon={icon || Mic} icon={icon || Mic}
{icon_button_color} {icon_button_color}
{pulse_color} {pulse_color}
{icon_radius}
/> />
</div> </div>
{/if} {/if}

View File

@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "fastrtc" name = "fastrtc"
version = "0.0.8post1" version = "0.0.9"
description = "The realtime communication library for Python" description = "The realtime communication library for Python"
readme = "README.md" readme = "README.md"
license = "apache-2.0" license = "apache-2.0"