Add ReplyOnStopWords (#35)

* add code

* fix dependencies

* add code:
This commit is contained in:
Freddy Boulton
2024-12-11 18:25:53 -08:00
committed by GitHub
parent b1e4326ae3
commit 6c983482b8
14 changed files with 368 additions and 18 deletions

View File

@@ -15,12 +15,18 @@ Stream video and audio in real time with Gradio using WebRTC.
pip install gradio_webrtc pip install gradio_webrtc
``` ```
to use built-in pause detection (see [conversational ai](#conversational-ai)), install the `vad` extra: to use built-in pause detection (see [ReplyOnPause](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-pause)), install the `vad` extra:
```bash ```bash
pip install gradio_webrtc[vad] pip install gradio_webrtc[vad]
``` ```
For stop word detection (see [ReplyOnStopWords](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-stopwords)), install the `stopword` extra:
```bash
pip install gradio_webrtc[stopword]
```
```
## Examples: ## Examples:
1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷 1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷
2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥 2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥

View File

@@ -4,6 +4,8 @@ from .credentials import (
get_twilio_turn_credentials, get_twilio_turn_credentials,
) )
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions
from .reply_on_stopwords import ReplyOnStopWords
from .speech_to_text import stt, stt_for_chunks
from .utils import AdditionalOutputs, audio_to_bytes, audio_to_file, audio_to_float32 from .utils import AdditionalOutputs, audio_to_bytes, audio_to_file, audio_to_float32
from .webrtc import StreamHandler, WebRTC from .webrtc import StreamHandler, WebRTC
@@ -17,7 +19,10 @@ __all__ = [
"get_twilio_turn_credentials", "get_twilio_turn_credentials",
"get_turn_credentials", "get_turn_credentials",
"ReplyOnPause", "ReplyOnPause",
"ReplyOnStopWords",
"SileroVadOptions", "SileroVadOptions",
"stt",
"stt_for_chunks",
"StreamHandler", "StreamHandler",
"WebRTC", "WebRTC",
] ]

View File

@@ -1,10 +1,13 @@
import logging import logging
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List, Literal, overload
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from numpy.typing import NDArray
from ..utils import AudioChunk
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -76,7 +79,7 @@ class SileroVADModel:
return h, c return h, c
@staticmethod @staticmethod
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray:
"""Collects and concatenates audio chunks.""" """Collects and concatenates audio chunks."""
if not chunks: if not chunks:
return np.array([], dtype=np.float32) return np.array([], dtype=np.float32)
@@ -90,7 +93,7 @@ class SileroVADModel:
audio: np.ndarray, audio: np.ndarray,
vad_options: SileroVadOptions, vad_options: SileroVadOptions,
**kwargs, **kwargs,
) -> List[dict]: ) -> List[AudioChunk]:
"""This method is used for splitting long audios into speech chunks using silero VAD. """This method is used for splitting long audios into speech chunks using silero VAD.
Args: Args:
@@ -236,15 +239,33 @@ class SileroVADModel:
return speeches return speeches
@overload
def vad( def vad(
self, self,
audio_tuple: tuple[int, np.ndarray], audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions, vad_parameters: None | SileroVadOptions,
) -> float: return_chunks: Literal[True],
) -> tuple[float, List[AudioChunk]]: ...
@overload
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float: ...
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float | tuple[float, List[AudioChunk]]:
sampling_rate, audio = audio_tuple sampling_rate, audio = audio_tuple
logger.debug("VAD audio shape input: %s", audio.shape) logger.debug("VAD audio shape input: %s", audio.shape)
try: try:
audio = audio.astype(np.float32) / 32768.0 if audio.dtype != np.float32:
audio = audio.astype(np.float32) / 32768.0
sr = 16000 sr = 16000
if sr != sampling_rate: if sr != sampling_rate:
try: try:
@@ -262,7 +283,8 @@ class SileroVADModel:
audio = self.collect_chunks(audio, speech_chunks) audio = self.collect_chunks(audio, speech_chunks)
logger.debug("VAD audio shape: %s", audio.shape) logger.debug("VAD audio shape: %s", audio.shape)
duration_after_vad = audio.shape[0] / sr duration_after_vad = audio.shape[0] / sr
if return_chunks:
return duration_after_vad, speech_chunks
return duration_after_vad return duration_after_vad
except Exception as e: except Exception as e:
import math import math
@@ -280,7 +302,7 @@ class SileroVADModel:
raise ValueError( raise ValueError(
f"Too many dimensions for input audio chunk {len(x.shape)}" f"Too many dimensions for input audio chunk {len(x.shape)}"
) )
if sr / x.shape[1] > 31.25: if sr / x.shape[1] > 31.25: # type: ignore
raise ValueError("Input audio chunk is too short") raise ValueError("Input audio chunk is too short")
h, c = state h, c = state

View File

@@ -117,6 +117,7 @@ class ReplyOnPause(StreamHandler):
self.expected_layout, self.expected_layout,
self.output_sample_rate, self.output_sample_rate,
self.output_frame_size, self.output_frame_size,
self.input_sample_rate,
) )
def determine_pause( def determine_pause(

View File

@@ -0,0 +1,146 @@
import asyncio
import logging
import re
from typing import Literal
import librosa
import numpy as np
from .reply_on_pause import (
AlgoOptions,
AppState,
ReplyFnGenerator,
ReplyOnPause,
SileroVadOptions,
)
from .speech_to_text import get_stt_model, stt_for_chunks
from .utils import audio_to_float32
logger = logging.getLogger(__name__)
class ReplyOnStopWordsState(AppState):
stop_word_detected: bool = False
post_stop_word_buffer: np.ndarray | None = None
started_talking_pre_stop_word: bool = False
class ReplyOnStopWords(ReplyOnPause):
def __init__(
self,
fn: ReplyFnGenerator,
stop_words: list[str],
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
input_sample_rate: int = 48000,
):
super().__init__(
fn,
algo_options=algo_options,
model_options=model_options,
expected_layout=expected_layout,
output_sample_rate=output_sample_rate,
output_frame_size=output_frame_size,
input_sample_rate=input_sample_rate,
)
self.stop_words = stop_words
self.state = ReplyOnStopWordsState()
# Download Model
get_stt_model()
def stop_word_detected(self, text: str) -> bool:
for stop_word in self.stop_words:
stop_word = stop_word.lower().strip().split(" ")
if bool(
re.search(r"\b" + r"\s+".join(map(re.escape, stop_word)) + r"\b", text)
):
logger.debug("Stop word detected: %s", stop_word)
return True
return False
async def _send_stopword(
self,
):
if self.channel:
self.channel.send("stopword")
logger.debug("Sent stopword")
def send_stopword(self):
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
def determine_pause(
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
) -> bool:
"""Take in the stream, determine if a pause happened"""
duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration:
if not state.stop_word_detected:
audio_f32 = audio_to_float32((sampling_rate, audio))
audio_rs = librosa.resample(
audio_f32, orig_sr=sampling_rate, target_sr=16000
)
if state.post_stop_word_buffer is None:
state.post_stop_word_buffer = audio_rs
else:
state.post_stop_word_buffer = np.concatenate(
(state.post_stop_word_buffer, audio_rs)
)
if len(state.post_stop_word_buffer) / 16000 > 2:
state.post_stop_word_buffer = state.post_stop_word_buffer[-32000:]
dur_vad, chunks = self.model.vad(
(16000, state.post_stop_word_buffer),
self.model_options,
return_chunks=True,
)
text = stt_for_chunks((16000, state.post_stop_word_buffer), chunks)
logger.debug(f"STT: {text}")
state.stop_word_detected = self.stop_word_detected(text)
if state.stop_word_detected:
logger.debug("Stop word detected")
self.send_stopword()
state.buffer = None
else:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad)
if (
dur_vad > self.algo_options.started_talking_threshold
and not state.started_talking
and state.stop_word_detected
):
state.started_talking = True
logger.debug("Started talking")
if state.started_talking:
if state.stream is None:
state.stream = audio
else:
state.stream = np.concatenate((state.stream, audio))
state.buffer = None
if (
dur_vad < self.algo_options.speech_threshold
and state.started_talking
and state.stop_word_detected
):
return True
return False
def reset(self):
self.args_set.clear()
self.generator = None
self.event.clear()
self.state = ReplyOnStopWordsState()
def copy(self):
return ReplyOnStopWords(
self.fn,
self.stop_words,
self.algo_options,
self.model_options,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
self.input_sample_rate,
)

View File

@@ -0,0 +1,3 @@
from .stt_ import get_stt_model, stt, stt_for_chunks
__all__ = ["stt", "stt_for_chunks", "get_stt_model"]

View File

@@ -0,0 +1,52 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Callable
import numpy as np
from numpy.typing import NDArray
from silero import silero_stt
from ..utils import AudioChunk
@dataclass
class STTModel:
encoder: Callable
decoder: Callable
@lru_cache
def get_stt_model() -> STTModel:
model, decoder, _ = silero_stt(language="en", version="v6", jit_model="jit_xlarge")
return STTModel(model, decoder)
def stt(audio: tuple[int, NDArray[np.int16]]) -> str:
model = get_stt_model()
sr, audio_np = audio
if audio_np.dtype != np.float32:
print("converting")
audio_np = audio_np.astype(np.float32) / 32768.0
try:
import torch
except ImportError:
raise ImportError(
"PyTorch is required to run speech-to-text for stopword detection. Run `pip install torch`."
)
audio_torch = torch.tensor(audio_np, dtype=torch.float32)
if audio_torch.ndim == 1:
audio_torch = audio_torch.unsqueeze(0)
assert audio_torch.ndim == 2, "Audio must have a batch dimension"
print("before")
res = model.decoder(model.encoder(audio_torch)[0])
print("after")
return res
def stt_for_chunks(
audio: tuple[int, NDArray[np.int16]], chunks: list[AudioChunk]
) -> str:
sr, audio_np = audio
return " ".join(
[stt((sr, audio_np[chunk["start"] : chunk["end"]])) for chunk in chunks]
)

View File

@@ -3,7 +3,7 @@ import fractions
import io import io
import logging import logging
import tempfile import tempfile
from typing import Any, Callable, Protocol, cast from typing import Any, Callable, Protocol, TypedDict, cast
import av import av
import numpy as np import numpy as np
@@ -15,6 +15,11 @@ logger = logging.getLogger(__name__)
AUDIO_PTIME = 0.020 AUDIO_PTIME = 0.020
class AudioChunk(TypedDict):
start: int
end: int
class AdditionalOutputs: class AdditionalOutputs:
def __init__(self, *args) -> None: def __init__(self, *args) -> None:
self.args = args self.args = args

View File

@@ -36,6 +36,19 @@
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-moshi/blob/main/app.py) [:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-moshi/blob/main/app.py)
- :speaking_head:{ .lg .middle } __Hello Llama: Stop Word Detection__
---
A code editor built with Llama 3.3 70b that is triggered by the phrase "Hello Llama".
Build a Siri-like coding assistant in 100 lines of code!
<video width=98% src="https://github.com/user-attachments/assets/3e10cb15-ff1b-4b17-b141-ff0ad852e613" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](hhttps://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor/blob/main/app.py)
- :robot:{ .lg .middle } __Llama Code Editor__ - :robot:{ .lg .middle } __Llama Code Editor__
--- ---

View File

@@ -15,11 +15,16 @@ Stream video and audio in real time with Gradio using WebRTC.
pip install gradio_webrtc pip install gradio_webrtc
``` ```
to use built-in pause detection (see [Audio Streaming](https://freddyaboulton.github.io/gradio-webrtc/user-guide/#reply-on-pause)), install the `vad` extra: to use built-in pause detection (see [ReplyOnPause](/user-guide/#reply-on-pause)), install the `vad` extra:
```bash ```bash
pip install gradio_webrtc[vad] pip install gradio_webrtc[vad]
``` ```
For stop word detection (see [ReplyOnStopWords](/user-guide/#reply-on-stopwords)), install the `stopword` extra:
```bash
pip install gradio_webrtc[stopword]
```
## Examples ## Examples
See the [cookbook](/cookbook) See the [cookbook](/cookbook)

View File

@@ -65,6 +65,54 @@ and passing it to the `stream` event of the `WebRTC` component.
5. Set a `time_limit` to control how long a conversation will last. If the `concurrency_count` is 1 (default), only one conversation will be handled at a time. 5. Set a `time_limit` to control how long a conversation will last. If the `concurrency_count` is 1 (default), only one conversation will be handled at a time.
### Reply On Stopwords
You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.
The API is similar to `ReplyOnPause` with the addition of a `stop_words` parameter.
=== "Code"
``` py title="ReplyonPause"
import gradio as gr
from gradio_webrtc import WebRTC, ReplyOnPause
def response(audio: tuple[int, np.ndarray]):
"""This function must yield audio frames"""
...
for numpy_array in generated_audio:
yield (sampling_rate, numpy_array, "mono")
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Chat (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Column():
with gr.Group():
audio = WebRTC(
mode="send",
modality="audio",
)
webrtc.stream(ReplyOnStopWords(generate,
input_sample_rate=16000,
stop_words=["computer"]), # (1)
inputs=[webrtc, history, code],
outputs=[webrtc], time_limit=90,
concurrency_limit=10)
demo.launch()
```
1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris".
=== "Notes"
1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris".
### Stream Handler ### Stream Handler
`ReplyOnPause` is an implementation of a `StreamHandler`. The `StreamHandler` is a low-level `ReplyOnPause` is an implementation of a `StreamHandler`. The `StreamHandler` is a low-level

View File

@@ -21,8 +21,6 @@
import AudioWave from "./AudioWave.svelte"; import AudioWave from "./AudioWave.svelte";
import WebcamPermissions from "./WebcamPermissions.svelte"; import WebcamPermissions from "./WebcamPermissions.svelte";
export let mode: "send-receive" | "send"; export let mode: "send-receive" | "send";
export let value: string | null = null; export let value: string | null = null;
export let label: string | undefined = undefined; export let label: string | undefined = undefined;
@@ -34,6 +32,28 @@
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters; export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
export let on_change_cb: (mg: "tick" | "change") => void; export let on_change_cb: (mg: "tick" | "change") => void;
let stopword_recognized = false;
let notification_sound;
onMount(() => {
if (value === "__webrtc_value__") {
notification_sound = new Audio("https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/pop-sounds.mp3");
}
});
let _on_change_cb = (msg: "change" | "tick" | "stopword") => {
if (msg === "stopword") {
console.log("stopword recognized");
stopword_recognized = true;
setTimeout(() => {
stopword_recognized = false;
}, 3000);
} else {
on_change_cb(msg);
}
};
let options_open = false; let options_open = false;
let _time_limit: number | null = null; let _time_limit: number | null = null;
@@ -144,7 +164,7 @@
} }
if (stream == null) return; if (stream == null) return;
start(stream, pc, mode === "send" ? null: audio_player, server.offer, _webrtc_id, "audio", on_change_cb, rtp_params).then((connection) => { start(stream, pc, mode === "send" ? null: audio_player, server.offer, _webrtc_id, "audio", _on_change_cb, rtp_params).then((connection) => {
pc = connection; pc = connection;
}).catch(() => { }).catch(() => {
console.info("catching") console.info("catching")
@@ -190,8 +210,10 @@
options_open = false; options_open = false;
}; };
$: if(stopword_recognized){
notification_sound.play();
}
</script> </script>
<BlockLabel <BlockLabel
@@ -220,7 +242,7 @@
{:else} {:else}
<AudioWave {audio_source_callback} {stream_state}/> <AudioWave {audio_source_callback} {stream_state}/>
<StreamingBar time_limit={_time_limit} /> <StreamingBar time_limit={_time_limit} />
<div class="button-wrap"> <div class="button-wrap" class:pulse={stopword_recognized}>
<button <button
on:click={start_stream} on:click={start_stream}
aria-label={"start stream"} aria-label={"start stream"}
@@ -328,6 +350,27 @@
color: var(--button-secondary-text-color); color: var(--button-secondary-text-color);
} }
@keyframes pulse {
0% {
transform: scale(1);
box-shadow: 0 0 0 0 rgba(var(--primary-500-rgb), 0.7);
}
70% {
transform: scale(1.25);
box-shadow: 0 0 0 10px rgba(var(--primary-500-rgb), 0);
}
100% {
transform: scale(1);
box-shadow: 0 0 0 0 rgba(var(--primary-500-rgb), 0);
}
}
.pulse {
animation: pulse 1s infinite;
}
.icon-with-text { .icon-with-text {
width: var(--size-20); width: var(--size-20);
align-items: center; align-items: center;

View File

@@ -64,7 +64,7 @@ export async function start(
data_channel.onmessage = (event) => { data_channel.onmessage = (event) => {
console.debug("Received message:", event.data); console.debug("Received message:", event.data);
if (event.data === "change" || event.data === "tick") { if (event.data === "change" || event.data === "tick" || event.data === "stopword") {
console.debug(`${event.data} event received`); console.debug(`${event.data} event received`);
on_change_cb(event.data); on_change_cb(event.data);
} }

View File

@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "gradio_webrtc" name = "gradio_webrtc"
version = "0.0.20" version = "0.0.21"
description = "Stream images in realtime with webrtc" description = "Stream images in realtime with webrtc"
readme = "README.md" readme = "README.md"
license = "apache-2.0" license = "apache-2.0"
@@ -44,6 +44,7 @@ classifiers = [
[project.optional-dependencies] [project.optional-dependencies]
dev = ["build", "twine"] dev = ["build", "twine"]
vad = ["librosa", "onnxruntime"] vad = ["librosa", "onnxruntime"]
stopword = ["silero", "librosa", "onnxruntime"]
[tool.hatch.build] [tool.hatch.build]
artifacts = ["/backend/gradio_webrtc/templates", "*.pyi"] artifacts = ["/backend/gradio_webrtc/templates", "*.pyi"]