Add Method for loading community Vad Models (#136)

* Add code

* add code
This commit is contained in:
Freddy Boulton
2025-03-07 16:27:18 -05:00
committed by GitHub
parent 6905810f37
commit cbbfa17679
8 changed files with 156 additions and 70 deletions

View File

@@ -3,7 +3,13 @@ from .credentials import (
get_turn_credentials,
get_twilio_turn_credentials,
)
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions
from .pause_detection import (
ModelOptions,
PauseDetectionModel,
SileroVadOptions,
get_silero_model,
)
from .reply_on_pause import AlgoOptions, ReplyOnPause
from .reply_on_stopwords import ReplyOnStopWords
from .speech_to_text import MoonshineSTT, get_stt_model
from .stream import Stream, UIArgs
@@ -63,4 +69,8 @@ __all__ = [
"KokoroTTSOptions",
"wait_for_item",
"UIArgs",
"ModelOptions",
"PauseDetectionModel",
"get_silero_model",
"SileroVadOptions",
]

View File

@@ -1,3 +1,10 @@
from .vad import SileroVADModel, SileroVadOptions
from .protocol import ModelOptions, PauseDetectionModel
from .silero import SileroVADModel, SileroVadOptions, get_silero_model
__all__ = ["SileroVADModel", "SileroVadOptions"]
__all__ = [
"SileroVADModel",
"SileroVadOptions",
"PauseDetectionModel",
"ModelOptions",
"get_silero_model",
]

View File

@@ -0,0 +1,20 @@
from typing import Any, Protocol, TypeAlias
import numpy as np
from numpy.typing import NDArray
from ..utils import AudioChunk
ModelOptions: TypeAlias = Any
class PauseDetectionModel(Protocol):
def vad(
self,
audio: tuple[int, NDArray[np.int16] | NDArray[np.float32]],
options: ModelOptions,
) -> tuple[float, list[AudioChunk]]: ...
def warmup(
self,
) -> None: ...

View File

@@ -1,13 +1,16 @@
import logging
import warnings
from dataclasses import dataclass
from typing import List, Literal, overload
from functools import lru_cache
from typing import List
import click
import numpy as np
from huggingface_hub import hf_hub_download
from numpy.typing import NDArray
from ..utils import AudioChunk
from .protocol import PauseDetectionModel
logger = logging.getLogger(__name__)
@@ -15,6 +18,26 @@ logger = logging.getLogger(__name__)
# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py
@lru_cache
def get_silero_model() -> PauseDetectionModel:
"""Returns the VAD model instance and warms it up with dummy data."""
# Warm up the model with dummy data
try:
import importlib.util
mod = importlib.util.find_spec("onnxruntime")
if mod is None:
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
except (ValueError, ModuleNotFoundError):
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
model = SileroVADModel()
print(click.style("INFO", fg="green") + ":\t Warming up VAD model.")
model.warmup()
print(click.style("INFO", fg="green") + ":\t VAD model warmed up.")
return model
@dataclass
class SileroVadOptions:
"""VAD options.
@@ -239,33 +262,21 @@ class SileroVADModel:
return speeches
@overload
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: Literal[True],
) -> tuple[float, List[AudioChunk]]: ...
@overload
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float: ...
def warmup(self):
for _ in range(10):
dummy_audio = np.zeros(102400, dtype=np.float32)
self.vad((24000, dummy_audio), None)
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float | tuple[float, List[AudioChunk]]:
sampling_rate, audio = audio_tuple
logger.debug("VAD audio shape input: %s", audio.shape)
audio: tuple[int, NDArray[np.float32] | NDArray[np.int16]],
options: None | SileroVadOptions,
) -> tuple[float, list[AudioChunk]]:
sampling_rate, audio_ = audio
logger.debug("VAD audio shape input: %s", audio_.shape)
try:
if audio.dtype != np.float32:
audio = audio.astype(np.float32) / 32768.0
if audio_.dtype != np.float32:
audio_ = audio_.astype(np.float32) / 32768.0
sr = 16000
if sr != sampling_rate:
try:
@@ -274,18 +285,16 @@ class SileroVADModel:
raise RuntimeError(
"Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz"
) from e
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
audio_ = librosa.resample(audio_, orig_sr=sampling_rate, target_sr=sr)
if not vad_parameters:
if not options:
vad_parameters = SileroVadOptions()
speech_chunks = self.get_speech_timestamps(audio, vad_parameters)
speech_chunks = self.get_speech_timestamps(audio_, vad_parameters)
logger.debug("VAD speech chunks: %s", speech_chunks)
audio = self.collect_chunks(audio, speech_chunks)
logger.debug("VAD audio shape: %s", audio.shape)
duration_after_vad = audio.shape[0] / sr
if return_chunks:
return duration_after_vad, speech_chunks
return duration_after_vad
audio_ = self.collect_chunks(audio_, speech_chunks)
logger.debug("VAD audio shape: %s", audio_.shape)
duration_after_vad = audio_.shape[0] / sr
return duration_after_vad, speech_chunks
except Exception as e:
import math
import traceback
@@ -293,7 +302,7 @@ class SileroVADModel:
logger.debug("VAD Exception: %s", str(e))
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
return math.inf
return math.inf, []
def __call__(self, x, state, sr: int):
if len(x.shape) == 1:

View File

@@ -1,44 +1,19 @@
import asyncio
import inspect
from dataclasses import dataclass, field
from functools import lru_cache
from logging import getLogger
from threading import Event
from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast
import click
import numpy as np
from numpy.typing import NDArray
from .pause_detection import SileroVADModel, SileroVadOptions
from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model
from .tracks import EmitType, StreamHandler
from .utils import create_message, split_output
logger = getLogger(__name__)
counter = 0
@lru_cache
def get_vad_model() -> SileroVADModel:
"""Returns the VAD model instance and warms it up with dummy data."""
try:
import importlib.util
mod = importlib.util.find_spec("onnxruntime")
if mod is None:
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
except (ValueError, ModuleNotFoundError):
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
model = SileroVADModel()
# Warm up the model with dummy data
print(click.style("INFO", fg="green") + ":\t Warming up VAD model.")
for _ in range(10):
dummy_audio = np.zeros(102400, dtype=np.float32)
model.vad((24000, dummy_audio), None)
print(click.style("INFO", fg="green") + ":\t VAD model warmed up.")
return model
@dataclass
class AlgoOptions:
@@ -94,12 +69,13 @@ class ReplyOnPause(StreamHandler):
self,
fn: ReplyFnGenerator,
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
model_options: ModelOptions | None = None,
can_interrupt: bool = True,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
input_sample_rate: int = 48000,
model: PauseDetectionModel | None = None,
):
super().__init__(
expected_layout,
@@ -111,7 +87,7 @@ class ReplyOnPause(StreamHandler):
self.expected_layout: Literal["mono", "stereo"] = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.model = get_vad_model()
self.model = model or get_silero_model()
self.fn = fn
self.is_async = inspect.isasyncgenfunction(fn)
self.event = Event()
@@ -145,7 +121,7 @@ class ReplyOnPause(StreamHandler):
duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
dur_vad, _ = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad)
if (
dur_vad > self.algo_options.started_talking_threshold

View File

@@ -8,9 +8,10 @@ import numpy as np
from .reply_on_pause import (
AlgoOptions,
AppState,
ModelOptions,
PauseDetectionModel,
ReplyFnGenerator,
ReplyOnPause,
SileroVadOptions,
)
from .speech_to_text import get_stt_model
from .utils import audio_to_float32, create_message
@@ -33,12 +34,13 @@ class ReplyOnStopWords(ReplyOnPause):
fn: ReplyFnGenerator,
stop_words: list[str],
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
model_options: ModelOptions | None = None,
can_interrupt: bool = True,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
input_sample_rate: int = 48000,
model: PauseDetectionModel | None = None,
):
super().__init__(
fn,
@@ -49,6 +51,7 @@ class ReplyOnStopWords(ReplyOnPause):
output_sample_rate=output_sample_rate,
output_frame_size=output_frame_size,
input_sample_rate=input_sample_rate,
model=model,
)
self.stop_words = stop_words
self.state = ReplyOnStopWordsState()
@@ -114,7 +117,7 @@ class ReplyOnStopWords(ReplyOnPause):
self.send_stopword()
state.buffer = None
else:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
dur_vad, _ = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad)
if (
dur_vad > self.algo_options.started_talking_threshold

60
docs/vad_gallery.md Normal file
View File

@@ -0,0 +1,60 @@
<style>
.tag-button {
cursor: pointer;
opacity: 0.5;
transition: opacity 0.2s ease;
}
.tag-button > code {
color: var(--supernova);
}
.tag-button.active {
opacity: 1;
}
</style>
A collection of VAD models ready to use with FastRTC. Click on the tags below to find the VAD model you're looking for!
<div class="tag-buttons">
<button class="tag-button" data-tag="pytorch"><code>pytorch</code></button>
</div>
<script>
function filterCards() {
const activeButtons = document.querySelectorAll('.tag-button.active');
const selectedTags = Array.from(activeButtons).map(button => button.getAttribute('data-tag'));
const cards = document.querySelectorAll('.grid.cards > ul > li > p[data-tags]');
cards.forEach(card => {
const cardTags = card.getAttribute('data-tags').split(',');
const shouldShow = selectedTags.length === 0 || selectedTags.some(tag => cardTags.includes(tag));
card.parentElement.style.display = shouldShow ? 'block' : 'none';
});
}
document.querySelectorAll('.tag-button').forEach(button => {
button.addEventListener('click', () => {
button.classList.toggle('active');
filterCards();
});
});
</script>
<div class="grid cards" markdown>
- :speaking_head:{ .lg .middle }:eyes:{ .lg .middle } __Your VAD Model__
{: data-tags="pytorch"}
---
Description
Install Instructions
Usage
[:octicons-arrow-right-24: Demo](Your demo here)
[:octicons-code-16: Repository](Code here)

View File

@@ -28,6 +28,7 @@ nav:
- Cookbook: cookbook.md
- Deployment: deployment.md
- Advanced Configuration: advanced-configuration.md
- VAD Gallery: vad_gallery.md
- Utils: utils.md
- Frequently Asked Questions: faq.md
extra_javascript:
@@ -49,4 +50,4 @@ markdown_extensions:
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
- admonition
- pymdownx.details
- pymdownx.details