diff --git a/backend/fastrtc/__init__.py b/backend/fastrtc/__init__.py index 43f61ec..2e08399 100644 --- a/backend/fastrtc/__init__.py +++ b/backend/fastrtc/__init__.py @@ -17,7 +17,11 @@ from .reply_on_pause import AlgoOptions, ReplyOnPause from .reply_on_stopwords import ReplyOnStopWords from .speech_to_text import MoonshineSTT, get_stt_model from .stream import Stream, UIArgs -from .text_to_speech import KokoroTTSOptions, get_tts_model +from .text_to_speech import ( + CartesiaTTSOptions, + KokoroTTSOptions, + get_tts_model, +) from .tracks import ( AsyncAudioVideoStreamHandler, AsyncStreamHandler, @@ -87,4 +91,5 @@ __all__ = [ "VideoStreamHandler", "CloseStream", "get_current_context", + "CartesiaTTSOptions", ] diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py index d1b53ac..1a2acdc 100644 --- a/backend/fastrtc/stream.py +++ b/backend/fastrtc/stream.py @@ -59,6 +59,8 @@ class UIArgs(TypedDict): If "submit", the input will be sent when the submit event is triggered by the user. If "change", the input will be sent whenever the user changes the input value. """ + hide_title: NotRequired[bool] + """If True, the title and subtitle will not be displayed.""" class Stream(WebRTCConnectionMixin): @@ -339,21 +341,22 @@ class Stream(WebRTCConnectionMixin): same_components.append(component) if self.modality == "video" and self.mode == "receive": with gr.Blocks() as demo: - gr.HTML( - f""" -

- {ui_args.get("title", "Video Streaming (Powered by FastRTC ⚡️)")} -

- """ - ) - if ui_args.get("subtitle"): - gr.Markdown( + if not ui_args.get("hide_title"): + gr.HTML( f""" -
- {ui_args.get("subtitle")} -
- """ +

+ {ui_args.get("title", "Video Streaming (Powered by FastRTC ⚡️)")} +

+ """ ) + if ui_args.get("subtitle"): + gr.Markdown( + f""" +
+ {ui_args.get("subtitle")} +
+ """ + ) with gr.Row(): with gr.Column(): if additional_input_components: @@ -391,21 +394,22 @@ class Stream(WebRTCConnectionMixin): ) elif self.modality == "video" and self.mode == "send": with gr.Blocks() as demo: - gr.HTML( - f""" -

- {ui_args.get("title", "Video Streaming (Powered by FastRTC ⚡️)")} -

- """ - ) - if ui_args.get("subtitle"): - gr.Markdown( + if not ui_args.get("hide_title"): + gr.HTML( f""" -
- {ui_args.get("subtitle")} -
- """ +

+ {ui_args.get("title", "Video Streaming (Powered by FastRTC ⚡️)")} +

+ """ ) + if ui_args.get("subtitle"): + gr.Markdown( + f""" +
+ {ui_args.get("subtitle")} +
+ """ + ) with gr.Row(): if additional_input_components: with gr.Column(): @@ -494,21 +498,22 @@ class Stream(WebRTCConnectionMixin): ) elif self.modality == "audio" and self.mode == "receive": with gr.Blocks() as demo: - gr.HTML( - f""" -

- {ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")} -

- """ - ) - if ui_args.get("subtitle"): - gr.Markdown( + if not ui_args.get("hide_title"): + gr.HTML( f""" -
- {ui_args.get("subtitle")} -
- """ +

+ {ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")} +

+ """ ) + if ui_args.get("subtitle"): + gr.Markdown( + f""" +
+ {ui_args.get("subtitle")} +
+ """ + ) with gr.Row(): with gr.Column(): for component in additional_input_components: @@ -549,21 +554,22 @@ class Stream(WebRTCConnectionMixin): ) elif self.modality == "audio" and self.mode == "send": with gr.Blocks() as demo: - gr.HTML( - f""" -

- {ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")} -

- """ - ) - if ui_args.get("subtitle"): - gr.Markdown( + if not ui_args.get("hide_title"): + gr.HTML( f""" -
- {ui_args.get("subtitle")} -
- """ +

+ {ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")} +

+ """ ) + if ui_args.get("subtitle"): + gr.Markdown( + f""" +
+ {ui_args.get("subtitle")} +
+ """ + ) with gr.Row(): with gr.Column(): with gr.Group(): @@ -604,21 +610,22 @@ class Stream(WebRTCConnectionMixin): ) elif self.modality == "audio" and self.mode == "send-receive": with gr.Blocks() as demo: - gr.HTML( - f""" -

- {ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")} -

- """ - ) - if ui_args.get("subtitle"): - gr.Markdown( + if not ui_args.get("hide_title"): + gr.HTML( f""" -
- {ui_args.get("subtitle")} -
- """ +

+ {ui_args.get("title", "Audio Streaming (Powered by FastRTC ⚡️)")} +

+ """ ) + if ui_args.get("subtitle"): + gr.Markdown( + f""" +
+ {ui_args.get("subtitle")} +
+ """ + ) with gr.Row(): with gr.Column(): with gr.Group(): @@ -662,21 +669,22 @@ class Stream(WebRTCConnectionMixin): css = """.my-group {max-width: 600px !important; max-height: 600 !important;} .my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" with gr.Blocks(css=css) as demo: - gr.HTML( - f""" -

- {ui_args.get("title", "Audio Video Streaming (Powered by FastRTC ⚡️)")} -

- """ - ) - if ui_args.get("subtitle"): - gr.Markdown( + if not ui_args.get("hide_title"): + gr.HTML( f""" -
- {ui_args.get("subtitle")} -
- """ +

+ {ui_args.get("title", "Audio Video Streaming (Powered by FastRTC ⚡️)")} +

+ """ ) + if ui_args.get("subtitle"): + gr.Markdown( + f""" +
+ {ui_args.get("subtitle")} +
+ """ + ) with gr.Row(): with gr.Column(elem_classes=["my-column"]): with gr.Group(elem_classes=["my-group"]): diff --git a/backend/fastrtc/text_to_speech/__init__.py b/backend/fastrtc/text_to_speech/__init__.py index 2cc082a..0d55538 100644 --- a/backend/fastrtc/text_to_speech/__init__.py +++ b/backend/fastrtc/text_to_speech/__init__.py @@ -1,3 +1,7 @@ -from .tts import KokoroTTSOptions, get_tts_model +from .tts import ( + CartesiaTTSOptions, + KokoroTTSOptions, + get_tts_model, +) -__all__ = ["get_tts_model", "KokoroTTSOptions"] +__all__ = ["get_tts_model", "KokoroTTSOptions", "CartesiaTTSOptions"] diff --git a/backend/fastrtc/text_to_speech/tts.py b/backend/fastrtc/text_to_speech/tts.py index 37743be..f800e4b 100644 --- a/backend/fastrtc/text_to_speech/tts.py +++ b/backend/fastrtc/text_to_speech/tts.py @@ -2,7 +2,7 @@ import asyncio import importlib.util import re from collections.abc import AsyncGenerator, Generator -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import lru_cache from typing import Literal, Protocol, TypeVar @@ -153,10 +153,11 @@ class KokoroTTSModel(TTSModel): break +@dataclass class CartesiaTTSOptions(TTSOptions): voice: str = "71a7ad14-091c-4e8e-a314-022ece01c121" language: str = "en" - emotion: list[str] = [] + emotion: list[str] = field(default_factory=list) cartesia_version: str = "2024-06-10" model: str = "sonic-2" sample_rate: int = 22_050 diff --git a/demo/talk_to_llama4/AV_Huggy.png b/demo/talk_to_llama4/AV_Huggy.png new file mode 100644 index 0000000..6a9fb74 Binary files /dev/null and b/demo/talk_to_llama4/AV_Huggy.png differ diff --git a/demo/talk_to_llama4/app.py b/demo/talk_to_llama4/app.py index 9929fca..4010fc5 100644 --- a/demo/talk_to_llama4/app.py +++ b/demo/talk_to_llama4/app.py @@ -9,11 +9,12 @@ from fastapi import FastAPI from fastapi.responses import HTMLResponse, StreamingResponse from fastrtc import ( AdditionalOutputs, + CartesiaTTSOptions, ReplyOnPause, Stream, - audio_to_bytes, get_cloudflare_turn_credentials_async, get_current_context, + get_stt_model, get_tts_model, ) from groq import Groq @@ -22,9 +23,11 @@ from numpy.typing import NDArray curr_dir = Path(__file__).parent load_dotenv() -tts_model = get_tts_model() +tts_model = get_tts_model( + model="cartesia", cartesia_api_key=os.getenv("CARTESIA_API_KEY") +) groq = Groq(api_key=os.getenv("GROQ_API_KEY")) - +stt_model = get_stt_model() conversations: dict[str, list[dict[str, str]]] = {} @@ -43,14 +46,8 @@ def response(user_audio: tuple[int, NDArray[np.int16]]): ] messages = conversations[context.webrtc_id] - transcription = groq.audio.transcriptions.create( - file=("audio.wav", audio_to_bytes(user_audio)), - model="distil-whisper-large-v3-en", - response_format="verbose_json", - ) - print(transcription.text) - - messages.append({"role": "user", "content": transcription.text}) + transcription = stt_model.stt(user_audio) + messages.append({"role": "user", "content": transcription}) completion = groq.chat.completions.create( # type: ignore model="meta-llama/llama-4-scout-17b-16e-instruct", @@ -68,7 +65,9 @@ def response(user_audio: tuple[int, NDArray[np.int16]]): long_response = response["long"] messages.append({"role": "assistant", "content": long_response}) conversations[context.webrtc_id] = messages - yield from tts_model.stream_tts_sync(short_response) + yield from tts_model.stream_tts_sync( + short_response, options=CartesiaTTSOptions(sample_rate=24_000) + ) yield AdditionalOutputs(messages) @@ -78,9 +77,22 @@ stream = Stream( mode="send-receive", additional_outputs=[gr.Chatbot(type="messages")], additional_outputs_handler=lambda old, new: new, - rtc_configuration=get_cloudflare_turn_credentials_async, + rtc_configuration=None, + ui_args={"hide_title": True}, ) +with gr.Blocks() as demo: + gr.HTML( + f""" +

+ AV Huggy FastRTC + Cartesia TTS = Blazing Fast LLM Audio +

+ """ + ) + stream.ui.render() + +stream.ui = demo + app = FastAPI() stream.mount(app) @@ -109,9 +121,13 @@ async def _(webrtc_id: str): if __name__ == "__main__": import os + from pathlib import Path if (mode := os.getenv("MODE")) == "UI": - stream.ui.launch(server_port=7860) + stream.ui.launch( + server_port=7860, + allowed_paths=[str((Path(__file__).parent / "AV_Huggy.png").resolve())], + ) elif mode == "PHONE": raise ValueError("Phone mode not supported") else: