From 9cc027898577b2cd25ad217bfef41e9553d82711 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Thu, 27 Feb 2025 12:30:19 -0500 Subject: [PATCH] Video Bugfix + generator (#96) * Code * Fix demo * move to init --------- Co-authored-by: Freddy Boulton --- backend/fastrtc/__init__.py | 2 ++ backend/fastrtc/reply_on_pause.py | 4 +-- backend/fastrtc/stream.py | 1 + backend/fastrtc/utils.py | 34 ++++++++++++++++++++++ backend/fastrtc/webrtc.py | 17 ++++------- backend/fastrtc/webrtc_connection_mixin.py | 27 ++++++++--------- demo/phonic_chat/app.py | 17 ++++------- demo/talk_to_sambanova/app.py | 2 +- docs/userguide/streams.md | 2 +- frontend/Index.svelte | 1 + frontend/shared/InteractiveVideo.svelte | 3 +- frontend/shared/Webcam.svelte | 2 ++ pyproject.toml | 2 +- 13 files changed, 73 insertions(+), 41 deletions(-) diff --git a/backend/fastrtc/__init__.py b/backend/fastrtc/__init__.py index 6068799..d8b9cf7 100644 --- a/backend/fastrtc/__init__.py +++ b/backend/fastrtc/__init__.py @@ -25,6 +25,7 @@ from .utils import ( audio_to_bytes, audio_to_file, audio_to_float32, + audio_to_int16, wait_for_item, ) from .webrtc import ( @@ -43,6 +44,7 @@ __all__ = [ "audio_to_bytes", "audio_to_file", "audio_to_float32", + "audio_to_int16", "get_hf_turn_credentials", "get_twilio_turn_credentials", "get_turn_credentials", diff --git a/backend/fastrtc/reply_on_pause.py b/backend/fastrtc/reply_on_pause.py index 7e93b57..23dc0b7 100644 --- a/backend/fastrtc/reply_on_pause.py +++ b/backend/fastrtc/reply_on_pause.py @@ -63,7 +63,7 @@ class AppState: ReplyFnGenerator = ( Callable[ - [tuple[int, NDArray[np.int16]], list[dict[Any, Any]]], + [tuple[int, NDArray[np.int16]], Any], Generator[EmitType, None, None], ] | Callable[ @@ -75,7 +75,7 @@ ReplyFnGenerator = ( AsyncGenerator[EmitType, None], ] | Callable[ - [tuple[int, NDArray[np.int16]], list[dict[Any, Any]]], + [tuple[int, NDArray[np.int16]], Any], AsyncGenerator[EmitType, None], ] ) diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py index 96966b0..fca6630 100644 --- a/backend/fastrtc/stream.py +++ b/backend/fastrtc/stream.py @@ -62,6 +62,7 @@ class Stream(WebRTCConnectionMixin): additional_outputs: list[Component] | None = None, ui_args: UIArgs | None = None, ): + WebRTCConnectionMixin.__init__(self) self.mode = mode self.modality = modality self.rtp_params = rtp_params diff --git a/backend/fastrtc/utils.py b/backend/fastrtc/utils.py index e11a8cc..52e0de2 100644 --- a/backend/fastrtc/utils.py +++ b/backend/fastrtc/utils.py @@ -294,6 +294,40 @@ def audio_to_float32( return audio[1].astype(np.float32) / 32768.0 +def audio_to_int16( + audio: tuple[int, NDArray[np.int16 | np.float32]], +) -> NDArray[np.int16]: + """ + Convert an audio tuple containing sample rate and numpy array data to int16. + + Parameters + ---------- + audio : tuple[int, np.ndarray] + A tuple containing: + - sample_rate (int): The audio sample rate in Hz + - data (np.ndarray): The audio data as a numpy array + + Returns + ------- + np.ndarray + The audio data as a numpy array with dtype int16 + + Example + ------- + >>> sample_rate = 44100 + >>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples + >>> audio_tuple = (sample_rate, audio_data) + >>> audio_int16 = audio_to_int16(audio_tuple) + """ + if audio[1].dtype == np.int16: + return audio[1] + elif audio[1].dtype == np.float32: + # Convert float32 to int16 by scaling to the int16 range + return (audio[1] * 32767.0).astype(np.int16) + else: + raise TypeError(f"Unsupported audio data type: {audio[1].dtype}") + + def aggregate_bytes_to_16bit(chunks_iterator): """ Aggregate bytes to 16-bit audio samples. diff --git a/backend/fastrtc/webrtc.py b/backend/fastrtc/webrtc.py index 0129c92..4374e90 100644 --- a/backend/fastrtc/webrtc.py +++ b/backend/fastrtc/webrtc.py @@ -122,6 +122,7 @@ class WebRTC(Component, WebRTCConnectionMixin): button_labels: Text to display on the audio or video start, stop, waiting buttons. Dict with keys "start", "stop", "waiting" mapping to the text to display on the buttons. icon_radius: Border radius of the icon button expressed as a percentage of the button size. Default is 50% """ + WebRTCConnectionMixin.__init__(self) self.time_limit = time_limit self.height = height self.width = width @@ -230,15 +231,9 @@ class WebRTC(Component, WebRTCConnectionMixin): inputs = [inputs] inputs = list(inputs) - def handler(webrtc_id: str, *args): - if self.additional_outputs[webrtc_id].queue.qsize() > 0: - next_outputs = self.additional_outputs[webrtc_id].queue.get_nowait() - return fn(*args, *next_outputs.args) # type: ignore - return ( - tuple([None for _ in range(len(outputs))]) - if isinstance(outputs, Iterable) - else None - ) + async def handler(webrtc_id: str, *args): + async for next_outputs in self.output_stream(webrtc_id): + yield fn(*args, *next_outputs.args) # type: ignore return self.state_change( # type: ignore fn=handler, @@ -247,9 +242,9 @@ class WebRTC(Component, WebRTCConnectionMixin): js=js, concurrency_limit=concurrency_limit, concurrency_id=concurrency_id, - show_progress=show_progress, + show_progress="minimal", queue=queue, - trigger_mode="multiple", + trigger_mode="once", ) def stream( diff --git a/backend/fastrtc/webrtc_connection_mixin.py b/backend/fastrtc/webrtc_connection_mixin.py index bc72b3c..05026e7 100644 --- a/backend/fastrtc/webrtc_connection_mixin.py +++ b/backend/fastrtc/webrtc_connection_mixin.py @@ -35,7 +35,6 @@ from fastrtc.tracks import ( ) from fastrtc.utils import ( AdditionalOutputs, - DataChannel, create_message, webrtc_error_handler, ) @@ -64,18 +63,20 @@ class OutputQueue: class WebRTCConnectionMixin: - pcs: set[RTCPeerConnection] = set([]) - relay = MediaRelay() - connections: dict[str, list[Track]] = defaultdict(list) - data_channels: dict[str, DataChannel] = {} - additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue) - handlers: dict[str, HandlerType | Callable] = {} - connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event) - concurrency_limit: int | float - event_handler: HandlerType - time_limit: float | int | None - modality: Literal["video", "audio", "audio-video"] - mode: Literal["send", "receive", "send-receive"] + def __init__(self): + self.pcs = set([]) + self.relay = MediaRelay() + self.connections = defaultdict(list) + self.data_channels = {} + self.additional_outputs = defaultdict(OutputQueue) + self.handlers = {} + self.connection_timeouts = defaultdict(asyncio.Event) + # These attributes should be set by subclasses: + self.concurrency_limit: int | float | None + self.event_handler: HandlerType | None + self.time_limit: float | None + self.modality: Literal["video", "audio", "audio-video"] + self.mode: Literal["send", "receive", "send-receive"] @staticmethod async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float): diff --git a/demo/phonic_chat/app.py b/demo/phonic_chat/app.py index 3acde98..b968bc2 100644 --- a/demo/phonic_chat/app.py +++ b/demo/phonic_chat/app.py @@ -1,7 +1,3 @@ -import subprocess - -subprocess.run(["pip", "install", "fastrtc==0.0.4.post1"]) - import asyncio import base64 import os @@ -80,12 +76,11 @@ class PhonicHandler(AsyncStreamHandler): return super().shutdown() -def add_to_chatbot(state, chatbot, message): - state.append(message) - return state, gr.skip() +def add_to_chatbot(chatbot, message): + chatbot.append(message) + return chatbot -state = gr.State(value=[]) chatbot = gr.Chatbot(type="messages", value=[]) stream = Stream( handler=PhonicHandler(), @@ -99,7 +94,7 @@ stream = Stream( info="Select a voice from the dropdown", ) ], - additional_outputs=[state, chatbot], + additional_outputs=[chatbot], additional_outputs_handler=add_to_chatbot, ui_args={ "title": "Phonic Chat (Powered by FastRTC ⚡️)", @@ -109,8 +104,8 @@ stream = Stream( time_limit=90 if get_space() else None, ) -with stream.ui: - state.change(lambda s: s, inputs=state, outputs=chatbot) +# with stream.ui: +# state.change(lambda s: s, inputs=state, outputs=chatbot) if __name__ == "__main__": if (mode := os.getenv("MODE")) == "UI": diff --git a/demo/talk_to_sambanova/app.py b/demo/talk_to_sambanova/app.py index 2faabae..7add9ef 100644 --- a/demo/talk_to_sambanova/app.py +++ b/demo/talk_to_sambanova/app.py @@ -38,6 +38,7 @@ def response( ): gradio_chatbot = gradio_chatbot or [] conversation_state = conversation_state or [] + print("chatbot", gradio_chatbot) text = stt_model.stt(audio) sample_rate, array = audio @@ -47,7 +48,6 @@ def response( yield AdditionalOutputs(gradio_chatbot, conversation_state) conversation_state.append({"role": "user", "content": text}) - request = client.chat.completions.create( model="meta-llama/Llama-3.2-3B-Instruct", messages=conversation_state, # type: ignore diff --git a/docs/userguide/streams.md b/docs/userguide/streams.md index d70dba9..b0ba5af 100644 --- a/docs/userguide/streams.md +++ b/docs/userguide/streams.md @@ -10,7 +10,7 @@ from fastrtc import Stream import gradio as gr import numpy as np -def detection(image): +def detection(image, slider): return np.flip(image, axis=0) stream = Stream( diff --git a/frontend/Index.svelte b/frontend/Index.svelte index e458cb1..b9faea6 100644 --- a/frontend/Index.svelte +++ b/frontend/Index.svelte @@ -148,6 +148,7 @@ {icon} {icon_button_color} {pulse_color} + {icon_radius} {button_labels} on:clear={() => gradio.dispatch("clear")} on:play={() => gradio.dispatch("play")} diff --git a/frontend/shared/InteractiveVideo.svelte b/frontend/shared/InteractiveVideo.svelte index f223acc..59b38a3 100644 --- a/frontend/shared/InteractiveVideo.svelte +++ b/frontend/shared/InteractiveVideo.svelte @@ -30,6 +30,7 @@ export let icon: string | undefined | ComponentType = undefined; export let icon_button_color: string = "var(--color-accent)"; export let pulse_color: string = "var(--color-accent)"; + export let icon_radius: number = 50; const dispatch = createEventDispatcher<{ change: FileData | null; @@ -62,8 +63,8 @@ {icon} {icon_button_color} {pulse_color} - {button_labels} {icon_radius} + {button_labels} on:error on:start_recording on:stop_recording diff --git a/frontend/shared/Webcam.svelte b/frontend/shared/Webcam.svelte index 4d78673..972fd45 100644 --- a/frontend/shared/Webcam.svelte +++ b/frontend/shared/Webcam.svelte @@ -33,6 +33,7 @@ export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters; export let icon: string | undefined | ComponentType = undefined; export let icon_button_color: string = "var(--color-accent)"; + export let icon_radius: number = 50; export let pulse_color: string = "var(--color-accent)"; export let button_labels: { start: string; stop: string; waiting: string }; @@ -242,6 +243,7 @@ icon={icon || Mic} {icon_button_color} {pulse_color} + {icon_radius} /> {/if} diff --git a/pyproject.toml b/pyproject.toml index da14e47..973497a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" [project] name = "fastrtc" -version = "0.0.8post1" +version = "0.0.9" description = "The realtime communication library for Python" readme = "README.md" license = "apache-2.0"