diff --git a/README.md b/README.md index d091fd8..078db92 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,12 @@ Stream video and audio in real time with Gradio using WebRTC. pip install gradio_webrtc ``` +to use built-in pause detection (see [conversational ai](#conversational-ai)), install the `vad` extra: + +```bash +pip install gradio_webrtc[vad] +``` + ## Examples: 1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷 2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥 @@ -176,7 +182,44 @@ if __name__ == "__main__": * An audio frame is represented as a tuple of (frame_rate, audio_samples) where `audio_samples` is a numpy array of shape (num_channels, num_samples). * You can also specify the audio layout ("mono" or "stereo") in the emit method by retuning it as the third element of the tuple. If not specified, the default is "mono". * The `time_limit` parameter is the maximum time in seconds the conversation will run. If the time limit is reached, the audio stream will stop. -* The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return None. +* The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`. + +An easy way to get started with Conversational AI is to use the `ReplyOnPause` stream handler. This will automatically run your function when the speaker has stopped speaking. In order to use `ReplyOnPause`, the `[vad]` extra dependencies must be installed. + +```python +import gradio as gr +from gradio_webrtc import WebRTC, ReplyOnPause + +def response(audio: tuple[int, np.ndarray]): + """This function must yield audio frames""" + ... + for numpy_array in generated_audio: + yield (sampling_rate, numpy_array, "mono") + + +with gr.Blocks() as demo: + gr.HTML( + """ +

+ Chat (Powered by WebRTC ⚡️) +

+ """ + ) + with gr.Column(): + with gr.Group(): + audio = WebRTC( + label="Stream", + rtc_configuration=rtc_configuration, + mode="send-receive", + modality="audio", + ) + audio.stream(fn=ReplyOnPause(response), inputs=[audio], outputs=[audio], time_limit=60) + + +demo.launch(ssr_mode=False) +``` + + ## Deployment diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index af3f6cd..1924616 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -1,3 +1,4 @@ +from .reply_on_pause import ReplyOnPause from .webrtc import StreamHandler, WebRTC -__all__ = ["StreamHandler", "WebRTC"] +__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"] diff --git a/backend/gradio_webrtc/pause_detection/__init__.py b/backend/gradio_webrtc/pause_detection/__init__.py new file mode 100644 index 0000000..e4874b7 --- /dev/null +++ b/backend/gradio_webrtc/pause_detection/__init__.py @@ -0,0 +1,3 @@ +from .vad import SileroVADModel, SileroVadOptions + +__all__ = ["SileroVADModel", "SileroVadOptions"] diff --git a/backend/gradio_webrtc/pause_detection/vad.py b/backend/gradio_webrtc/pause_detection/vad.py new file mode 100644 index 0000000..1ff911a --- /dev/null +++ b/backend/gradio_webrtc/pause_detection/vad.py @@ -0,0 +1,298 @@ +import logging +import warnings +from dataclasses import dataclass +from typing import List + +import numpy as np +from huggingface_hub import hf_hub_download + +logger = logging.getLogger(__name__) + +# The code below is adapted from https://github.com/snakers4/silero-vad. +# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py + + +@dataclass +class SileroVadOptions: + """VAD options. + + Attributes: + threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, + probabilities ABOVE this value are considered as SPEECH. It is better to tune this + parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. + min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. + max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer + than max_speech_duration_s will be split at the timestamp of the last silence that + lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be + split aggressively just before max_speech_duration_s. + min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms + before separating it + window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. + WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. + Values other than these may affect model performance!! + speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side + speech_duration: If the length of the speech is less than this value, a pause will be detected. + """ + + threshold: float = 0.5 + min_speech_duration_ms: int = 250 + max_speech_duration_s: float = float("inf") + min_silence_duration_ms: int = 2000 + window_size_samples: int = 1024 + speech_pad_ms: int = 400 + + +class SileroVADModel: + @staticmethod + def download_model() -> str: + return hf_hub_download( + repo_id="freddyaboulton/silero-vad", filename="silero_vad.onnx" + ) + + def __init__(self): + try: + import onnxruntime + except ImportError as e: + raise RuntimeError( + "Applying the VAD filter requires the onnxruntime package" + ) from e + + path = self.download_model() + + opts = onnxruntime.SessionOptions() + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 + opts.log_severity_level = 4 + + self.session = onnxruntime.InferenceSession( + path, + providers=["CPUExecutionProvider"], + sess_options=opts, + ) + + def get_initial_state(self, batch_size: int): + h = np.zeros((2, batch_size, 64), dtype=np.float32) + c = np.zeros((2, batch_size, 64), dtype=np.float32) + return h, c + + @staticmethod + def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: + """Collects and concatenates audio chunks.""" + if not chunks: + return np.array([], dtype=np.float32) + + return np.concatenate( + [audio[chunk["start"] : chunk["end"]] for chunk in chunks] + ) + + def get_speech_timestamps( + self, + audio: np.ndarray, + vad_options: SileroVadOptions, + **kwargs, + ) -> List[dict]: + """This method is used for splitting long audios into speech chunks using silero VAD. + + Args: + audio: One dimensional float array. + vad_options: Options for VAD processing. + kwargs: VAD options passed as keyword arguments for backward compatibility. + + Returns: + List of dicts containing begin and end samples of each speech chunk. + """ + + threshold = vad_options.threshold + min_speech_duration_ms = vad_options.min_speech_duration_ms + max_speech_duration_s = vad_options.max_speech_duration_s + min_silence_duration_ms = vad_options.min_silence_duration_ms + window_size_samples = vad_options.window_size_samples + speech_pad_ms = vad_options.speech_pad_ms + + if window_size_samples not in [512, 1024, 1536]: + warnings.warn( + "Unusual window_size_samples! Supported window_size_samples:\n" + " - [512, 1024, 1536] for 16000 sampling_rate" + ) + + sampling_rate = 16000 + min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 + speech_pad_samples = sampling_rate * speech_pad_ms / 1000 + max_speech_samples = ( + sampling_rate * max_speech_duration_s + - window_size_samples + - 2 * speech_pad_samples + ) + min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 + min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 + + audio_length_samples = len(audio) + + state = self.get_initial_state(batch_size=1) + + speech_probs = [] + for current_start_sample in range(0, audio_length_samples, window_size_samples): + chunk = audio[ + current_start_sample : current_start_sample + window_size_samples + ] + if len(chunk) < window_size_samples: + chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk)))) + speech_prob, state = self(chunk, state, sampling_rate) + speech_probs.append(speech_prob) + + triggered = False + speeches = [] + current_speech = {} + neg_threshold = threshold - 0.15 + + # to save potential segment end (and tolerate some silence) + temp_end = 0 + # to save potential segment limits in case of maximum segment size reached + prev_end = next_start = 0 + + for i, speech_prob in enumerate(speech_probs): + if (speech_prob >= threshold) and temp_end: + temp_end = 0 + if next_start < prev_end: + next_start = window_size_samples * i + + if (speech_prob >= threshold) and not triggered: + triggered = True + current_speech["start"] = window_size_samples * i + continue + + if ( + triggered + and (window_size_samples * i) - current_speech["start"] + > max_speech_samples + ): + if prev_end: + current_speech["end"] = prev_end + speeches.append(current_speech) + current_speech = {} + # previously reached silence (< neg_thres) and is still not speech (< thres) + if next_start < prev_end: + triggered = False + else: + current_speech["start"] = next_start + prev_end = next_start = temp_end = 0 + else: + current_speech["end"] = window_size_samples * i + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if (speech_prob < neg_threshold) and triggered: + if not temp_end: + temp_end = window_size_samples * i + # condition to avoid cutting in very short silence + if ( + window_size_samples * i + ) - temp_end > min_silence_samples_at_max_speech: + prev_end = temp_end + if (window_size_samples * i) - temp_end < min_silence_samples: + continue + else: + current_speech["end"] = temp_end + if ( + current_speech["end"] - current_speech["start"] + ) > min_speech_samples: + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if ( + current_speech + and (audio_length_samples - current_speech["start"]) > min_speech_samples + ): + current_speech["end"] = audio_length_samples + speeches.append(current_speech) + + for i, speech in enumerate(speeches): + if i == 0: + speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) + if i != len(speeches) - 1: + silence_duration = speeches[i + 1]["start"] - speech["end"] + if silence_duration < 2 * speech_pad_samples: + speech["end"] += int(silence_duration // 2) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - silence_duration // 2) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - speech_pad_samples) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + + return speeches + + def vad( + self, + audio_tuple: tuple[int, np.ndarray], + vad_parameters: None | SileroVadOptions, + ) -> float: + sampling_rate, audio = audio_tuple + logger.debug("VAD audio shape input: %s", audio.shape) + try: + audio = audio.astype(np.float32) / 32768.0 + sr = 16000 + if sr != sampling_rate: + try: + import librosa # type: ignore + except ImportError as e: + raise RuntimeError( + "Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz" + ) from e + audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) + + if not vad_parameters: + vad_parameters = SileroVadOptions() + speech_chunks = self.get_speech_timestamps(audio, vad_parameters) + logger.debug("VAD speech chunks: %s", speech_chunks) + audio = self.collect_chunks(audio, speech_chunks) + logger.debug("VAD audio shape: %s", audio.shape) + duration_after_vad = audio.shape[0] / sr + + return duration_after_vad + except Exception as e: + import math + import traceback + + logger.debug("VAD Exception: %s", str(e)) + exec = traceback.format_exc() + logger.debug("traceback %s", exec) + return math.inf + + def __call__(self, x, state, sr: int): + if len(x.shape) == 1: + x = np.expand_dims(x, 0) + if len(x.shape) > 2: + raise ValueError( + f"Too many dimensions for input audio chunk {len(x.shape)}" + ) + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + h, c = state + + ort_inputs = { + "input": x, + "h": h, + "c": c, + "sr": np.array(sr, dtype="int64"), + } + + out, h, c = self.session.run(None, ort_inputs) + state = (h, c) + + return out, state diff --git a/backend/gradio_webrtc/reply_on_pause.py b/backend/gradio_webrtc/reply_on_pause.py new file mode 100644 index 0000000..e429ecc --- /dev/null +++ b/backend/gradio_webrtc/reply_on_pause.py @@ -0,0 +1,148 @@ +from dataclasses import dataclass +from functools import lru_cache +from logging import getLogger +from threading import Event +from typing import Callable, Generator, Literal, cast + +import numpy as np + +from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions +from gradio_webrtc.webrtc import StreamHandler + +logger = getLogger(__name__) + +counter = 0 + + +@lru_cache +def get_vad_model() -> SileroVADModel: + """Returns the VAD model instance.""" + return SileroVADModel() + + +@dataclass +class AlgoOptions: + """Algorithm options.""" + + audio_chunk_duration: float = 0.6 + started_talking_threshold: float = 0.2 + speech_threshold: float = 0.1 + + +@dataclass +class AppState: + stream: np.ndarray | None = None + sampling_rate: int = 0 + pause_detected: bool = False + started_talking: bool = False + responding: bool = False + stopped: bool = False + buffer: np.ndarray | None = None + + +ReplyFnGenerator = Callable[ + [tuple[int, np.ndarray]], + Generator[ + tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]], + None, + None, + ], +] + + +class ReplyOnPause(StreamHandler): + def __init__( + self, + fn: ReplyFnGenerator, + algo_options: AlgoOptions | None = None, + model_options: SileroVadOptions | None = None, + expected_layout: Literal["mono", "stereo"] = "mono", + output_sample_rate: int = 24000, + output_frame_size: int = 480, + ): + super().__init__(expected_layout, output_sample_rate, output_frame_size) + self.expected_layout: Literal["mono", "stereo"] = expected_layout + self.output_sample_rate = output_sample_rate + self.output_frame_size = output_frame_size + self.model = get_vad_model() + self.fn = fn + self.event = Event() + self.state = AppState() + self.generator = None + self.model_options = model_options + self.algo_options = algo_options or AlgoOptions() + + def copy(self): + return ReplyOnPause( + self.fn, + self.algo_options, + self.model_options, + self.expected_layout, + self.output_sample_rate, + self.output_frame_size, + ) + + def determine_pause( + self, audio: np.ndarray, sampling_rate: int, state: AppState + ) -> bool: + """Take in the stream, determine if a pause happened""" + duration = len(audio) / sampling_rate + + if duration >= self.algo_options.audio_chunk_duration: + dur_vad = self.model.vad((sampling_rate, audio), self.model_options) + logger.debug("VAD duration: %s", dur_vad) + if ( + dur_vad > self.algo_options.started_talking_threshold + and not state.started_talking + ): + state.started_talking = True + logger.debug("Started talking") + if state.started_talking: + if state.stream is None: + state.stream = audio + else: + state.stream = np.concatenate((state.stream, audio)) + state.buffer = None + if dur_vad < self.algo_options.speech_threshold and state.started_talking: + return True + return False + + def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None: + frame_rate, array = audio + array = np.squeeze(array) + if not state.sampling_rate: + state.sampling_rate = frame_rate + if state.buffer is None: + state.buffer = array + else: + state.buffer = np.concatenate((state.buffer, array)) + + pause_detected = self.determine_pause( + state.buffer, state.sampling_rate, self.state + ) + state.pause_detected = pause_detected + + def receive(self, frame: tuple[int, np.ndarray]) -> None: + if self.state.responding: + return + self.process_audio(frame, self.state) + if self.state.pause_detected: + self.event.set() + + def reset(self): + self.generator = None + self.event.clear() + self.state = AppState() + + def emit(self): + if not self.event.is_set(): + return None + else: + if not self.generator: + audio = cast(np.ndarray, self.state.stream).reshape(1, -1) + self.generator = self.fn((self.state.sampling_rate, audio)) + self.state.responding = True + try: + return next(self.generator) + except StopIteration: + self.reset() diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index 7348eb2..c380fbe 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -55,7 +55,7 @@ async def player_worker_decode( # Convert to audio frame and resample # This runs in the same timeout context - frame = av.AudioFrame.from_ndarray( + frame = av.AudioFrame.from_ndarray( # type: ignore audio_array, format=format, layout=layout ) frame.sample_rate = sample_rate diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index 647b6f2..3f82197 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -10,6 +10,7 @@ import time import traceback from abc import ABC, abstractmethod from collections.abc import Callable +from copy import deepcopy from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast import anyio.to_thread @@ -117,6 +118,14 @@ class StreamHandler(ABC): self.output_frame_size = output_frame_size self._resampler = None + def copy(self) -> "StreamHandler": + try: + return deepcopy(self) + except Exception: + raise ValueError( + "Current StreamHandler implementation cannot be deepcopied. Implement the copy method." + ) + def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]: if self._resampler is None: self._resampler = av.AudioResampler( # type: ignore @@ -622,7 +631,7 @@ class WebRTC(Component): elif self.modality == "audio": cb = AudioCallback( relay.subscribe(track), - event_handler=cast(StreamHandler, self.event_handler), + event_handler=cast(StreamHandler, self.event_handler).copy(), ) self.connections[body["webrtc_id"]] = cb logger.debug("Adding track to peer connection %s", cb) diff --git a/pyproject.toml b/pyproject.toml index 0f8ab2b..5e12922 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,12 +8,12 @@ build-backend = "hatchling.build" [project] name = "gradio_webrtc" -version = "0.0.8" +version = "0.0.9" description = "Stream images in realtime with webrtc" readme = "README.md" license = "apache-2.0" requires-python = ">=3.10" -authors = [{ name = "YOUR NAME", email = "YOUREMAIL@domain.com" }] +authors = [{ name = "Freddy Boulton", email = "YOUREMAIL@domain.com" }] keywords = ["gradio-custom-component", "gradio-template-Video", "streaming", "webrtc", "realtime"] # Add dependencies here dependencies = ["gradio>=4.0,<6.0", "aiortc"] @@ -22,10 +22,10 @@ classifiers = [ 'Operating System :: OS Independent', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Visualization', @@ -43,6 +43,7 @@ classifiers = [ [project.optional-dependencies] dev = ["build", "twine"] +vad = ["librosa", "onnxruntime"] [tool.hatch.build] artifacts = ["/backend/gradio_webrtc/templates", "*.pyi"]