diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f0a0c6d..7a0c833 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,6 +34,6 @@ jobs: - name: Run tests run: | python -m pip install -U pip - pip install .[dev] + pip install '.[dev, tts]' python -m pytest --capture=no shell: bash diff --git a/backend/fastrtc/text_to_speech/tts.py b/backend/fastrtc/text_to_speech/tts.py index 5e910e1..37743be 100644 --- a/backend/fastrtc/text_to_speech/tts.py +++ b/backend/fastrtc/text_to_speech/tts.py @@ -1,4 +1,5 @@ import asyncio +import importlib.util import re from collections.abc import AsyncGenerator, Generator from dataclasses import dataclass @@ -9,6 +10,8 @@ import numpy as np from huggingface_hub import hf_hub_download from numpy.typing import NDArray +from fastrtc.utils import async_aggregate_bytes_to_16bit + class TTSOptions: pass @@ -20,15 +23,15 @@ T = TypeVar("T", bound=TTSOptions, contravariant=True) class TTSModel(Protocol[T]): def tts( self, text: str, options: T | None = None - ) -> tuple[int, NDArray[np.float32]]: ... + ) -> tuple[int, NDArray[np.float32] | NDArray[np.int16]]: ... def stream_tts( self, text: str, options: T | None = None - ) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]: ... + ) -> AsyncGenerator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None]: ... def stream_tts_sync( self, text: str, options: T | None = None - ) -> Generator[tuple[int, NDArray[np.float32]], None, None]: ... + ) -> Generator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None, None]: ... @dataclass @@ -39,10 +42,19 @@ class KokoroTTSOptions(TTSOptions): @lru_cache -def get_tts_model(model: Literal["kokoro"] = "kokoro") -> TTSModel: - m = KokoroTTSModel() - m.tts("Hello, world!") - return m +def get_tts_model( + model: Literal["kokoro", "cartesia"] = "kokoro", **kwargs +) -> TTSModel: + if model == "kokoro": + m = KokoroTTSModel() + m.tts("Hello, world!") + return m + elif model == "cartesia": + m = CartesiaTTSModel(api_key=kwargs.get("cartesia_api_key", "")) + m.tts("Hello, world!") + return m + else: + raise ValueError(f"Invalid model: {model}") class KokoroFixedBatchSize: @@ -139,3 +151,77 @@ class KokoroTTSModel(TTSModel): yield loop.run_until_complete(iterator.__anext__()) except StopAsyncIteration: break + + +class CartesiaTTSOptions(TTSOptions): + voice: str = "71a7ad14-091c-4e8e-a314-022ece01c121" + language: str = "en" + emotion: list[str] = [] + cartesia_version: str = "2024-06-10" + model: str = "sonic-2" + sample_rate: int = 22_050 + + +class CartesiaTTSModel(TTSModel): + def __init__(self, api_key: str): + if importlib.util.find_spec("cartesia") is None: + raise RuntimeError( + "cartesia is not installed. Please install it using 'pip install cartesia'." + ) + from cartesia import AsyncCartesia + + self.client = AsyncCartesia(api_key=api_key) + + async def stream_tts( + self, text: str, options: CartesiaTTSOptions | None = None + ) -> AsyncGenerator[tuple[int, NDArray[np.int16]], None]: + options = options or CartesiaTTSOptions() + + sentences = re.split(r"(?<=[.!?])\s+", text.strip()) + + for sentence in sentences: + if not sentence.strip(): + continue + async for output in async_aggregate_bytes_to_16bit( + self.client.tts.bytes( + model_id="sonic-2", + transcript=sentence, + voice={"id": options.voice}, # type: ignore + language="en", + output_format={ + "container": "raw", + "sample_rate": options.sample_rate, + "encoding": "pcm_s16le", + }, + ) + ): + yield options.sample_rate, np.frombuffer(output, dtype=np.int16) + + def stream_tts_sync( + self, text: str, options: CartesiaTTSOptions | None = None + ) -> Generator[tuple[int, NDArray[np.int16]], None, None]: + loop = asyncio.new_event_loop() + + iterator = self.stream_tts(text, options).__aiter__() + while True: + try: + yield loop.run_until_complete(iterator.__anext__()) + except StopAsyncIteration: + break + + def tts( + self, text: str, options: CartesiaTTSOptions | None = None + ) -> tuple[int, NDArray[np.int16]]: + loop = asyncio.new_event_loop() + buffer = np.array([], dtype=np.int16) + + options = options or CartesiaTTSOptions() + + iterator = self.stream_tts(text, options).__aiter__() + while True: + try: + _, chunk = loop.run_until_complete(iterator.__anext__()) + buffer = np.concatenate([buffer, chunk]) + except StopAsyncIteration: + break + return options.sample_rate, buffer diff --git a/backend/fastrtc/text_to_speech/test_tts.py b/test/test_tts.py similarity index 84% rename from backend/fastrtc/text_to_speech/test_tts.py rename to test/test_tts.py index e3abaf7..56a59e1 100644 --- a/backend/fastrtc/text_to_speech/test_tts.py +++ b/test/test_tts.py @@ -1,13 +1,11 @@ +import pytest from fastrtc.text_to_speech.tts import get_tts_model -def test_tts_long_prompt(): - model = get_tts_model() +@pytest.mark.parametrize("model", ["kokoro"]) +def test_tts_long_prompt(model): + model = get_tts_model(model=model) prompt = "It may be that this communication will be considered as a madman's freak but at any rate it must be admitted that in its clearness and frankness it left nothing to be desired The serious part of it was that the Federal Government had undertaken to treat a sale by auction as a valid concession of these undiscovered territories Opinions on the matter were many Some readers saw in it only one of those prodigious outbursts of American humbug which would exceed the limits of puffism if the depths of human credulity were not unfathomable" for i, chunk in enumerate(model.stream_tts_sync(prompt)): print(f"Chunk {i}: {chunk[1].shape}") - - -if __name__ == "__main__": - test_tts_long_prompt() diff --git a/test/test_webrtc_connection_mixin.py b/test/test_webrtc_connection_mixin.py index dbf0da0..7262eb5 100644 --- a/test/test_webrtc_connection_mixin.py +++ b/test/test_webrtc_connection_mixin.py @@ -39,6 +39,7 @@ class MinimalTestStream(WebRTCConnectionMixin): ) self.time_limit = time_limit self.allow_extra_tracks = allow_extra_tracks + self.server_rtc_configuration = None def mount(self, app: FastAPI, path: str = ""): from fastapi import APIRouter