Merge pull request #11 from freddyaboulton/built-in-pause-detection

Built in pause detection, Add ReplyOnPause
This commit is contained in:
Freddy Boulton
2024-10-25 17:42:54 -07:00
committed by GitHub
8 changed files with 510 additions and 7 deletions

View File

@@ -15,6 +15,12 @@ Stream video and audio in real time with Gradio using WebRTC.
pip install gradio_webrtc
```
to use built-in pause detection (see [conversational ai](#conversational-ai)), install the `vad` extra:
```bash
pip install gradio_webrtc[vad]
```
## Examples:
1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷
2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥
@@ -176,7 +182,44 @@ if __name__ == "__main__":
* An audio frame is represented as a tuple of (frame_rate, audio_samples) where `audio_samples` is a numpy array of shape (num_channels, num_samples).
* You can also specify the audio layout ("mono" or "stereo") in the emit method by retuning it as the third element of the tuple. If not specified, the default is "mono".
* The `time_limit` parameter is the maximum time in seconds the conversation will run. If the time limit is reached, the audio stream will stop.
* The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return None.
* The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
An easy way to get started with Conversational AI is to use the `ReplyOnPause` stream handler. This will automatically run your function when the speaker has stopped speaking. In order to use `ReplyOnPause`, the `[vad]` extra dependencies must be installed.
```python
import gradio as gr
from gradio_webrtc import WebRTC, ReplyOnPause
def response(audio: tuple[int, np.ndarray]):
"""This function must yield audio frames"""
...
for numpy_array in generated_audio:
yield (sampling_rate, numpy_array, "mono")
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Chat (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Column():
with gr.Group():
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="send-receive",
modality="audio",
)
audio.stream(fn=ReplyOnPause(response), inputs=[audio], outputs=[audio], time_limit=60)
demo.launch(ssr_mode=False)
```
## Deployment

View File

@@ -1,3 +1,4 @@
from .reply_on_pause import ReplyOnPause
from .webrtc import StreamHandler, WebRTC
__all__ = ["StreamHandler", "WebRTC"]
__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"]

View File

@@ -0,0 +1,3 @@
from .vad import SileroVADModel, SileroVadOptions
__all__ = ["SileroVADModel", "SileroVadOptions"]

View File

@@ -0,0 +1,298 @@
import logging
import warnings
from dataclasses import dataclass
from typing import List
import numpy as np
from huggingface_hub import hf_hub_download
logger = logging.getLogger(__name__)
# The code below is adapted from https://github.com/snakers4/silero-vad.
# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py
@dataclass
class SileroVadOptions:
"""VAD options.
Attributes:
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
than max_speech_duration_s will be split at the timestamp of the last silence that
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
split aggressively just before max_speech_duration_s.
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
before separating it
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
Values other than these may affect model performance!!
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
speech_duration: If the length of the speech is less than this value, a pause will be detected.
"""
threshold: float = 0.5
min_speech_duration_ms: int = 250
max_speech_duration_s: float = float("inf")
min_silence_duration_ms: int = 2000
window_size_samples: int = 1024
speech_pad_ms: int = 400
class SileroVADModel:
@staticmethod
def download_model() -> str:
return hf_hub_download(
repo_id="freddyaboulton/silero-vad", filename="silero_vad.onnx"
)
def __init__(self):
try:
import onnxruntime
except ImportError as e:
raise RuntimeError(
"Applying the VAD filter requires the onnxruntime package"
) from e
path = self.download_model()
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
opts.log_severity_level = 4
self.session = onnxruntime.InferenceSession(
path,
providers=["CPUExecutionProvider"],
sess_options=opts,
)
def get_initial_state(self, batch_size: int):
h = np.zeros((2, batch_size, 64), dtype=np.float32)
c = np.zeros((2, batch_size, 64), dtype=np.float32)
return h, c
@staticmethod
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
"""Collects and concatenates audio chunks."""
if not chunks:
return np.array([], dtype=np.float32)
return np.concatenate(
[audio[chunk["start"] : chunk["end"]] for chunk in chunks]
)
def get_speech_timestamps(
self,
audio: np.ndarray,
vad_options: SileroVadOptions,
**kwargs,
) -> List[dict]:
"""This method is used for splitting long audios into speech chunks using silero VAD.
Args:
audio: One dimensional float array.
vad_options: Options for VAD processing.
kwargs: VAD options passed as keyword arguments for backward compatibility.
Returns:
List of dicts containing begin and end samples of each speech chunk.
"""
threshold = vad_options.threshold
min_speech_duration_ms = vad_options.min_speech_duration_ms
max_speech_duration_s = vad_options.max_speech_duration_s
min_silence_duration_ms = vad_options.min_silence_duration_ms
window_size_samples = vad_options.window_size_samples
speech_pad_ms = vad_options.speech_pad_ms
if window_size_samples not in [512, 1024, 1536]:
warnings.warn(
"Unusual window_size_samples! Supported window_size_samples:\n"
" - [512, 1024, 1536] for 16000 sampling_rate"
)
sampling_rate = 16000
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
max_speech_samples = (
sampling_rate * max_speech_duration_s
- window_size_samples
- 2 * speech_pad_samples
)
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
audio_length_samples = len(audio)
state = self.get_initial_state(batch_size=1)
speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[
current_start_sample : current_start_sample + window_size_samples
]
if len(chunk) < window_size_samples:
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob, state = self(chunk, state, sampling_rate)
speech_probs.append(speech_prob)
triggered = False
speeches = []
current_speech = {}
neg_threshold = threshold - 0.15
# to save potential segment end (and tolerate some silence)
temp_end = 0
# to save potential segment limits in case of maximum segment size reached
prev_end = next_start = 0
for i, speech_prob in enumerate(speech_probs):
if (speech_prob >= threshold) and temp_end:
temp_end = 0
if next_start < prev_end:
next_start = window_size_samples * i
if (speech_prob >= threshold) and not triggered:
triggered = True
current_speech["start"] = window_size_samples * i
continue
if (
triggered
and (window_size_samples * i) - current_speech["start"]
> max_speech_samples
):
if prev_end:
current_speech["end"] = prev_end
speeches.append(current_speech)
current_speech = {}
# previously reached silence (< neg_thres) and is still not speech (< thres)
if next_start < prev_end:
triggered = False
else:
current_speech["start"] = next_start
prev_end = next_start = temp_end = 0
else:
current_speech["end"] = window_size_samples * i
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if (speech_prob < neg_threshold) and triggered:
if not temp_end:
temp_end = window_size_samples * i
# condition to avoid cutting in very short silence
if (
window_size_samples * i
) - temp_end > min_silence_samples_at_max_speech:
prev_end = temp_end
if (window_size_samples * i) - temp_end < min_silence_samples:
continue
else:
current_speech["end"] = temp_end
if (
current_speech["end"] - current_speech["start"]
) > min_speech_samples:
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if (
current_speech
and (audio_length_samples - current_speech["start"]) > min_speech_samples
):
current_speech["end"] = audio_length_samples
speeches.append(current_speech)
for i, speech in enumerate(speeches):
if i == 0:
speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
if i != len(speeches) - 1:
silence_duration = speeches[i + 1]["start"] - speech["end"]
if silence_duration < 2 * speech_pad_samples:
speech["end"] += int(silence_duration // 2)
speeches[i + 1]["start"] = int(
max(0, speeches[i + 1]["start"] - silence_duration // 2)
)
else:
speech["end"] = int(
min(audio_length_samples, speech["end"] + speech_pad_samples)
)
speeches[i + 1]["start"] = int(
max(0, speeches[i + 1]["start"] - speech_pad_samples)
)
else:
speech["end"] = int(
min(audio_length_samples, speech["end"] + speech_pad_samples)
)
return speeches
def vad(
self,
audio_tuple: tuple[int, np.ndarray],
vad_parameters: None | SileroVadOptions,
) -> float:
sampling_rate, audio = audio_tuple
logger.debug("VAD audio shape input: %s", audio.shape)
try:
audio = audio.astype(np.float32) / 32768.0
sr = 16000
if sr != sampling_rate:
try:
import librosa # type: ignore
except ImportError as e:
raise RuntimeError(
"Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz"
) from e
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
if not vad_parameters:
vad_parameters = SileroVadOptions()
speech_chunks = self.get_speech_timestamps(audio, vad_parameters)
logger.debug("VAD speech chunks: %s", speech_chunks)
audio = self.collect_chunks(audio, speech_chunks)
logger.debug("VAD audio shape: %s", audio.shape)
duration_after_vad = audio.shape[0] / sr
return duration_after_vad
except Exception as e:
import math
import traceback
logger.debug("VAD Exception: %s", str(e))
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
return math.inf
def __call__(self, x, state, sr: int):
if len(x.shape) == 1:
x = np.expand_dims(x, 0)
if len(x.shape) > 2:
raise ValueError(
f"Too many dimensions for input audio chunk {len(x.shape)}"
)
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
h, c = state
ort_inputs = {
"input": x,
"h": h,
"c": c,
"sr": np.array(sr, dtype="int64"),
}
out, h, c = self.session.run(None, ort_inputs)
state = (h, c)
return out, state

View File

@@ -0,0 +1,148 @@
from dataclasses import dataclass
from functools import lru_cache
from logging import getLogger
from threading import Event
from typing import Callable, Generator, Literal, cast
import numpy as np
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
from gradio_webrtc.webrtc import StreamHandler
logger = getLogger(__name__)
counter = 0
@lru_cache
def get_vad_model() -> SileroVADModel:
"""Returns the VAD model instance."""
return SileroVADModel()
@dataclass
class AlgoOptions:
"""Algorithm options."""
audio_chunk_duration: float = 0.6
started_talking_threshold: float = 0.2
speech_threshold: float = 0.1
@dataclass
class AppState:
stream: np.ndarray | None = None
sampling_rate: int = 0
pause_detected: bool = False
started_talking: bool = False
responding: bool = False
stopped: bool = False
buffer: np.ndarray | None = None
ReplyFnGenerator = Callable[
[tuple[int, np.ndarray]],
Generator[
tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]],
None,
None,
],
]
class ReplyOnPause(StreamHandler):
def __init__(
self,
fn: ReplyFnGenerator,
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
):
super().__init__(expected_layout, output_sample_rate, output_frame_size)
self.expected_layout: Literal["mono", "stereo"] = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.model = get_vad_model()
self.fn = fn
self.event = Event()
self.state = AppState()
self.generator = None
self.model_options = model_options
self.algo_options = algo_options or AlgoOptions()
def copy(self):
return ReplyOnPause(
self.fn,
self.algo_options,
self.model_options,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
)
def determine_pause(
self, audio: np.ndarray, sampling_rate: int, state: AppState
) -> bool:
"""Take in the stream, determine if a pause happened"""
duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad)
if (
dur_vad > self.algo_options.started_talking_threshold
and not state.started_talking
):
state.started_talking = True
logger.debug("Started talking")
if state.started_talking:
if state.stream is None:
state.stream = audio
else:
state.stream = np.concatenate((state.stream, audio))
state.buffer = None
if dur_vad < self.algo_options.speech_threshold and state.started_talking:
return True
return False
def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
frame_rate, array = audio
array = np.squeeze(array)
if not state.sampling_rate:
state.sampling_rate = frame_rate
if state.buffer is None:
state.buffer = array
else:
state.buffer = np.concatenate((state.buffer, array))
pause_detected = self.determine_pause(
state.buffer, state.sampling_rate, self.state
)
state.pause_detected = pause_detected
def receive(self, frame: tuple[int, np.ndarray]) -> None:
if self.state.responding:
return
self.process_audio(frame, self.state)
if self.state.pause_detected:
self.event.set()
def reset(self):
self.generator = None
self.event.clear()
self.state = AppState()
def emit(self):
if not self.event.is_set():
return None
else:
if not self.generator:
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
self.generator = self.fn((self.state.sampling_rate, audio))
self.state.responding = True
try:
return next(self.generator)
except StopIteration:
self.reset()

View File

@@ -55,7 +55,7 @@ async def player_worker_decode(
# Convert to audio frame and resample
# This runs in the same timeout context
frame = av.AudioFrame.from_ndarray(
frame = av.AudioFrame.from_ndarray( # type: ignore
audio_array, format=format, layout=layout
)
frame.sample_rate = sample_rate

View File

@@ -10,6 +10,7 @@ import time
import traceback
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
import anyio.to_thread
@@ -117,6 +118,14 @@ class StreamHandler(ABC):
self.output_frame_size = output_frame_size
self._resampler = None
def copy(self) -> "StreamHandler":
try:
return deepcopy(self)
except Exception:
raise ValueError(
"Current StreamHandler implementation cannot be deepcopied. Implement the copy method."
)
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
if self._resampler is None:
self._resampler = av.AudioResampler( # type: ignore
@@ -622,7 +631,7 @@ class WebRTC(Component):
elif self.modality == "audio":
cb = AudioCallback(
relay.subscribe(track),
event_handler=cast(StreamHandler, self.event_handler),
event_handler=cast(StreamHandler, self.event_handler).copy(),
)
self.connections[body["webrtc_id"]] = cb
logger.debug("Adding track to peer connection %s", cb)

View File

@@ -8,12 +8,12 @@ build-backend = "hatchling.build"
[project]
name = "gradio_webrtc"
version = "0.0.8"
version = "0.0.9"
description = "Stream images in realtime with webrtc"
readme = "README.md"
license = "apache-2.0"
requires-python = ">=3.10"
authors = [{ name = "YOUR NAME", email = "YOUREMAIL@domain.com" }]
authors = [{ name = "Freddy Boulton", email = "YOUREMAIL@domain.com" }]
keywords = ["gradio-custom-component", "gradio-template-Video", "streaming", "webrtc", "realtime"]
# Add dependencies here
dependencies = ["gradio>=4.0,<6.0", "aiortc"]
@@ -22,10 +22,10 @@ classifiers = [
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Scientific/Engineering :: Visualization',
@@ -43,6 +43,7 @@ classifiers = [
[project.optional-dependencies]
dev = ["build", "twine"]
vad = ["librosa", "onnxruntime"]
[tool.hatch.build]
artifacts = ["/backend/gradio_webrtc/templates", "*.pyi"]