mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
@@ -4,6 +4,8 @@ from .credentials import (
|
||||
get_twilio_turn_credentials,
|
||||
)
|
||||
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions
|
||||
from .reply_on_stopwords import ReplyOnStopWords
|
||||
from .speech_to_text import stt, stt_for_chunks
|
||||
from .utils import AdditionalOutputs, audio_to_bytes, audio_to_file, audio_to_float32
|
||||
from .webrtc import StreamHandler, WebRTC
|
||||
|
||||
@@ -17,7 +19,10 @@ __all__ = [
|
||||
"get_twilio_turn_credentials",
|
||||
"get_turn_credentials",
|
||||
"ReplyOnPause",
|
||||
"ReplyOnStopWords",
|
||||
"SileroVadOptions",
|
||||
"stt",
|
||||
"stt_for_chunks",
|
||||
"StreamHandler",
|
||||
"WebRTC",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from typing import List, Literal, overload
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..utils import AudioChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -76,7 +79,7 @@ class SileroVADModel:
|
||||
return h, c
|
||||
|
||||
@staticmethod
|
||||
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
|
||||
def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray:
|
||||
"""Collects and concatenates audio chunks."""
|
||||
if not chunks:
|
||||
return np.array([], dtype=np.float32)
|
||||
@@ -90,7 +93,7 @@ class SileroVADModel:
|
||||
audio: np.ndarray,
|
||||
vad_options: SileroVadOptions,
|
||||
**kwargs,
|
||||
) -> List[dict]:
|
||||
) -> List[AudioChunk]:
|
||||
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
||||
|
||||
Args:
|
||||
@@ -236,15 +239,33 @@ class SileroVADModel:
|
||||
|
||||
return speeches
|
||||
|
||||
@overload
|
||||
def vad(
|
||||
self,
|
||||
audio_tuple: tuple[int, np.ndarray],
|
||||
audio_tuple: tuple[int, NDArray],
|
||||
vad_parameters: None | SileroVadOptions,
|
||||
) -> float:
|
||||
return_chunks: Literal[True],
|
||||
) -> tuple[float, List[AudioChunk]]: ...
|
||||
|
||||
@overload
|
||||
def vad(
|
||||
self,
|
||||
audio_tuple: tuple[int, NDArray],
|
||||
vad_parameters: None | SileroVadOptions,
|
||||
return_chunks: bool = False,
|
||||
) -> float: ...
|
||||
|
||||
def vad(
|
||||
self,
|
||||
audio_tuple: tuple[int, NDArray],
|
||||
vad_parameters: None | SileroVadOptions,
|
||||
return_chunks: bool = False,
|
||||
) -> float | tuple[float, List[AudioChunk]]:
|
||||
sampling_rate, audio = audio_tuple
|
||||
logger.debug("VAD audio shape input: %s", audio.shape)
|
||||
try:
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
if audio.dtype != np.float32:
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
sr = 16000
|
||||
if sr != sampling_rate:
|
||||
try:
|
||||
@@ -262,7 +283,8 @@ class SileroVADModel:
|
||||
audio = self.collect_chunks(audio, speech_chunks)
|
||||
logger.debug("VAD audio shape: %s", audio.shape)
|
||||
duration_after_vad = audio.shape[0] / sr
|
||||
|
||||
if return_chunks:
|
||||
return duration_after_vad, speech_chunks
|
||||
return duration_after_vad
|
||||
except Exception as e:
|
||||
import math
|
||||
@@ -280,7 +302,7 @@ class SileroVADModel:
|
||||
raise ValueError(
|
||||
f"Too many dimensions for input audio chunk {len(x.shape)}"
|
||||
)
|
||||
if sr / x.shape[1] > 31.25:
|
||||
if sr / x.shape[1] > 31.25: # type: ignore
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
h, c = state
|
||||
|
||||
@@ -117,6 +117,7 @@ class ReplyOnPause(StreamHandler):
|
||||
self.expected_layout,
|
||||
self.output_sample_rate,
|
||||
self.output_frame_size,
|
||||
self.input_sample_rate,
|
||||
)
|
||||
|
||||
def determine_pause(
|
||||
|
||||
146
backend/gradio_webrtc/reply_on_stopwords.py
Normal file
146
backend/gradio_webrtc/reply_on_stopwords.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
from .reply_on_pause import (
|
||||
AlgoOptions,
|
||||
AppState,
|
||||
ReplyFnGenerator,
|
||||
ReplyOnPause,
|
||||
SileroVadOptions,
|
||||
)
|
||||
from .speech_to_text import get_stt_model, stt_for_chunks
|
||||
from .utils import audio_to_float32
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplyOnStopWordsState(AppState):
|
||||
stop_word_detected: bool = False
|
||||
post_stop_word_buffer: np.ndarray | None = None
|
||||
started_talking_pre_stop_word: bool = False
|
||||
|
||||
|
||||
class ReplyOnStopWords(ReplyOnPause):
|
||||
def __init__(
|
||||
self,
|
||||
fn: ReplyFnGenerator,
|
||||
stop_words: list[str],
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: SileroVadOptions | None = None,
|
||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 480,
|
||||
input_sample_rate: int = 48000,
|
||||
):
|
||||
super().__init__(
|
||||
fn,
|
||||
algo_options=algo_options,
|
||||
model_options=model_options,
|
||||
expected_layout=expected_layout,
|
||||
output_sample_rate=output_sample_rate,
|
||||
output_frame_size=output_frame_size,
|
||||
input_sample_rate=input_sample_rate,
|
||||
)
|
||||
self.stop_words = stop_words
|
||||
self.state = ReplyOnStopWordsState()
|
||||
# Download Model
|
||||
get_stt_model()
|
||||
|
||||
def stop_word_detected(self, text: str) -> bool:
|
||||
for stop_word in self.stop_words:
|
||||
stop_word = stop_word.lower().strip().split(" ")
|
||||
if bool(
|
||||
re.search(r"\b" + r"\s+".join(map(re.escape, stop_word)) + r"\b", text)
|
||||
):
|
||||
logger.debug("Stop word detected: %s", stop_word)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _send_stopword(
|
||||
self,
|
||||
):
|
||||
if self.channel:
|
||||
self.channel.send("stopword")
|
||||
logger.debug("Sent stopword")
|
||||
|
||||
def send_stopword(self):
|
||||
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
|
||||
|
||||
def determine_pause(
|
||||
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
|
||||
) -> bool:
|
||||
"""Take in the stream, determine if a pause happened"""
|
||||
duration = len(audio) / sampling_rate
|
||||
|
||||
if duration >= self.algo_options.audio_chunk_duration:
|
||||
if not state.stop_word_detected:
|
||||
audio_f32 = audio_to_float32((sampling_rate, audio))
|
||||
audio_rs = librosa.resample(
|
||||
audio_f32, orig_sr=sampling_rate, target_sr=16000
|
||||
)
|
||||
if state.post_stop_word_buffer is None:
|
||||
state.post_stop_word_buffer = audio_rs
|
||||
else:
|
||||
state.post_stop_word_buffer = np.concatenate(
|
||||
(state.post_stop_word_buffer, audio_rs)
|
||||
)
|
||||
if len(state.post_stop_word_buffer) / 16000 > 2:
|
||||
state.post_stop_word_buffer = state.post_stop_word_buffer[-32000:]
|
||||
dur_vad, chunks = self.model.vad(
|
||||
(16000, state.post_stop_word_buffer),
|
||||
self.model_options,
|
||||
return_chunks=True,
|
||||
)
|
||||
text = stt_for_chunks((16000, state.post_stop_word_buffer), chunks)
|
||||
logger.debug(f"STT: {text}")
|
||||
state.stop_word_detected = self.stop_word_detected(text)
|
||||
if state.stop_word_detected:
|
||||
logger.debug("Stop word detected")
|
||||
self.send_stopword()
|
||||
state.buffer = None
|
||||
else:
|
||||
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
|
||||
logger.debug("VAD duration: %s", dur_vad)
|
||||
if (
|
||||
dur_vad > self.algo_options.started_talking_threshold
|
||||
and not state.started_talking
|
||||
and state.stop_word_detected
|
||||
):
|
||||
state.started_talking = True
|
||||
logger.debug("Started talking")
|
||||
if state.started_talking:
|
||||
if state.stream is None:
|
||||
state.stream = audio
|
||||
else:
|
||||
state.stream = np.concatenate((state.stream, audio))
|
||||
state.buffer = None
|
||||
if (
|
||||
dur_vad < self.algo_options.speech_threshold
|
||||
and state.started_talking
|
||||
and state.stop_word_detected
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset(self):
|
||||
self.args_set.clear()
|
||||
self.generator = None
|
||||
self.event.clear()
|
||||
self.state = ReplyOnStopWordsState()
|
||||
|
||||
def copy(self):
|
||||
return ReplyOnStopWords(
|
||||
self.fn,
|
||||
self.stop_words,
|
||||
self.algo_options,
|
||||
self.model_options,
|
||||
self.expected_layout,
|
||||
self.output_sample_rate,
|
||||
self.output_frame_size,
|
||||
self.input_sample_rate,
|
||||
)
|
||||
3
backend/gradio_webrtc/speech_to_text/__init__.py
Normal file
3
backend/gradio_webrtc/speech_to_text/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .stt_ import get_stt_model, stt, stt_for_chunks
|
||||
|
||||
__all__ = ["stt", "stt_for_chunks", "get_stt_model"]
|
||||
52
backend/gradio_webrtc/speech_to_text/stt_.py
Normal file
52
backend/gradio_webrtc/speech_to_text/stt_.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from silero import silero_stt
|
||||
|
||||
from ..utils import AudioChunk
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTModel:
|
||||
encoder: Callable
|
||||
decoder: Callable
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_stt_model() -> STTModel:
|
||||
model, decoder, _ = silero_stt(language="en", version="v6", jit_model="jit_xlarge")
|
||||
return STTModel(model, decoder)
|
||||
|
||||
|
||||
def stt(audio: tuple[int, NDArray[np.int16]]) -> str:
|
||||
model = get_stt_model()
|
||||
sr, audio_np = audio
|
||||
if audio_np.dtype != np.float32:
|
||||
print("converting")
|
||||
audio_np = audio_np.astype(np.float32) / 32768.0
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PyTorch is required to run speech-to-text for stopword detection. Run `pip install torch`."
|
||||
)
|
||||
audio_torch = torch.tensor(audio_np, dtype=torch.float32)
|
||||
if audio_torch.ndim == 1:
|
||||
audio_torch = audio_torch.unsqueeze(0)
|
||||
assert audio_torch.ndim == 2, "Audio must have a batch dimension"
|
||||
print("before")
|
||||
res = model.decoder(model.encoder(audio_torch)[0])
|
||||
print("after")
|
||||
return res
|
||||
|
||||
|
||||
def stt_for_chunks(
|
||||
audio: tuple[int, NDArray[np.int16]], chunks: list[AudioChunk]
|
||||
) -> str:
|
||||
sr, audio_np = audio
|
||||
return " ".join(
|
||||
[stt((sr, audio_np[chunk["start"] : chunk["end"]])) for chunk in chunks]
|
||||
)
|
||||
@@ -3,7 +3,7 @@ import fractions
|
||||
import io
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Any, Callable, Protocol, cast
|
||||
from typing import Any, Callable, Protocol, TypedDict, cast
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
@@ -15,6 +15,11 @@ logger = logging.getLogger(__name__)
|
||||
AUDIO_PTIME = 0.020
|
||||
|
||||
|
||||
class AudioChunk(TypedDict):
|
||||
start: int
|
||||
end: int
|
||||
|
||||
|
||||
class AdditionalOutputs:
|
||||
def __init__(self, *args) -> None:
|
||||
self.args = args
|
||||
|
||||
Reference in New Issue
Block a user