mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 09:29:23 +08:00
sync code of fastrtc, add text support through datachannel, fix safari connect problem support chat without camera or mic
77 lines
2.2 KiB
Python
77 lines
2.2 KiB
Python
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import Literal, Protocol
|
|
|
|
import click
|
|
import librosa
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
from ..utils import AudioChunk, audio_to_float32
|
|
|
|
curr_dir = Path(__file__).parent
|
|
|
|
|
|
class STTModel(Protocol):
|
|
def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str: ...
|
|
|
|
|
|
class MoonshineSTT(STTModel):
|
|
def __init__(
|
|
self, model: Literal["moonshine/base", "moonshine/tiny"] = "moonshine/base"
|
|
):
|
|
try:
|
|
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
|
|
except (ImportError, ModuleNotFoundError):
|
|
raise ImportError(
|
|
"Install fastrtc[stt] for speech-to-text and stopword detection support."
|
|
)
|
|
|
|
self.model = MoonshineOnnxModel(model_name=model)
|
|
self.tokenizer = load_tokenizer()
|
|
|
|
def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str:
|
|
sr, audio_np = audio # type: ignore
|
|
if audio_np.dtype == np.int16:
|
|
audio_np = audio_to_float32(audio)
|
|
if sr != 16000:
|
|
audio_np: NDArray[np.float32] = librosa.resample(
|
|
audio_np, orig_sr=sr, target_sr=16000
|
|
)
|
|
if audio_np.ndim == 1:
|
|
audio_np = audio_np.reshape(1, -1)
|
|
tokens = self.model.generate(audio_np)
|
|
return self.tokenizer.decode_batch(tokens)[0]
|
|
|
|
|
|
@lru_cache
|
|
def get_stt_model(
|
|
model: Literal["moonshine/base", "moonshine/tiny"] = "moonshine/base",
|
|
) -> STTModel:
|
|
import os
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
m = MoonshineSTT(model)
|
|
from moonshine_onnx import load_audio
|
|
|
|
audio = load_audio(str(curr_dir / "test_file.wav"))
|
|
print(click.style("INFO", fg="green") + ":\t Warming up STT model.")
|
|
|
|
m.stt((16000, audio))
|
|
print(click.style("INFO", fg="green") + ":\t STT model warmed up.")
|
|
return m
|
|
|
|
|
|
def stt_for_chunks(
|
|
stt_model: STTModel,
|
|
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
|
chunks: list[AudioChunk],
|
|
) -> str:
|
|
sr, audio_np = audio
|
|
return " ".join(
|
|
[
|
|
stt_model.stt((sr, audio_np[chunk["start"] : chunk["end"]]))
|
|
for chunk in chunks
|
|
]
|
|
)
|