Add ReplyOnStopWords (#35)

* add code

* fix dependencies

* add code:
This commit is contained in:
Freddy Boulton
2024-12-11 18:25:53 -08:00
committed by GitHub
parent b1e4326ae3
commit 6c983482b8
14 changed files with 368 additions and 18 deletions

View File

@@ -4,6 +4,8 @@ from .credentials import (
get_twilio_turn_credentials,
)
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions
from .reply_on_stopwords import ReplyOnStopWords
from .speech_to_text import stt, stt_for_chunks
from .utils import AdditionalOutputs, audio_to_bytes, audio_to_file, audio_to_float32
from .webrtc import StreamHandler, WebRTC
@@ -17,7 +19,10 @@ __all__ = [
"get_twilio_turn_credentials",
"get_turn_credentials",
"ReplyOnPause",
"ReplyOnStopWords",
"SileroVadOptions",
"stt",
"stt_for_chunks",
"StreamHandler",
"WebRTC",
]

View File

@@ -1,10 +1,13 @@
import logging
import warnings
from dataclasses import dataclass
from typing import List
from typing import List, Literal, overload
import numpy as np
from huggingface_hub import hf_hub_download
from numpy.typing import NDArray
from ..utils import AudioChunk
logger = logging.getLogger(__name__)
@@ -76,7 +79,7 @@ class SileroVADModel:
return h, c
@staticmethod
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray:
"""Collects and concatenates audio chunks."""
if not chunks:
return np.array([], dtype=np.float32)
@@ -90,7 +93,7 @@ class SileroVADModel:
audio: np.ndarray,
vad_options: SileroVadOptions,
**kwargs,
) -> List[dict]:
) -> List[AudioChunk]:
"""This method is used for splitting long audios into speech chunks using silero VAD.
Args:
@@ -236,15 +239,33 @@ class SileroVADModel:
return speeches
@overload
def vad(
self,
audio_tuple: tuple[int, np.ndarray],
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
) -> float:
return_chunks: Literal[True],
) -> tuple[float, List[AudioChunk]]: ...
@overload
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float: ...
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float | tuple[float, List[AudioChunk]]:
sampling_rate, audio = audio_tuple
logger.debug("VAD audio shape input: %s", audio.shape)
try:
audio = audio.astype(np.float32) / 32768.0
if audio.dtype != np.float32:
audio = audio.astype(np.float32) / 32768.0
sr = 16000
if sr != sampling_rate:
try:
@@ -262,7 +283,8 @@ class SileroVADModel:
audio = self.collect_chunks(audio, speech_chunks)
logger.debug("VAD audio shape: %s", audio.shape)
duration_after_vad = audio.shape[0] / sr
if return_chunks:
return duration_after_vad, speech_chunks
return duration_after_vad
except Exception as e:
import math
@@ -280,7 +302,7 @@ class SileroVADModel:
raise ValueError(
f"Too many dimensions for input audio chunk {len(x.shape)}"
)
if sr / x.shape[1] > 31.25:
if sr / x.shape[1] > 31.25: # type: ignore
raise ValueError("Input audio chunk is too short")
h, c = state

View File

@@ -117,6 +117,7 @@ class ReplyOnPause(StreamHandler):
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
self.input_sample_rate,
)
def determine_pause(

View File

@@ -0,0 +1,146 @@
import asyncio
import logging
import re
from typing import Literal
import librosa
import numpy as np
from .reply_on_pause import (
AlgoOptions,
AppState,
ReplyFnGenerator,
ReplyOnPause,
SileroVadOptions,
)
from .speech_to_text import get_stt_model, stt_for_chunks
from .utils import audio_to_float32
logger = logging.getLogger(__name__)
class ReplyOnStopWordsState(AppState):
stop_word_detected: bool = False
post_stop_word_buffer: np.ndarray | None = None
started_talking_pre_stop_word: bool = False
class ReplyOnStopWords(ReplyOnPause):
def __init__(
self,
fn: ReplyFnGenerator,
stop_words: list[str],
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
input_sample_rate: int = 48000,
):
super().__init__(
fn,
algo_options=algo_options,
model_options=model_options,
expected_layout=expected_layout,
output_sample_rate=output_sample_rate,
output_frame_size=output_frame_size,
input_sample_rate=input_sample_rate,
)
self.stop_words = stop_words
self.state = ReplyOnStopWordsState()
# Download Model
get_stt_model()
def stop_word_detected(self, text: str) -> bool:
for stop_word in self.stop_words:
stop_word = stop_word.lower().strip().split(" ")
if bool(
re.search(r"\b" + r"\s+".join(map(re.escape, stop_word)) + r"\b", text)
):
logger.debug("Stop word detected: %s", stop_word)
return True
return False
async def _send_stopword(
self,
):
if self.channel:
self.channel.send("stopword")
logger.debug("Sent stopword")
def send_stopword(self):
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
def determine_pause(
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
) -> bool:
"""Take in the stream, determine if a pause happened"""
duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration:
if not state.stop_word_detected:
audio_f32 = audio_to_float32((sampling_rate, audio))
audio_rs = librosa.resample(
audio_f32, orig_sr=sampling_rate, target_sr=16000
)
if state.post_stop_word_buffer is None:
state.post_stop_word_buffer = audio_rs
else:
state.post_stop_word_buffer = np.concatenate(
(state.post_stop_word_buffer, audio_rs)
)
if len(state.post_stop_word_buffer) / 16000 > 2:
state.post_stop_word_buffer = state.post_stop_word_buffer[-32000:]
dur_vad, chunks = self.model.vad(
(16000, state.post_stop_word_buffer),
self.model_options,
return_chunks=True,
)
text = stt_for_chunks((16000, state.post_stop_word_buffer), chunks)
logger.debug(f"STT: {text}")
state.stop_word_detected = self.stop_word_detected(text)
if state.stop_word_detected:
logger.debug("Stop word detected")
self.send_stopword()
state.buffer = None
else:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad)
if (
dur_vad > self.algo_options.started_talking_threshold
and not state.started_talking
and state.stop_word_detected
):
state.started_talking = True
logger.debug("Started talking")
if state.started_talking:
if state.stream is None:
state.stream = audio
else:
state.stream = np.concatenate((state.stream, audio))
state.buffer = None
if (
dur_vad < self.algo_options.speech_threshold
and state.started_talking
and state.stop_word_detected
):
return True
return False
def reset(self):
self.args_set.clear()
self.generator = None
self.event.clear()
self.state = ReplyOnStopWordsState()
def copy(self):
return ReplyOnStopWords(
self.fn,
self.stop_words,
self.algo_options,
self.model_options,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
self.input_sample_rate,
)

View File

@@ -0,0 +1,3 @@
from .stt_ import get_stt_model, stt, stt_for_chunks
__all__ = ["stt", "stt_for_chunks", "get_stt_model"]

View File

@@ -0,0 +1,52 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Callable
import numpy as np
from numpy.typing import NDArray
from silero import silero_stt
from ..utils import AudioChunk
@dataclass
class STTModel:
encoder: Callable
decoder: Callable
@lru_cache
def get_stt_model() -> STTModel:
model, decoder, _ = silero_stt(language="en", version="v6", jit_model="jit_xlarge")
return STTModel(model, decoder)
def stt(audio: tuple[int, NDArray[np.int16]]) -> str:
model = get_stt_model()
sr, audio_np = audio
if audio_np.dtype != np.float32:
print("converting")
audio_np = audio_np.astype(np.float32) / 32768.0
try:
import torch
except ImportError:
raise ImportError(
"PyTorch is required to run speech-to-text for stopword detection. Run `pip install torch`."
)
audio_torch = torch.tensor(audio_np, dtype=torch.float32)
if audio_torch.ndim == 1:
audio_torch = audio_torch.unsqueeze(0)
assert audio_torch.ndim == 2, "Audio must have a batch dimension"
print("before")
res = model.decoder(model.encoder(audio_torch)[0])
print("after")
return res
def stt_for_chunks(
audio: tuple[int, NDArray[np.int16]], chunks: list[AudioChunk]
) -> str:
sr, audio_np = audio
return " ".join(
[stt((sr, audio_np[chunk["start"] : chunk["end"]])) for chunk in chunks]
)

View File

@@ -3,7 +3,7 @@ import fractions
import io
import logging
import tempfile
from typing import Any, Callable, Protocol, cast
from typing import Any, Callable, Protocol, TypedDict, cast
import av
import numpy as np
@@ -15,6 +15,11 @@ logger = logging.getLogger(__name__)
AUDIO_PTIME = 0.020
class AudioChunk(TypedDict):
start: int
end: int
class AdditionalOutputs:
def __init__(self, *args) -> None:
self.args = args