Add Method for loading community Vad Models (#136)

* Add code

* add code
This commit is contained in:
Freddy Boulton
2025-03-07 16:27:18 -05:00
committed by GitHub
parent 6905810f37
commit cbbfa17679
8 changed files with 156 additions and 70 deletions

View File

@@ -3,7 +3,13 @@ from .credentials import (
get_turn_credentials, get_turn_credentials,
get_twilio_turn_credentials, get_twilio_turn_credentials,
) )
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions from .pause_detection import (
ModelOptions,
PauseDetectionModel,
SileroVadOptions,
get_silero_model,
)
from .reply_on_pause import AlgoOptions, ReplyOnPause
from .reply_on_stopwords import ReplyOnStopWords from .reply_on_stopwords import ReplyOnStopWords
from .speech_to_text import MoonshineSTT, get_stt_model from .speech_to_text import MoonshineSTT, get_stt_model
from .stream import Stream, UIArgs from .stream import Stream, UIArgs
@@ -63,4 +69,8 @@ __all__ = [
"KokoroTTSOptions", "KokoroTTSOptions",
"wait_for_item", "wait_for_item",
"UIArgs", "UIArgs",
"ModelOptions",
"PauseDetectionModel",
"get_silero_model",
"SileroVadOptions",
] ]

View File

@@ -1,3 +1,10 @@
from .vad import SileroVADModel, SileroVadOptions from .protocol import ModelOptions, PauseDetectionModel
from .silero import SileroVADModel, SileroVadOptions, get_silero_model
__all__ = ["SileroVADModel", "SileroVadOptions"] __all__ = [
"SileroVADModel",
"SileroVadOptions",
"PauseDetectionModel",
"ModelOptions",
"get_silero_model",
]

View File

@@ -0,0 +1,20 @@
from typing import Any, Protocol, TypeAlias
import numpy as np
from numpy.typing import NDArray
from ..utils import AudioChunk
ModelOptions: TypeAlias = Any
class PauseDetectionModel(Protocol):
def vad(
self,
audio: tuple[int, NDArray[np.int16] | NDArray[np.float32]],
options: ModelOptions,
) -> tuple[float, list[AudioChunk]]: ...
def warmup(
self,
) -> None: ...

View File

@@ -1,13 +1,16 @@
import logging import logging
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Literal, overload from functools import lru_cache
from typing import List
import click
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from numpy.typing import NDArray from numpy.typing import NDArray
from ..utils import AudioChunk from ..utils import AudioChunk
from .protocol import PauseDetectionModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -15,6 +18,26 @@ logger = logging.getLogger(__name__)
# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py # The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py
@lru_cache
def get_silero_model() -> PauseDetectionModel:
"""Returns the VAD model instance and warms it up with dummy data."""
# Warm up the model with dummy data
try:
import importlib.util
mod = importlib.util.find_spec("onnxruntime")
if mod is None:
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
except (ValueError, ModuleNotFoundError):
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
model = SileroVADModel()
print(click.style("INFO", fg="green") + ":\t Warming up VAD model.")
model.warmup()
print(click.style("INFO", fg="green") + ":\t VAD model warmed up.")
return model
@dataclass @dataclass
class SileroVadOptions: class SileroVadOptions:
"""VAD options. """VAD options.
@@ -239,33 +262,21 @@ class SileroVADModel:
return speeches return speeches
@overload def warmup(self):
def vad( for _ in range(10):
self, dummy_audio = np.zeros(102400, dtype=np.float32)
audio_tuple: tuple[int, NDArray], self.vad((24000, dummy_audio), None)
vad_parameters: None | SileroVadOptions,
return_chunks: Literal[True],
) -> tuple[float, List[AudioChunk]]: ...
@overload
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float: ...
def vad( def vad(
self, self,
audio_tuple: tuple[int, NDArray], audio: tuple[int, NDArray[np.float32] | NDArray[np.int16]],
vad_parameters: None | SileroVadOptions, options: None | SileroVadOptions,
return_chunks: bool = False, ) -> tuple[float, list[AudioChunk]]:
) -> float | tuple[float, List[AudioChunk]]: sampling_rate, audio_ = audio
sampling_rate, audio = audio_tuple logger.debug("VAD audio shape input: %s", audio_.shape)
logger.debug("VAD audio shape input: %s", audio.shape)
try: try:
if audio.dtype != np.float32: if audio_.dtype != np.float32:
audio = audio.astype(np.float32) / 32768.0 audio_ = audio_.astype(np.float32) / 32768.0
sr = 16000 sr = 16000
if sr != sampling_rate: if sr != sampling_rate:
try: try:
@@ -274,18 +285,16 @@ class SileroVADModel:
raise RuntimeError( raise RuntimeError(
"Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz" "Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz"
) from e ) from e
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) audio_ = librosa.resample(audio_, orig_sr=sampling_rate, target_sr=sr)
if not vad_parameters: if not options:
vad_parameters = SileroVadOptions() vad_parameters = SileroVadOptions()
speech_chunks = self.get_speech_timestamps(audio, vad_parameters) speech_chunks = self.get_speech_timestamps(audio_, vad_parameters)
logger.debug("VAD speech chunks: %s", speech_chunks) logger.debug("VAD speech chunks: %s", speech_chunks)
audio = self.collect_chunks(audio, speech_chunks) audio_ = self.collect_chunks(audio_, speech_chunks)
logger.debug("VAD audio shape: %s", audio.shape) logger.debug("VAD audio shape: %s", audio_.shape)
duration_after_vad = audio.shape[0] / sr duration_after_vad = audio_.shape[0] / sr
if return_chunks: return duration_after_vad, speech_chunks
return duration_after_vad, speech_chunks
return duration_after_vad
except Exception as e: except Exception as e:
import math import math
import traceback import traceback
@@ -293,7 +302,7 @@ class SileroVADModel:
logger.debug("VAD Exception: %s", str(e)) logger.debug("VAD Exception: %s", str(e))
exec = traceback.format_exc() exec = traceback.format_exc()
logger.debug("traceback %s", exec) logger.debug("traceback %s", exec)
return math.inf return math.inf, []
def __call__(self, x, state, sr: int): def __call__(self, x, state, sr: int):
if len(x.shape) == 1: if len(x.shape) == 1:

View File

@@ -1,44 +1,19 @@
import asyncio import asyncio
import inspect import inspect
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache
from logging import getLogger from logging import getLogger
from threading import Event from threading import Event
from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast
import click
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from .pause_detection import SileroVADModel, SileroVadOptions from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model
from .tracks import EmitType, StreamHandler from .tracks import EmitType, StreamHandler
from .utils import create_message, split_output from .utils import create_message, split_output
logger = getLogger(__name__) logger = getLogger(__name__)
counter = 0
@lru_cache
def get_vad_model() -> SileroVADModel:
"""Returns the VAD model instance and warms it up with dummy data."""
try:
import importlib.util
mod = importlib.util.find_spec("onnxruntime")
if mod is None:
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
except (ValueError, ModuleNotFoundError):
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
model = SileroVADModel()
# Warm up the model with dummy data
print(click.style("INFO", fg="green") + ":\t Warming up VAD model.")
for _ in range(10):
dummy_audio = np.zeros(102400, dtype=np.float32)
model.vad((24000, dummy_audio), None)
print(click.style("INFO", fg="green") + ":\t VAD model warmed up.")
return model
@dataclass @dataclass
class AlgoOptions: class AlgoOptions:
@@ -94,12 +69,13 @@ class ReplyOnPause(StreamHandler):
self, self,
fn: ReplyFnGenerator, fn: ReplyFnGenerator,
algo_options: AlgoOptions | None = None, algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None, model_options: ModelOptions | None = None,
can_interrupt: bool = True, can_interrupt: bool = True,
expected_layout: Literal["mono", "stereo"] = "mono", expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000, output_sample_rate: int = 24000,
output_frame_size: int = 480, output_frame_size: int = 480,
input_sample_rate: int = 48000, input_sample_rate: int = 48000,
model: PauseDetectionModel | None = None,
): ):
super().__init__( super().__init__(
expected_layout, expected_layout,
@@ -111,7 +87,7 @@ class ReplyOnPause(StreamHandler):
self.expected_layout: Literal["mono", "stereo"] = expected_layout self.expected_layout: Literal["mono", "stereo"] = expected_layout
self.output_sample_rate = output_sample_rate self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size self.output_frame_size = output_frame_size
self.model = get_vad_model() self.model = model or get_silero_model()
self.fn = fn self.fn = fn
self.is_async = inspect.isasyncgenfunction(fn) self.is_async = inspect.isasyncgenfunction(fn)
self.event = Event() self.event = Event()
@@ -145,7 +121,7 @@ class ReplyOnPause(StreamHandler):
duration = len(audio) / sampling_rate duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration: if duration >= self.algo_options.audio_chunk_duration:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options) dur_vad, _ = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad) logger.debug("VAD duration: %s", dur_vad)
if ( if (
dur_vad > self.algo_options.started_talking_threshold dur_vad > self.algo_options.started_talking_threshold

View File

@@ -8,9 +8,10 @@ import numpy as np
from .reply_on_pause import ( from .reply_on_pause import (
AlgoOptions, AlgoOptions,
AppState, AppState,
ModelOptions,
PauseDetectionModel,
ReplyFnGenerator, ReplyFnGenerator,
ReplyOnPause, ReplyOnPause,
SileroVadOptions,
) )
from .speech_to_text import get_stt_model from .speech_to_text import get_stt_model
from .utils import audio_to_float32, create_message from .utils import audio_to_float32, create_message
@@ -33,12 +34,13 @@ class ReplyOnStopWords(ReplyOnPause):
fn: ReplyFnGenerator, fn: ReplyFnGenerator,
stop_words: list[str], stop_words: list[str],
algo_options: AlgoOptions | None = None, algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None, model_options: ModelOptions | None = None,
can_interrupt: bool = True, can_interrupt: bool = True,
expected_layout: Literal["mono", "stereo"] = "mono", expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000, output_sample_rate: int = 24000,
output_frame_size: int = 480, output_frame_size: int = 480,
input_sample_rate: int = 48000, input_sample_rate: int = 48000,
model: PauseDetectionModel | None = None,
): ):
super().__init__( super().__init__(
fn, fn,
@@ -49,6 +51,7 @@ class ReplyOnStopWords(ReplyOnPause):
output_sample_rate=output_sample_rate, output_sample_rate=output_sample_rate,
output_frame_size=output_frame_size, output_frame_size=output_frame_size,
input_sample_rate=input_sample_rate, input_sample_rate=input_sample_rate,
model=model,
) )
self.stop_words = stop_words self.stop_words = stop_words
self.state = ReplyOnStopWordsState() self.state = ReplyOnStopWordsState()
@@ -114,7 +117,7 @@ class ReplyOnStopWords(ReplyOnPause):
self.send_stopword() self.send_stopword()
state.buffer = None state.buffer = None
else: else:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options) dur_vad, _ = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad) logger.debug("VAD duration: %s", dur_vad)
if ( if (
dur_vad > self.algo_options.started_talking_threshold dur_vad > self.algo_options.started_talking_threshold

60
docs/vad_gallery.md Normal file
View File

@@ -0,0 +1,60 @@
<style>
.tag-button {
cursor: pointer;
opacity: 0.5;
transition: opacity 0.2s ease;
}
.tag-button > code {
color: var(--supernova);
}
.tag-button.active {
opacity: 1;
}
</style>
A collection of VAD models ready to use with FastRTC. Click on the tags below to find the VAD model you're looking for!
<div class="tag-buttons">
<button class="tag-button" data-tag="pytorch"><code>pytorch</code></button>
</div>
<script>
function filterCards() {
const activeButtons = document.querySelectorAll('.tag-button.active');
const selectedTags = Array.from(activeButtons).map(button => button.getAttribute('data-tag'));
const cards = document.querySelectorAll('.grid.cards > ul > li > p[data-tags]');
cards.forEach(card => {
const cardTags = card.getAttribute('data-tags').split(',');
const shouldShow = selectedTags.length === 0 || selectedTags.some(tag => cardTags.includes(tag));
card.parentElement.style.display = shouldShow ? 'block' : 'none';
});
}
document.querySelectorAll('.tag-button').forEach(button => {
button.addEventListener('click', () => {
button.classList.toggle('active');
filterCards();
});
});
</script>
<div class="grid cards" markdown>
- :speaking_head:{ .lg .middle }:eyes:{ .lg .middle } __Your VAD Model__
{: data-tags="pytorch"}
---
Description
Install Instructions
Usage
[:octicons-arrow-right-24: Demo](Your demo here)
[:octicons-code-16: Repository](Code here)

View File

@@ -28,6 +28,7 @@ nav:
- Cookbook: cookbook.md - Cookbook: cookbook.md
- Deployment: deployment.md - Deployment: deployment.md
- Advanced Configuration: advanced-configuration.md - Advanced Configuration: advanced-configuration.md
- VAD Gallery: vad_gallery.md
- Utils: utils.md - Utils: utils.md
- Frequently Asked Questions: faq.md - Frequently Asked Questions: faq.md
extra_javascript: extra_javascript:
@@ -49,4 +50,4 @@ markdown_extensions:
emoji_index: !!python/name:material.extensions.emoji.twemoji emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg emoji_generator: !!python/name:material.extensions.emoji.to_svg
- admonition - admonition
- pymdownx.details - pymdownx.details