Add ReplyOnStopWords (#35)

* add code

* fix dependencies

* add code:
This commit is contained in:
Freddy Boulton
2024-12-11 18:25:53 -08:00
committed by GitHub
parent b1e4326ae3
commit 6c983482b8
14 changed files with 368 additions and 18 deletions

View File

@@ -15,12 +15,18 @@ Stream video and audio in real time with Gradio using WebRTC.
pip install gradio_webrtc
```
to use built-in pause detection (see [conversational ai](#conversational-ai)), install the `vad` extra:
to use built-in pause detection (see [ReplyOnPause](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-pause)), install the `vad` extra:
```bash
pip install gradio_webrtc[vad]
```
For stop word detection (see [ReplyOnStopWords](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-stopwords)), install the `stopword` extra:
```bash
pip install gradio_webrtc[stopword]
```
```
## Examples:
1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷
2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥

View File

@@ -4,6 +4,8 @@ from .credentials import (
get_twilio_turn_credentials,
)
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions
from .reply_on_stopwords import ReplyOnStopWords
from .speech_to_text import stt, stt_for_chunks
from .utils import AdditionalOutputs, audio_to_bytes, audio_to_file, audio_to_float32
from .webrtc import StreamHandler, WebRTC
@@ -17,7 +19,10 @@ __all__ = [
"get_twilio_turn_credentials",
"get_turn_credentials",
"ReplyOnPause",
"ReplyOnStopWords",
"SileroVadOptions",
"stt",
"stt_for_chunks",
"StreamHandler",
"WebRTC",
]

View File

@@ -1,10 +1,13 @@
import logging
import warnings
from dataclasses import dataclass
from typing import List
from typing import List, Literal, overload
import numpy as np
from huggingface_hub import hf_hub_download
from numpy.typing import NDArray
from ..utils import AudioChunk
logger = logging.getLogger(__name__)
@@ -76,7 +79,7 @@ class SileroVADModel:
return h, c
@staticmethod
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray:
"""Collects and concatenates audio chunks."""
if not chunks:
return np.array([], dtype=np.float32)
@@ -90,7 +93,7 @@ class SileroVADModel:
audio: np.ndarray,
vad_options: SileroVadOptions,
**kwargs,
) -> List[dict]:
) -> List[AudioChunk]:
"""This method is used for splitting long audios into speech chunks using silero VAD.
Args:
@@ -236,14 +239,32 @@ class SileroVADModel:
return speeches
@overload
def vad(
self,
audio_tuple: tuple[int, np.ndarray],
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
) -> float:
return_chunks: Literal[True],
) -> tuple[float, List[AudioChunk]]: ...
@overload
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float: ...
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float | tuple[float, List[AudioChunk]]:
sampling_rate, audio = audio_tuple
logger.debug("VAD audio shape input: %s", audio.shape)
try:
if audio.dtype != np.float32:
audio = audio.astype(np.float32) / 32768.0
sr = 16000
if sr != sampling_rate:
@@ -262,7 +283,8 @@ class SileroVADModel:
audio = self.collect_chunks(audio, speech_chunks)
logger.debug("VAD audio shape: %s", audio.shape)
duration_after_vad = audio.shape[0] / sr
if return_chunks:
return duration_after_vad, speech_chunks
return duration_after_vad
except Exception as e:
import math
@@ -280,7 +302,7 @@ class SileroVADModel:
raise ValueError(
f"Too many dimensions for input audio chunk {len(x.shape)}"
)
if sr / x.shape[1] > 31.25:
if sr / x.shape[1] > 31.25: # type: ignore
raise ValueError("Input audio chunk is too short")
h, c = state

View File

@@ -117,6 +117,7 @@ class ReplyOnPause(StreamHandler):
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
self.input_sample_rate,
)
def determine_pause(

View File

@@ -0,0 +1,146 @@
import asyncio
import logging
import re
from typing import Literal
import librosa
import numpy as np
from .reply_on_pause import (
AlgoOptions,
AppState,
ReplyFnGenerator,
ReplyOnPause,
SileroVadOptions,
)
from .speech_to_text import get_stt_model, stt_for_chunks
from .utils import audio_to_float32
logger = logging.getLogger(__name__)
class ReplyOnStopWordsState(AppState):
stop_word_detected: bool = False
post_stop_word_buffer: np.ndarray | None = None
started_talking_pre_stop_word: bool = False
class ReplyOnStopWords(ReplyOnPause):
def __init__(
self,
fn: ReplyFnGenerator,
stop_words: list[str],
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
input_sample_rate: int = 48000,
):
super().__init__(
fn,
algo_options=algo_options,
model_options=model_options,
expected_layout=expected_layout,
output_sample_rate=output_sample_rate,
output_frame_size=output_frame_size,
input_sample_rate=input_sample_rate,
)
self.stop_words = stop_words
self.state = ReplyOnStopWordsState()
# Download Model
get_stt_model()
def stop_word_detected(self, text: str) -> bool:
for stop_word in self.stop_words:
stop_word = stop_word.lower().strip().split(" ")
if bool(
re.search(r"\b" + r"\s+".join(map(re.escape, stop_word)) + r"\b", text)
):
logger.debug("Stop word detected: %s", stop_word)
return True
return False
async def _send_stopword(
self,
):
if self.channel:
self.channel.send("stopword")
logger.debug("Sent stopword")
def send_stopword(self):
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
def determine_pause(
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
) -> bool:
"""Take in the stream, determine if a pause happened"""
duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration:
if not state.stop_word_detected:
audio_f32 = audio_to_float32((sampling_rate, audio))
audio_rs = librosa.resample(
audio_f32, orig_sr=sampling_rate, target_sr=16000
)
if state.post_stop_word_buffer is None:
state.post_stop_word_buffer = audio_rs
else:
state.post_stop_word_buffer = np.concatenate(
(state.post_stop_word_buffer, audio_rs)
)
if len(state.post_stop_word_buffer) / 16000 > 2:
state.post_stop_word_buffer = state.post_stop_word_buffer[-32000:]
dur_vad, chunks = self.model.vad(
(16000, state.post_stop_word_buffer),
self.model_options,
return_chunks=True,
)
text = stt_for_chunks((16000, state.post_stop_word_buffer), chunks)
logger.debug(f"STT: {text}")
state.stop_word_detected = self.stop_word_detected(text)
if state.stop_word_detected:
logger.debug("Stop word detected")
self.send_stopword()
state.buffer = None
else:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad)
if (
dur_vad > self.algo_options.started_talking_threshold
and not state.started_talking
and state.stop_word_detected
):
state.started_talking = True
logger.debug("Started talking")
if state.started_talking:
if state.stream is None:
state.stream = audio
else:
state.stream = np.concatenate((state.stream, audio))
state.buffer = None
if (
dur_vad < self.algo_options.speech_threshold
and state.started_talking
and state.stop_word_detected
):
return True
return False
def reset(self):
self.args_set.clear()
self.generator = None
self.event.clear()
self.state = ReplyOnStopWordsState()
def copy(self):
return ReplyOnStopWords(
self.fn,
self.stop_words,
self.algo_options,
self.model_options,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
self.input_sample_rate,
)

View File

@@ -0,0 +1,3 @@
from .stt_ import get_stt_model, stt, stt_for_chunks
__all__ = ["stt", "stt_for_chunks", "get_stt_model"]

View File

@@ -0,0 +1,52 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Callable
import numpy as np
from numpy.typing import NDArray
from silero import silero_stt
from ..utils import AudioChunk
@dataclass
class STTModel:
encoder: Callable
decoder: Callable
@lru_cache
def get_stt_model() -> STTModel:
model, decoder, _ = silero_stt(language="en", version="v6", jit_model="jit_xlarge")
return STTModel(model, decoder)
def stt(audio: tuple[int, NDArray[np.int16]]) -> str:
model = get_stt_model()
sr, audio_np = audio
if audio_np.dtype != np.float32:
print("converting")
audio_np = audio_np.astype(np.float32) / 32768.0
try:
import torch
except ImportError:
raise ImportError(
"PyTorch is required to run speech-to-text for stopword detection. Run `pip install torch`."
)
audio_torch = torch.tensor(audio_np, dtype=torch.float32)
if audio_torch.ndim == 1:
audio_torch = audio_torch.unsqueeze(0)
assert audio_torch.ndim == 2, "Audio must have a batch dimension"
print("before")
res = model.decoder(model.encoder(audio_torch)[0])
print("after")
return res
def stt_for_chunks(
audio: tuple[int, NDArray[np.int16]], chunks: list[AudioChunk]
) -> str:
sr, audio_np = audio
return " ".join(
[stt((sr, audio_np[chunk["start"] : chunk["end"]])) for chunk in chunks]
)

View File

@@ -3,7 +3,7 @@ import fractions
import io
import logging
import tempfile
from typing import Any, Callable, Protocol, cast
from typing import Any, Callable, Protocol, TypedDict, cast
import av
import numpy as np
@@ -15,6 +15,11 @@ logger = logging.getLogger(__name__)
AUDIO_PTIME = 0.020
class AudioChunk(TypedDict):
start: int
end: int
class AdditionalOutputs:
def __init__(self, *args) -> None:
self.args = args

View File

@@ -36,6 +36,19 @@
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-moshi/blob/main/app.py)
- :speaking_head:{ .lg .middle } __Hello Llama: Stop Word Detection__
---
A code editor built with Llama 3.3 70b that is triggered by the phrase "Hello Llama".
Build a Siri-like coding assistant in 100 lines of code!
<video width=98% src="https://github.com/user-attachments/assets/3e10cb15-ff1b-4b17-b141-ff0ad852e613" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](hhttps://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor/blob/main/app.py)
- :robot:{ .lg .middle } __Llama Code Editor__
---

View File

@@ -15,11 +15,16 @@ Stream video and audio in real time with Gradio using WebRTC.
pip install gradio_webrtc
```
to use built-in pause detection (see [Audio Streaming](https://freddyaboulton.github.io/gradio-webrtc/user-guide/#reply-on-pause)), install the `vad` extra:
to use built-in pause detection (see [ReplyOnPause](/user-guide/#reply-on-pause)), install the `vad` extra:
```bash
pip install gradio_webrtc[vad]
```
For stop word detection (see [ReplyOnStopWords](/user-guide/#reply-on-stopwords)), install the `stopword` extra:
```bash
pip install gradio_webrtc[stopword]
```
## Examples
See the [cookbook](/cookbook)

View File

@@ -65,6 +65,54 @@ and passing it to the `stream` event of the `WebRTC` component.
5. Set a `time_limit` to control how long a conversation will last. If the `concurrency_count` is 1 (default), only one conversation will be handled at a time.
### Reply On Stopwords
You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.
The API is similar to `ReplyOnPause` with the addition of a `stop_words` parameter.
=== "Code"
``` py title="ReplyonPause"
import gradio as gr
from gradio_webrtc import WebRTC, ReplyOnPause
def response(audio: tuple[int, np.ndarray]):
"""This function must yield audio frames"""
...
for numpy_array in generated_audio:
yield (sampling_rate, numpy_array, "mono")
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Chat (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Column():
with gr.Group():
audio = WebRTC(
mode="send",
modality="audio",
)
webrtc.stream(ReplyOnStopWords(generate,
input_sample_rate=16000,
stop_words=["computer"]), # (1)
inputs=[webrtc, history, code],
outputs=[webrtc], time_limit=90,
concurrency_limit=10)
demo.launch()
```
1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris".
=== "Notes"
1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris".
### Stream Handler
`ReplyOnPause` is an implementation of a `StreamHandler`. The `StreamHandler` is a low-level

View File

@@ -21,8 +21,6 @@
import AudioWave from "./AudioWave.svelte";
import WebcamPermissions from "./WebcamPermissions.svelte";
export let mode: "send-receive" | "send";
export let value: string | null = null;
export let label: string | undefined = undefined;
@@ -34,6 +32,28 @@
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
export let on_change_cb: (mg: "tick" | "change") => void;
let stopword_recognized = false;
let notification_sound;
onMount(() => {
if (value === "__webrtc_value__") {
notification_sound = new Audio("https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/pop-sounds.mp3");
}
});
let _on_change_cb = (msg: "change" | "tick" | "stopword") => {
if (msg === "stopword") {
console.log("stopword recognized");
stopword_recognized = true;
setTimeout(() => {
stopword_recognized = false;
}, 3000);
} else {
on_change_cb(msg);
}
};
let options_open = false;
let _time_limit: number | null = null;
@@ -144,7 +164,7 @@
}
if (stream == null) return;
start(stream, pc, mode === "send" ? null: audio_player, server.offer, _webrtc_id, "audio", on_change_cb, rtp_params).then((connection) => {
start(stream, pc, mode === "send" ? null: audio_player, server.offer, _webrtc_id, "audio", _on_change_cb, rtp_params).then((connection) => {
pc = connection;
}).catch(() => {
console.info("catching")
@@ -190,7 +210,9 @@
options_open = false;
};
$: if(stopword_recognized){
notification_sound.play();
}
</script>
@@ -220,7 +242,7 @@
{:else}
<AudioWave {audio_source_callback} {stream_state}/>
<StreamingBar time_limit={_time_limit} />
<div class="button-wrap">
<div class="button-wrap" class:pulse={stopword_recognized}>
<button
on:click={start_stream}
aria-label={"start stream"}
@@ -328,6 +350,27 @@
color: var(--button-secondary-text-color);
}
@keyframes pulse {
0% {
transform: scale(1);
box-shadow: 0 0 0 0 rgba(var(--primary-500-rgb), 0.7);
}
70% {
transform: scale(1.25);
box-shadow: 0 0 0 10px rgba(var(--primary-500-rgb), 0);
}
100% {
transform: scale(1);
box-shadow: 0 0 0 0 rgba(var(--primary-500-rgb), 0);
}
}
.pulse {
animation: pulse 1s infinite;
}
.icon-with-text {
width: var(--size-20);
align-items: center;

View File

@@ -64,7 +64,7 @@ export async function start(
data_channel.onmessage = (event) => {
console.debug("Received message:", event.data);
if (event.data === "change" || event.data === "tick") {
if (event.data === "change" || event.data === "tick" || event.data === "stopword") {
console.debug(`${event.data} event received`);
on_change_cb(event.data);
}

View File

@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
[project]
name = "gradio_webrtc"
version = "0.0.20"
version = "0.0.21"
description = "Stream images in realtime with webrtc"
readme = "README.md"
license = "apache-2.0"
@@ -44,6 +44,7 @@ classifiers = [
[project.optional-dependencies]
dev = ["build", "twine"]
vad = ["librosa", "onnxruntime"]
stopword = ["silero", "librosa", "onnxruntime"]
[tool.hatch.build]
artifacts = ["/backend/gradio_webrtc/templates", "*.pyi"]