mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Fix kokoro batch issue (#128)
* Fix kokoro batch issue * code * fix batch size --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
@@ -39,6 +39,48 @@ def get_tts_model(model: Literal["kokoro"] = "kokoro") -> TTSModel:
|
||||
return m
|
||||
|
||||
|
||||
class KokoroFixedBatchSize:
|
||||
# Source: https://github.com/thewh1teagle/kokoro-onnx/issues/115#issuecomment-2676625392
|
||||
def _split_phonemes(self, phonemes: str) -> list[str]:
|
||||
MAX_PHONEME_LENGTH = 510
|
||||
max_length = MAX_PHONEME_LENGTH - 1
|
||||
batched_phonemes = []
|
||||
while len(phonemes) > max_length:
|
||||
# Find best split point within limit
|
||||
split_idx = max_length
|
||||
|
||||
# Try to find the last period before max_length
|
||||
period_idx = phonemes.rfind(".", 0, max_length)
|
||||
if period_idx != -1:
|
||||
split_idx = period_idx + 1 # Include period
|
||||
|
||||
else:
|
||||
# Try other punctuation
|
||||
match = re.search(
|
||||
r"[!?;,]", phonemes[:max_length][::-1]
|
||||
) # Search backwards
|
||||
if match:
|
||||
split_idx = max_length - match.start()
|
||||
|
||||
else:
|
||||
# Try last space
|
||||
space_idx = phonemes.rfind(" ", 0, max_length)
|
||||
if space_idx != -1:
|
||||
split_idx = space_idx
|
||||
|
||||
# If no good split point is found, force split at max_length
|
||||
chunk = phonemes[:split_idx].strip()
|
||||
batched_phonemes.append(chunk)
|
||||
|
||||
# Move to the next part
|
||||
phonemes = phonemes[split_idx:].strip()
|
||||
|
||||
# Add remaining phonemes
|
||||
if phonemes:
|
||||
batched_phonemes.append(phonemes)
|
||||
return batched_phonemes
|
||||
|
||||
|
||||
class KokoroTTSModel(TTSModel):
|
||||
def __init__(self):
|
||||
from kokoro_onnx import Kokoro
|
||||
@@ -48,6 +90,8 @@ class KokoroTTSModel(TTSModel):
|
||||
voices_path=hf_hub_download("fastrtc/kokoro-onnx", "voices-v1.0.bin"),
|
||||
)
|
||||
|
||||
self.model._split_phonemes = KokoroFixedBatchSize()._split_phonemes
|
||||
|
||||
def tts(
|
||||
self, text: str, options: KokoroTTSOptions | None = None
|
||||
) -> tuple[int, NDArray[np.float32]]:
|
||||
@@ -74,6 +118,7 @@ class KokoroTTSModel(TTSModel):
|
||||
):
|
||||
if s_idx != 0 and chunk_idx == 0:
|
||||
yield chunk[1], np.zeros(chunk[1] // 7, dtype=np.float32)
|
||||
chunk_idx += 1
|
||||
yield chunk[1], chunk[0]
|
||||
|
||||
def stream_tts_sync(
|
||||
|
||||
Reference in New Issue
Block a user