This commit is contained in:
freddyaboulton
2024-10-25 16:28:33 -07:00
parent a5dbaaf49b
commit 50611d3772
6 changed files with 60 additions and 35 deletions

View File

@@ -1,4 +1,4 @@
from .webrtc import StreamHandler, WebRTC
from .reply_on_pause import ReplyOnPause from .reply_on_pause import ReplyOnPause
from .webrtc import StreamHandler, WebRTC
__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"] __all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"]

View File

@@ -1,4 +1,3 @@
from .vad import SileroVADModel, SileroVadOptions from .vad import SileroVADModel, SileroVadOptions
__all__ = ["SileroVADModel", "SileroVadOptions"] __all__ = ["SileroVADModel", "SileroVadOptions"]

View File

@@ -1,14 +1,16 @@
import logging import logging
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from typing import List from typing import List
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The code below is adapted from https://github.com/snakers4/silero-vad.
# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py
@dataclass @dataclass
class SileroVadOptions: class SileroVadOptions:
@@ -235,9 +237,10 @@ class SileroVADModel:
return speeches return speeches
def vad( def vad(
self, audio_tuple: tuple[int, np.ndarray], vad_parameters: None | SileroVadOptions self,
audio_tuple: tuple[int, np.ndarray],
vad_parameters: None | SileroVadOptions,
) -> float: ) -> float:
sampling_rate, audio = audio_tuple sampling_rate, audio = audio_tuple
logger.debug("VAD audio shape input: %s", audio.shape) logger.debug("VAD audio shape input: %s", audio.shape)
try: try:
@@ -264,6 +267,7 @@ class SileroVADModel:
except Exception as e: except Exception as e:
import math import math
import traceback import traceback
logger.debug("VAD Exception: %s", str(e)) logger.debug("VAD Exception: %s", str(e))
exec = traceback.format_exc() exec = traceback.format_exc()
logger.debug("traceback %s", exec) logger.debug("traceback %s", exec)

View File

@@ -1,8 +1,8 @@
from typing import Callable, Literal, Generator, cast
from functools import lru_cache
from dataclasses import dataclass from dataclasses import dataclass
from threading import Event from functools import lru_cache
from logging import getLogger from logging import getLogger
from threading import Event
from typing import Callable, Generator, Literal, cast
import numpy as np import numpy as np
@@ -13,6 +13,7 @@ logger = getLogger(__name__)
counter = 0 counter = 0
@lru_cache @lru_cache
def get_vad_model() -> SileroVADModel: def get_vad_model() -> SileroVADModel:
"""Returns the VAD model instance.""" """Returns the VAD model instance."""
@@ -22,6 +23,7 @@ def get_vad_model() -> SileroVADModel:
@dataclass @dataclass
class AlgoOptions: class AlgoOptions:
"""Algorithm options.""" """Algorithm options."""
audio_chunk_duration: float = 0.6 audio_chunk_duration: float = 0.6
started_talking_threshold: float = 0.2 started_talking_threshold: float = 0.2
speech_threshold: float = 0.1 speech_threshold: float = 0.1
@@ -38,17 +40,27 @@ class AppState:
buffer: np.ndarray | None = None buffer: np.ndarray | None = None
ReplyFnGenerator = Callable[[tuple[int, np.ndarray]], Generator[tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]], None, None]] ReplyFnGenerator = Callable[
[tuple[int, np.ndarray]],
Generator[
tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]],
None,
None,
],
]
class ReplyOnPause(StreamHandler): class ReplyOnPause(StreamHandler):
def __init__(self, fn: ReplyFnGenerator, def __init__(
self,
fn: ReplyFnGenerator,
algo_options: AlgoOptions | None = None, algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None, model_options: SileroVadOptions | None = None,
expected_layout: Literal["mono", "stereo"] = "mono", expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000, output_sample_rate: int = 24000,
output_frame_size: int = 960,): output_frame_size: int = 960,
super().__init__(expected_layout, ):
output_sample_rate, output_frame_size) super().__init__(expected_layout, output_sample_rate, output_frame_size)
self.expected_layout: Literal["mono", "stereo"] = expected_layout self.expected_layout: Literal["mono", "stereo"] = expected_layout
self.output_sample_rate = output_sample_rate self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size self.output_frame_size = output_frame_size
@@ -61,17 +73,28 @@ class ReplyOnPause(StreamHandler):
self.algo_options = algo_options or AlgoOptions() self.algo_options = algo_options or AlgoOptions()
def copy(self): def copy(self):
return ReplyOnPause(self.fn, self.algo_options, self.model_options, return ReplyOnPause(
self.expected_layout, self.output_sample_rate, self.output_frame_size) self.fn,
self.algo_options,
self.model_options,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
)
def determine_pause(self, audio: np.ndarray, sampling_rate: int, state: AppState) -> bool: def determine_pause(
self, audio: np.ndarray, sampling_rate: int, state: AppState
) -> bool:
"""Take in the stream, determine if a pause happened""" """Take in the stream, determine if a pause happened"""
duration = len(audio) / sampling_rate duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration: if duration >= self.algo_options.audio_chunk_duration:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options) dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad) logger.debug("VAD duration: %s", dur_vad)
if dur_vad > self.algo_options.started_talking_threshold and not state.started_talking: if (
dur_vad > self.algo_options.started_talking_threshold
and not state.started_talking
):
state.started_talking = True state.started_talking = True
logger.debug("Started talking") logger.debug("Started talking")
if state.started_talking: if state.started_talking:
@@ -84,7 +107,6 @@ class ReplyOnPause(StreamHandler):
return True return True
return False return False
def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None: def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
frame_rate, array = audio frame_rate, array = audio
array = np.squeeze(array) array = np.squeeze(array)
@@ -95,10 +117,11 @@ class ReplyOnPause(StreamHandler):
else: else:
state.buffer = np.concatenate((state.buffer, array)) state.buffer = np.concatenate((state.buffer, array))
pause_detected = self.determine_pause(state.buffer, state.sampling_rate, self.state) pause_detected = self.determine_pause(
state.buffer, state.sampling_rate, self.state
)
state.pause_detected = pause_detected state.pause_detected = pause_detected
def receive(self, frame: tuple[int, np.ndarray]) -> None: def receive(self, frame: tuple[int, np.ndarray]) -> None:
if self.state.responding: if self.state.responding:
return return
@@ -123,6 +146,3 @@ class ReplyOnPause(StreamHandler):
return next(self.generator) return next(self.generator)
except StopIteration: except StopIteration:
self.reset() self.reset()

View File

@@ -10,8 +10,8 @@ import time
import traceback import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
import anyio.to_thread import anyio.to_thread
import av import av
@@ -122,7 +122,9 @@ class StreamHandler(ABC):
try: try:
return deepcopy(self) return deepcopy(self)
except Exception: except Exception:
raise ValueError("Current StreamHandler implementation cannot be deepcopied. Implement the copy method.") raise ValueError(
"Current StreamHandler implementation cannot be deepcopied. Implement the copy method."
)
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]: def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
if self._resampler is None: if self._resampler is None: