mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add first-class support for Cartesia text-to-speech (#298)
* Demo * patient intake * cartesia * Add cartesia * Fix * lint * Move test * Fix * Fix * Fix * Fix
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import re
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from dataclasses import dataclass
|
||||
@@ -9,6 +10,8 @@ import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from fastrtc.utils import async_aggregate_bytes_to_16bit
|
||||
|
||||
|
||||
class TTSOptions:
|
||||
pass
|
||||
@@ -20,15 +23,15 @@ T = TypeVar("T", bound=TTSOptions, contravariant=True)
|
||||
class TTSModel(Protocol[T]):
|
||||
def tts(
|
||||
self, text: str, options: T | None = None
|
||||
) -> tuple[int, NDArray[np.float32]]: ...
|
||||
) -> tuple[int, NDArray[np.float32] | NDArray[np.int16]]: ...
|
||||
|
||||
def stream_tts(
|
||||
self, text: str, options: T | None = None
|
||||
) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]: ...
|
||||
) -> AsyncGenerator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None]: ...
|
||||
|
||||
def stream_tts_sync(
|
||||
self, text: str, options: T | None = None
|
||||
) -> Generator[tuple[int, NDArray[np.float32]], None, None]: ...
|
||||
) -> Generator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None, None]: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,10 +42,19 @@ class KokoroTTSOptions(TTSOptions):
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_tts_model(model: Literal["kokoro"] = "kokoro") -> TTSModel:
|
||||
m = KokoroTTSModel()
|
||||
m.tts("Hello, world!")
|
||||
return m
|
||||
def get_tts_model(
|
||||
model: Literal["kokoro", "cartesia"] = "kokoro", **kwargs
|
||||
) -> TTSModel:
|
||||
if model == "kokoro":
|
||||
m = KokoroTTSModel()
|
||||
m.tts("Hello, world!")
|
||||
return m
|
||||
elif model == "cartesia":
|
||||
m = CartesiaTTSModel(api_key=kwargs.get("cartesia_api_key", ""))
|
||||
m.tts("Hello, world!")
|
||||
return m
|
||||
else:
|
||||
raise ValueError(f"Invalid model: {model}")
|
||||
|
||||
|
||||
class KokoroFixedBatchSize:
|
||||
@@ -139,3 +151,77 @@ class KokoroTTSModel(TTSModel):
|
||||
yield loop.run_until_complete(iterator.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
|
||||
class CartesiaTTSOptions(TTSOptions):
|
||||
voice: str = "71a7ad14-091c-4e8e-a314-022ece01c121"
|
||||
language: str = "en"
|
||||
emotion: list[str] = []
|
||||
cartesia_version: str = "2024-06-10"
|
||||
model: str = "sonic-2"
|
||||
sample_rate: int = 22_050
|
||||
|
||||
|
||||
class CartesiaTTSModel(TTSModel):
|
||||
def __init__(self, api_key: str):
|
||||
if importlib.util.find_spec("cartesia") is None:
|
||||
raise RuntimeError(
|
||||
"cartesia is not installed. Please install it using 'pip install cartesia'."
|
||||
)
|
||||
from cartesia import AsyncCartesia
|
||||
|
||||
self.client = AsyncCartesia(api_key=api_key)
|
||||
|
||||
async def stream_tts(
|
||||
self, text: str, options: CartesiaTTSOptions | None = None
|
||||
) -> AsyncGenerator[tuple[int, NDArray[np.int16]], None]:
|
||||
options = options or CartesiaTTSOptions()
|
||||
|
||||
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
|
||||
|
||||
for sentence in sentences:
|
||||
if not sentence.strip():
|
||||
continue
|
||||
async for output in async_aggregate_bytes_to_16bit(
|
||||
self.client.tts.bytes(
|
||||
model_id="sonic-2",
|
||||
transcript=sentence,
|
||||
voice={"id": options.voice}, # type: ignore
|
||||
language="en",
|
||||
output_format={
|
||||
"container": "raw",
|
||||
"sample_rate": options.sample_rate,
|
||||
"encoding": "pcm_s16le",
|
||||
},
|
||||
)
|
||||
):
|
||||
yield options.sample_rate, np.frombuffer(output, dtype=np.int16)
|
||||
|
||||
def stream_tts_sync(
|
||||
self, text: str, options: CartesiaTTSOptions | None = None
|
||||
) -> Generator[tuple[int, NDArray[np.int16]], None, None]:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
iterator = self.stream_tts(text, options).__aiter__()
|
||||
while True:
|
||||
try:
|
||||
yield loop.run_until_complete(iterator.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
def tts(
|
||||
self, text: str, options: CartesiaTTSOptions | None = None
|
||||
) -> tuple[int, NDArray[np.int16]]:
|
||||
loop = asyncio.new_event_loop()
|
||||
buffer = np.array([], dtype=np.int16)
|
||||
|
||||
options = options or CartesiaTTSOptions()
|
||||
|
||||
iterator = self.stream_tts(text, options).__aiter__()
|
||||
while True:
|
||||
try:
|
||||
_, chunk = loop.run_until_complete(iterator.__anext__())
|
||||
buffer = np.concatenate([buffer, chunk])
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
return options.sample_rate, buffer
|
||||
|
||||
Reference in New Issue
Block a user