mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 09:29:23 +08:00
Add first-class support for Cartesia text-to-speech (#298)
* Demo * patient intake * cartesia * Add cartesia * Fix * lint * Move test * Fix * Fix * Fix * Fix
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -34,6 +34,6 @@ jobs:
|
||||
- name: Run tests
|
||||
run: |
|
||||
python -m pip install -U pip
|
||||
pip install .[dev]
|
||||
pip install '.[dev, tts]'
|
||||
python -m pytest --capture=no
|
||||
shell: bash
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import re
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from dataclasses import dataclass
|
||||
@@ -9,6 +10,8 @@ import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from fastrtc.utils import async_aggregate_bytes_to_16bit
|
||||
|
||||
|
||||
class TTSOptions:
|
||||
pass
|
||||
@@ -20,15 +23,15 @@ T = TypeVar("T", bound=TTSOptions, contravariant=True)
|
||||
class TTSModel(Protocol[T]):
|
||||
def tts(
|
||||
self, text: str, options: T | None = None
|
||||
) -> tuple[int, NDArray[np.float32]]: ...
|
||||
) -> tuple[int, NDArray[np.float32] | NDArray[np.int16]]: ...
|
||||
|
||||
def stream_tts(
|
||||
self, text: str, options: T | None = None
|
||||
) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]: ...
|
||||
) -> AsyncGenerator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None]: ...
|
||||
|
||||
def stream_tts_sync(
|
||||
self, text: str, options: T | None = None
|
||||
) -> Generator[tuple[int, NDArray[np.float32]], None, None]: ...
|
||||
) -> Generator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None, None]: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,10 +42,19 @@ class KokoroTTSOptions(TTSOptions):
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_tts_model(model: Literal["kokoro"] = "kokoro") -> TTSModel:
|
||||
m = KokoroTTSModel()
|
||||
m.tts("Hello, world!")
|
||||
return m
|
||||
def get_tts_model(
|
||||
model: Literal["kokoro", "cartesia"] = "kokoro", **kwargs
|
||||
) -> TTSModel:
|
||||
if model == "kokoro":
|
||||
m = KokoroTTSModel()
|
||||
m.tts("Hello, world!")
|
||||
return m
|
||||
elif model == "cartesia":
|
||||
m = CartesiaTTSModel(api_key=kwargs.get("cartesia_api_key", ""))
|
||||
m.tts("Hello, world!")
|
||||
return m
|
||||
else:
|
||||
raise ValueError(f"Invalid model: {model}")
|
||||
|
||||
|
||||
class KokoroFixedBatchSize:
|
||||
@@ -139,3 +151,77 @@ class KokoroTTSModel(TTSModel):
|
||||
yield loop.run_until_complete(iterator.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
|
||||
class CartesiaTTSOptions(TTSOptions):
|
||||
voice: str = "71a7ad14-091c-4e8e-a314-022ece01c121"
|
||||
language: str = "en"
|
||||
emotion: list[str] = []
|
||||
cartesia_version: str = "2024-06-10"
|
||||
model: str = "sonic-2"
|
||||
sample_rate: int = 22_050
|
||||
|
||||
|
||||
class CartesiaTTSModel(TTSModel):
|
||||
def __init__(self, api_key: str):
|
||||
if importlib.util.find_spec("cartesia") is None:
|
||||
raise RuntimeError(
|
||||
"cartesia is not installed. Please install it using 'pip install cartesia'."
|
||||
)
|
||||
from cartesia import AsyncCartesia
|
||||
|
||||
self.client = AsyncCartesia(api_key=api_key)
|
||||
|
||||
async def stream_tts(
|
||||
self, text: str, options: CartesiaTTSOptions | None = None
|
||||
) -> AsyncGenerator[tuple[int, NDArray[np.int16]], None]:
|
||||
options = options or CartesiaTTSOptions()
|
||||
|
||||
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
|
||||
|
||||
for sentence in sentences:
|
||||
if not sentence.strip():
|
||||
continue
|
||||
async for output in async_aggregate_bytes_to_16bit(
|
||||
self.client.tts.bytes(
|
||||
model_id="sonic-2",
|
||||
transcript=sentence,
|
||||
voice={"id": options.voice}, # type: ignore
|
||||
language="en",
|
||||
output_format={
|
||||
"container": "raw",
|
||||
"sample_rate": options.sample_rate,
|
||||
"encoding": "pcm_s16le",
|
||||
},
|
||||
)
|
||||
):
|
||||
yield options.sample_rate, np.frombuffer(output, dtype=np.int16)
|
||||
|
||||
def stream_tts_sync(
|
||||
self, text: str, options: CartesiaTTSOptions | None = None
|
||||
) -> Generator[tuple[int, NDArray[np.int16]], None, None]:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
iterator = self.stream_tts(text, options).__aiter__()
|
||||
while True:
|
||||
try:
|
||||
yield loop.run_until_complete(iterator.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
def tts(
|
||||
self, text: str, options: CartesiaTTSOptions | None = None
|
||||
) -> tuple[int, NDArray[np.int16]]:
|
||||
loop = asyncio.new_event_loop()
|
||||
buffer = np.array([], dtype=np.int16)
|
||||
|
||||
options = options or CartesiaTTSOptions()
|
||||
|
||||
iterator = self.stream_tts(text, options).__aiter__()
|
||||
while True:
|
||||
try:
|
||||
_, chunk = loop.run_until_complete(iterator.__anext__())
|
||||
buffer = np.concatenate([buffer, chunk])
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
return options.sample_rate, buffer
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import pytest
|
||||
from fastrtc.text_to_speech.tts import get_tts_model
|
||||
|
||||
|
||||
def test_tts_long_prompt():
|
||||
model = get_tts_model()
|
||||
@pytest.mark.parametrize("model", ["kokoro"])
|
||||
def test_tts_long_prompt(model):
|
||||
model = get_tts_model(model=model)
|
||||
prompt = "It may be that this communication will be considered as a madman's freak but at any rate it must be admitted that in its clearness and frankness it left nothing to be desired The serious part of it was that the Federal Government had undertaken to treat a sale by auction as a valid concession of these undiscovered territories Opinions on the matter were many Some readers saw in it only one of those prodigious outbursts of American humbug which would exceed the limits of puffism if the depths of human credulity were not unfathomable"
|
||||
|
||||
for i, chunk in enumerate(model.stream_tts_sync(prompt)):
|
||||
print(f"Chunk {i}: {chunk[1].shape}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tts_long_prompt()
|
||||
@@ -39,6 +39,7 @@ class MinimalTestStream(WebRTCConnectionMixin):
|
||||
)
|
||||
self.time_limit = time_limit
|
||||
self.allow_extra_tracks = allow_extra_tracks
|
||||
self.server_rtc_configuration = None
|
||||
|
||||
def mount(self, app: FastAPI, path: str = ""):
|
||||
from fastapi import APIRouter
|
||||
|
||||
Reference in New Issue
Block a user