This commit is contained in:
freddyaboulton
2024-10-25 16:28:33 -07:00
parent a5dbaaf49b
commit 50611d3772
6 changed files with 60 additions and 35 deletions

View File

@@ -1,4 +1,4 @@
from .webrtc import StreamHandler, WebRTC
from .reply_on_pause import ReplyOnPause from .reply_on_pause import ReplyOnPause
from .webrtc import StreamHandler, WebRTC
__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"] __all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"]

View File

@@ -1,4 +1,3 @@
from .vad import SileroVADModel, SileroVadOptions from .vad import SileroVADModel, SileroVadOptions
__all__ = ["SileroVADModel", "SileroVadOptions"] __all__ = ["SileroVADModel", "SileroVadOptions"]

View File

@@ -1,14 +1,16 @@
import logging import logging
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from typing import List from typing import List
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The code below is adapted from https://github.com/snakers4/silero-vad.
# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py
@dataclass @dataclass
class SileroVadOptions: class SileroVadOptions:
@@ -235,9 +237,10 @@ class SileroVADModel:
return speeches return speeches
def vad( def vad(
self, audio_tuple: tuple[int, np.ndarray], vad_parameters: None | SileroVadOptions self,
audio_tuple: tuple[int, np.ndarray],
vad_parameters: None | SileroVadOptions,
) -> float: ) -> float:
sampling_rate, audio = audio_tuple sampling_rate, audio = audio_tuple
logger.debug("VAD audio shape input: %s", audio.shape) logger.debug("VAD audio shape input: %s", audio.shape)
try: try:
@@ -245,7 +248,7 @@ class SileroVADModel:
sr = 16000 sr = 16000
if sr != sampling_rate: if sr != sampling_rate:
try: try:
import librosa # type: ignore import librosa # type: ignore
except ImportError as e: except ImportError as e:
raise RuntimeError( raise RuntimeError(
"Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz" "Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz"
@@ -264,6 +267,7 @@ class SileroVADModel:
except Exception as e: except Exception as e:
import math import math
import traceback import traceback
logger.debug("VAD Exception: %s", str(e)) logger.debug("VAD Exception: %s", str(e))
exec = traceback.format_exc() exec = traceback.format_exc()
logger.debug("traceback %s", exec) logger.debug("traceback %s", exec)

View File

@@ -1,8 +1,8 @@
from typing import Callable, Literal, Generator, cast
from functools import lru_cache
from dataclasses import dataclass from dataclasses import dataclass
from threading import Event from functools import lru_cache
from logging import getLogger from logging import getLogger
from threading import Event
from typing import Callable, Generator, Literal, cast
import numpy as np import numpy as np
@@ -13,6 +13,7 @@ logger = getLogger(__name__)
counter = 0 counter = 0
@lru_cache @lru_cache
def get_vad_model() -> SileroVADModel: def get_vad_model() -> SileroVADModel:
"""Returns the VAD model instance.""" """Returns the VAD model instance."""
@@ -22,6 +23,7 @@ def get_vad_model() -> SileroVADModel:
@dataclass @dataclass
class AlgoOptions: class AlgoOptions:
"""Algorithm options.""" """Algorithm options."""
audio_chunk_duration: float = 0.6 audio_chunk_duration: float = 0.6
started_talking_threshold: float = 0.2 started_talking_threshold: float = 0.2
speech_threshold: float = 0.1 speech_threshold: float = 0.1
@@ -38,17 +40,27 @@ class AppState:
buffer: np.ndarray | None = None buffer: np.ndarray | None = None
ReplyFnGenerator = Callable[[tuple[int, np.ndarray]], Generator[tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]], None, None]] ReplyFnGenerator = Callable[
[tuple[int, np.ndarray]],
Generator[
tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]],
None,
None,
],
]
class ReplyOnPause(StreamHandler): class ReplyOnPause(StreamHandler):
def __init__(self, fn: ReplyFnGenerator, def __init__(
algo_options: AlgoOptions | None = None, self,
model_options: SileroVadOptions | None = None, fn: ReplyFnGenerator,
expected_layout: Literal["mono", "stereo"] = "mono", algo_options: AlgoOptions | None = None,
output_sample_rate: int = 24000, model_options: SileroVadOptions | None = None,
output_frame_size: int = 960,): expected_layout: Literal["mono", "stereo"] = "mono",
super().__init__(expected_layout, output_sample_rate: int = 24000,
output_sample_rate, output_frame_size) output_frame_size: int = 960,
):
super().__init__(expected_layout, output_sample_rate, output_frame_size)
self.expected_layout: Literal["mono", "stereo"] = expected_layout self.expected_layout: Literal["mono", "stereo"] = expected_layout
self.output_sample_rate = output_sample_rate self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size self.output_frame_size = output_frame_size
@@ -61,17 +73,28 @@ class ReplyOnPause(StreamHandler):
self.algo_options = algo_options or AlgoOptions() self.algo_options = algo_options or AlgoOptions()
def copy(self): def copy(self):
return ReplyOnPause(self.fn, self.algo_options, self.model_options, return ReplyOnPause(
self.expected_layout, self.output_sample_rate, self.output_frame_size) self.fn,
self.algo_options,
self.model_options,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
)
def determine_pause(self, audio: np.ndarray, sampling_rate: int, state: AppState) -> bool: def determine_pause(
self, audio: np.ndarray, sampling_rate: int, state: AppState
) -> bool:
"""Take in the stream, determine if a pause happened""" """Take in the stream, determine if a pause happened"""
duration = len(audio) / sampling_rate duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration: if duration >= self.algo_options.audio_chunk_duration:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options) dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad) logger.debug("VAD duration: %s", dur_vad)
if dur_vad > self.algo_options.started_talking_threshold and not state.started_talking: if (
dur_vad > self.algo_options.started_talking_threshold
and not state.started_talking
):
state.started_talking = True state.started_talking = True
logger.debug("Started talking") logger.debug("Started talking")
if state.started_talking: if state.started_talking:
@@ -84,7 +107,6 @@ class ReplyOnPause(StreamHandler):
return True return True
return False return False
def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None: def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
frame_rate, array = audio frame_rate, array = audio
array = np.squeeze(array) array = np.squeeze(array)
@@ -95,10 +117,11 @@ class ReplyOnPause(StreamHandler):
else: else:
state.buffer = np.concatenate((state.buffer, array)) state.buffer = np.concatenate((state.buffer, array))
pause_detected = self.determine_pause(state.buffer, state.sampling_rate, self.state) pause_detected = self.determine_pause(
state.buffer, state.sampling_rate, self.state
)
state.pause_detected = pause_detected state.pause_detected = pause_detected
def receive(self, frame: tuple[int, np.ndarray]) -> None: def receive(self, frame: tuple[int, np.ndarray]) -> None:
if self.state.responding: if self.state.responding:
return return
@@ -123,6 +146,3 @@ class ReplyOnPause(StreamHandler):
return next(self.generator) return next(self.generator)
except StopIteration: except StopIteration:
self.reset() self.reset()

View File

@@ -55,7 +55,7 @@ async def player_worker_decode(
# Convert to audio frame and resample # Convert to audio frame and resample
# This runs in the same timeout context # This runs in the same timeout context
frame = av.AudioFrame.from_ndarray( # type: ignore frame = av.AudioFrame.from_ndarray( # type: ignore
audio_array, format=format, layout=layout audio_array, format=format, layout=layout
) )
frame.sample_rate = sample_rate frame.sample_rate = sample_rate

View File

@@ -10,8 +10,8 @@ import time
import traceback import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
import anyio.to_thread import anyio.to_thread
import av import av
@@ -122,7 +122,9 @@ class StreamHandler(ABC):
try: try:
return deepcopy(self) return deepcopy(self)
except Exception: except Exception:
raise ValueError("Current StreamHandler implementation cannot be deepcopied. Implement the copy method.") raise ValueError(
"Current StreamHandler implementation cannot be deepcopied. Implement the copy method."
)
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]: def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
if self._resampler is None: if self._resampler is None: