Add first-class support for Cartesia text-to-speech (#298)

* Demo

* patient intake

* cartesia

* Add cartesia

* Fix

* lint

* Move test

* Fix

* Fix

* Fix

* Fix
This commit is contained in:
Freddy Boulton
2025-04-23 15:15:57 -04:00
committed by GitHub
parent 24349dee0c
commit 745701c79c
4 changed files with 99 additions and 14 deletions

View File

@@ -34,6 +34,6 @@ jobs:
- name: Run tests
run: |
python -m pip install -U pip
pip install .[dev]
pip install '.[dev, tts]'
python -m pytest --capture=no
shell: bash

View File

@@ -1,4 +1,5 @@
import asyncio
import importlib.util
import re
from collections.abc import AsyncGenerator, Generator
from dataclasses import dataclass
@@ -9,6 +10,8 @@ import numpy as np
from huggingface_hub import hf_hub_download
from numpy.typing import NDArray
from fastrtc.utils import async_aggregate_bytes_to_16bit
class TTSOptions:
pass
@@ -20,15 +23,15 @@ T = TypeVar("T", bound=TTSOptions, contravariant=True)
class TTSModel(Protocol[T]):
def tts(
self, text: str, options: T | None = None
) -> tuple[int, NDArray[np.float32]]: ...
) -> tuple[int, NDArray[np.float32] | NDArray[np.int16]]: ...
def stream_tts(
self, text: str, options: T | None = None
) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]: ...
) -> AsyncGenerator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None]: ...
def stream_tts_sync(
self, text: str, options: T | None = None
) -> Generator[tuple[int, NDArray[np.float32]], None, None]: ...
) -> Generator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None, None]: ...
@dataclass
@@ -39,10 +42,19 @@ class KokoroTTSOptions(TTSOptions):
@lru_cache
def get_tts_model(model: Literal["kokoro"] = "kokoro") -> TTSModel:
m = KokoroTTSModel()
m.tts("Hello, world!")
return m
def get_tts_model(
model: Literal["kokoro", "cartesia"] = "kokoro", **kwargs
) -> TTSModel:
if model == "kokoro":
m = KokoroTTSModel()
m.tts("Hello, world!")
return m
elif model == "cartesia":
m = CartesiaTTSModel(api_key=kwargs.get("cartesia_api_key", ""))
m.tts("Hello, world!")
return m
else:
raise ValueError(f"Invalid model: {model}")
class KokoroFixedBatchSize:
@@ -139,3 +151,77 @@ class KokoroTTSModel(TTSModel):
yield loop.run_until_complete(iterator.__anext__())
except StopAsyncIteration:
break
class CartesiaTTSOptions(TTSOptions):
voice: str = "71a7ad14-091c-4e8e-a314-022ece01c121"
language: str = "en"
emotion: list[str] = []
cartesia_version: str = "2024-06-10"
model: str = "sonic-2"
sample_rate: int = 22_050
class CartesiaTTSModel(TTSModel):
def __init__(self, api_key: str):
if importlib.util.find_spec("cartesia") is None:
raise RuntimeError(
"cartesia is not installed. Please install it using 'pip install cartesia'."
)
from cartesia import AsyncCartesia
self.client = AsyncCartesia(api_key=api_key)
async def stream_tts(
self, text: str, options: CartesiaTTSOptions | None = None
) -> AsyncGenerator[tuple[int, NDArray[np.int16]], None]:
options = options or CartesiaTTSOptions()
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
for sentence in sentences:
if not sentence.strip():
continue
async for output in async_aggregate_bytes_to_16bit(
self.client.tts.bytes(
model_id="sonic-2",
transcript=sentence,
voice={"id": options.voice}, # type: ignore
language="en",
output_format={
"container": "raw",
"sample_rate": options.sample_rate,
"encoding": "pcm_s16le",
},
)
):
yield options.sample_rate, np.frombuffer(output, dtype=np.int16)
def stream_tts_sync(
self, text: str, options: CartesiaTTSOptions | None = None
) -> Generator[tuple[int, NDArray[np.int16]], None, None]:
loop = asyncio.new_event_loop()
iterator = self.stream_tts(text, options).__aiter__()
while True:
try:
yield loop.run_until_complete(iterator.__anext__())
except StopAsyncIteration:
break
def tts(
self, text: str, options: CartesiaTTSOptions | None = None
) -> tuple[int, NDArray[np.int16]]:
loop = asyncio.new_event_loop()
buffer = np.array([], dtype=np.int16)
options = options or CartesiaTTSOptions()
iterator = self.stream_tts(text, options).__aiter__()
while True:
try:
_, chunk = loop.run_until_complete(iterator.__anext__())
buffer = np.concatenate([buffer, chunk])
except StopAsyncIteration:
break
return options.sample_rate, buffer

View File

@@ -1,13 +1,11 @@
import pytest
from fastrtc.text_to_speech.tts import get_tts_model
def test_tts_long_prompt():
model = get_tts_model()
@pytest.mark.parametrize("model", ["kokoro"])
def test_tts_long_prompt(model):
model = get_tts_model(model=model)
prompt = "It may be that this communication will be considered as a madman's freak but at any rate it must be admitted that in its clearness and frankness it left nothing to be desired The serious part of it was that the Federal Government had undertaken to treat a sale by auction as a valid concession of these undiscovered territories Opinions on the matter were many Some readers saw in it only one of those prodigious outbursts of American humbug which would exceed the limits of puffism if the depths of human credulity were not unfathomable"
for i, chunk in enumerate(model.stream_tts_sync(prompt)):
print(f"Chunk {i}: {chunk[1].shape}")
if __name__ == "__main__":
test_tts_long_prompt()

View File

@@ -39,6 +39,7 @@ class MinimalTestStream(WebRTCConnectionMixin):
)
self.time_limit = time_limit
self.allow_extra_tracks = allow_extra_tracks
self.server_rtc_configuration = None
def mount(self, app: FastAPI, path: str = ""):
from fastapi import APIRouter