mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add code
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from .webrtc import StreamHandler, WebRTC
|
|
||||||
from .reply_on_pause import ReplyOnPause
|
from .reply_on_pause import ReplyOnPause
|
||||||
|
from .webrtc import StreamHandler, WebRTC
|
||||||
|
|
||||||
__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"]
|
__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"]
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from .vad import SileroVADModel, SileroVadOptions
|
from .vad import SileroVADModel, SileroVadOptions
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["SileroVADModel", "SileroVadOptions"]
|
__all__ = ["SileroVADModel", "SileroVadOptions"]
|
||||||
@@ -1,14 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# The code below is adapted from https://github.com/snakers4/silero-vad.
|
||||||
|
# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SileroVadOptions:
|
class SileroVadOptions:
|
||||||
@@ -235,9 +237,10 @@ class SileroVADModel:
|
|||||||
return speeches
|
return speeches
|
||||||
|
|
||||||
def vad(
|
def vad(
|
||||||
self, audio_tuple: tuple[int, np.ndarray], vad_parameters: None | SileroVadOptions
|
self,
|
||||||
|
audio_tuple: tuple[int, np.ndarray],
|
||||||
|
vad_parameters: None | SileroVadOptions,
|
||||||
) -> float:
|
) -> float:
|
||||||
|
|
||||||
sampling_rate, audio = audio_tuple
|
sampling_rate, audio = audio_tuple
|
||||||
logger.debug("VAD audio shape input: %s", audio.shape)
|
logger.debug("VAD audio shape input: %s", audio.shape)
|
||||||
try:
|
try:
|
||||||
@@ -245,7 +248,7 @@ class SileroVADModel:
|
|||||||
sr = 16000
|
sr = 16000
|
||||||
if sr != sampling_rate:
|
if sr != sampling_rate:
|
||||||
try:
|
try:
|
||||||
import librosa # type: ignore
|
import librosa # type: ignore
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz"
|
"Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz"
|
||||||
@@ -264,6 +267,7 @@ class SileroVADModel:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
import math
|
import math
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.debug("VAD Exception: %s", str(e))
|
logger.debug("VAD Exception: %s", str(e))
|
||||||
exec = traceback.format_exc()
|
exec = traceback.format_exc()
|
||||||
logger.debug("traceback %s", exec)
|
logger.debug("traceback %s", exec)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from typing import Callable, Literal, Generator, cast
|
|
||||||
from functools import lru_cache
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from threading import Event
|
from functools import lru_cache
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
from threading import Event
|
||||||
|
from typing import Callable, Generator, Literal, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -13,6 +13,7 @@ logger = getLogger(__name__)
|
|||||||
|
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_vad_model() -> SileroVADModel:
|
def get_vad_model() -> SileroVADModel:
|
||||||
"""Returns the VAD model instance."""
|
"""Returns the VAD model instance."""
|
||||||
@@ -22,6 +23,7 @@ def get_vad_model() -> SileroVADModel:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AlgoOptions:
|
class AlgoOptions:
|
||||||
"""Algorithm options."""
|
"""Algorithm options."""
|
||||||
|
|
||||||
audio_chunk_duration: float = 0.6
|
audio_chunk_duration: float = 0.6
|
||||||
started_talking_threshold: float = 0.2
|
started_talking_threshold: float = 0.2
|
||||||
speech_threshold: float = 0.1
|
speech_threshold: float = 0.1
|
||||||
@@ -38,17 +40,27 @@ class AppState:
|
|||||||
buffer: np.ndarray | None = None
|
buffer: np.ndarray | None = None
|
||||||
|
|
||||||
|
|
||||||
ReplyFnGenerator = Callable[[tuple[int, np.ndarray]], Generator[tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]], None, None]]
|
ReplyFnGenerator = Callable[
|
||||||
|
[tuple[int, np.ndarray]],
|
||||||
|
Generator[
|
||||||
|
tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ReplyOnPause(StreamHandler):
|
class ReplyOnPause(StreamHandler):
|
||||||
def __init__(self, fn: ReplyFnGenerator,
|
def __init__(
|
||||||
algo_options: AlgoOptions | None = None,
|
self,
|
||||||
model_options: SileroVadOptions | None = None,
|
fn: ReplyFnGenerator,
|
||||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
algo_options: AlgoOptions | None = None,
|
||||||
output_sample_rate: int = 24000,
|
model_options: SileroVadOptions | None = None,
|
||||||
output_frame_size: int = 960,):
|
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||||
super().__init__(expected_layout,
|
output_sample_rate: int = 24000,
|
||||||
output_sample_rate, output_frame_size)
|
output_frame_size: int = 960,
|
||||||
|
):
|
||||||
|
super().__init__(expected_layout, output_sample_rate, output_frame_size)
|
||||||
self.expected_layout: Literal["mono", "stereo"] = expected_layout
|
self.expected_layout: Literal["mono", "stereo"] = expected_layout
|
||||||
self.output_sample_rate = output_sample_rate
|
self.output_sample_rate = output_sample_rate
|
||||||
self.output_frame_size = output_frame_size
|
self.output_frame_size = output_frame_size
|
||||||
@@ -61,17 +73,28 @@ class ReplyOnPause(StreamHandler):
|
|||||||
self.algo_options = algo_options or AlgoOptions()
|
self.algo_options = algo_options or AlgoOptions()
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return ReplyOnPause(self.fn, self.algo_options, self.model_options,
|
return ReplyOnPause(
|
||||||
self.expected_layout, self.output_sample_rate, self.output_frame_size)
|
self.fn,
|
||||||
|
self.algo_options,
|
||||||
|
self.model_options,
|
||||||
|
self.expected_layout,
|
||||||
|
self.output_sample_rate,
|
||||||
|
self.output_frame_size,
|
||||||
|
)
|
||||||
|
|
||||||
def determine_pause(self, audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
|
def determine_pause(
|
||||||
|
self, audio: np.ndarray, sampling_rate: int, state: AppState
|
||||||
|
) -> bool:
|
||||||
"""Take in the stream, determine if a pause happened"""
|
"""Take in the stream, determine if a pause happened"""
|
||||||
duration = len(audio) / sampling_rate
|
duration = len(audio) / sampling_rate
|
||||||
|
|
||||||
if duration >= self.algo_options.audio_chunk_duration:
|
if duration >= self.algo_options.audio_chunk_duration:
|
||||||
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
|
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
|
||||||
logger.debug("VAD duration: %s", dur_vad)
|
logger.debug("VAD duration: %s", dur_vad)
|
||||||
if dur_vad > self.algo_options.started_talking_threshold and not state.started_talking:
|
if (
|
||||||
|
dur_vad > self.algo_options.started_talking_threshold
|
||||||
|
and not state.started_talking
|
||||||
|
):
|
||||||
state.started_talking = True
|
state.started_talking = True
|
||||||
logger.debug("Started talking")
|
logger.debug("Started talking")
|
||||||
if state.started_talking:
|
if state.started_talking:
|
||||||
@@ -84,7 +107,6 @@ class ReplyOnPause(StreamHandler):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
|
def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
|
||||||
frame_rate, array = audio
|
frame_rate, array = audio
|
||||||
array = np.squeeze(array)
|
array = np.squeeze(array)
|
||||||
@@ -95,10 +117,11 @@ class ReplyOnPause(StreamHandler):
|
|||||||
else:
|
else:
|
||||||
state.buffer = np.concatenate((state.buffer, array))
|
state.buffer = np.concatenate((state.buffer, array))
|
||||||
|
|
||||||
pause_detected = self.determine_pause(state.buffer, state.sampling_rate, self.state)
|
pause_detected = self.determine_pause(
|
||||||
|
state.buffer, state.sampling_rate, self.state
|
||||||
|
)
|
||||||
state.pause_detected = pause_detected
|
state.pause_detected = pause_detected
|
||||||
|
|
||||||
|
|
||||||
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||||
if self.state.responding:
|
if self.state.responding:
|
||||||
return
|
return
|
||||||
@@ -123,6 +146,3 @@ class ReplyOnPause(StreamHandler):
|
|||||||
return next(self.generator)
|
return next(self.generator)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ async def player_worker_decode(
|
|||||||
|
|
||||||
# Convert to audio frame and resample
|
# Convert to audio frame and resample
|
||||||
# This runs in the same timeout context
|
# This runs in the same timeout context
|
||||||
frame = av.AudioFrame.from_ndarray( # type: ignore
|
frame = av.AudioFrame.from_ndarray( # type: ignore
|
||||||
audio_array, format=format, layout=layout
|
audio_array, format=format, layout=layout
|
||||||
)
|
)
|
||||||
frame.sample_rate = sample_rate
|
frame.sample_rate = sample_rate
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
|
||||||
|
|
||||||
import anyio.to_thread
|
import anyio.to_thread
|
||||||
import av
|
import av
|
||||||
@@ -122,7 +122,9 @@ class StreamHandler(ABC):
|
|||||||
try:
|
try:
|
||||||
return deepcopy(self)
|
return deepcopy(self)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("Current StreamHandler implementation cannot be deepcopied. Implement the copy method.")
|
raise ValueError(
|
||||||
|
"Current StreamHandler implementation cannot be deepcopied. Implement the copy method."
|
||||||
|
)
|
||||||
|
|
||||||
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
|
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
|
||||||
if self._resampler is None:
|
if self._resampler is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user