diff --git a/backend/fastrtc/__init__.py b/backend/fastrtc/__init__.py index 9b42ae5..4c646e5 100644 --- a/backend/fastrtc/__init__.py +++ b/backend/fastrtc/__init__.py @@ -3,7 +3,13 @@ from .credentials import ( get_turn_credentials, get_twilio_turn_credentials, ) -from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions +from .pause_detection import ( + ModelOptions, + PauseDetectionModel, + SileroVadOptions, + get_silero_model, +) +from .reply_on_pause import AlgoOptions, ReplyOnPause from .reply_on_stopwords import ReplyOnStopWords from .speech_to_text import MoonshineSTT, get_stt_model from .stream import Stream, UIArgs @@ -63,4 +69,8 @@ __all__ = [ "KokoroTTSOptions", "wait_for_item", "UIArgs", + "ModelOptions", + "PauseDetectionModel", + "get_silero_model", + "SileroVadOptions", ] diff --git a/backend/fastrtc/pause_detection/__init__.py b/backend/fastrtc/pause_detection/__init__.py index e4874b7..ab18632 100644 --- a/backend/fastrtc/pause_detection/__init__.py +++ b/backend/fastrtc/pause_detection/__init__.py @@ -1,3 +1,10 @@ -from .vad import SileroVADModel, SileroVadOptions +from .protocol import ModelOptions, PauseDetectionModel +from .silero import SileroVADModel, SileroVadOptions, get_silero_model -__all__ = ["SileroVADModel", "SileroVadOptions"] +__all__ = [ + "SileroVADModel", + "SileroVadOptions", + "PauseDetectionModel", + "ModelOptions", + "get_silero_model", +] diff --git a/backend/fastrtc/pause_detection/protocol.py b/backend/fastrtc/pause_detection/protocol.py new file mode 100644 index 0000000..e73859a --- /dev/null +++ b/backend/fastrtc/pause_detection/protocol.py @@ -0,0 +1,20 @@ +from typing import Any, Protocol, TypeAlias + +import numpy as np +from numpy.typing import NDArray + +from ..utils import AudioChunk + +ModelOptions: TypeAlias = Any + + +class PauseDetectionModel(Protocol): + def vad( + self, + audio: tuple[int, NDArray[np.int16] | NDArray[np.float32]], + options: ModelOptions, + ) -> tuple[float, list[AudioChunk]]: ... + + def warmup( + self, + ) -> None: ... diff --git a/backend/fastrtc/pause_detection/vad.py b/backend/fastrtc/pause_detection/silero.py similarity index 85% rename from backend/fastrtc/pause_detection/vad.py rename to backend/fastrtc/pause_detection/silero.py index bf4bb1e..196a27a 100644 --- a/backend/fastrtc/pause_detection/vad.py +++ b/backend/fastrtc/pause_detection/silero.py @@ -1,13 +1,16 @@ import logging import warnings from dataclasses import dataclass -from typing import List, Literal, overload +from functools import lru_cache +from typing import List +import click import numpy as np from huggingface_hub import hf_hub_download from numpy.typing import NDArray from ..utils import AudioChunk +from .protocol import PauseDetectionModel logger = logging.getLogger(__name__) @@ -15,6 +18,26 @@ logger = logging.getLogger(__name__) # The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py +@lru_cache +def get_silero_model() -> PauseDetectionModel: + """Returns the VAD model instance and warms it up with dummy data.""" + # Warm up the model with dummy data + + try: + import importlib.util + + mod = importlib.util.find_spec("onnxruntime") + if mod is None: + raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause") + except (ValueError, ModuleNotFoundError): + raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause") + model = SileroVADModel() + print(click.style("INFO", fg="green") + ":\t Warming up VAD model.") + model.warmup() + print(click.style("INFO", fg="green") + ":\t VAD model warmed up.") + return model + + @dataclass class SileroVadOptions: """VAD options. @@ -239,33 +262,21 @@ class SileroVADModel: return speeches - @overload - def vad( - self, - audio_tuple: tuple[int, NDArray], - vad_parameters: None | SileroVadOptions, - return_chunks: Literal[True], - ) -> tuple[float, List[AudioChunk]]: ... - - @overload - def vad( - self, - audio_tuple: tuple[int, NDArray], - vad_parameters: None | SileroVadOptions, - return_chunks: bool = False, - ) -> float: ... + def warmup(self): + for _ in range(10): + dummy_audio = np.zeros(102400, dtype=np.float32) + self.vad((24000, dummy_audio), None) def vad( self, - audio_tuple: tuple[int, NDArray], - vad_parameters: None | SileroVadOptions, - return_chunks: bool = False, - ) -> float | tuple[float, List[AudioChunk]]: - sampling_rate, audio = audio_tuple - logger.debug("VAD audio shape input: %s", audio.shape) + audio: tuple[int, NDArray[np.float32] | NDArray[np.int16]], + options: None | SileroVadOptions, + ) -> tuple[float, list[AudioChunk]]: + sampling_rate, audio_ = audio + logger.debug("VAD audio shape input: %s", audio_.shape) try: - if audio.dtype != np.float32: - audio = audio.astype(np.float32) / 32768.0 + if audio_.dtype != np.float32: + audio_ = audio_.astype(np.float32) / 32768.0 sr = 16000 if sr != sampling_rate: try: @@ -274,18 +285,16 @@ class SileroVADModel: raise RuntimeError( "Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz" ) from e - audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) + audio_ = librosa.resample(audio_, orig_sr=sampling_rate, target_sr=sr) - if not vad_parameters: + if not options: vad_parameters = SileroVadOptions() - speech_chunks = self.get_speech_timestamps(audio, vad_parameters) + speech_chunks = self.get_speech_timestamps(audio_, vad_parameters) logger.debug("VAD speech chunks: %s", speech_chunks) - audio = self.collect_chunks(audio, speech_chunks) - logger.debug("VAD audio shape: %s", audio.shape) - duration_after_vad = audio.shape[0] / sr - if return_chunks: - return duration_after_vad, speech_chunks - return duration_after_vad + audio_ = self.collect_chunks(audio_, speech_chunks) + logger.debug("VAD audio shape: %s", audio_.shape) + duration_after_vad = audio_.shape[0] / sr + return duration_after_vad, speech_chunks except Exception as e: import math import traceback @@ -293,7 +302,7 @@ class SileroVADModel: logger.debug("VAD Exception: %s", str(e)) exec = traceback.format_exc() logger.debug("traceback %s", exec) - return math.inf + return math.inf, [] def __call__(self, x, state, sr: int): if len(x.shape) == 1: diff --git a/backend/fastrtc/reply_on_pause.py b/backend/fastrtc/reply_on_pause.py index 9b6405c..276f171 100644 --- a/backend/fastrtc/reply_on_pause.py +++ b/backend/fastrtc/reply_on_pause.py @@ -1,44 +1,19 @@ import asyncio import inspect from dataclasses import dataclass, field -from functools import lru_cache from logging import getLogger from threading import Event from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast -import click import numpy as np from numpy.typing import NDArray -from .pause_detection import SileroVADModel, SileroVadOptions +from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model from .tracks import EmitType, StreamHandler from .utils import create_message, split_output logger = getLogger(__name__) -counter = 0 - - -@lru_cache -def get_vad_model() -> SileroVADModel: - """Returns the VAD model instance and warms it up with dummy data.""" - try: - import importlib.util - - mod = importlib.util.find_spec("onnxruntime") - if mod is None: - raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause") - except (ValueError, ModuleNotFoundError): - raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause") - model = SileroVADModel() - # Warm up the model with dummy data - print(click.style("INFO", fg="green") + ":\t Warming up VAD model.") - for _ in range(10): - dummy_audio = np.zeros(102400, dtype=np.float32) - model.vad((24000, dummy_audio), None) - print(click.style("INFO", fg="green") + ":\t VAD model warmed up.") - return model - @dataclass class AlgoOptions: @@ -94,12 +69,13 @@ class ReplyOnPause(StreamHandler): self, fn: ReplyFnGenerator, algo_options: AlgoOptions | None = None, - model_options: SileroVadOptions | None = None, + model_options: ModelOptions | None = None, can_interrupt: bool = True, expected_layout: Literal["mono", "stereo"] = "mono", output_sample_rate: int = 24000, output_frame_size: int = 480, input_sample_rate: int = 48000, + model: PauseDetectionModel | None = None, ): super().__init__( expected_layout, @@ -111,7 +87,7 @@ class ReplyOnPause(StreamHandler): self.expected_layout: Literal["mono", "stereo"] = expected_layout self.output_sample_rate = output_sample_rate self.output_frame_size = output_frame_size - self.model = get_vad_model() + self.model = model or get_silero_model() self.fn = fn self.is_async = inspect.isasyncgenfunction(fn) self.event = Event() @@ -145,7 +121,7 @@ class ReplyOnPause(StreamHandler): duration = len(audio) / sampling_rate if duration >= self.algo_options.audio_chunk_duration: - dur_vad = self.model.vad((sampling_rate, audio), self.model_options) + dur_vad, _ = self.model.vad((sampling_rate, audio), self.model_options) logger.debug("VAD duration: %s", dur_vad) if ( dur_vad > self.algo_options.started_talking_threshold diff --git a/backend/fastrtc/reply_on_stopwords.py b/backend/fastrtc/reply_on_stopwords.py index d0063ac..6a05e76 100644 --- a/backend/fastrtc/reply_on_stopwords.py +++ b/backend/fastrtc/reply_on_stopwords.py @@ -8,9 +8,10 @@ import numpy as np from .reply_on_pause import ( AlgoOptions, AppState, + ModelOptions, + PauseDetectionModel, ReplyFnGenerator, ReplyOnPause, - SileroVadOptions, ) from .speech_to_text import get_stt_model from .utils import audio_to_float32, create_message @@ -33,12 +34,13 @@ class ReplyOnStopWords(ReplyOnPause): fn: ReplyFnGenerator, stop_words: list[str], algo_options: AlgoOptions | None = None, - model_options: SileroVadOptions | None = None, + model_options: ModelOptions | None = None, can_interrupt: bool = True, expected_layout: Literal["mono", "stereo"] = "mono", output_sample_rate: int = 24000, output_frame_size: int = 480, input_sample_rate: int = 48000, + model: PauseDetectionModel | None = None, ): super().__init__( fn, @@ -49,6 +51,7 @@ class ReplyOnStopWords(ReplyOnPause): output_sample_rate=output_sample_rate, output_frame_size=output_frame_size, input_sample_rate=input_sample_rate, + model=model, ) self.stop_words = stop_words self.state = ReplyOnStopWordsState() @@ -114,7 +117,7 @@ class ReplyOnStopWords(ReplyOnPause): self.send_stopword() state.buffer = None else: - dur_vad = self.model.vad((sampling_rate, audio), self.model_options) + dur_vad, _ = self.model.vad((sampling_rate, audio), self.model_options) logger.debug("VAD duration: %s", dur_vad) if ( dur_vad > self.algo_options.started_talking_threshold diff --git a/docs/vad_gallery.md b/docs/vad_gallery.md new file mode 100644 index 0000000..e8632e0 --- /dev/null +++ b/docs/vad_gallery.md @@ -0,0 +1,60 @@ + + +A collection of VAD models ready to use with FastRTC. Click on the tags below to find the VAD model you're looking for! + + +
+ + + + +