diff --git a/README.md b/README.md index f13f121..9134cb5 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,18 @@ Stream video and audio in real time with Gradio using WebRTC. pip install gradio_webrtc ``` -to use built-in pause detection (see [conversational ai](#conversational-ai)), install the `vad` extra: +to use built-in pause detection (see [ReplyOnPause](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-pause)), install the `vad` extra: ```bash pip install gradio_webrtc[vad] ``` +For stop word detection (see [ReplyOnStopWords](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-stopwords)), install the `stopword` extra: +```bash +pip install gradio_webrtc[stopword] +``` +``` + ## Examples: 1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷 2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥 diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index 59be422..98783db 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -4,6 +4,8 @@ from .credentials import ( get_twilio_turn_credentials, ) from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions +from .reply_on_stopwords import ReplyOnStopWords +from .speech_to_text import stt, stt_for_chunks from .utils import AdditionalOutputs, audio_to_bytes, audio_to_file, audio_to_float32 from .webrtc import StreamHandler, WebRTC @@ -17,7 +19,10 @@ __all__ = [ "get_twilio_turn_credentials", "get_turn_credentials", "ReplyOnPause", + "ReplyOnStopWords", "SileroVadOptions", + "stt", + "stt_for_chunks", "StreamHandler", "WebRTC", ] diff --git a/backend/gradio_webrtc/pause_detection/vad.py b/backend/gradio_webrtc/pause_detection/vad.py index 1ff911a..bf4bb1e 100644 --- a/backend/gradio_webrtc/pause_detection/vad.py +++ b/backend/gradio_webrtc/pause_detection/vad.py @@ -1,10 +1,13 @@ import logging import warnings from dataclasses import dataclass -from typing import List +from typing import List, Literal, overload import numpy as np from huggingface_hub import hf_hub_download +from numpy.typing import NDArray + +from ..utils import AudioChunk logger = logging.getLogger(__name__) @@ -76,7 +79,7 @@ class SileroVADModel: return h, c @staticmethod - def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: + def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray: """Collects and concatenates audio chunks.""" if not chunks: return np.array([], dtype=np.float32) @@ -90,7 +93,7 @@ class SileroVADModel: audio: np.ndarray, vad_options: SileroVadOptions, **kwargs, - ) -> List[dict]: + ) -> List[AudioChunk]: """This method is used for splitting long audios into speech chunks using silero VAD. Args: @@ -236,15 +239,33 @@ class SileroVADModel: return speeches + @overload def vad( self, - audio_tuple: tuple[int, np.ndarray], + audio_tuple: tuple[int, NDArray], vad_parameters: None | SileroVadOptions, - ) -> float: + return_chunks: Literal[True], + ) -> tuple[float, List[AudioChunk]]: ... + + @overload + def vad( + self, + audio_tuple: tuple[int, NDArray], + vad_parameters: None | SileroVadOptions, + return_chunks: bool = False, + ) -> float: ... + + def vad( + self, + audio_tuple: tuple[int, NDArray], + vad_parameters: None | SileroVadOptions, + return_chunks: bool = False, + ) -> float | tuple[float, List[AudioChunk]]: sampling_rate, audio = audio_tuple logger.debug("VAD audio shape input: %s", audio.shape) try: - audio = audio.astype(np.float32) / 32768.0 + if audio.dtype != np.float32: + audio = audio.astype(np.float32) / 32768.0 sr = 16000 if sr != sampling_rate: try: @@ -262,7 +283,8 @@ class SileroVADModel: audio = self.collect_chunks(audio, speech_chunks) logger.debug("VAD audio shape: %s", audio.shape) duration_after_vad = audio.shape[0] / sr - + if return_chunks: + return duration_after_vad, speech_chunks return duration_after_vad except Exception as e: import math @@ -280,7 +302,7 @@ class SileroVADModel: raise ValueError( f"Too many dimensions for input audio chunk {len(x.shape)}" ) - if sr / x.shape[1] > 31.25: + if sr / x.shape[1] > 31.25: # type: ignore raise ValueError("Input audio chunk is too short") h, c = state diff --git a/backend/gradio_webrtc/reply_on_pause.py b/backend/gradio_webrtc/reply_on_pause.py index 372efb4..5cb425f 100644 --- a/backend/gradio_webrtc/reply_on_pause.py +++ b/backend/gradio_webrtc/reply_on_pause.py @@ -117,6 +117,7 @@ class ReplyOnPause(StreamHandler): self.expected_layout, self.output_sample_rate, self.output_frame_size, + self.input_sample_rate, ) def determine_pause( diff --git a/backend/gradio_webrtc/reply_on_stopwords.py b/backend/gradio_webrtc/reply_on_stopwords.py new file mode 100644 index 0000000..a391e17 --- /dev/null +++ b/backend/gradio_webrtc/reply_on_stopwords.py @@ -0,0 +1,146 @@ +import asyncio +import logging +import re +from typing import Literal + +import librosa +import numpy as np + +from .reply_on_pause import ( + AlgoOptions, + AppState, + ReplyFnGenerator, + ReplyOnPause, + SileroVadOptions, +) +from .speech_to_text import get_stt_model, stt_for_chunks +from .utils import audio_to_float32 + +logger = logging.getLogger(__name__) + + +class ReplyOnStopWordsState(AppState): + stop_word_detected: bool = False + post_stop_word_buffer: np.ndarray | None = None + started_talking_pre_stop_word: bool = False + + +class ReplyOnStopWords(ReplyOnPause): + def __init__( + self, + fn: ReplyFnGenerator, + stop_words: list[str], + algo_options: AlgoOptions | None = None, + model_options: SileroVadOptions | None = None, + expected_layout: Literal["mono", "stereo"] = "mono", + output_sample_rate: int = 24000, + output_frame_size: int = 480, + input_sample_rate: int = 48000, + ): + super().__init__( + fn, + algo_options=algo_options, + model_options=model_options, + expected_layout=expected_layout, + output_sample_rate=output_sample_rate, + output_frame_size=output_frame_size, + input_sample_rate=input_sample_rate, + ) + self.stop_words = stop_words + self.state = ReplyOnStopWordsState() + # Download Model + get_stt_model() + + def stop_word_detected(self, text: str) -> bool: + for stop_word in self.stop_words: + stop_word = stop_word.lower().strip().split(" ") + if bool( + re.search(r"\b" + r"\s+".join(map(re.escape, stop_word)) + r"\b", text) + ): + logger.debug("Stop word detected: %s", stop_word) + return True + return False + + async def _send_stopword( + self, + ): + if self.channel: + self.channel.send("stopword") + logger.debug("Sent stopword") + + def send_stopword(self): + asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop) + + def determine_pause( + self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState + ) -> bool: + """Take in the stream, determine if a pause happened""" + duration = len(audio) / sampling_rate + + if duration >= self.algo_options.audio_chunk_duration: + if not state.stop_word_detected: + audio_f32 = audio_to_float32((sampling_rate, audio)) + audio_rs = librosa.resample( + audio_f32, orig_sr=sampling_rate, target_sr=16000 + ) + if state.post_stop_word_buffer is None: + state.post_stop_word_buffer = audio_rs + else: + state.post_stop_word_buffer = np.concatenate( + (state.post_stop_word_buffer, audio_rs) + ) + if len(state.post_stop_word_buffer) / 16000 > 2: + state.post_stop_word_buffer = state.post_stop_word_buffer[-32000:] + dur_vad, chunks = self.model.vad( + (16000, state.post_stop_word_buffer), + self.model_options, + return_chunks=True, + ) + text = stt_for_chunks((16000, state.post_stop_word_buffer), chunks) + logger.debug(f"STT: {text}") + state.stop_word_detected = self.stop_word_detected(text) + if state.stop_word_detected: + logger.debug("Stop word detected") + self.send_stopword() + state.buffer = None + else: + dur_vad = self.model.vad((sampling_rate, audio), self.model_options) + logger.debug("VAD duration: %s", dur_vad) + if ( + dur_vad > self.algo_options.started_talking_threshold + and not state.started_talking + and state.stop_word_detected + ): + state.started_talking = True + logger.debug("Started talking") + if state.started_talking: + if state.stream is None: + state.stream = audio + else: + state.stream = np.concatenate((state.stream, audio)) + state.buffer = None + if ( + dur_vad < self.algo_options.speech_threshold + and state.started_talking + and state.stop_word_detected + ): + return True + return False + + def reset(self): + self.args_set.clear() + self.generator = None + self.event.clear() + self.state = ReplyOnStopWordsState() + + def copy(self): + return ReplyOnStopWords( + self.fn, + self.stop_words, + self.algo_options, + self.model_options, + self.expected_layout, + self.output_sample_rate, + self.output_frame_size, + self.input_sample_rate, + ) diff --git a/backend/gradio_webrtc/speech_to_text/__init__.py b/backend/gradio_webrtc/speech_to_text/__init__.py new file mode 100644 index 0000000..8569c11 --- /dev/null +++ b/backend/gradio_webrtc/speech_to_text/__init__.py @@ -0,0 +1,3 @@ +from .stt_ import get_stt_model, stt, stt_for_chunks + +__all__ = ["stt", "stt_for_chunks", "get_stt_model"] diff --git a/backend/gradio_webrtc/speech_to_text/stt_.py b/backend/gradio_webrtc/speech_to_text/stt_.py new file mode 100644 index 0000000..9987e2b --- /dev/null +++ b/backend/gradio_webrtc/speech_to_text/stt_.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from functools import lru_cache +from typing import Callable + +import numpy as np +from numpy.typing import NDArray +from silero import silero_stt + +from ..utils import AudioChunk + + +@dataclass +class STTModel: + encoder: Callable + decoder: Callable + + +@lru_cache +def get_stt_model() -> STTModel: + model, decoder, _ = silero_stt(language="en", version="v6", jit_model="jit_xlarge") + return STTModel(model, decoder) + + +def stt(audio: tuple[int, NDArray[np.int16]]) -> str: + model = get_stt_model() + sr, audio_np = audio + if audio_np.dtype != np.float32: + print("converting") + audio_np = audio_np.astype(np.float32) / 32768.0 + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is required to run speech-to-text for stopword detection. Run `pip install torch`." + ) + audio_torch = torch.tensor(audio_np, dtype=torch.float32) + if audio_torch.ndim == 1: + audio_torch = audio_torch.unsqueeze(0) + assert audio_torch.ndim == 2, "Audio must have a batch dimension" + print("before") + res = model.decoder(model.encoder(audio_torch)[0]) + print("after") + return res + + +def stt_for_chunks( + audio: tuple[int, NDArray[np.int16]], chunks: list[AudioChunk] +) -> str: + sr, audio_np = audio + return " ".join( + [stt((sr, audio_np[chunk["start"] : chunk["end"]])) for chunk in chunks] + ) diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index 6389281..bb93032 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -3,7 +3,7 @@ import fractions import io import logging import tempfile -from typing import Any, Callable, Protocol, cast +from typing import Any, Callable, Protocol, TypedDict, cast import av import numpy as np @@ -15,6 +15,11 @@ logger = logging.getLogger(__name__) AUDIO_PTIME = 0.020 +class AudioChunk(TypedDict): + start: int + end: int + + class AdditionalOutputs: def __init__(self, *args) -> None: self.args = args diff --git a/docs/cookbook.md b/docs/cookbook.md index 4a5d8b1..713d51f 100644 --- a/docs/cookbook.md +++ b/docs/cookbook.md @@ -36,6 +36,19 @@ [:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-moshi/blob/main/app.py) +- :speaking_head:{ .lg .middle } __Hello Llama: Stop Word Detection__ + + --- + + A code editor built with Llama 3.3 70b that is triggered by the phrase "Hello Llama". + Build a Siri-like coding assistant in 100 lines of code! + + + + [:octicons-arrow-right-24: Demo](hhttps://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor) + + [:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor/blob/main/app.py) + - :robot:{ .lg .middle } __Llama Code Editor__ --- diff --git a/docs/index.md b/docs/index.md index 9654531..92923e7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -15,11 +15,16 @@ Stream video and audio in real time with Gradio using WebRTC. pip install gradio_webrtc ``` -to use built-in pause detection (see [Audio Streaming](https://freddyaboulton.github.io/gradio-webrtc/user-guide/#reply-on-pause)), install the `vad` extra: +to use built-in pause detection (see [ReplyOnPause](/user-guide/#reply-on-pause)), install the `vad` extra: ```bash pip install gradio_webrtc[vad] ``` +For stop word detection (see [ReplyOnStopWords](/user-guide/#reply-on-stopwords)), install the `stopword` extra: +```bash +pip install gradio_webrtc[stopword] +``` + ## Examples See the [cookbook](/cookbook) \ No newline at end of file diff --git a/docs/user-guide.md b/docs/user-guide.md index a07169b..04ff7fa 100644 --- a/docs/user-guide.md +++ b/docs/user-guide.md @@ -65,6 +65,54 @@ and passing it to the `stream` event of the `WebRTC` component. 5. Set a `time_limit` to control how long a conversation will last. If the `concurrency_count` is 1 (default), only one conversation will be handled at a time. + +### Reply On Stopwords + +You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class. + +The API is similar to `ReplyOnPause` with the addition of a `stop_words` parameter. + +=== "Code" + ``` py title="ReplyonPause" + import gradio as gr + from gradio_webrtc import WebRTC, ReplyOnPause + + def response(audio: tuple[int, np.ndarray]): + """This function must yield audio frames""" + ... + for numpy_array in generated_audio: + yield (sampling_rate, numpy_array, "mono") + + + with gr.Blocks() as demo: + gr.HTML( + """ +

+ Chat (Powered by WebRTC ⚡️) +

+ """ + ) + with gr.Column(): + with gr.Group(): + audio = WebRTC( + mode="send", + modality="audio", + ) + webrtc.stream(ReplyOnStopWords(generate, + input_sample_rate=16000, + stop_words=["computer"]), # (1) + inputs=[webrtc, history, code], + outputs=[webrtc], time_limit=90, + concurrency_limit=10) + + demo.launch() + ``` + + 1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris". + +=== "Notes" + 1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris". + ### Stream Handler `ReplyOnPause` is an implementation of a `StreamHandler`. The `StreamHandler` is a low-level diff --git a/frontend/shared/InteractiveAudio.svelte b/frontend/shared/InteractiveAudio.svelte index 85762e5..15e7d8e 100644 --- a/frontend/shared/InteractiveAudio.svelte +++ b/frontend/shared/InteractiveAudio.svelte @@ -21,8 +21,6 @@ import AudioWave from "./AudioWave.svelte"; import WebcamPermissions from "./WebcamPermissions.svelte"; - - export let mode: "send-receive" | "send"; export let value: string | null = null; export let label: string | undefined = undefined; @@ -34,6 +32,28 @@ export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters; export let on_change_cb: (mg: "tick" | "change") => void; + let stopword_recognized = false; + + let notification_sound; + + onMount(() => { + if (value === "__webrtc_value__") { + notification_sound = new Audio("https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/pop-sounds.mp3"); + } + }); + + let _on_change_cb = (msg: "change" | "tick" | "stopword") => { + if (msg === "stopword") { + console.log("stopword recognized"); + stopword_recognized = true; + setTimeout(() => { + stopword_recognized = false; + }, 3000); + } else { + on_change_cb(msg); + } + }; + let options_open = false; let _time_limit: number | null = null; @@ -144,7 +164,7 @@ } if (stream == null) return; - start(stream, pc, mode === "send" ? null: audio_player, server.offer, _webrtc_id, "audio", on_change_cb, rtp_params).then((connection) => { + start(stream, pc, mode === "send" ? null: audio_player, server.offer, _webrtc_id, "audio", _on_change_cb, rtp_params).then((connection) => { pc = connection; }).catch(() => { console.info("catching") @@ -190,8 +210,10 @@ options_open = false; }; + $: if(stopword_recognized){ + notification_sound.play(); + } - -
+