diff --git a/backend/fastrtc/text_to_speech/test_tts.py b/backend/fastrtc/text_to_speech/test_tts.py new file mode 100644 index 0000000..e3abaf7 --- /dev/null +++ b/backend/fastrtc/text_to_speech/test_tts.py @@ -0,0 +1,13 @@ +from fastrtc.text_to_speech.tts import get_tts_model + + +def test_tts_long_prompt(): + model = get_tts_model() + prompt = "It may be that this communication will be considered as a madman's freak but at any rate it must be admitted that in its clearness and frankness it left nothing to be desired The serious part of it was that the Federal Government had undertaken to treat a sale by auction as a valid concession of these undiscovered territories Opinions on the matter were many Some readers saw in it only one of those prodigious outbursts of American humbug which would exceed the limits of puffism if the depths of human credulity were not unfathomable" + + for i, chunk in enumerate(model.stream_tts_sync(prompt)): + print(f"Chunk {i}: {chunk[1].shape}") + + +if __name__ == "__main__": + test_tts_long_prompt() diff --git a/backend/fastrtc/text_to_speech/tts.py b/backend/fastrtc/text_to_speech/tts.py index 318a7ae..bee94e3 100644 --- a/backend/fastrtc/text_to_speech/tts.py +++ b/backend/fastrtc/text_to_speech/tts.py @@ -39,6 +39,48 @@ def get_tts_model(model: Literal["kokoro"] = "kokoro") -> TTSModel: return m +class KokoroFixedBatchSize: + # Source: https://github.com/thewh1teagle/kokoro-onnx/issues/115#issuecomment-2676625392 + def _split_phonemes(self, phonemes: str) -> list[str]: + MAX_PHONEME_LENGTH = 510 + max_length = MAX_PHONEME_LENGTH - 1 + batched_phonemes = [] + while len(phonemes) > max_length: + # Find best split point within limit + split_idx = max_length + + # Try to find the last period before max_length + period_idx = phonemes.rfind(".", 0, max_length) + if period_idx != -1: + split_idx = period_idx + 1 # Include period + + else: + # Try other punctuation + match = re.search( + r"[!?;,]", phonemes[:max_length][::-1] + ) # Search backwards + if match: + split_idx = max_length - match.start() + + else: + # Try last space + space_idx = phonemes.rfind(" ", 0, max_length) + if space_idx != -1: + split_idx = space_idx + + # If no good split point is found, force split at max_length + chunk = phonemes[:split_idx].strip() + batched_phonemes.append(chunk) + + # Move to the next part + phonemes = phonemes[split_idx:].strip() + + # Add remaining phonemes + if phonemes: + batched_phonemes.append(phonemes) + return batched_phonemes + + class KokoroTTSModel(TTSModel): def __init__(self): from kokoro_onnx import Kokoro @@ -48,6 +90,8 @@ class KokoroTTSModel(TTSModel): voices_path=hf_hub_download("fastrtc/kokoro-onnx", "voices-v1.0.bin"), ) + self.model._split_phonemes = KokoroFixedBatchSize()._split_phonemes + def tts( self, text: str, options: KokoroTTSOptions | None = None ) -> tuple[int, NDArray[np.float32]]: @@ -74,6 +118,7 @@ class KokoroTTSModel(TTSModel): ): if s_idx != 0 and chunk_idx == 0: yield chunk[1], np.zeros(chunk[1] // 7, dtype=np.float32) + chunk_idx += 1 yield chunk[1], chunk[0] def stream_tts_sync(