mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add code
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from typing import Callable, Literal, Generator, cast
|
||||
from functools import lru_cache
|
||||
from dataclasses import dataclass
|
||||
from threading import Event
|
||||
from functools import lru_cache
|
||||
from logging import getLogger
|
||||
from threading import Event
|
||||
from typing import Callable, Generator, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -13,6 +13,7 @@ logger = getLogger(__name__)
|
||||
|
||||
counter = 0
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_vad_model() -> SileroVADModel:
|
||||
"""Returns the VAD model instance."""
|
||||
@@ -22,6 +23,7 @@ def get_vad_model() -> SileroVADModel:
|
||||
@dataclass
|
||||
class AlgoOptions:
|
||||
"""Algorithm options."""
|
||||
|
||||
audio_chunk_duration: float = 0.6
|
||||
started_talking_threshold: float = 0.2
|
||||
speech_threshold: float = 0.1
|
||||
@@ -38,17 +40,27 @@ class AppState:
|
||||
buffer: np.ndarray | None = None
|
||||
|
||||
|
||||
ReplyFnGenerator = Callable[[tuple[int, np.ndarray]], Generator[tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]], None, None]]
|
||||
ReplyFnGenerator = Callable[
|
||||
[tuple[int, np.ndarray]],
|
||||
Generator[
|
||||
tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]],
|
||||
None,
|
||||
None,
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
class ReplyOnPause(StreamHandler):
|
||||
def __init__(self, fn: ReplyFnGenerator,
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: SileroVadOptions | None = None,
|
||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 960,):
|
||||
super().__init__(expected_layout,
|
||||
output_sample_rate, output_frame_size)
|
||||
def __init__(
|
||||
self,
|
||||
fn: ReplyFnGenerator,
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: SileroVadOptions | None = None,
|
||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 960,
|
||||
):
|
||||
super().__init__(expected_layout, output_sample_rate, output_frame_size)
|
||||
self.expected_layout: Literal["mono", "stereo"] = expected_layout
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_frame_size = output_frame_size
|
||||
@@ -59,19 +71,30 @@ class ReplyOnPause(StreamHandler):
|
||||
self.generator = None
|
||||
self.model_options = model_options
|
||||
self.algo_options = algo_options or AlgoOptions()
|
||||
|
||||
|
||||
def copy(self):
|
||||
return ReplyOnPause(self.fn, self.algo_options, self.model_options,
|
||||
self.expected_layout, self.output_sample_rate, self.output_frame_size)
|
||||
|
||||
def determine_pause(self, audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
|
||||
return ReplyOnPause(
|
||||
self.fn,
|
||||
self.algo_options,
|
||||
self.model_options,
|
||||
self.expected_layout,
|
||||
self.output_sample_rate,
|
||||
self.output_frame_size,
|
||||
)
|
||||
|
||||
def determine_pause(
|
||||
self, audio: np.ndarray, sampling_rate: int, state: AppState
|
||||
) -> bool:
|
||||
"""Take in the stream, determine if a pause happened"""
|
||||
duration = len(audio) / sampling_rate
|
||||
|
||||
if duration >= self.algo_options.audio_chunk_duration:
|
||||
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
|
||||
logger.debug("VAD duration: %s", dur_vad)
|
||||
if dur_vad > self.algo_options.started_talking_threshold and not state.started_talking:
|
||||
if (
|
||||
dur_vad > self.algo_options.started_talking_threshold
|
||||
and not state.started_talking
|
||||
):
|
||||
state.started_talking = True
|
||||
logger.debug("Started talking")
|
||||
if state.started_talking:
|
||||
@@ -84,7 +107,6 @@ class ReplyOnPause(StreamHandler):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
|
||||
frame_rate, array = audio
|
||||
array = np.squeeze(array)
|
||||
@@ -95,9 +117,10 @@ class ReplyOnPause(StreamHandler):
|
||||
else:
|
||||
state.buffer = np.concatenate((state.buffer, array))
|
||||
|
||||
pause_detected = self.determine_pause(state.buffer, state.sampling_rate, self.state)
|
||||
pause_detected = self.determine_pause(
|
||||
state.buffer, state.sampling_rate, self.state
|
||||
)
|
||||
state.pause_detected = pause_detected
|
||||
|
||||
|
||||
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
if self.state.responding:
|
||||
@@ -123,6 +146,3 @@ class ReplyOnPause(StreamHandler):
|
||||
return next(self.generator)
|
||||
except StopIteration:
|
||||
self.reset()
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user