Files
gradio-webrtc/backend/fastrtc/speech_to_text/stt_.py
Freddy Boulton 853d6a06b5 Rebrand to FastRTC (#60)
* Add code

* add code

* add code

* Rename messages

* rename

* add code

* Add demo

* docs + demos + bug fixes

* add code

* styles

* user guide

* Styles

* Add code

* misc docs updates

* print nit

* whisper + pr

* url for images

* whsiper update

* Fix bugs

* remove demo files

* version number

* Fix pypi readme

* Fix

* demos

* Add llama code editor

* Update llama code editor and object detection cookbook

* Add more cookbook demos

* add code

* Fix links for PR deploys

* add code

* Fix the install

* add tts

* TTS docs

* Typo

* Pending bubbles for reply on pause

* Stream redesign (#63)

* better error handling

* Websocket error handling

* add code

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>

* remove docs from dist

* Some docs typos

* more typos

* upload changes + docs

* docs

* better phone

* update docs

* add code

* Make demos better

* fix docs + websocket start_up

* remove mention of FastAPI app

* fastphone tweaks

* add code

* ReplyOnStopWord fixes

* Fix cookbook

* Fix pypi readme

* add code

* bump versions

* sambanova cookbook

* Fix tags

* Llm voice chat

* kyutai tag

* Add error message to all index.html

* STT module uses Moonshine

* Not required from typing extensions

* fix llm voice chat

* Add vpn warning

* demo fixes

* demos

* Add more ui args and gemini audio-video

* update cookbook

* version 9

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
2025-02-24 01:13:42 -05:00

82 lines
2.4 KiB
Python

from functools import lru_cache
from pathlib import Path
from typing import Literal, Protocol
import click
import librosa
import numpy as np
from numpy.typing import NDArray
from ..utils import AudioChunk, audio_to_float32
curr_dir = Path(__file__).parent
class STTModel(Protocol):
def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str: ...
def stt_for_chunks(
self,
audio: tuple[int, NDArray[np.int16 | np.float32]],
chunks: list[AudioChunk],
) -> str: ...
class MoonshineSTT(STTModel):
def __init__(
self, model: Literal["moonshine/base", "moonshine/tiny"] = "moonshine/base"
):
try:
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Install fastrtc[stt] for speech-to-text and stopword detection support."
)
self.model = MoonshineOnnxModel(model_name=model)
self.tokenizer = load_tokenizer()
def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str:
sr, audio_np = audio # type: ignore
if audio_np.dtype == np.int16:
audio_np = audio_to_float32(audio)
if sr != 16000:
audio_np: NDArray[np.float32] = librosa.resample(
audio_np, orig_sr=sr, target_sr=16000
)
if audio_np.ndim == 1:
audio_np = audio_np.reshape(1, -1)
tokens = self.model.generate(audio_np)
return self.tokenizer.decode_batch(tokens)[0]
def stt_for_chunks(
self,
audio: tuple[int, NDArray[np.int16 | np.float32]],
chunks: list[AudioChunk],
) -> str:
sr, audio_np = audio
return " ".join(
[
self.stt((sr, audio_np[chunk["start"] : chunk["end"]]))
for chunk in chunks
]
)
@lru_cache
def get_stt_model(
model: Literal["moonshine/base", "moonshine/tiny"] = "moonshine/base",
) -> STTModel:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
m = MoonshineSTT(model)
from moonshine_onnx import load_audio
audio = load_audio(str(curr_dir / "test_file.wav"))
print(click.style("INFO", fg="green") + ":\t Warming up STT model.")
m.stt((16000, audio))
print(click.style("INFO", fg="green") + ":\t STT model warmed up.")
return m