Add first-class support for Cartesia text-to-speech (#298)

* Demo

* patient intake

* cartesia

* Add cartesia

* Fix

* lint

* Move test

* Fix

* Fix

* Fix

* Fix
This commit is contained in:
Freddy Boulton
2025-04-23 15:15:57 -04:00
committed by GitHub
parent 24349dee0c
commit 745701c79c
4 changed files with 99 additions and 14 deletions

View File

@@ -34,6 +34,6 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
python -m pip install -U pip python -m pip install -U pip
pip install .[dev] pip install '.[dev, tts]'
python -m pytest --capture=no python -m pytest --capture=no
shell: bash shell: bash

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import importlib.util
import re import re
from collections.abc import AsyncGenerator, Generator from collections.abc import AsyncGenerator, Generator
from dataclasses import dataclass from dataclasses import dataclass
@@ -9,6 +10,8 @@ import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from numpy.typing import NDArray from numpy.typing import NDArray
from fastrtc.utils import async_aggregate_bytes_to_16bit
class TTSOptions: class TTSOptions:
pass pass
@@ -20,15 +23,15 @@ T = TypeVar("T", bound=TTSOptions, contravariant=True)
class TTSModel(Protocol[T]): class TTSModel(Protocol[T]):
def tts( def tts(
self, text: str, options: T | None = None self, text: str, options: T | None = None
) -> tuple[int, NDArray[np.float32]]: ... ) -> tuple[int, NDArray[np.float32] | NDArray[np.int16]]: ...
def stream_tts( def stream_tts(
self, text: str, options: T | None = None self, text: str, options: T | None = None
) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]: ... ) -> AsyncGenerator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None]: ...
def stream_tts_sync( def stream_tts_sync(
self, text: str, options: T | None = None self, text: str, options: T | None = None
) -> Generator[tuple[int, NDArray[np.float32]], None, None]: ... ) -> Generator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None, None]: ...
@dataclass @dataclass
@@ -39,10 +42,19 @@ class KokoroTTSOptions(TTSOptions):
@lru_cache @lru_cache
def get_tts_model(model: Literal["kokoro"] = "kokoro") -> TTSModel: def get_tts_model(
m = KokoroTTSModel() model: Literal["kokoro", "cartesia"] = "kokoro", **kwargs
m.tts("Hello, world!") ) -> TTSModel:
return m if model == "kokoro":
m = KokoroTTSModel()
m.tts("Hello, world!")
return m
elif model == "cartesia":
m = CartesiaTTSModel(api_key=kwargs.get("cartesia_api_key", ""))
m.tts("Hello, world!")
return m
else:
raise ValueError(f"Invalid model: {model}")
class KokoroFixedBatchSize: class KokoroFixedBatchSize:
@@ -139,3 +151,77 @@ class KokoroTTSModel(TTSModel):
yield loop.run_until_complete(iterator.__anext__()) yield loop.run_until_complete(iterator.__anext__())
except StopAsyncIteration: except StopAsyncIteration:
break break
class CartesiaTTSOptions(TTSOptions):
voice: str = "71a7ad14-091c-4e8e-a314-022ece01c121"
language: str = "en"
emotion: list[str] = []
cartesia_version: str = "2024-06-10"
model: str = "sonic-2"
sample_rate: int = 22_050
class CartesiaTTSModel(TTSModel):
def __init__(self, api_key: str):
if importlib.util.find_spec("cartesia") is None:
raise RuntimeError(
"cartesia is not installed. Please install it using 'pip install cartesia'."
)
from cartesia import AsyncCartesia
self.client = AsyncCartesia(api_key=api_key)
async def stream_tts(
self, text: str, options: CartesiaTTSOptions | None = None
) -> AsyncGenerator[tuple[int, NDArray[np.int16]], None]:
options = options or CartesiaTTSOptions()
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
for sentence in sentences:
if not sentence.strip():
continue
async for output in async_aggregate_bytes_to_16bit(
self.client.tts.bytes(
model_id="sonic-2",
transcript=sentence,
voice={"id": options.voice}, # type: ignore
language="en",
output_format={
"container": "raw",
"sample_rate": options.sample_rate,
"encoding": "pcm_s16le",
},
)
):
yield options.sample_rate, np.frombuffer(output, dtype=np.int16)
def stream_tts_sync(
self, text: str, options: CartesiaTTSOptions | None = None
) -> Generator[tuple[int, NDArray[np.int16]], None, None]:
loop = asyncio.new_event_loop()
iterator = self.stream_tts(text, options).__aiter__()
while True:
try:
yield loop.run_until_complete(iterator.__anext__())
except StopAsyncIteration:
break
def tts(
self, text: str, options: CartesiaTTSOptions | None = None
) -> tuple[int, NDArray[np.int16]]:
loop = asyncio.new_event_loop()
buffer = np.array([], dtype=np.int16)
options = options or CartesiaTTSOptions()
iterator = self.stream_tts(text, options).__aiter__()
while True:
try:
_, chunk = loop.run_until_complete(iterator.__anext__())
buffer = np.concatenate([buffer, chunk])
except StopAsyncIteration:
break
return options.sample_rate, buffer

View File

@@ -1,13 +1,11 @@
import pytest
from fastrtc.text_to_speech.tts import get_tts_model from fastrtc.text_to_speech.tts import get_tts_model
def test_tts_long_prompt(): @pytest.mark.parametrize("model", ["kokoro"])
model = get_tts_model() def test_tts_long_prompt(model):
model = get_tts_model(model=model)
prompt = "It may be that this communication will be considered as a madman's freak but at any rate it must be admitted that in its clearness and frankness it left nothing to be desired The serious part of it was that the Federal Government had undertaken to treat a sale by auction as a valid concession of these undiscovered territories Opinions on the matter were many Some readers saw in it only one of those prodigious outbursts of American humbug which would exceed the limits of puffism if the depths of human credulity were not unfathomable" prompt = "It may be that this communication will be considered as a madman's freak but at any rate it must be admitted that in its clearness and frankness it left nothing to be desired The serious part of it was that the Federal Government had undertaken to treat a sale by auction as a valid concession of these undiscovered territories Opinions on the matter were many Some readers saw in it only one of those prodigious outbursts of American humbug which would exceed the limits of puffism if the depths of human credulity were not unfathomable"
for i, chunk in enumerate(model.stream_tts_sync(prompt)): for i, chunk in enumerate(model.stream_tts_sync(prompt)):
print(f"Chunk {i}: {chunk[1].shape}") print(f"Chunk {i}: {chunk[1].shape}")
if __name__ == "__main__":
test_tts_long_prompt()

View File

@@ -39,6 +39,7 @@ class MinimalTestStream(WebRTCConnectionMixin):
) )
self.time_limit = time_limit self.time_limit = time_limit
self.allow_extra_tracks = allow_extra_tracks self.allow_extra_tracks = allow_extra_tracks
self.server_rtc_configuration = None
def mount(self, app: FastAPI, path: str = ""): def mount(self, app: FastAPI, path: str = ""):
from fastapi import APIRouter from fastapi import APIRouter