mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
t :# 请为您的变更输入提交说明。以 '#' 开始的行将被忽略,而一个空的提交
This commit is contained in:
3
backend/gradio_webrtc/speech_to_text/__init__.py
Normal file
3
backend/gradio_webrtc/speech_to_text/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .stt_ import get_stt_model, stt, stt_for_chunks
|
||||
|
||||
__all__ = ["stt", "stt_for_chunks", "get_stt_model"]
|
||||
53
backend/gradio_webrtc/speech_to_text/stt_.py
Normal file
53
backend/gradio_webrtc/speech_to_text/stt_.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..utils import AudioChunk
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTModel:
|
||||
encoder: Callable
|
||||
decoder: Callable
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_stt_model() -> STTModel:
|
||||
from silero import silero_stt
|
||||
|
||||
model, decoder, _ = silero_stt(language="en", version="v6", jit_model="jit_xlarge")
|
||||
return STTModel(model, decoder)
|
||||
|
||||
|
||||
def stt(audio: tuple[int, NDArray[np.int16]]) -> str:
|
||||
model = get_stt_model()
|
||||
sr, audio_np = audio
|
||||
if audio_np.dtype != np.float32:
|
||||
print("converting")
|
||||
audio_np = audio_np.astype(np.float32) / 32768.0
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PyTorch is required to run speech-to-text for stopword detection. Run `pip install torch`."
|
||||
)
|
||||
audio_torch = torch.tensor(audio_np, dtype=torch.float32)
|
||||
if audio_torch.ndim == 1:
|
||||
audio_torch = audio_torch.unsqueeze(0)
|
||||
assert audio_torch.ndim == 2, "Audio must have a batch dimension"
|
||||
print("before")
|
||||
res = model.decoder(model.encoder(audio_torch)[0])
|
||||
print("after")
|
||||
return res
|
||||
|
||||
|
||||
def stt_for_chunks(
|
||||
audio: tuple[int, NDArray[np.int16]], chunks: list[AudioChunk]
|
||||
) -> str:
|
||||
sr, audio_np = audio
|
||||
return " ".join(
|
||||
[stt((sr, audio_np[chunk["start"] : chunk["end"]])) for chunk in chunks]
|
||||
)
|
||||
Reference in New Issue
Block a user