Add ability to Hide Title in Built-in UI + llama 4 cartesia tweaks (#299)

* merge title

* Fix
This commit is contained in:
Freddy Boulton
2025-04-23 16:01:54 -04:00
committed by GitHub
parent 745701c79c
commit 02aef9da58
6 changed files with 131 additions and 97 deletions

View File

@@ -17,7 +17,11 @@ from .reply_on_pause import AlgoOptions, ReplyOnPause
from .reply_on_stopwords import ReplyOnStopWords from .reply_on_stopwords import ReplyOnStopWords
from .speech_to_text import MoonshineSTT, get_stt_model from .speech_to_text import MoonshineSTT, get_stt_model
from .stream import Stream, UIArgs from .stream import Stream, UIArgs
from .text_to_speech import KokoroTTSOptions, get_tts_model from .text_to_speech import (
CartesiaTTSOptions,
KokoroTTSOptions,
get_tts_model,
)
from .tracks import ( from .tracks import (
AsyncAudioVideoStreamHandler, AsyncAudioVideoStreamHandler,
AsyncStreamHandler, AsyncStreamHandler,
@@ -87,4 +91,5 @@ __all__ = [
"VideoStreamHandler", "VideoStreamHandler",
"CloseStream", "CloseStream",
"get_current_context", "get_current_context",
"CartesiaTTSOptions",
] ]

View File

@@ -59,6 +59,8 @@ class UIArgs(TypedDict):
If "submit", the input will be sent when the submit event is triggered by the user. If "submit", the input will be sent when the submit event is triggered by the user.
If "change", the input will be sent whenever the user changes the input value. If "change", the input will be sent whenever the user changes the input value.
""" """
hide_title: NotRequired[bool]
"""If True, the title and subtitle will not be displayed."""
class Stream(WebRTCConnectionMixin): class Stream(WebRTCConnectionMixin):
@@ -339,21 +341,22 @@ class Stream(WebRTCConnectionMixin):
same_components.append(component) same_components.append(component)
if self.modality == "video" and self.mode == "receive": if self.modality == "video" and self.mode == "receive":
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.HTML( if not ui_args.get("hide_title"):
f""" gr.HTML(
<h1 style='text-align: center'>
{ui_args.get("title", "Video Streaming (Powered by FastRTC ⚡️)")}
</h1>
"""
)
if ui_args.get("subtitle"):
gr.Markdown(
f""" f"""
<div style='text-align: center'> <h1 style='text-align: center'>
{ui_args.get("subtitle")} {ui_args.get("title", "Video Streaming (Powered by FastRTC ⚡️)")}
</div> </h1>
""" """
) )
if ui_args.get("subtitle"):
gr.Markdown(
f"""
<div style='text-align: center'>
{ui_args.get("subtitle")}
</div>
"""
)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
if additional_input_components: if additional_input_components:
@@ -391,21 +394,22 @@ class Stream(WebRTCConnectionMixin):
) )
elif self.modality == "video" and self.mode == "send": elif self.modality == "video" and self.mode == "send":
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.HTML( if not ui_args.get("hide_title"):
f""" gr.HTML(
<h1 style='text-align: center'>
{ui_args.get("title", "Video Streaming (Powered by FastRTC ⚡️)")}
</h1>
"""
)
if ui_args.get("subtitle"):
gr.Markdown(
f""" f"""
<div style='text-align: center'> <h1 style='text-align: center'>
{ui_args.get("subtitle")} {ui_args.get("title", "Video Streaming (Powered by FastRTC ⚡️)")}
</div> </h1>
""" """
) )
if ui_args.get("subtitle"):
gr.Markdown(
f"""
<div style='text-align: center'>
{ui_args.get("subtitle")}
</div>
"""
)
with gr.Row(): with gr.Row():
if additional_input_components: if additional_input_components:
with gr.Column(): with gr.Column():
@@ -494,21 +498,22 @@ class Stream(WebRTCConnectionMixin):
) )
elif self.modality == "audio" and self.mode == "receive": elif self.modality == "audio" and self.mode == "receive":
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.HTML( if not ui_args.get("hide_title"):
f""" gr.HTML(
<h1 style='text-align: center'>
{ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")}
</h1>
"""
)
if ui_args.get("subtitle"):
gr.Markdown(
f""" f"""
<div style='text-align: center'> <h1 style='text-align: center'>
{ui_args.get("subtitle")} {ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")}
</div> </h1>
""" """
) )
if ui_args.get("subtitle"):
gr.Markdown(
f"""
<div style='text-align: center'>
{ui_args.get("subtitle")}
</div>
"""
)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
for component in additional_input_components: for component in additional_input_components:
@@ -549,21 +554,22 @@ class Stream(WebRTCConnectionMixin):
) )
elif self.modality == "audio" and self.mode == "send": elif self.modality == "audio" and self.mode == "send":
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.HTML( if not ui_args.get("hide_title"):
f""" gr.HTML(
<h1 style='text-align: center'>
{ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")}
</h1>
"""
)
if ui_args.get("subtitle"):
gr.Markdown(
f""" f"""
<div style='text-align: center'> <h1 style='text-align: center'>
{ui_args.get("subtitle")} {ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")}
</div> </h1>
""" """
) )
if ui_args.get("subtitle"):
gr.Markdown(
f"""
<div style='text-align: center'>
{ui_args.get("subtitle")}
</div>
"""
)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Group(): with gr.Group():
@@ -604,21 +610,22 @@ class Stream(WebRTCConnectionMixin):
) )
elif self.modality == "audio" and self.mode == "send-receive": elif self.modality == "audio" and self.mode == "send-receive":
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.HTML( if not ui_args.get("hide_title"):
f""" gr.HTML(
<h1 style='text-align: center'>
{ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")}
</h1>
"""
)
if ui_args.get("subtitle"):
gr.Markdown(
f""" f"""
<div style='text-align: center'> <h1 style='text-align: center'>
{ui_args.get("subtitle")} {ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")}
</div> </h1>
""" """
) )
if ui_args.get("subtitle"):
gr.Markdown(
f"""
<div style='text-align: center'>
{ui_args.get("subtitle")}
</div>
"""
)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Group(): with gr.Group():
@@ -662,21 +669,22 @@ class Stream(WebRTCConnectionMixin):
css = """.my-group {max-width: 600px !important; max-height: 600 !important;} css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks(css=css) as demo: with gr.Blocks(css=css) as demo:
gr.HTML( if not ui_args.get("hide_title"):
f""" gr.HTML(
<h1 style='text-align: center'>
{ui_args.get("title", "Audio Video Streaming (Powered by FastRTC ⚡️)")}
</h1>
"""
)
if ui_args.get("subtitle"):
gr.Markdown(
f""" f"""
<div style='text-align: center'> <h1 style='text-align: center'>
{ui_args.get("subtitle")} {ui_args.get("title", "Audio Video Streaming (Powered by FastRTC ⚡️)")}
</div> </h1>
""" """
) )
if ui_args.get("subtitle"):
gr.Markdown(
f"""
<div style='text-align: center'>
{ui_args.get("subtitle")}
</div>
"""
)
with gr.Row(): with gr.Row():
with gr.Column(elem_classes=["my-column"]): with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]): with gr.Group(elem_classes=["my-group"]):

View File

@@ -1,3 +1,7 @@
from .tts import KokoroTTSOptions, get_tts_model from .tts import (
CartesiaTTSOptions,
KokoroTTSOptions,
get_tts_model,
)
__all__ = ["get_tts_model", "KokoroTTSOptions"] __all__ = ["get_tts_model", "KokoroTTSOptions", "CartesiaTTSOptions"]

View File

@@ -2,7 +2,7 @@ import asyncio
import importlib.util import importlib.util
import re import re
from collections.abc import AsyncGenerator, Generator from collections.abc import AsyncGenerator, Generator
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import Literal, Protocol, TypeVar from typing import Literal, Protocol, TypeVar
@@ -153,10 +153,11 @@ class KokoroTTSModel(TTSModel):
break break
@dataclass
class CartesiaTTSOptions(TTSOptions): class CartesiaTTSOptions(TTSOptions):
voice: str = "71a7ad14-091c-4e8e-a314-022ece01c121" voice: str = "71a7ad14-091c-4e8e-a314-022ece01c121"
language: str = "en" language: str = "en"
emotion: list[str] = [] emotion: list[str] = field(default_factory=list)
cartesia_version: str = "2024-06-10" cartesia_version: str = "2024-06-10"
model: str = "sonic-2" model: str = "sonic-2"
sample_rate: int = 22_050 sample_rate: int = 22_050

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

View File

@@ -9,11 +9,12 @@ from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.responses import HTMLResponse, StreamingResponse
from fastrtc import ( from fastrtc import (
AdditionalOutputs, AdditionalOutputs,
CartesiaTTSOptions,
ReplyOnPause, ReplyOnPause,
Stream, Stream,
audio_to_bytes,
get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials_async,
get_current_context, get_current_context,
get_stt_model,
get_tts_model, get_tts_model,
) )
from groq import Groq from groq import Groq
@@ -22,9 +23,11 @@ from numpy.typing import NDArray
curr_dir = Path(__file__).parent curr_dir = Path(__file__).parent
load_dotenv() load_dotenv()
tts_model = get_tts_model() tts_model = get_tts_model(
model="cartesia", cartesia_api_key=os.getenv("CARTESIA_API_KEY")
)
groq = Groq(api_key=os.getenv("GROQ_API_KEY")) groq = Groq(api_key=os.getenv("GROQ_API_KEY"))
stt_model = get_stt_model()
conversations: dict[str, list[dict[str, str]]] = {} conversations: dict[str, list[dict[str, str]]] = {}
@@ -43,14 +46,8 @@ def response(user_audio: tuple[int, NDArray[np.int16]]):
] ]
messages = conversations[context.webrtc_id] messages = conversations[context.webrtc_id]
transcription = groq.audio.transcriptions.create( transcription = stt_model.stt(user_audio)
file=("audio.wav", audio_to_bytes(user_audio)), messages.append({"role": "user", "content": transcription})
model="distil-whisper-large-v3-en",
response_format="verbose_json",
)
print(transcription.text)
messages.append({"role": "user", "content": transcription.text})
completion = groq.chat.completions.create( # type: ignore completion = groq.chat.completions.create( # type: ignore
model="meta-llama/llama-4-scout-17b-16e-instruct", model="meta-llama/llama-4-scout-17b-16e-instruct",
@@ -68,7 +65,9 @@ def response(user_audio: tuple[int, NDArray[np.int16]]):
long_response = response["long"] long_response = response["long"]
messages.append({"role": "assistant", "content": long_response}) messages.append({"role": "assistant", "content": long_response})
conversations[context.webrtc_id] = messages conversations[context.webrtc_id] = messages
yield from tts_model.stream_tts_sync(short_response) yield from tts_model.stream_tts_sync(
short_response, options=CartesiaTTSOptions(sample_rate=24_000)
)
yield AdditionalOutputs(messages) yield AdditionalOutputs(messages)
@@ -78,9 +77,22 @@ stream = Stream(
mode="send-receive", mode="send-receive",
additional_outputs=[gr.Chatbot(type="messages")], additional_outputs=[gr.Chatbot(type="messages")],
additional_outputs_handler=lambda old, new: new, additional_outputs_handler=lambda old, new: new,
rtc_configuration=get_cloudflare_turn_credentials_async, rtc_configuration=None,
ui_args={"hide_title": True},
) )
with gr.Blocks() as demo:
gr.HTML(
f"""
<h1 style='text-align: center; display: flex; align-items: center; justify-content: center;'>
<img src="/gradio_api/file={str((Path(__file__).parent / "AV_Huggy.png").resolve())}" alt="AV Huggy" style="height: 100px; margin-right: 10px"> FastRTC + Cartesia TTS = Blazing Fast LLM Audio
</h1>
"""
)
stream.ui.render()
stream.ui = demo
app = FastAPI() app = FastAPI()
stream.mount(app) stream.mount(app)
@@ -109,9 +121,13 @@ async def _(webrtc_id: str):
if __name__ == "__main__": if __name__ == "__main__":
import os import os
from pathlib import Path
if (mode := os.getenv("MODE")) == "UI": if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860) stream.ui.launch(
server_port=7860,
allowed_paths=[str((Path(__file__).parent / "AV_Huggy.png").resolve())],
)
elif mode == "PHONE": elif mode == "PHONE":
raise ValueError("Phone mode not supported") raise ValueError("Phone mode not supported")
else: else: