Add ReplyOnStopWords (#35)

* add code

* fix dependencies

* add code:
This commit is contained in:
Freddy Boulton
2024-12-11 18:25:53 -08:00
committed by GitHub
parent b1e4326ae3
commit 6c983482b8
14 changed files with 368 additions and 18 deletions

View File

@@ -1,10 +1,13 @@
import logging
import warnings
from dataclasses import dataclass
from typing import List
from typing import List, Literal, overload
import numpy as np
from huggingface_hub import hf_hub_download
from numpy.typing import NDArray
from ..utils import AudioChunk
logger = logging.getLogger(__name__)
@@ -76,7 +79,7 @@ class SileroVADModel:
return h, c
@staticmethod
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray:
"""Collects and concatenates audio chunks."""
if not chunks:
return np.array([], dtype=np.float32)
@@ -90,7 +93,7 @@ class SileroVADModel:
audio: np.ndarray,
vad_options: SileroVadOptions,
**kwargs,
) -> List[dict]:
) -> List[AudioChunk]:
"""This method is used for splitting long audios into speech chunks using silero VAD.
Args:
@@ -236,15 +239,33 @@ class SileroVADModel:
return speeches
@overload
def vad(
self,
audio_tuple: tuple[int, np.ndarray],
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
) -> float:
return_chunks: Literal[True],
) -> tuple[float, List[AudioChunk]]: ...
@overload
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float: ...
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float | tuple[float, List[AudioChunk]]:
sampling_rate, audio = audio_tuple
logger.debug("VAD audio shape input: %s", audio.shape)
try:
audio = audio.astype(np.float32) / 32768.0
if audio.dtype != np.float32:
audio = audio.astype(np.float32) / 32768.0
sr = 16000
if sr != sampling_rate:
try:
@@ -262,7 +283,8 @@ class SileroVADModel:
audio = self.collect_chunks(audio, speech_chunks)
logger.debug("VAD audio shape: %s", audio.shape)
duration_after_vad = audio.shape[0] / sr
if return_chunks:
return duration_after_vad, speech_chunks
return duration_after_vad
except Exception as e:
import math
@@ -280,7 +302,7 @@ class SileroVADModel:
raise ValueError(
f"Too many dimensions for input audio chunk {len(x.shape)}"
)
if sr / x.shape[1] > 31.25:
if sr / x.shape[1] > 31.25: # type: ignore
raise ValueError("Input audio chunk is too short")
h, c = state