From a5dbaaf49b8ae6332ed721a25b9d25fbb28fcf61 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 25 Oct 2024 16:24:17 -0700 Subject: [PATCH 1/6] first cut --- backend/gradio_webrtc/__init__.py | 3 +- .../gradio_webrtc/pause_detection/__init__.py | 4 + backend/gradio_webrtc/pause_detection/vad.py | 294 ++++++++++++++++++ backend/gradio_webrtc/reply_on_pause.py | 128 ++++++++ backend/gradio_webrtc/utils.py | 2 +- backend/gradio_webrtc/webrtc.py | 9 +- pyproject.toml | 4 +- 7 files changed, 439 insertions(+), 5 deletions(-) create mode 100644 backend/gradio_webrtc/pause_detection/__init__.py create mode 100644 backend/gradio_webrtc/pause_detection/vad.py create mode 100644 backend/gradio_webrtc/reply_on_pause.py diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index af3f6cd..d774b0e 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -1,3 +1,4 @@ from .webrtc import StreamHandler, WebRTC +from .reply_on_pause import ReplyOnPause -__all__ = ["StreamHandler", "WebRTC"] +__all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"] diff --git a/backend/gradio_webrtc/pause_detection/__init__.py b/backend/gradio_webrtc/pause_detection/__init__.py new file mode 100644 index 0000000..79e57ce --- /dev/null +++ b/backend/gradio_webrtc/pause_detection/__init__.py @@ -0,0 +1,4 @@ +from .vad import SileroVADModel, SileroVadOptions + + +__all__ = ["SileroVADModel", "SileroVadOptions"] \ No newline at end of file diff --git a/backend/gradio_webrtc/pause_detection/vad.py b/backend/gradio_webrtc/pause_detection/vad.py new file mode 100644 index 0000000..18e1d3d --- /dev/null +++ b/backend/gradio_webrtc/pause_detection/vad.py @@ -0,0 +1,294 @@ +import logging +import warnings +from dataclasses import dataclass +from huggingface_hub import hf_hub_download + +from typing import List + +import numpy as np + +logger = logging.getLogger(__name__) + + +@dataclass +class SileroVadOptions: + """VAD options. + + Attributes: + threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, + probabilities ABOVE this value are considered as SPEECH. It is better to tune this + parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. + min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. + max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer + than max_speech_duration_s will be split at the timestamp of the last silence that + lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be + split aggressively just before max_speech_duration_s. + min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms + before separating it + window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. + WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. + Values other than these may affect model performance!! + speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side + speech_duration: If the length of the speech is less than this value, a pause will be detected. + """ + + threshold: float = 0.5 + min_speech_duration_ms: int = 250 + max_speech_duration_s: float = float("inf") + min_silence_duration_ms: int = 2000 + window_size_samples: int = 1024 + speech_pad_ms: int = 400 + + +class SileroVADModel: + @staticmethod + def download_model() -> str: + return hf_hub_download( + repo_id="freddyaboulton/silero-vad", filename="silero_vad.onnx" + ) + + def __init__(self): + try: + import onnxruntime + except ImportError as e: + raise RuntimeError( + "Applying the VAD filter requires the onnxruntime package" + ) from e + + path = self.download_model() + + opts = onnxruntime.SessionOptions() + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 + opts.log_severity_level = 4 + + self.session = onnxruntime.InferenceSession( + path, + providers=["CPUExecutionProvider"], + sess_options=opts, + ) + + def get_initial_state(self, batch_size: int): + h = np.zeros((2, batch_size, 64), dtype=np.float32) + c = np.zeros((2, batch_size, 64), dtype=np.float32) + return h, c + + @staticmethod + def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: + """Collects and concatenates audio chunks.""" + if not chunks: + return np.array([], dtype=np.float32) + + return np.concatenate( + [audio[chunk["start"] : chunk["end"]] for chunk in chunks] + ) + + def get_speech_timestamps( + self, + audio: np.ndarray, + vad_options: SileroVadOptions, + **kwargs, + ) -> List[dict]: + """This method is used for splitting long audios into speech chunks using silero VAD. + + Args: + audio: One dimensional float array. + vad_options: Options for VAD processing. + kwargs: VAD options passed as keyword arguments for backward compatibility. + + Returns: + List of dicts containing begin and end samples of each speech chunk. + """ + + threshold = vad_options.threshold + min_speech_duration_ms = vad_options.min_speech_duration_ms + max_speech_duration_s = vad_options.max_speech_duration_s + min_silence_duration_ms = vad_options.min_silence_duration_ms + window_size_samples = vad_options.window_size_samples + speech_pad_ms = vad_options.speech_pad_ms + + if window_size_samples not in [512, 1024, 1536]: + warnings.warn( + "Unusual window_size_samples! Supported window_size_samples:\n" + " - [512, 1024, 1536] for 16000 sampling_rate" + ) + + sampling_rate = 16000 + min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 + speech_pad_samples = sampling_rate * speech_pad_ms / 1000 + max_speech_samples = ( + sampling_rate * max_speech_duration_s + - window_size_samples + - 2 * speech_pad_samples + ) + min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 + min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 + + audio_length_samples = len(audio) + + state = self.get_initial_state(batch_size=1) + + speech_probs = [] + for current_start_sample in range(0, audio_length_samples, window_size_samples): + chunk = audio[ + current_start_sample : current_start_sample + window_size_samples + ] + if len(chunk) < window_size_samples: + chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk)))) + speech_prob, state = self(chunk, state, sampling_rate) + speech_probs.append(speech_prob) + + triggered = False + speeches = [] + current_speech = {} + neg_threshold = threshold - 0.15 + + # to save potential segment end (and tolerate some silence) + temp_end = 0 + # to save potential segment limits in case of maximum segment size reached + prev_end = next_start = 0 + + for i, speech_prob in enumerate(speech_probs): + if (speech_prob >= threshold) and temp_end: + temp_end = 0 + if next_start < prev_end: + next_start = window_size_samples * i + + if (speech_prob >= threshold) and not triggered: + triggered = True + current_speech["start"] = window_size_samples * i + continue + + if ( + triggered + and (window_size_samples * i) - current_speech["start"] + > max_speech_samples + ): + if prev_end: + current_speech["end"] = prev_end + speeches.append(current_speech) + current_speech = {} + # previously reached silence (< neg_thres) and is still not speech (< thres) + if next_start < prev_end: + triggered = False + else: + current_speech["start"] = next_start + prev_end = next_start = temp_end = 0 + else: + current_speech["end"] = window_size_samples * i + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if (speech_prob < neg_threshold) and triggered: + if not temp_end: + temp_end = window_size_samples * i + # condition to avoid cutting in very short silence + if ( + window_size_samples * i + ) - temp_end > min_silence_samples_at_max_speech: + prev_end = temp_end + if (window_size_samples * i) - temp_end < min_silence_samples: + continue + else: + current_speech["end"] = temp_end + if ( + current_speech["end"] - current_speech["start"] + ) > min_speech_samples: + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if ( + current_speech + and (audio_length_samples - current_speech["start"]) > min_speech_samples + ): + current_speech["end"] = audio_length_samples + speeches.append(current_speech) + + for i, speech in enumerate(speeches): + if i == 0: + speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) + if i != len(speeches) - 1: + silence_duration = speeches[i + 1]["start"] - speech["end"] + if silence_duration < 2 * speech_pad_samples: + speech["end"] += int(silence_duration // 2) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - silence_duration // 2) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - speech_pad_samples) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + + return speeches + + def vad( + self, audio_tuple: tuple[int, np.ndarray], vad_parameters: None | SileroVadOptions + ) -> float: + + sampling_rate, audio = audio_tuple + logger.debug("VAD audio shape input: %s", audio.shape) + try: + audio = audio.astype(np.float32) / 32768.0 + sr = 16000 + if sr != sampling_rate: + try: + import librosa # type: ignore + except ImportError as e: + raise RuntimeError( + "Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz" + ) from e + audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) + + if not vad_parameters: + vad_parameters = SileroVadOptions() + speech_chunks = self.get_speech_timestamps(audio, vad_parameters) + logger.debug("VAD speech chunks: %s", speech_chunks) + audio = self.collect_chunks(audio, speech_chunks) + logger.debug("VAD audio shape: %s", audio.shape) + duration_after_vad = audio.shape[0] / sr + + return duration_after_vad + except Exception as e: + import math + import traceback + logger.debug("VAD Exception: %s", str(e)) + exec = traceback.format_exc() + logger.debug("traceback %s", exec) + return math.inf + + def __call__(self, x, state, sr: int): + if len(x.shape) == 1: + x = np.expand_dims(x, 0) + if len(x.shape) > 2: + raise ValueError( + f"Too many dimensions for input audio chunk {len(x.shape)}" + ) + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + h, c = state + + ort_inputs = { + "input": x, + "h": h, + "c": c, + "sr": np.array(sr, dtype="int64"), + } + + out, h, c = self.session.run(None, ort_inputs) + state = (h, c) + + return out, state diff --git a/backend/gradio_webrtc/reply_on_pause.py b/backend/gradio_webrtc/reply_on_pause.py new file mode 100644 index 0000000..640ea66 --- /dev/null +++ b/backend/gradio_webrtc/reply_on_pause.py @@ -0,0 +1,128 @@ +from typing import Callable, Literal, Generator, cast +from functools import lru_cache +from dataclasses import dataclass +from threading import Event +from logging import getLogger + +import numpy as np + +from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions +from gradio_webrtc.webrtc import StreamHandler + +logger = getLogger(__name__) + +counter = 0 + +@lru_cache +def get_vad_model() -> SileroVADModel: + """Returns the VAD model instance.""" + return SileroVADModel() + + +@dataclass +class AlgoOptions: + """Algorithm options.""" + audio_chunk_duration: float = 0.6 + started_talking_threshold: float = 0.2 + speech_threshold: float = 0.1 + + +@dataclass +class AppState: + stream: np.ndarray | None = None + sampling_rate: int = 0 + pause_detected: bool = False + started_talking: bool = False + responding: bool = False + stopped: bool = False + buffer: np.ndarray | None = None + + +ReplyFnGenerator = Callable[[tuple[int, np.ndarray]], Generator[tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]], None, None]] + +class ReplyOnPause(StreamHandler): + def __init__(self, fn: ReplyFnGenerator, + algo_options: AlgoOptions | None = None, + model_options: SileroVadOptions | None = None, + expected_layout: Literal["mono", "stereo"] = "mono", + output_sample_rate: int = 24000, + output_frame_size: int = 960,): + super().__init__(expected_layout, + output_sample_rate, output_frame_size) + self.expected_layout: Literal["mono", "stereo"] = expected_layout + self.output_sample_rate = output_sample_rate + self.output_frame_size = output_frame_size + self.model = get_vad_model() + self.fn = fn + self.event = Event() + self.state = AppState() + self.generator = None + self.model_options = model_options + self.algo_options = algo_options or AlgoOptions() + + def copy(self): + return ReplyOnPause(self.fn, self.algo_options, self.model_options, + self.expected_layout, self.output_sample_rate, self.output_frame_size) + + def determine_pause(self, audio: np.ndarray, sampling_rate: int, state: AppState) -> bool: + """Take in the stream, determine if a pause happened""" + duration = len(audio) / sampling_rate + + if duration >= self.algo_options.audio_chunk_duration: + dur_vad = self.model.vad((sampling_rate, audio), self.model_options) + logger.debug("VAD duration: %s", dur_vad) + if dur_vad > self.algo_options.started_talking_threshold and not state.started_talking: + state.started_talking = True + logger.debug("Started talking") + if state.started_talking: + if state.stream is None: + state.stream = audio + else: + state.stream = np.concatenate((state.stream, audio)) + state.buffer = None + if dur_vad < self.algo_options.speech_threshold and state.started_talking: + return True + return False + + + def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None: + frame_rate, array = audio + array = np.squeeze(array) + if not state.sampling_rate: + state.sampling_rate = frame_rate + if state.buffer is None: + state.buffer = array + else: + state.buffer = np.concatenate((state.buffer, array)) + + pause_detected = self.determine_pause(state.buffer, state.sampling_rate, self.state) + state.pause_detected = pause_detected + + + def receive(self, frame: tuple[int, np.ndarray]) -> None: + if self.state.responding: + return + self.process_audio(frame, self.state) + if self.state.pause_detected: + self.event.set() + + def reset(self): + self.generator = None + self.event.clear() + self.state = AppState() + + def emit(self): + if not self.event.is_set(): + return None + else: + if not self.generator: + audio = cast(np.ndarray, self.state.stream).reshape(1, -1) + self.generator = self.fn((self.state.sampling_rate, audio)) + self.state.responding = True + try: + return next(self.generator) + except StopIteration: + self.reset() + + + diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index 7348eb2..8762c1b 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -55,7 +55,7 @@ async def player_worker_decode( # Convert to audio frame and resample # This runs in the same timeout context - frame = av.AudioFrame.from_ndarray( + frame = av.AudioFrame.from_ndarray( # type: ignore audio_array, format=format, layout=layout ) frame.sample_rate = sample_rate diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index 647b6f2..e210f55 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -11,6 +11,7 @@ import traceback from abc import ABC, abstractmethod from collections.abc import Callable from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast +from copy import deepcopy import anyio.to_thread import av @@ -117,6 +118,12 @@ class StreamHandler(ABC): self.output_frame_size = output_frame_size self._resampler = None + def copy(self) -> "StreamHandler": + try: + return deepcopy(self) + except Exception: + raise ValueError("Current StreamHandler implementation cannot be deepcopied. Implement the copy method.") + def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]: if self._resampler is None: self._resampler = av.AudioResampler( # type: ignore @@ -622,7 +629,7 @@ class WebRTC(Component): elif self.modality == "audio": cb = AudioCallback( relay.subscribe(track), - event_handler=cast(StreamHandler, self.event_handler), + event_handler=cast(StreamHandler, self.event_handler).copy(), ) self.connections[body["webrtc_id"]] = cb logger.debug("Adding track to peer connection %s", cb) diff --git a/pyproject.toml b/pyproject.toml index 0f8ab2b..0b331aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ description = "Stream images in realtime with webrtc" readme = "README.md" license = "apache-2.0" requires-python = ">=3.10" -authors = [{ name = "YOUR NAME", email = "YOUREMAIL@domain.com" }] +authors = [{ name = "Freddy Boulton", email = "YOUREMAIL@domain.com" }] keywords = ["gradio-custom-component", "gradio-template-Video", "streaming", "webrtc", "realtime"] # Add dependencies here dependencies = ["gradio>=4.0,<6.0", "aiortc"] @@ -22,10 +22,10 @@ classifiers = [ 'Operating System :: OS Independent', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Visualization', From 50611d3772b91ec4e143c0f9c3cfc47dab66f38e Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 25 Oct 2024 16:28:33 -0700 Subject: [PATCH 2/6] Add code --- backend/gradio_webrtc/__init__.py | 2 +- .../gradio_webrtc/pause_detection/__init__.py | 3 +- backend/gradio_webrtc/pause_detection/vad.py | 14 ++-- backend/gradio_webrtc/reply_on_pause.py | 68 ++++++++++++------- backend/gradio_webrtc/utils.py | 2 +- backend/gradio_webrtc/webrtc.py | 6 +- 6 files changed, 60 insertions(+), 35 deletions(-) diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index d774b0e..1924616 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -1,4 +1,4 @@ -from .webrtc import StreamHandler, WebRTC from .reply_on_pause import ReplyOnPause +from .webrtc import StreamHandler, WebRTC __all__ = ["ReplyOnPause", "StreamHandler", "WebRTC"] diff --git a/backend/gradio_webrtc/pause_detection/__init__.py b/backend/gradio_webrtc/pause_detection/__init__.py index 79e57ce..e4874b7 100644 --- a/backend/gradio_webrtc/pause_detection/__init__.py +++ b/backend/gradio_webrtc/pause_detection/__init__.py @@ -1,4 +1,3 @@ from .vad import SileroVADModel, SileroVadOptions - -__all__ = ["SileroVADModel", "SileroVadOptions"] \ No newline at end of file +__all__ = ["SileroVADModel", "SileroVadOptions"] diff --git a/backend/gradio_webrtc/pause_detection/vad.py b/backend/gradio_webrtc/pause_detection/vad.py index 18e1d3d..1ff911a 100644 --- a/backend/gradio_webrtc/pause_detection/vad.py +++ b/backend/gradio_webrtc/pause_detection/vad.py @@ -1,14 +1,16 @@ import logging import warnings from dataclasses import dataclass -from huggingface_hub import hf_hub_download - from typing import List import numpy as np +from huggingface_hub import hf_hub_download logger = logging.getLogger(__name__) +# The code below is adapted from https://github.com/snakers4/silero-vad. +# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py + @dataclass class SileroVadOptions: @@ -235,9 +237,10 @@ class SileroVADModel: return speeches def vad( - self, audio_tuple: tuple[int, np.ndarray], vad_parameters: None | SileroVadOptions + self, + audio_tuple: tuple[int, np.ndarray], + vad_parameters: None | SileroVadOptions, ) -> float: - sampling_rate, audio = audio_tuple logger.debug("VAD audio shape input: %s", audio.shape) try: @@ -245,7 +248,7 @@ class SileroVADModel: sr = 16000 if sr != sampling_rate: try: - import librosa # type: ignore + import librosa # type: ignore except ImportError as e: raise RuntimeError( "Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz" @@ -264,6 +267,7 @@ class SileroVADModel: except Exception as e: import math import traceback + logger.debug("VAD Exception: %s", str(e)) exec = traceback.format_exc() logger.debug("traceback %s", exec) diff --git a/backend/gradio_webrtc/reply_on_pause.py b/backend/gradio_webrtc/reply_on_pause.py index 640ea66..6864358 100644 --- a/backend/gradio_webrtc/reply_on_pause.py +++ b/backend/gradio_webrtc/reply_on_pause.py @@ -1,8 +1,8 @@ -from typing import Callable, Literal, Generator, cast -from functools import lru_cache from dataclasses import dataclass -from threading import Event +from functools import lru_cache from logging import getLogger +from threading import Event +from typing import Callable, Generator, Literal, cast import numpy as np @@ -13,6 +13,7 @@ logger = getLogger(__name__) counter = 0 + @lru_cache def get_vad_model() -> SileroVADModel: """Returns the VAD model instance.""" @@ -22,6 +23,7 @@ def get_vad_model() -> SileroVADModel: @dataclass class AlgoOptions: """Algorithm options.""" + audio_chunk_duration: float = 0.6 started_talking_threshold: float = 0.2 speech_threshold: float = 0.1 @@ -38,17 +40,27 @@ class AppState: buffer: np.ndarray | None = None -ReplyFnGenerator = Callable[[tuple[int, np.ndarray]], Generator[tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]], None, None]] +ReplyFnGenerator = Callable[ + [tuple[int, np.ndarray]], + Generator[ + tuple[int, np.ndarray] | tuple[int, np.ndarray, Literal["mono", "stereo"]], + None, + None, + ], +] + class ReplyOnPause(StreamHandler): - def __init__(self, fn: ReplyFnGenerator, - algo_options: AlgoOptions | None = None, - model_options: SileroVadOptions | None = None, - expected_layout: Literal["mono", "stereo"] = "mono", - output_sample_rate: int = 24000, - output_frame_size: int = 960,): - super().__init__(expected_layout, - output_sample_rate, output_frame_size) + def __init__( + self, + fn: ReplyFnGenerator, + algo_options: AlgoOptions | None = None, + model_options: SileroVadOptions | None = None, + expected_layout: Literal["mono", "stereo"] = "mono", + output_sample_rate: int = 24000, + output_frame_size: int = 960, + ): + super().__init__(expected_layout, output_sample_rate, output_frame_size) self.expected_layout: Literal["mono", "stereo"] = expected_layout self.output_sample_rate = output_sample_rate self.output_frame_size = output_frame_size @@ -59,19 +71,30 @@ class ReplyOnPause(StreamHandler): self.generator = None self.model_options = model_options self.algo_options = algo_options or AlgoOptions() - + def copy(self): - return ReplyOnPause(self.fn, self.algo_options, self.model_options, - self.expected_layout, self.output_sample_rate, self.output_frame_size) - - def determine_pause(self, audio: np.ndarray, sampling_rate: int, state: AppState) -> bool: + return ReplyOnPause( + self.fn, + self.algo_options, + self.model_options, + self.expected_layout, + self.output_sample_rate, + self.output_frame_size, + ) + + def determine_pause( + self, audio: np.ndarray, sampling_rate: int, state: AppState + ) -> bool: """Take in the stream, determine if a pause happened""" duration = len(audio) / sampling_rate if duration >= self.algo_options.audio_chunk_duration: dur_vad = self.model.vad((sampling_rate, audio), self.model_options) logger.debug("VAD duration: %s", dur_vad) - if dur_vad > self.algo_options.started_talking_threshold and not state.started_talking: + if ( + dur_vad > self.algo_options.started_talking_threshold + and not state.started_talking + ): state.started_talking = True logger.debug("Started talking") if state.started_talking: @@ -84,7 +107,6 @@ class ReplyOnPause(StreamHandler): return True return False - def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None: frame_rate, array = audio array = np.squeeze(array) @@ -95,9 +117,10 @@ class ReplyOnPause(StreamHandler): else: state.buffer = np.concatenate((state.buffer, array)) - pause_detected = self.determine_pause(state.buffer, state.sampling_rate, self.state) + pause_detected = self.determine_pause( + state.buffer, state.sampling_rate, self.state + ) state.pause_detected = pause_detected - def receive(self, frame: tuple[int, np.ndarray]) -> None: if self.state.responding: @@ -123,6 +146,3 @@ class ReplyOnPause(StreamHandler): return next(self.generator) except StopIteration: self.reset() - - - diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index 8762c1b..c380fbe 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -55,7 +55,7 @@ async def player_worker_decode( # Convert to audio frame and resample # This runs in the same timeout context - frame = av.AudioFrame.from_ndarray( # type: ignore + frame = av.AudioFrame.from_ndarray( # type: ignore audio_array, format=format, layout=layout ) frame.sample_rate = sample_rate diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index e210f55..3f82197 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -10,8 +10,8 @@ import time import traceback from abc import ABC, abstractmethod from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast from copy import deepcopy +from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast import anyio.to_thread import av @@ -122,7 +122,9 @@ class StreamHandler(ABC): try: return deepcopy(self) except Exception: - raise ValueError("Current StreamHandler implementation cannot be deepcopied. Implement the copy method.") + raise ValueError( + "Current StreamHandler implementation cannot be deepcopied. Implement the copy method." + ) def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]: if self._resampler is None: From 792e0a3663c0bc954e26c9036a0a0facd14bf002 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 25 Oct 2024 16:30:34 -0700 Subject: [PATCH 3/6] add code --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0b331aa..ddfc887 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" [project] name = "gradio_webrtc" -version = "0.0.8" +version = "0.0.9a1" description = "Stream images in realtime with webrtc" readme = "README.md" license = "apache-2.0" @@ -43,6 +43,7 @@ classifiers = [ [project.optional-dependencies] dev = ["build", "twine"] +vad = ["librosa", "onnxruntime"] [tool.hatch.build] artifacts = ["/backend/gradio_webrtc/templates", "*.pyi"] From e0b11c1cb0b39edaeb6ee3b9e0ab9d7f57eff99d Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 25 Oct 2024 17:20:31 -0700 Subject: [PATCH 4/6] fix frame_size --- backend/gradio_webrtc/reply_on_pause.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/gradio_webrtc/reply_on_pause.py b/backend/gradio_webrtc/reply_on_pause.py index 6864358..e429ecc 100644 --- a/backend/gradio_webrtc/reply_on_pause.py +++ b/backend/gradio_webrtc/reply_on_pause.py @@ -58,7 +58,7 @@ class ReplyOnPause(StreamHandler): model_options: SileroVadOptions | None = None, expected_layout: Literal["mono", "stereo"] = "mono", output_sample_rate: int = 24000, - output_frame_size: int = 960, + output_frame_size: int = 480, ): super().__init__(expected_layout, output_sample_rate, output_frame_size) self.expected_layout: Literal["mono", "stereo"] = expected_layout From 792f4c9af2f4db9a175f896ff59c3c9c135d174a Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 25 Oct 2024 17:21:10 -0700 Subject: [PATCH 5/6] add pyproject --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ddfc887..5e12922 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" [project] name = "gradio_webrtc" -version = "0.0.9a1" +version = "0.0.9" description = "Stream images in realtime with webrtc" readme = "README.md" license = "apache-2.0" From a1c289973b65e88ed0e2f11a5e6138e3ffb358d4 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 25 Oct 2024 17:37:25 -0700 Subject: [PATCH 6/6] README --- README.md | 45 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d091fd8..078db92 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,12 @@ Stream video and audio in real time with Gradio using WebRTC. pip install gradio_webrtc ``` +to use built-in pause detection (see [conversational ai](#conversational-ai)), install the `vad` extra: + +```bash +pip install gradio_webrtc[vad] +``` + ## Examples: 1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷 2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥 @@ -176,7 +182,44 @@ if __name__ == "__main__": * An audio frame is represented as a tuple of (frame_rate, audio_samples) where `audio_samples` is a numpy array of shape (num_channels, num_samples). * You can also specify the audio layout ("mono" or "stereo") in the emit method by retuning it as the third element of the tuple. If not specified, the default is "mono". * The `time_limit` parameter is the maximum time in seconds the conversation will run. If the time limit is reached, the audio stream will stop. -* The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return None. +* The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`. + +An easy way to get started with Conversational AI is to use the `ReplyOnPause` stream handler. This will automatically run your function when the speaker has stopped speaking. In order to use `ReplyOnPause`, the `[vad]` extra dependencies must be installed. + +```python +import gradio as gr +from gradio_webrtc import WebRTC, ReplyOnPause + +def response(audio: tuple[int, np.ndarray]): + """This function must yield audio frames""" + ... + for numpy_array in generated_audio: + yield (sampling_rate, numpy_array, "mono") + + +with gr.Blocks() as demo: + gr.HTML( + """ +

+ Chat (Powered by WebRTC ⚡️) +

+ """ + ) + with gr.Column(): + with gr.Group(): + audio = WebRTC( + label="Stream", + rtc_configuration=rtc_configuration, + mode="send-receive", + modality="audio", + ) + audio.stream(fn=ReplyOnPause(response), inputs=[audio], outputs=[audio], time_limit=60) + + +demo.launch(ssr_mode=False) +``` + + ## Deployment