Fix kokoro batch issue (#128)

* Fix kokoro batch issue

* code

* fix batch size

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
Freddy Boulton
2025-03-05 19:43:02 -05:00
committed by GitHub
parent 6517a93472
commit df0706e048
2 changed files with 58 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
from fastrtc.text_to_speech.tts import get_tts_model
def test_tts_long_prompt():
model = get_tts_model()
prompt = "It may be that this communication will be considered as a madman's freak but at any rate it must be admitted that in its clearness and frankness it left nothing to be desired The serious part of it was that the Federal Government had undertaken to treat a sale by auction as a valid concession of these undiscovered territories Opinions on the matter were many Some readers saw in it only one of those prodigious outbursts of American humbug which would exceed the limits of puffism if the depths of human credulity were not unfathomable"
for i, chunk in enumerate(model.stream_tts_sync(prompt)):
print(f"Chunk {i}: {chunk[1].shape}")
if __name__ == "__main__":
test_tts_long_prompt()

View File

@@ -39,6 +39,48 @@ def get_tts_model(model: Literal["kokoro"] = "kokoro") -> TTSModel:
return m
class KokoroFixedBatchSize:
# Source: https://github.com/thewh1teagle/kokoro-onnx/issues/115#issuecomment-2676625392
def _split_phonemes(self, phonemes: str) -> list[str]:
MAX_PHONEME_LENGTH = 510
max_length = MAX_PHONEME_LENGTH - 1
batched_phonemes = []
while len(phonemes) > max_length:
# Find best split point within limit
split_idx = max_length
# Try to find the last period before max_length
period_idx = phonemes.rfind(".", 0, max_length)
if period_idx != -1:
split_idx = period_idx + 1 # Include period
else:
# Try other punctuation
match = re.search(
r"[!?;,]", phonemes[:max_length][::-1]
) # Search backwards
if match:
split_idx = max_length - match.start()
else:
# Try last space
space_idx = phonemes.rfind(" ", 0, max_length)
if space_idx != -1:
split_idx = space_idx
# If no good split point is found, force split at max_length
chunk = phonemes[:split_idx].strip()
batched_phonemes.append(chunk)
# Move to the next part
phonemes = phonemes[split_idx:].strip()
# Add remaining phonemes
if phonemes:
batched_phonemes.append(phonemes)
return batched_phonemes
class KokoroTTSModel(TTSModel):
def __init__(self):
from kokoro_onnx import Kokoro
@@ -48,6 +90,8 @@ class KokoroTTSModel(TTSModel):
voices_path=hf_hub_download("fastrtc/kokoro-onnx", "voices-v1.0.bin"),
)
self.model._split_phonemes = KokoroFixedBatchSize()._split_phonemes
def tts(
self, text: str, options: KokoroTTSOptions | None = None
) -> tuple[int, NDArray[np.float32]]:
@@ -74,6 +118,7 @@ class KokoroTTSModel(TTSModel):
):
if s_idx != 0 and chunk_idx == 0:
yield chunk[1], np.zeros(chunk[1] // 7, dtype=np.float32)
chunk_idx += 1
yield chunk[1], chunk[0]
def stream_tts_sync(