stt models (#147)

This commit is contained in:
Freddy Boulton
2025-03-07 17:03:11 -05:00
committed by GitHub
parent cbbfa17679
commit 504eb452f0
6 changed files with 55 additions and 64 deletions

View File

@@ -13,7 +13,7 @@ from .reply_on_pause import (
ReplyFnGenerator, ReplyFnGenerator,
ReplyOnPause, ReplyOnPause,
) )
from .speech_to_text import get_stt_model from .speech_to_text import get_stt_model, stt_for_chunks
from .utils import audio_to_float32, create_message from .utils import audio_to_float32, create_message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -105,10 +105,9 @@ class ReplyOnStopWords(ReplyOnPause):
dur_vad, chunks = self.model.vad( dur_vad, chunks = self.model.vad(
(16000, state.post_stop_word_buffer), (16000, state.post_stop_word_buffer),
self.model_options, self.model_options,
return_chunks=True,
) )
text = self.stt_model.stt_for_chunks( text = stt_for_chunks(
(16000, state.post_stop_word_buffer), chunks self.stt_model, (16000, state.post_stop_word_buffer), chunks
) )
logger.debug(f"STT: {text}") logger.debug(f"STT: {text}")
state.stop_word_detected = self.stop_word_detected(text) state.stop_word_detected = self.stop_word_detected(text)

View File

@@ -1,3 +1,3 @@
from .stt_ import MoonshineSTT, get_stt_model from .stt_ import MoonshineSTT, get_stt_model, stt_for_chunks
__all__ = ["get_stt_model", "MoonshineSTT", "get_stt_model"] __all__ = ["get_stt_model", "MoonshineSTT", "get_stt_model", "stt_for_chunks"]

View File

@@ -15,12 +15,6 @@ curr_dir = Path(__file__).parent
class STTModel(Protocol): class STTModel(Protocol):
def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str: ... def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str: ...
def stt_for_chunks(
self,
audio: tuple[int, NDArray[np.int16 | np.float32]],
chunks: list[AudioChunk],
) -> str: ...
class MoonshineSTT(STTModel): class MoonshineSTT(STTModel):
def __init__( def __init__(
@@ -49,19 +43,6 @@ class MoonshineSTT(STTModel):
tokens = self.model.generate(audio_np) tokens = self.model.generate(audio_np)
return self.tokenizer.decode_batch(tokens)[0] return self.tokenizer.decode_batch(tokens)[0]
def stt_for_chunks(
self,
audio: tuple[int, NDArray[np.int16 | np.float32]],
chunks: list[AudioChunk],
) -> str:
sr, audio_np = audio
return " ".join(
[
self.stt((sr, audio_np[chunk["start"] : chunk["end"]]))
for chunk in chunks
]
)
@lru_cache @lru_cache
def get_stt_model( def get_stt_model(
@@ -79,3 +60,17 @@ def get_stt_model(
m.stt((16000, audio)) m.stt((16000, audio))
print(click.style("INFO", fg="green") + ":\t STT model warmed up.") print(click.style("INFO", fg="green") + ":\t STT model warmed up.")
return m return m
def stt_for_chunks(
stt_model: STTModel,
audio: tuple[int, NDArray[np.int16 | np.float32]],
chunks: list[AudioChunk],
) -> str:
sr, audio_np = audio
return " ".join(
[
stt_model.stt((sr, audio_np[chunk["start"] : chunk["end"]]))
for chunk in chunks
]
)

View File

@@ -1,5 +1,4 @@
import fastapi import fastapi
from fastapi.responses import FileResponse
from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions
from fastrtc.utils import audio_to_bytes from fastrtc.utils import audio_to_bytes
from openai import OpenAI from openai import OpenAI
@@ -9,7 +8,6 @@ from fastapi.middleware.cors import CORSMiddleware
from elevenlabs import VoiceSettings, stream from elevenlabs import VoiceSettings, stream
from elevenlabs.client import ElevenLabs from elevenlabs.client import ElevenLabs
import numpy as np import numpy as np
import io
from .env import LLM_API_KEY, ELEVENLABS_API_KEY from .env import LLM_API_KEY, ELEVENLABS_API_KEY
@@ -22,16 +20,14 @@ Begin a conversation with a self-deprecating joke like 'I'm not sure if I'm read
messages = [{"role": "system", "content": sys_prompt}] messages = [{"role": "system", "content": sys_prompt}]
openai_client = OpenAI( openai_client = OpenAI(api_key=LLM_API_KEY)
api_key=LLM_API_KEY
)
elevenlabs_client = ElevenLabs(api_key=ELEVENLABS_API_KEY) elevenlabs_client = ElevenLabs(api_key=ELEVENLABS_API_KEY)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def echo(audio):
def echo(audio):
stt_time = time.time() stt_time = time.time()
logging.info("Performing STT") logging.info("Performing STT")
@@ -60,10 +56,7 @@ def echo(audio):
full_response = "" full_response = ""
response = openai_client.chat.completions.create( response = openai_client.chat.completions.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo", messages=messages, max_tokens=200, stream=True
messages=messages,
max_tokens=200,
stream=True
) )
for chunk in response: for chunk in response:
@@ -77,18 +70,17 @@ def echo(audio):
text=text_stream(), text=text_stream(),
voice="Rachel", # Cassidy is also really good voice="Rachel", # Cassidy is also really good
voice_settings=VoiceSettings( voice_settings=VoiceSettings(
similarity_boost=0.9, similarity_boost=0.9, stability=0.6, style=0.4, speed=1
stability=0.6,
style=0.4,
speed=1
), ),
model="eleven_multilingual_v2", model="eleven_multilingual_v2",
output_format="pcm_24000", output_format="pcm_24000",
stream=True stream=True,
) )
for audio_chunk in audio_stream: for audio_chunk in audio_stream:
audio_array = np.frombuffer(audio_chunk, dtype=np.int16).astype(np.float32) / 32768.0 audio_array = (
np.frombuffer(audio_chunk, dtype=np.int16).astype(np.float32) / 32768.0
)
yield (24000, audio_array) yield (24000, audio_array)
messages.append({"role": "assistant", "content": full_response + " "}) messages.append({"role": "assistant", "content": full_response + " "})
@@ -96,22 +88,25 @@ def echo(audio):
logging.info(f"LLM took {time.time() - llm_time} seconds") logging.info(f"LLM took {time.time() - llm_time} seconds")
stream = Stream(ReplyOnPause(echo, stream = Stream(
algo_options=AlgoOptions( ReplyOnPause(
audio_chunk_duration=0.5, echo,
started_talking_threshold=0.1, algo_options=AlgoOptions(
speech_threshold=0.03 audio_chunk_duration=0.5,
), started_talking_threshold=0.1,
model_options=SileroVadOptions( speech_threshold=0.03,
threshold=0.75, ),
min_speech_duration_ms=250, model_options=SileroVadOptions(
min_silence_duration_ms=1500, threshold=0.75,
speech_pad_ms=400, min_speech_duration_ms=250,
max_speech_duration_s=15 min_silence_duration_ms=1500,
)), speech_pad_ms=400,
modality="audio", max_speech_duration_s=15,
mode="send-receive" ),
) ),
modality="audio",
mode="send-receive",
)
app = fastapi.FastAPI() app = fastapi.FastAPI()
@@ -125,6 +120,7 @@ app.add_middleware(
stream.mount(app) stream.mount(app)
@app.get("/reset") @app.get("/reset")
async def reset(): async def reset():
global messages global messages

View File

@@ -28,6 +28,7 @@ nav:
- Cookbook: cookbook.md - Cookbook: cookbook.md
- Deployment: deployment.md - Deployment: deployment.md
- Advanced Configuration: advanced-configuration.md - Advanced Configuration: advanced-configuration.md
- Speech-to-Text Gallery: speech_to_text_gallery.md
- VAD Gallery: vad_gallery.md - VAD Gallery: vad_gallery.md
- Utils: utils.md - Utils: utils.md
- Frequently Asked Questions: faq.md - Frequently Asked Questions: faq.md

View File

@@ -83,4 +83,4 @@ packages = ["/backend/fastrtc"]
[tool.ruff] [tool.ruff]
target-version = "py310" target-version = "py310"
extend-exclude = ["demo/phonic_chat"] extend-exclude = ["demo/phonic_chat", "demo/nextjs_voice_chat"]