diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index d774b0e..1924616 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -1,4 +1,4 @@ -from .webrtc import StreamHandler, WebRTC from .reply_on_pause import ReplyOnPause +from .webrtc import StreamHandler, WebRTC __all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"] diff --git a/backend/gradio_webrtc/pause_detection/__init__.py b/backend/gradio_webrtc/pause_detection/__init__.py index 79e57ce..e4874b7 100644 --- a/backend/gradio_webrtc/pause_detection/__init__.py +++ b/backend/gradio_webrtc/pause_detection/__init__.py @@ -1,4 +1,3 @@ from .vad import SileroVADModel, SileroVadOptions - -__all__ = ["SileroVADModel", "SileroVadOptions"] \ No newline at end of file +__all__ = ["SileroVADModel", "SileroVadOptions"] diff --git a/backend/gradio_webrtc/pause_detection/vad.py b/backend/gradio_webrtc/pause_detection/vad.py index 18e1d3d..1ff911a 100644 --- a/backend/gradio_webrtc/pause_detection/vad.py +++ b/backend/gradio_webrtc/pause_detection/vad.py @@ -1,14 +1,16 @@ import logging import warnings from dataclasses import dataclass -from huggingface_hub import hf_hub_download - from typing import List import numpy as np +from huggingface_hub import hf_hub_download logger = logging.getLogger(__name__) +# The code below is adapted from https://github.com/snakers4/silero-vad. +# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py + @dataclass class SileroVadOptions: @@ -235,9 +237,10 @@ class SileroVADModel: return speeches def vad( - self, audio_tuple: tuple[int, np.ndarray], vad_parameters: None | SileroVadOptions + self, + audio_tuple: tuple[int, np.ndarray], + vad_parameters: None | SileroVadOptions, ) -> float: - sampling_rate, audio = audio_tuple logger.debug("VAD audio shape input: %s", audio.shape) try: @@ -245,7 +248,7 @@ class SileroVADModel: sr = 16000 if sr != sampling_rate: try: - import librosa # type: ignore + import librosa # type: ignore except ImportError as e: raise RuntimeError( "Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz" @@ -264,6 +267,7 @@ class SileroVADModel: except Exception as e: import math import traceback + logger.debug("VAD Exception: %s", str(e)) exec = traceback.format_exc() logger.debug("traceback %s", exec) diff --git a/backend/gradio_webrtc/reply_on_pause.py b/backend/gradio_webrtc/reply_on_pause.py index 640ea66..6864358 100644 --- a/backend/gradio_webrtc/reply_on_pause.py +++ b/backend/gradio_webrtc/reply_on_pause.py @@ -1,8 +1,8 @@ -from typing import Callable, Literal, Generator, cast -from functools import lru_cache from dataclasses import dataclass -from threading import Event +from functools import lru_cache from logging import getLogger +from threading import Event +from typing import Callable, Generator, Literal, cast import numpy as np @@ -13,6 +13,7 @@ logger = getLogger(__name__) counter = 0 + @lru_cache def get_vad_model() -> SileroVADModel: """Returns the VAD model instance.""" @@ -22,6 +23,7 @@ def get_vad_model() -> SileroVADModel: @dataclass class AlgoOptions: """Algorithm options.""" + audio_chunk_duration: float = 0.6 started_talking_threshold: float = 0.2 speech_threshold: float = 0.1 @@ -38,17 +40,27 @@ class AppState: buffer: np.ndarray | None = None -ReplyFnGenerator = Callable[[tuple[int, np.ndarray]], Generator[tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]], None, None]] +ReplyFnGenerator = Callable[ + [tuple[int, np.ndarray]], + Generator[ + tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]], + None, + None, + ], +] + class ReplyOnPause(StreamHandler): - def __init__(self, fn: ReplyFnGenerator, - algo_options: AlgoOptions | None = None, - model_options: SileroVadOptions | None = None, - expected_layout: Literal["mono", "stereo"] = "mono", - output_sample_rate: int = 24000, - output_frame_size: int = 960,): - super().__init__(expected_layout, - output_sample_rate, output_frame_size) + def __init__( + self, + fn: ReplyFnGenerator, + algo_options: AlgoOptions | None = None, + model_options: SileroVadOptions | None = None, + expected_layout: Literal["mono", "stereo"] = "mono", + output_sample_rate: int = 24000, + output_frame_size: int = 960, + ): + super().__init__(expected_layout, output_sample_rate, output_frame_size) self.expected_layout: Literal["mono", "stereo"] = expected_layout self.output_sample_rate = output_sample_rate self.output_frame_size = output_frame_size @@ -59,19 +71,30 @@ class ReplyOnPause(StreamHandler): self.generator = None self.model_options = model_options self.algo_options = algo_options or AlgoOptions() - + def copy(self): - return ReplyOnPause(self.fn, self.algo_options, self.model_options, - self.expected_layout, self.output_sample_rate, self.output_frame_size) - - def determine_pause(self, audio: np.ndarray, sampling_rate: int, state: AppState) -> bool: + return ReplyOnPause( + self.fn, + self.algo_options, + self.model_options, + self.expected_layout, + self.output_sample_rate, + self.output_frame_size, + ) + + def determine_pause( + self, audio: np.ndarray, sampling_rate: int, state: AppState + ) -> bool: """Take in the stream, determine if a pause happened""" duration = len(audio) / sampling_rate if duration >= self.algo_options.audio_chunk_duration: dur_vad = self.model.vad((sampling_rate, audio), self.model_options) logger.debug("VAD duration: %s", dur_vad) - if dur_vad > self.algo_options.started_talking_threshold and not state.started_talking: + if ( + dur_vad > self.algo_options.started_talking_threshold + and not state.started_talking + ): state.started_talking = True logger.debug("Started talking") if state.started_talking: @@ -84,7 +107,6 @@ class ReplyOnPause(StreamHandler): return True return False - def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None: frame_rate, array = audio array = np.squeeze(array) @@ -95,9 +117,10 @@ class ReplyOnPause(StreamHandler): else: state.buffer = np.concatenate((state.buffer, array)) - pause_detected = self.determine_pause(state.buffer, state.sampling_rate, self.state) + pause_detected = self.determine_pause( + state.buffer, state.sampling_rate, self.state + ) state.pause_detected = pause_detected - def receive(self, frame: tuple[int, np.ndarray]) -> None: if self.state.responding: @@ -123,6 +146,3 @@ class ReplyOnPause(StreamHandler): return next(self.generator) except StopIteration: self.reset() - - - diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index 8762c1b..c380fbe 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -55,7 +55,7 @@ async def player_worker_decode( # Convert to audio frame and resample # This runs in the same timeout context - frame = av.AudioFrame.from_ndarray( # type: ignore + frame = av.AudioFrame.from_ndarray( # type: ignore audio_array, format=format, layout=layout ) frame.sample_rate = sample_rate diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index e210f55..3f82197 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -10,8 +10,8 @@ import time import traceback from abc import ABC, abstractmethod from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast from copy import deepcopy +from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast import anyio.to_thread import av @@ -122,7 +122,9 @@ class StreamHandler(ABC): try: return deepcopy(self) except Exception: - raise ValueError("Current StreamHandler implementation cannot be deepcopied. Implement the copy method.") + raise ValueError( + "Current StreamHandler implementation cannot be deepcopied. Implement the copy method." + ) def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]: if self._resampler is None: