mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add Method for loading community Vad Models (#136)
* Add code * add code
This commit is contained in:
@@ -3,7 +3,13 @@ from .credentials import (
|
|||||||
get_turn_credentials,
|
get_turn_credentials,
|
||||||
get_twilio_turn_credentials,
|
get_twilio_turn_credentials,
|
||||||
)
|
)
|
||||||
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions
|
from .pause_detection import (
|
||||||
|
ModelOptions,
|
||||||
|
PauseDetectionModel,
|
||||||
|
SileroVadOptions,
|
||||||
|
get_silero_model,
|
||||||
|
)
|
||||||
|
from .reply_on_pause import AlgoOptions, ReplyOnPause
|
||||||
from .reply_on_stopwords import ReplyOnStopWords
|
from .reply_on_stopwords import ReplyOnStopWords
|
||||||
from .speech_to_text import MoonshineSTT, get_stt_model
|
from .speech_to_text import MoonshineSTT, get_stt_model
|
||||||
from .stream import Stream, UIArgs
|
from .stream import Stream, UIArgs
|
||||||
@@ -63,4 +69,8 @@ __all__ = [
|
|||||||
"KokoroTTSOptions",
|
"KokoroTTSOptions",
|
||||||
"wait_for_item",
|
"wait_for_item",
|
||||||
"UIArgs",
|
"UIArgs",
|
||||||
|
"ModelOptions",
|
||||||
|
"PauseDetectionModel",
|
||||||
|
"get_silero_model",
|
||||||
|
"SileroVadOptions",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
from .vad import SileroVADModel, SileroVadOptions
|
from .protocol import ModelOptions, PauseDetectionModel
|
||||||
|
from .silero import SileroVADModel, SileroVadOptions, get_silero_model
|
||||||
|
|
||||||
__all__ = ["SileroVADModel", "SileroVadOptions"]
|
__all__ = [
|
||||||
|
"SileroVADModel",
|
||||||
|
"SileroVadOptions",
|
||||||
|
"PauseDetectionModel",
|
||||||
|
"ModelOptions",
|
||||||
|
"get_silero_model",
|
||||||
|
]
|
||||||
|
|||||||
20
backend/fastrtc/pause_detection/protocol.py
Normal file
20
backend/fastrtc/pause_detection/protocol.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import Any, Protocol, TypeAlias
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from ..utils import AudioChunk
|
||||||
|
|
||||||
|
ModelOptions: TypeAlias = Any
|
||||||
|
|
||||||
|
|
||||||
|
class PauseDetectionModel(Protocol):
|
||||||
|
def vad(
|
||||||
|
self,
|
||||||
|
audio: tuple[int, NDArray[np.int16] | NDArray[np.float32]],
|
||||||
|
options: ModelOptions,
|
||||||
|
) -> tuple[float, list[AudioChunk]]: ...
|
||||||
|
|
||||||
|
def warmup(
|
||||||
|
self,
|
||||||
|
) -> None: ...
|
||||||
@@ -1,13 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Literal, overload
|
from functools import lru_cache
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import click
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from ..utils import AudioChunk
|
from ..utils import AudioChunk
|
||||||
|
from .protocol import PauseDetectionModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -15,6 +18,26 @@ logger = logging.getLogger(__name__)
|
|||||||
# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py
|
# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_silero_model() -> PauseDetectionModel:
|
||||||
|
"""Returns the VAD model instance and warms it up with dummy data."""
|
||||||
|
# Warm up the model with dummy data
|
||||||
|
|
||||||
|
try:
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
mod = importlib.util.find_spec("onnxruntime")
|
||||||
|
if mod is None:
|
||||||
|
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
|
||||||
|
except (ValueError, ModuleNotFoundError):
|
||||||
|
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
|
||||||
|
model = SileroVADModel()
|
||||||
|
print(click.style("INFO", fg="green") + ":\t Warming up VAD model.")
|
||||||
|
model.warmup()
|
||||||
|
print(click.style("INFO", fg="green") + ":\t VAD model warmed up.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SileroVadOptions:
|
class SileroVadOptions:
|
||||||
"""VAD options.
|
"""VAD options.
|
||||||
@@ -239,33 +262,21 @@ class SileroVADModel:
|
|||||||
|
|
||||||
return speeches
|
return speeches
|
||||||
|
|
||||||
@overload
|
def warmup(self):
|
||||||
def vad(
|
for _ in range(10):
|
||||||
self,
|
dummy_audio = np.zeros(102400, dtype=np.float32)
|
||||||
audio_tuple: tuple[int, NDArray],
|
self.vad((24000, dummy_audio), None)
|
||||||
vad_parameters: None | SileroVadOptions,
|
|
||||||
return_chunks: Literal[True],
|
|
||||||
) -> tuple[float, List[AudioChunk]]: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def vad(
|
|
||||||
self,
|
|
||||||
audio_tuple: tuple[int, NDArray],
|
|
||||||
vad_parameters: None | SileroVadOptions,
|
|
||||||
return_chunks: bool = False,
|
|
||||||
) -> float: ...
|
|
||||||
|
|
||||||
def vad(
|
def vad(
|
||||||
self,
|
self,
|
||||||
audio_tuple: tuple[int, NDArray],
|
audio: tuple[int, NDArray[np.float32] | NDArray[np.int16]],
|
||||||
vad_parameters: None | SileroVadOptions,
|
options: None | SileroVadOptions,
|
||||||
return_chunks: bool = False,
|
) -> tuple[float, list[AudioChunk]]:
|
||||||
) -> float | tuple[float, List[AudioChunk]]:
|
sampling_rate, audio_ = audio
|
||||||
sampling_rate, audio = audio_tuple
|
logger.debug("VAD audio shape input: %s", audio_.shape)
|
||||||
logger.debug("VAD audio shape input: %s", audio.shape)
|
|
||||||
try:
|
try:
|
||||||
if audio.dtype != np.float32:
|
if audio_.dtype != np.float32:
|
||||||
audio = audio.astype(np.float32) / 32768.0
|
audio_ = audio_.astype(np.float32) / 32768.0
|
||||||
sr = 16000
|
sr = 16000
|
||||||
if sr != sampling_rate:
|
if sr != sampling_rate:
|
||||||
try:
|
try:
|
||||||
@@ -274,18 +285,16 @@ class SileroVADModel:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz"
|
"Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz"
|
||||||
) from e
|
) from e
|
||||||
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
|
audio_ = librosa.resample(audio_, orig_sr=sampling_rate, target_sr=sr)
|
||||||
|
|
||||||
if not vad_parameters:
|
if not options:
|
||||||
vad_parameters = SileroVadOptions()
|
vad_parameters = SileroVadOptions()
|
||||||
speech_chunks = self.get_speech_timestamps(audio, vad_parameters)
|
speech_chunks = self.get_speech_timestamps(audio_, vad_parameters)
|
||||||
logger.debug("VAD speech chunks: %s", speech_chunks)
|
logger.debug("VAD speech chunks: %s", speech_chunks)
|
||||||
audio = self.collect_chunks(audio, speech_chunks)
|
audio_ = self.collect_chunks(audio_, speech_chunks)
|
||||||
logger.debug("VAD audio shape: %s", audio.shape)
|
logger.debug("VAD audio shape: %s", audio_.shape)
|
||||||
duration_after_vad = audio.shape[0] / sr
|
duration_after_vad = audio_.shape[0] / sr
|
||||||
if return_chunks:
|
return duration_after_vad, speech_chunks
|
||||||
return duration_after_vad, speech_chunks
|
|
||||||
return duration_after_vad
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import math
|
import math
|
||||||
import traceback
|
import traceback
|
||||||
@@ -293,7 +302,7 @@ class SileroVADModel:
|
|||||||
logger.debug("VAD Exception: %s", str(e))
|
logger.debug("VAD Exception: %s", str(e))
|
||||||
exec = traceback.format_exc()
|
exec = traceback.format_exc()
|
||||||
logger.debug("traceback %s", exec)
|
logger.debug("traceback %s", exec)
|
||||||
return math.inf
|
return math.inf, []
|
||||||
|
|
||||||
def __call__(self, x, state, sr: int):
|
def __call__(self, x, state, sr: int):
|
||||||
if len(x.shape) == 1:
|
if len(x.shape) == 1:
|
||||||
@@ -1,44 +1,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import lru_cache
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast
|
from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast
|
||||||
|
|
||||||
import click
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from .pause_detection import SileroVADModel, SileroVadOptions
|
from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model
|
||||||
from .tracks import EmitType, StreamHandler
|
from .tracks import EmitType, StreamHandler
|
||||||
from .utils import create_message, split_output
|
from .utils import create_message, split_output
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
counter = 0
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def get_vad_model() -> SileroVADModel:
|
|
||||||
"""Returns the VAD model instance and warms it up with dummy data."""
|
|
||||||
try:
|
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
mod = importlib.util.find_spec("onnxruntime")
|
|
||||||
if mod is None:
|
|
||||||
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
|
|
||||||
except (ValueError, ModuleNotFoundError):
|
|
||||||
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
|
|
||||||
model = SileroVADModel()
|
|
||||||
# Warm up the model with dummy data
|
|
||||||
print(click.style("INFO", fg="green") + ":\t Warming up VAD model.")
|
|
||||||
for _ in range(10):
|
|
||||||
dummy_audio = np.zeros(102400, dtype=np.float32)
|
|
||||||
model.vad((24000, dummy_audio), None)
|
|
||||||
print(click.style("INFO", fg="green") + ":\t VAD model warmed up.")
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AlgoOptions:
|
class AlgoOptions:
|
||||||
@@ -94,12 +69,13 @@ class ReplyOnPause(StreamHandler):
|
|||||||
self,
|
self,
|
||||||
fn: ReplyFnGenerator,
|
fn: ReplyFnGenerator,
|
||||||
algo_options: AlgoOptions | None = None,
|
algo_options: AlgoOptions | None = None,
|
||||||
model_options: SileroVadOptions | None = None,
|
model_options: ModelOptions | None = None,
|
||||||
can_interrupt: bool = True,
|
can_interrupt: bool = True,
|
||||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||||
output_sample_rate: int = 24000,
|
output_sample_rate: int = 24000,
|
||||||
output_frame_size: int = 480,
|
output_frame_size: int = 480,
|
||||||
input_sample_rate: int = 48000,
|
input_sample_rate: int = 48000,
|
||||||
|
model: PauseDetectionModel | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
expected_layout,
|
expected_layout,
|
||||||
@@ -111,7 +87,7 @@ class ReplyOnPause(StreamHandler):
|
|||||||
self.expected_layout: Literal["mono", "stereo"] = expected_layout
|
self.expected_layout: Literal["mono", "stereo"] = expected_layout
|
||||||
self.output_sample_rate = output_sample_rate
|
self.output_sample_rate = output_sample_rate
|
||||||
self.output_frame_size = output_frame_size
|
self.output_frame_size = output_frame_size
|
||||||
self.model = get_vad_model()
|
self.model = model or get_silero_model()
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.is_async = inspect.isasyncgenfunction(fn)
|
self.is_async = inspect.isasyncgenfunction(fn)
|
||||||
self.event = Event()
|
self.event = Event()
|
||||||
@@ -145,7 +121,7 @@ class ReplyOnPause(StreamHandler):
|
|||||||
duration = len(audio) / sampling_rate
|
duration = len(audio) / sampling_rate
|
||||||
|
|
||||||
if duration >= self.algo_options.audio_chunk_duration:
|
if duration >= self.algo_options.audio_chunk_duration:
|
||||||
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
|
dur_vad, _ = self.model.vad((sampling_rate, audio), self.model_options)
|
||||||
logger.debug("VAD duration: %s", dur_vad)
|
logger.debug("VAD duration: %s", dur_vad)
|
||||||
if (
|
if (
|
||||||
dur_vad > self.algo_options.started_talking_threshold
|
dur_vad > self.algo_options.started_talking_threshold
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ import numpy as np
|
|||||||
from .reply_on_pause import (
|
from .reply_on_pause import (
|
||||||
AlgoOptions,
|
AlgoOptions,
|
||||||
AppState,
|
AppState,
|
||||||
|
ModelOptions,
|
||||||
|
PauseDetectionModel,
|
||||||
ReplyFnGenerator,
|
ReplyFnGenerator,
|
||||||
ReplyOnPause,
|
ReplyOnPause,
|
||||||
SileroVadOptions,
|
|
||||||
)
|
)
|
||||||
from .speech_to_text import get_stt_model
|
from .speech_to_text import get_stt_model
|
||||||
from .utils import audio_to_float32, create_message
|
from .utils import audio_to_float32, create_message
|
||||||
@@ -33,12 +34,13 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
fn: ReplyFnGenerator,
|
fn: ReplyFnGenerator,
|
||||||
stop_words: list[str],
|
stop_words: list[str],
|
||||||
algo_options: AlgoOptions | None = None,
|
algo_options: AlgoOptions | None = None,
|
||||||
model_options: SileroVadOptions | None = None,
|
model_options: ModelOptions | None = None,
|
||||||
can_interrupt: bool = True,
|
can_interrupt: bool = True,
|
||||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||||
output_sample_rate: int = 24000,
|
output_sample_rate: int = 24000,
|
||||||
output_frame_size: int = 480,
|
output_frame_size: int = 480,
|
||||||
input_sample_rate: int = 48000,
|
input_sample_rate: int = 48000,
|
||||||
|
model: PauseDetectionModel | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fn,
|
fn,
|
||||||
@@ -49,6 +51,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
output_sample_rate=output_sample_rate,
|
output_sample_rate=output_sample_rate,
|
||||||
output_frame_size=output_frame_size,
|
output_frame_size=output_frame_size,
|
||||||
input_sample_rate=input_sample_rate,
|
input_sample_rate=input_sample_rate,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
self.stop_words = stop_words
|
self.stop_words = stop_words
|
||||||
self.state = ReplyOnStopWordsState()
|
self.state = ReplyOnStopWordsState()
|
||||||
@@ -114,7 +117,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
self.send_stopword()
|
self.send_stopword()
|
||||||
state.buffer = None
|
state.buffer = None
|
||||||
else:
|
else:
|
||||||
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
|
dur_vad, _ = self.model.vad((sampling_rate, audio), self.model_options)
|
||||||
logger.debug("VAD duration: %s", dur_vad)
|
logger.debug("VAD duration: %s", dur_vad)
|
||||||
if (
|
if (
|
||||||
dur_vad > self.algo_options.started_talking_threshold
|
dur_vad > self.algo_options.started_talking_threshold
|
||||||
|
|||||||
60
docs/vad_gallery.md
Normal file
60
docs/vad_gallery.md
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
<style>
|
||||||
|
.tag-button {
|
||||||
|
cursor: pointer;
|
||||||
|
opacity: 0.5;
|
||||||
|
transition: opacity 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tag-button > code {
|
||||||
|
color: var(--supernova);
|
||||||
|
}
|
||||||
|
|
||||||
|
.tag-button.active {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
|
||||||
|
A collection of VAD models ready to use with FastRTC. Click on the tags below to find the VAD model you're looking for!
|
||||||
|
|
||||||
|
|
||||||
|
<div class="tag-buttons">
|
||||||
|
<button class="tag-button" data-tag="pytorch"><code>pytorch</code></button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
function filterCards() {
|
||||||
|
const activeButtons = document.querySelectorAll('.tag-button.active');
|
||||||
|
const selectedTags = Array.from(activeButtons).map(button => button.getAttribute('data-tag'));
|
||||||
|
const cards = document.querySelectorAll('.grid.cards > ul > li > p[data-tags]');
|
||||||
|
|
||||||
|
cards.forEach(card => {
|
||||||
|
const cardTags = card.getAttribute('data-tags').split(',');
|
||||||
|
const shouldShow = selectedTags.length === 0 || selectedTags.some(tag => cardTags.includes(tag));
|
||||||
|
card.parentElement.style.display = shouldShow ? 'block' : 'none';
|
||||||
|
});
|
||||||
|
}
|
||||||
|
document.querySelectorAll('.tag-button').forEach(button => {
|
||||||
|
button.addEventListener('click', () => {
|
||||||
|
button.classList.toggle('active');
|
||||||
|
filterCards();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
|
||||||
|
|
||||||
|
<div class="grid cards" markdown>
|
||||||
|
|
||||||
|
- :speaking_head:{ .lg .middle }:eyes:{ .lg .middle } __Your VAD Model__
|
||||||
|
{: data-tags="pytorch"}
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Description
|
||||||
|
|
||||||
|
Install Instructions
|
||||||
|
|
||||||
|
Usage
|
||||||
|
|
||||||
|
[:octicons-arrow-right-24: Demo](Your demo here)
|
||||||
|
|
||||||
|
[:octicons-code-16: Repository](Code here)
|
||||||
@@ -28,6 +28,7 @@ nav:
|
|||||||
- Cookbook: cookbook.md
|
- Cookbook: cookbook.md
|
||||||
- Deployment: deployment.md
|
- Deployment: deployment.md
|
||||||
- Advanced Configuration: advanced-configuration.md
|
- Advanced Configuration: advanced-configuration.md
|
||||||
|
- VAD Gallery: vad_gallery.md
|
||||||
- Utils: utils.md
|
- Utils: utils.md
|
||||||
- Frequently Asked Questions: faq.md
|
- Frequently Asked Questions: faq.md
|
||||||
extra_javascript:
|
extra_javascript:
|
||||||
@@ -49,4 +50,4 @@ markdown_extensions:
|
|||||||
emoji_index: !!python/name:material.extensions.emoji.twemoji
|
emoji_index: !!python/name:material.extensions.emoji.twemoji
|
||||||
emoji_generator: !!python/name:material.extensions.emoji.to_svg
|
emoji_generator: !!python/name:material.extensions.emoji.to_svg
|
||||||
- admonition
|
- admonition
|
||||||
- pymdownx.details
|
- pymdownx.details
|
||||||
Reference in New Issue
Block a user