mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add Method for loading community Vad Models (#136)
* Add code * add code
This commit is contained in:
@@ -1,44 +1,19 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from logging import getLogger
|
||||
from threading import Event
|
||||
from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .pause_detection import SileroVADModel, SileroVadOptions
|
||||
from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model
|
||||
from .tracks import EmitType, StreamHandler
|
||||
from .utils import create_message, split_output
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
counter = 0
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_vad_model() -> SileroVADModel:
|
||||
"""Returns the VAD model instance and warms it up with dummy data."""
|
||||
try:
|
||||
import importlib.util
|
||||
|
||||
mod = importlib.util.find_spec("onnxruntime")
|
||||
if mod is None:
|
||||
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
|
||||
except (ValueError, ModuleNotFoundError):
|
||||
raise RuntimeError("Install fastrtc[vad] to use ReplyOnPause")
|
||||
model = SileroVADModel()
|
||||
# Warm up the model with dummy data
|
||||
print(click.style("INFO", fg="green") + ":\t Warming up VAD model.")
|
||||
for _ in range(10):
|
||||
dummy_audio = np.zeros(102400, dtype=np.float32)
|
||||
model.vad((24000, dummy_audio), None)
|
||||
print(click.style("INFO", fg="green") + ":\t VAD model warmed up.")
|
||||
return model
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlgoOptions:
|
||||
@@ -94,12 +69,13 @@ class ReplyOnPause(StreamHandler):
|
||||
self,
|
||||
fn: ReplyFnGenerator,
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: SileroVadOptions | None = None,
|
||||
model_options: ModelOptions | None = None,
|
||||
can_interrupt: bool = True,
|
||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 480,
|
||||
input_sample_rate: int = 48000,
|
||||
model: PauseDetectionModel | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
expected_layout,
|
||||
@@ -111,7 +87,7 @@ class ReplyOnPause(StreamHandler):
|
||||
self.expected_layout: Literal["mono", "stereo"] = expected_layout
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_frame_size = output_frame_size
|
||||
self.model = get_vad_model()
|
||||
self.model = model or get_silero_model()
|
||||
self.fn = fn
|
||||
self.is_async = inspect.isasyncgenfunction(fn)
|
||||
self.event = Event()
|
||||
@@ -145,7 +121,7 @@ class ReplyOnPause(StreamHandler):
|
||||
duration = len(audio) / sampling_rate
|
||||
|
||||
if duration >= self.algo_options.audio_chunk_duration:
|
||||
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
|
||||
dur_vad, _ = self.model.vad((sampling_rate, audio), self.model_options)
|
||||
logger.debug("VAD duration: %s", dur_vad)
|
||||
if (
|
||||
dur_vad > self.algo_options.started_talking_threshold
|
||||
|
||||
Reference in New Issue
Block a user