mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
@@ -1,10 +1,13 @@
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from typing import List, Literal, overload
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..utils import AudioChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -76,7 +79,7 @@ class SileroVADModel:
|
||||
return h, c
|
||||
|
||||
@staticmethod
|
||||
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
|
||||
def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray:
|
||||
"""Collects and concatenates audio chunks."""
|
||||
if not chunks:
|
||||
return np.array([], dtype=np.float32)
|
||||
@@ -90,7 +93,7 @@ class SileroVADModel:
|
||||
audio: np.ndarray,
|
||||
vad_options: SileroVadOptions,
|
||||
**kwargs,
|
||||
) -> List[dict]:
|
||||
) -> List[AudioChunk]:
|
||||
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
||||
|
||||
Args:
|
||||
@@ -236,15 +239,33 @@ class SileroVADModel:
|
||||
|
||||
return speeches
|
||||
|
||||
@overload
|
||||
def vad(
|
||||
self,
|
||||
audio_tuple: tuple[int, np.ndarray],
|
||||
audio_tuple: tuple[int, NDArray],
|
||||
vad_parameters: None | SileroVadOptions,
|
||||
) -> float:
|
||||
return_chunks: Literal[True],
|
||||
) -> tuple[float, List[AudioChunk]]: ...
|
||||
|
||||
@overload
|
||||
def vad(
|
||||
self,
|
||||
audio_tuple: tuple[int, NDArray],
|
||||
vad_parameters: None | SileroVadOptions,
|
||||
return_chunks: bool = False,
|
||||
) -> float: ...
|
||||
|
||||
def vad(
|
||||
self,
|
||||
audio_tuple: tuple[int, NDArray],
|
||||
vad_parameters: None | SileroVadOptions,
|
||||
return_chunks: bool = False,
|
||||
) -> float | tuple[float, List[AudioChunk]]:
|
||||
sampling_rate, audio = audio_tuple
|
||||
logger.debug("VAD audio shape input: %s", audio.shape)
|
||||
try:
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
if audio.dtype != np.float32:
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
sr = 16000
|
||||
if sr != sampling_rate:
|
||||
try:
|
||||
@@ -262,7 +283,8 @@ class SileroVADModel:
|
||||
audio = self.collect_chunks(audio, speech_chunks)
|
||||
logger.debug("VAD audio shape: %s", audio.shape)
|
||||
duration_after_vad = audio.shape[0] / sr
|
||||
|
||||
if return_chunks:
|
||||
return duration_after_vad, speech_chunks
|
||||
return duration_after_vad
|
||||
except Exception as e:
|
||||
import math
|
||||
@@ -280,7 +302,7 @@ class SileroVADModel:
|
||||
raise ValueError(
|
||||
f"Too many dimensions for input audio chunk {len(x.shape)}"
|
||||
)
|
||||
if sr / x.shape[1] > 31.25:
|
||||
if sr / x.shape[1] > 31.25: # type: ignore
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
h, c = state
|
||||
|
||||
Reference in New Issue
Block a user