diff --git a/README.md b/README.md index f13f121..9134cb5 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,18 @@ Stream video and audio in real time with Gradio using WebRTC. pip install gradio_webrtc ``` -to use built-in pause detection (see [conversational ai](#conversational-ai)), install the `vad` extra: +to use built-in pause detection (see [ReplyOnPause](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-pause)), install the `vad` extra: ```bash pip install gradio_webrtc[vad] ``` +For stop word detection (see [ReplyOnStopWords](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-stopwords)), install the `stopword` extra: +```bash +pip install gradio_webrtc[stopword] +``` +``` + ## Examples: 1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷 2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥 diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index 59be422..98783db 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -4,6 +4,8 @@ from .credentials import ( get_twilio_turn_credentials, ) from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions +from .reply_on_stopwords import ReplyOnStopWords +from .speech_to_text import stt, stt_for_chunks from .utils import AdditionalOutputs, audio_to_bytes, audio_to_file, audio_to_float32 from .webrtc import StreamHandler, WebRTC @@ -17,7 +19,10 @@ __all__ = [ "get_twilio_turn_credentials", "get_turn_credentials", "ReplyOnPause", + "ReplyOnStopWords", "SileroVadOptions", + "stt", + "stt_for_chunks", "StreamHandler", "WebRTC", ] diff --git a/backend/gradio_webrtc/pause_detection/vad.py b/backend/gradio_webrtc/pause_detection/vad.py index 1ff911a..bf4bb1e 100644 --- a/backend/gradio_webrtc/pause_detection/vad.py +++ b/backend/gradio_webrtc/pause_detection/vad.py @@ -1,10 +1,13 @@ import logging import warnings from dataclasses import dataclass -from typing import List +from typing import List, Literal, overload import numpy as np from huggingface_hub import hf_hub_download +from numpy.typing import NDArray + +from ..utils import AudioChunk logger = logging.getLogger(__name__) @@ -76,7 +79,7 @@ class SileroVADModel: return h, c @staticmethod - def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: + def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray: """Collects and concatenates audio chunks.""" if not chunks: return np.array([], dtype=np.float32) @@ -90,7 +93,7 @@ class SileroVADModel: audio: np.ndarray, vad_options: SileroVadOptions, **kwargs, - ) -> List[dict]: + ) -> List[AudioChunk]: """This method is used for splitting long audios into speech chunks using silero VAD. Args: @@ -236,15 +239,33 @@ class SileroVADModel: return speeches + @overload def vad( self, - audio_tuple: tuple[int, np.ndarray], + audio_tuple: tuple[int, NDArray], vad_parameters: None | SileroVadOptions, - ) -> float: + return_chunks: Literal[True], + ) -> tuple[float, List[AudioChunk]]: ... + + @overload + def vad( + self, + audio_tuple: tuple[int, NDArray], + vad_parameters: None | SileroVadOptions, + return_chunks: bool = False, + ) -> float: ... + + def vad( + self, + audio_tuple: tuple[int, NDArray], + vad_parameters: None | SileroVadOptions, + return_chunks: bool = False, + ) -> float | tuple[float, List[AudioChunk]]: sampling_rate, audio = audio_tuple logger.debug("VAD audio shape input: %s", audio.shape) try: - audio = audio.astype(np.float32) / 32768.0 + if audio.dtype != np.float32: + audio = audio.astype(np.float32) / 32768.0 sr = 16000 if sr != sampling_rate: try: @@ -262,7 +283,8 @@ class SileroVADModel: audio = self.collect_chunks(audio, speech_chunks) logger.debug("VAD audio shape: %s", audio.shape) duration_after_vad = audio.shape[0] / sr - + if return_chunks: + return duration_after_vad, speech_chunks return duration_after_vad except Exception as e: import math @@ -280,7 +302,7 @@ class SileroVADModel: raise ValueError( f"Too many dimensions for input audio chunk {len(x.shape)}" ) - if sr / x.shape[1] > 31.25: + if sr / x.shape[1] > 31.25: # type: ignore raise ValueError("Input audio chunk is too short") h, c = state diff --git a/backend/gradio_webrtc/reply_on_pause.py b/backend/gradio_webrtc/reply_on_pause.py index 372efb4..5cb425f 100644 --- a/backend/gradio_webrtc/reply_on_pause.py +++ b/backend/gradio_webrtc/reply_on_pause.py @@ -117,6 +117,7 @@ class ReplyOnPause(StreamHandler): self.expected_layout, self.output_sample_rate, self.output_frame_size, + self.input_sample_rate, ) def determine_pause( diff --git a/backend/gradio_webrtc/reply_on_stopwords.py b/backend/gradio_webrtc/reply_on_stopwords.py new file mode 100644 index 0000000..a391e17 --- /dev/null +++ b/backend/gradio_webrtc/reply_on_stopwords.py @@ -0,0 +1,146 @@ +import asyncio +import logging +import re +from typing import Literal + +import librosa +import numpy as np + +from .reply_on_pause import ( + AlgoOptions, + AppState, + ReplyFnGenerator, + ReplyOnPause, + SileroVadOptions, +) +from .speech_to_text import get_stt_model, stt_for_chunks +from .utils import audio_to_float32 + +logger = logging.getLogger(__name__) + + +class ReplyOnStopWordsState(AppState): + stop_word_detected: bool = False + post_stop_word_buffer: np.ndarray | None = None + started_talking_pre_stop_word: bool = False + + +class ReplyOnStopWords(ReplyOnPause): + def __init__( + self, + fn: ReplyFnGenerator, + stop_words: list[str], + algo_options: AlgoOptions | None = None, + model_options: SileroVadOptions | None = None, + expected_layout: Literal["mono", "stereo"] = "mono", + output_sample_rate: int = 24000, + output_frame_size: int = 480, + input_sample_rate: int = 48000, + ): + super().__init__( + fn, + algo_options=algo_options, + model_options=model_options, + expected_layout=expected_layout, + output_sample_rate=output_sample_rate, + output_frame_size=output_frame_size, + input_sample_rate=input_sample_rate, + ) + self.stop_words = stop_words + self.state = ReplyOnStopWordsState() + # Download Model + get_stt_model() + + def stop_word_detected(self, text: str) -> bool: + for stop_word in self.stop_words: + stop_word = stop_word.lower().strip().split(" ") + if bool( + re.search(r"\b" + r"\s+".join(map(re.escape, stop_word)) + r"\b", text) + ): + logger.debug("Stop word detected: %s", stop_word) + return True + return False + + async def _send_stopword( + self, + ): + if self.channel: + self.channel.send("stopword") + logger.debug("Sent stopword") + + def send_stopword(self): + asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop) + + def determine_pause( + self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState + ) -> bool: + """Take in the stream, determine if a pause happened""" + duration = len(audio) / sampling_rate + + if duration >= self.algo_options.audio_chunk_duration: + if not state.stop_word_detected: + audio_f32 = audio_to_float32((sampling_rate, audio)) + audio_rs = librosa.resample( + audio_f32, orig_sr=sampling_rate, target_sr=16000 + ) + if state.post_stop_word_buffer is None: + state.post_stop_word_buffer = audio_rs + else: + state.post_stop_word_buffer = np.concatenate( + (state.post_stop_word_buffer, audio_rs) + ) + if len(state.post_stop_word_buffer) / 16000 > 2: + state.post_stop_word_buffer = state.post_stop_word_buffer[-32000:] + dur_vad, chunks = self.model.vad( + (16000, state.post_stop_word_buffer), + self.model_options, + return_chunks=True, + ) + text = stt_for_chunks((16000, state.post_stop_word_buffer), chunks) + logger.debug(f"STT: {text}") + state.stop_word_detected = self.stop_word_detected(text) + if state.stop_word_detected: + logger.debug("Stop word detected") + self.send_stopword() + state.buffer = None + else: + dur_vad = self.model.vad((sampling_rate, audio), self.model_options) + logger.debug("VAD duration: %s", dur_vad) + if ( + dur_vad > self.algo_options.started_talking_threshold + and not state.started_talking + and state.stop_word_detected + ): + state.started_talking = True + logger.debug("Started talking") + if state.started_talking: + if state.stream is None: + state.stream = audio + else: + state.stream = np.concatenate((state.stream, audio)) + state.buffer = None + if ( + dur_vad < self.algo_options.speech_threshold + and state.started_talking + and state.stop_word_detected + ): + return True + return False + + def reset(self): + self.args_set.clear() + self.generator = None + self.event.clear() + self.state = ReplyOnStopWordsState() + + def copy(self): + return ReplyOnStopWords( + self.fn, + self.stop_words, + self.algo_options, + self.model_options, + self.expected_layout, + self.output_sample_rate, + self.output_frame_size, + self.input_sample_rate, + ) diff --git a/backend/gradio_webrtc/speech_to_text/__init__.py b/backend/gradio_webrtc/speech_to_text/__init__.py new file mode 100644 index 0000000..8569c11 --- /dev/null +++ b/backend/gradio_webrtc/speech_to_text/__init__.py @@ -0,0 +1,3 @@ +from .stt_ import get_stt_model, stt, stt_for_chunks + +__all__ = ["stt", "stt_for_chunks", "get_stt_model"] diff --git a/backend/gradio_webrtc/speech_to_text/stt_.py b/backend/gradio_webrtc/speech_to_text/stt_.py new file mode 100644 index 0000000..9987e2b --- /dev/null +++ b/backend/gradio_webrtc/speech_to_text/stt_.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from functools import lru_cache +from typing import Callable + +import numpy as np +from numpy.typing import NDArray +from silero import silero_stt + +from ..utils import AudioChunk + + +@dataclass +class STTModel: + encoder: Callable + decoder: Callable + + +@lru_cache +def get_stt_model() -> STTModel: + model, decoder, _ = silero_stt(language="en", version="v6", jit_model="jit_xlarge") + return STTModel(model, decoder) + + +def stt(audio: tuple[int, NDArray[np.int16]]) -> str: + model = get_stt_model() + sr, audio_np = audio + if audio_np.dtype != np.float32: + print("converting") + audio_np = audio_np.astype(np.float32) / 32768.0 + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is required to run speech-to-text for stopword detection. Run `pip install torch`." + ) + audio_torch = torch.tensor(audio_np, dtype=torch.float32) + if audio_torch.ndim == 1: + audio_torch = audio_torch.unsqueeze(0) + assert audio_torch.ndim == 2, "Audio must have a batch dimension" + print("before") + res = model.decoder(model.encoder(audio_torch)[0]) + print("after") + return res + + +def stt_for_chunks( + audio: tuple[int, NDArray[np.int16]], chunks: list[AudioChunk] +) -> str: + sr, audio_np = audio + return " ".join( + [stt((sr, audio_np[chunk["start"] : chunk["end"]])) for chunk in chunks] + ) diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index 6389281..bb93032 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -3,7 +3,7 @@ import fractions import io import logging import tempfile -from typing import Any, Callable, Protocol, cast +from typing import Any, Callable, Protocol, TypedDict, cast import av import numpy as np @@ -15,6 +15,11 @@ logger = logging.getLogger(__name__) AUDIO_PTIME = 0.020 +class AudioChunk(TypedDict): + start: int + end: int + + class AdditionalOutputs: def __init__(self, *args) -> None: self.args = args diff --git a/docs/cookbook.md b/docs/cookbook.md index 4a5d8b1..713d51f 100644 --- a/docs/cookbook.md +++ b/docs/cookbook.md @@ -36,6 +36,19 @@ [:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-moshi/blob/main/app.py) +- :speaking_head:{ .lg .middle } __Hello Llama: Stop Word Detection__ + + --- + + A code editor built with Llama 3.3 70b that is triggered by the phrase "Hello Llama". + Build a Siri-like coding assistant in 100 lines of code! + + + + [:octicons-arrow-right-24: Demo](hhttps://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor) + + [:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor/blob/main/app.py) + - :robot:{ .lg .middle } __Llama Code Editor__ --- diff --git a/docs/index.md b/docs/index.md index 9654531..92923e7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -15,11 +15,16 @@ Stream video and audio in real time with Gradio using WebRTC. pip install gradio_webrtc ``` -to use built-in pause detection (see [Audio Streaming](https://freddyaboulton.github.io/gradio-webrtc/user-guide/#reply-on-pause)), install the `vad` extra: +to use built-in pause detection (see [ReplyOnPause](/user-guide/#reply-on-pause)), install the `vad` extra: ```bash pip install gradio_webrtc[vad] ``` +For stop word detection (see [ReplyOnStopWords](/user-guide/#reply-on-stopwords)), install the `stopword` extra: +```bash +pip install gradio_webrtc[stopword] +``` + ## Examples See the [cookbook](/cookbook) \ No newline at end of file diff --git a/docs/user-guide.md b/docs/user-guide.md index a07169b..04ff7fa 100644 --- a/docs/user-guide.md +++ b/docs/user-guide.md @@ -65,6 +65,54 @@ and passing it to the `stream` event of the `WebRTC` component. 5. Set a `time_limit` to control how long a conversation will last. If the `concurrency_count` is 1 (default), only one conversation will be handled at a time. + +### Reply On Stopwords + +You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class. + +The API is similar to `ReplyOnPause` with the addition of a `stop_words` parameter. + +=== "Code" + ``` py title="ReplyonPause" + import gradio as gr + from gradio_webrtc import WebRTC, ReplyOnPause + + def response(audio: tuple[int, np.ndarray]): + """This function must yield audio frames""" + ... + for numpy_array in generated_audio: + yield (sampling_rate, numpy_array, "mono") + + + with gr.Blocks() as demo: + gr.HTML( + """ +