t :# 请为您的变更输入提交说明。以 '#' 开始的行将被忽略,而一个空的提交

This commit is contained in:
杍超
2025-01-21 13:55:45 +08:00
commit 8a313bd700
62 changed files with 14687 additions and 0 deletions

28
.github/workflows/docs.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
name: docs
on:
push:
branches:
- main
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Configure Git Credentials
run: |
git config user.name github-actions[bot]
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
- uses: actions/setup-python@v5
with:
python-version: 3.x
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v4
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-
- run: pip install mkdocs-material
- run: mkdocs gh-deploy --force

17
.gitignore vendored Normal file
View File

@@ -0,0 +1,17 @@
.eggs/
dist/
*.pyc
__pycache__/
*.py[cod]
*$py.class
__tmp/*
*.pyi
.mypycache
.ruff_cache
node_modules
backend/**/templates/
demo/MobileNetSSD_deploy.caffemodel
demo/MobileNetSSD_deploy.prototxt.txt
.DS_Store
test/
.env

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Freddy Boulton
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

441
README.md Normal file
View File

@@ -0,0 +1,441 @@
<h1 style='text-align: center; margin-bottom: 1rem'> Gradio WebRTC ⚡️ </h1>
<div style="display: flex; flex-direction: row; justify-content: center">
<img style="display: block; padding-right: 5px; height: 20px;" alt="Static Badge" src="https://img.shields.io/pypi/v/gradio_webrtc">
<a href="https://github.com/freddyaboulton/gradio-webrtc" target="_blank"><img alt="Static Badge" style="display: block; padding-right: 5px; height: 20px;" src="https://img.shields.io/badge/github-white?logo=github&logoColor=black"></a>
<a href="https://freddyaboulton.github.io/gradio-webrtc/" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/Docs-ffcf40"></a>
</div>
<h3 style='text-align: center'>
Stream video and audio in real time with Gradio using WebRTC.
</h3>
## Installation
```bash
pip install gradio_webrtc
```
to use built-in pause detection (see [ReplyOnPause](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-pause)), install the `vad` extra:
```bash
pip install gradio_webrtc[vad]
```
For stop word detection (see [ReplyOnStopWords](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-stopwords)), install the `stopword` extra:
```bash
pip install gradio_webrtc[stopword]
```
## Docs
https://freddyaboulton.github.io/gradio-webrtc/
## Examples
<table>
<tr>
<td width="50%">
<h3>🗣️ Audio Input/Output with mini-omni2</h3>
<p>Build a GPT-4o like experience with mini-omni2, an audio-native LLM.</p>
<video width="100%" src="https://github.com/user-attachments/assets/58c06523-fc38-4f5f-a4ba-a02a28e7fa9e" controls></video>
<p>
<a href="https://huggingface.co/spaces/freddyaboulton/mini-omni2-webrtc">Demo</a> |
<a href="https://huggingface.co/spaces/freddyaboulton/mini-omni2-webrtc/blob/main/app.py">Code</a>
</p>
</td>
<td width="50%">
<h3>🗣️ Talk to Claude</h3>
<p>Use the Anthropic and Play.Ht APIs to have an audio conversation with Claude.</p>
<video width="100%" src="https://github.com/user-attachments/assets/650bc492-798e-4995-8cef-159e1cfc2185" controls></video>
<p>
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-claude">Demo</a> |
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-claude/blob/main/app.py">Code</a>
</p>
</td>
</tr>
<tr>
<td width="50%">
<h3>🗣️ Kyutai Moshi</h3>
<p>Kyutai's moshi is a novel speech-to-speech model for modeling human conversations.</p>
<video width="100%" src="https://github.com/user-attachments/assets/becc7a13-9e89-4a19-9df2-5fb1467a0137" controls></video>
<p>
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-moshi">Demo</a> |
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-moshi/blob/main/app.py">Code</a>
</p>
</td>
<td width="50%">
<h3>🗣️ Hello Llama: Stop Word Detection</h3>
<p>A code editor built with Llama 3.3 70b that is triggered by the phrase "Hello Llama". Build a Siri-like coding assistant in 100 lines of code!</p>
<video width="100%" src="https://github.com/user-attachments/assets/3e10cb15-ff1b-4b17-b141-ff0ad852e613" controls></video>
<p>
<a href="https://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor">Demo</a> |
<a href="https://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor/blob/main/app.py">Code</a>
</p>
</td>
</tr>
<tr>
<td width="50%">
<h3>🤖 Llama Code Editor</h3>
<p>Create and edit HTML pages with just your voice! Powered by SambaNova systems.</p>
<video width="100%" src="https://github.com/user-attachments/assets/a09647f1-33e1-4154-a5a3-ffefda8a736a" controls></video>
<p>
<a href="https://huggingface.co/spaces/freddyaboulton/llama-code-editor">Demo</a> |
<a href="https://huggingface.co/spaces/freddyaboulton/llama-code-editor/blob/main/app.py">Code</a>
</p>
</td>
<td width="50%">
<h3>🗣️ Talk to Ultravox</h3>
<p>Talk to Fixie.AI's audio-native Ultravox LLM with the transformers library.</p>
<video width="100%" src="https://github.com/user-attachments/assets/e6e62482-518c-4021-9047-9da14cd82be1" controls></video>
<p>
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-ultravox">Demo</a> |
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-ultravox/blob/main/app.py">Code</a>
</p>
</td>
</tr>
<tr>
<td width="50%">
<h3>🗣️ Talk to Llama 3.2 3b</h3>
<p>Use the Lepton API to make Llama 3.2 talk back to you!</p>
<video width="100%" src="https://github.com/user-attachments/assets/3ee37a6b-0892-45f5-b801-73188fdfad9a" controls></video>
<p>
<a href="https://huggingface.co/spaces/freddyaboulton/llama-3.2-3b-voice-webrtc">Demo</a> |
<a href="https://huggingface.co/spaces/freddyaboulton/llama-3.2-3b-voice-webrtc/blob/main/app.py">Code</a>
</p>
</td>
<td width="50%">
<h3>🤖 Talk to Qwen2-Audio</h3>
<p>Qwen2-Audio is a SOTA audio-to-text LLM developed by Alibaba.</p>
<video width="100%" src="https://github.com/user-attachments/assets/c821ad86-44cc-4d0c-8dc4-8c02ad1e5dc8" controls></video>
<p>
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-qwen-webrtc">Demo</a> |
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-qwen-webrtc/blob/main/app.py">Code</a>
</p>
</td>
</tr>
<tr>
<td width="50%">
<h3>📷 Yolov10 Object Detection</h3>
<p>Run the Yolov10 model on a user webcam stream in real time!</p>
<video width="100%" src="https://github.com/user-attachments/assets/c90d8c9d-d2d5-462e-9e9b-af969f2ea73c" controls></video>
<p>
<a href="https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n">Demo</a> |
<a href="https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n/blob/main/app.py">Code</a>
</p>
</td>
<td width="50%">
<h3>📷 Video Object Detection with RT-DETR</h3>
<p>Upload a video and stream out frames with detected objects (powered by RT-DETR) model.</p>
<p>
<a href="https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc">Demo</a> |
<a href="https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc/blob/main/app.py">Code</a>
</p>
</td>
</tr>
<tr>
<td width="50%">
<h3>🔊 Text-to-Speech with Parler</h3>
<p>Stream out audio generated by Parler TTS!</p>
<p>
<a href="https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc">Demo</a> |
<a href="https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc/blob/main/app.py">Code</a>
</p>
</td>
<td width="50%">
</td>
</tr>
</table>
## Usage
This is an shortened version of the official [usage guide](https://freddyaboulton.github.io/gradio-webrtc/user-guide/).
To get started with WebRTC streams, all that's needed is to import the `WebRTC` component from this package and implement its `stream` event.
### Reply on Pause
Typically, you want to run an AI model that generates audio when the user has stopped speaking. This can be done by wrapping a python generator with the `ReplyOnPause` class
and passing it to the `stream` event of the `WebRTC` component.
```py
import gradio as gr
from gradio_webrtc import WebRTC, ReplyOnPause
def response(audio: tuple[int, np.ndarray]): # (1)
"""This function must yield audio frames"""
...
for numpy_array in generated_audio:
yield (sampling_rate, numpy_array, "mono") # (2)
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Chat (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Column():
with gr.Group():
audio = WebRTC(
mode="send-receive", # (3)
modality="audio",
)
audio.stream(fn=ReplyOnPause(response),
inputs=[audio], outputs=[audio], # (4)
time_limit=60) # (5)
demo.launch()
```
1. The python generator will receive the **entire** audio up until the user stopped. It will be a tuple of the form (sampling_rate, numpy array of audio). The array will have a shape of (1, num_samples). You can also pass in additional input components.
2. The generator must yield audio chunks as a tuple of (sampling_rate, numpy audio array). Each numpy audio array must have a shape of (1, num_samples).
3. The `mode` and `modality` arguments must be set to `"send-receive"` and `"audio"`.
4. The `WebRTC` component must be the first input and output component.
5. Set a `time_limit` to control how long a conversation will last. If the `concurrency_count` is 1 (default), only one conversation will be handled at a time.
### Reply On Stopwords
You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.
The API is similar to `ReplyOnPause` with the addition of a `stop_words` parameter.
```py
import gradio as gr
from gradio_webrtc import WebRTC, ReplyOnPause
def response(audio: tuple[int, np.ndarray]):
"""This function must yield audio frames"""
...
for numpy_array in generated_audio:
yield (sampling_rate, numpy_array, "mono")
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Chat (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Column():
with gr.Group():
audio = WebRTC(
mode="send",
modality="audio",
)
webrtc.stream(ReplyOnStopWords(generate,
input_sample_rate=16000,
stop_words=["computer"]), # (1)
inputs=[webrtc, history, code],
outputs=[webrtc], time_limit=90,
concurrency_limit=10)
demo.launch()
```
1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris".
### Audio Server-To-Clien
To stream only from the server to the client, implement a python generator and pass it to the component's `stream` event. The stream event must also specify a `trigger` corresponding to a UI interaction that starts the stream. In this case, it's a button click.
```py
import gradio as gr
from gradio_webrtc import WebRTC
from pydub import AudioSegment
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file("audio_file.wav")
array = np.array(segment.get_array_of_samples()).reshape(1, -1)
yield (segment.frame_rate, array)
with gr.Blocks() as demo:
audio = WebRTC(label="Stream", mode="receive", # (1)
modality="audio")
num_steps = gr.Slider(label="Number of Steps", minimum=1,
maximum=10, step=1, value=5)
button = gr.Button("Generate")
audio.stream(
fn=generation, inputs=[num_steps], outputs=[audio],
trigger=button.click # (2)
)
```
1. Set `mode="receive"` to only receive audio from the server.
2. The `stream` event must take a `trigger` that corresponds to the gradio event that starts the stream. In this case, it's the button click.
### Video Input/Output Streaming
Set up a video Input/Output stream to continuosly receive webcam frames from the user and run an arbitrary python function to return a modified frame.
```py
import gradio as gr
from gradio_webrtc import WebRTC
def detection(image, conf_threshold=0.3): # (1)
... your detection code here ...
return modified_frame # (2)
with gr.Blocks() as demo:
image = WebRTC(label="Stream", mode="send-receive", modality="video") # (3)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
image.stream(
fn=detection,
inputs=[image, conf_threshold], # (4)
outputs=[image], time_limit=10
)
if __name__ == "__main__":
demo.launch()
```
1. The webcam frame will be represented as a numpy array of shape (height, width, RGB).
2. The function must return a numpy array. It can take arbitrary values from other components.
3. Set the `modality="video"` and `mode="send-receive"`
4. The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
### Server-to-Client Only
Set up a server-to-client stream to stream video from an arbitrary user interaction.
```py
import gradio as gr
from gradio_webrtc import WebRTC
import cv2
def generation():
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
cap = cv2.VideoCapture(url)
iterating = True
while iterating:
iterating, frame = cap.read()
yield frame # (1)
with gr.Blocks() as demo:
output_video = WebRTC(label="Video Stream", mode="receive", # (2)
modality="video")
button = gr.Button("Start", variant="primary")
output_video.stream(
fn=generation, inputs=None, outputs=[output_video],
trigger=button.click # (3)
)
demo.launch()
```
1. The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
2. Set `mode="receive"` to only receive audio from the server.
3. The `trigger` parameter the gradio event that will trigger the stream. In this case, the button click event.
### Additional Outputs
In order to modify other components from within the WebRTC stream, you must yield an instance of `AdditionalOutputs` and add an `on_additional_outputs` event to the `WebRTC` component.
This is common for displaying a multimodal text/audio conversation in a Chatbot UI.
``` py title="Additional Outputs"
from gradio_webrtc import AdditionalOutputs, WebRTC
def transcribe(audio: tuple[int, np.ndarray],
transformers_convo: list[dict],
gradio_convo: list[dict]):
response = model.generate(**inputs, max_length=256)
transformers_convo.append({"role": "assistant", "content": response})
gradio_convo.append({"role": "assistant", "content": response})
yield AdditionalOutputs(transformers_convo, gradio_convo) # (1)
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Talk to Qwen2Audio (Powered by WebRTC ⚡️)
</h1>
"""
)
transformers_convo = gr.State(value=[])
with gr.Row():
with gr.Column():
audio = WebRTC(
label="Stream",
mode="send", # (2)
modality="audio",
)
with gr.Column():
transcript = gr.Chatbot(label="transcript", type="messages")
audio.stream(ReplyOnPause(transcribe),
inputs=[audio, transformers_convo, transcript],
outputs=[audio], time_limit=90)
audio.on_additional_outputs(lambda s,a: (s,a), # (3)
outputs=[transformers_convo, transcript],
queue=False, show_progress="hidden")
demo.launch()
```
1. Pass your data to `AdditionalOutputs` and yield it.
2. In this case, no audio is being returned, so we set `mode="send"`. However, if we set `mode="send-receive"`, we could also yield generated audio and `AdditionalOutputs`.
3. The `on_additional_outputs` event does not take `inputs`. It's common practice to not run this event on the queue since it is just a quick UI update.
=== "Notes"
1. Pass your data to `AdditionalOutputs` and yield it.
2. In this case, no audio is being returned, so we set `mode="send"`. However, if we set `mode="send-receive"`, we could also yield generated audio and `AdditionalOutputs`.
3. The `on_additional_outputs` event does not take `inputs`. It's common practice to not run this event on the queue since it is just a quick UI update.
## Deployment
When deploying in a cloud environment (like Hugging Face Spaces, EC2, etc), you need to set up a TURN server to relay the WebRTC traffic.
The easiest way to do this is to use a service like Twilio.
```python
from twilio.rest import Client
import os
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
with gr.Blocks() as demo:
...
rtc = WebRTC(rtc_configuration=rtc_configuration, ...)
...
```

View File

@@ -0,0 +1,54 @@
from .credentials import (
get_hf_turn_credentials,
get_turn_credentials,
get_twilio_turn_credentials,
)
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions
from .reply_on_stopwords import ReplyOnStopWords
from .speech_to_text import stt, stt_for_chunks
from .utils import (
AdditionalOutputs,
Warning,
WebRTCError,
aggregate_bytes_to_16bit,
async_aggregate_bytes_to_16bit,
audio_to_bytes,
audio_to_file,
audio_to_float32,
)
from .webrtc import (
AsyncAudioVideoStreamHandler,
AsyncStreamHandler,
AudioVideoStreamHandler,
StreamHandler,
WebRTC,
VideoEmitType,
AudioEmitType,
)
__all__ = [
"AsyncStreamHandler",
"AudioVideoStreamHandler",
"AudioEmitType",
"AsyncAudioVideoStreamHandler",
"AlgoOptions",
"AdditionalOutputs",
"aggregate_bytes_to_16bit",
"async_aggregate_bytes_to_16bit",
"audio_to_bytes",
"audio_to_file",
"audio_to_float32",
"get_hf_turn_credentials",
"get_twilio_turn_credentials",
"get_turn_credentials",
"ReplyOnPause",
"ReplyOnStopWords",
"SileroVadOptions",
"stt",
"stt_for_chunks",
"StreamHandler",
"VideoEmitType",
"WebRTC",
"WebRTCError",
"Warning",
]

View File

@@ -0,0 +1,52 @@
import os
from typing import Literal
import requests
def get_hf_turn_credentials(token=None):
if token is None:
token = os.getenv("HF_TOKEN")
credentials = requests.get(
"https://freddyaboulton-turn-server-login.hf.space/credentials",
headers={"X-HF-Access-Token": token},
)
if not credentials.status_code == 200:
raise ValueError("Failed to get credentials from HF turn server")
return {
"iceServers": [
{
"urls": "turn:gradio-turn.com:80",
**credentials.json(),
},
]
}
def get_twilio_turn_credentials(twilio_sid=None, twilio_token=None):
try:
from twilio.rest import Client
except ImportError:
raise ImportError("Please install twilio with `pip install twilio`")
if not twilio_sid and not twilio_token:
twilio_sid = os.environ.get("TWILIO_ACCOUNT_SID")
twilio_token = os.environ.get("TWILIO_AUTH_TOKEN")
client = Client(twilio_sid, twilio_token)
token = client.tokens.create()
return {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
def get_turn_credentials(method: Literal["hf", "twilio"] = "hf", **kwargs):
if method == "hf":
return get_hf_turn_credentials(**kwargs)
elif method == "twilio":
return get_twilio_turn_credentials(**kwargs)
else:
raise ValueError("Invalid method. Must be 'hf' or 'twilio'")

View File

@@ -0,0 +1,3 @@
from .vad import SileroVADModel, SileroVadOptions
__all__ = ["SileroVADModel", "SileroVadOptions"]

View File

@@ -0,0 +1,320 @@
import logging
import warnings
from dataclasses import dataclass
from typing import List, Literal, overload
import numpy as np
from huggingface_hub import hf_hub_download
from numpy.typing import NDArray
from ..utils import AudioChunk
logger = logging.getLogger(__name__)
# The code below is adapted from https://github.com/snakers4/silero-vad.
# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py
@dataclass
class SileroVadOptions:
"""VAD options.
Attributes:
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
than max_speech_duration_s will be split at the timestamp of the last silence that
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
split aggressively just before max_speech_duration_s.
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
before separating it
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
Values other than these may affect model performance!!
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
speech_duration: If the length of the speech is less than this value, a pause will be detected.
"""
threshold: float = 0.5
min_speech_duration_ms: int = 250
max_speech_duration_s: float = float("inf")
min_silence_duration_ms: int = 2000
window_size_samples: int = 1024
speech_pad_ms: int = 400
class SileroVADModel:
@staticmethod
def download_model() -> str:
return hf_hub_download(
repo_id="freddyaboulton/silero-vad", filename="silero_vad.onnx"
)
def __init__(self):
try:
import onnxruntime
except ImportError as e:
raise RuntimeError(
"Applying the VAD filter requires the onnxruntime package"
) from e
path = self.download_model()
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
opts.log_severity_level = 4
self.session = onnxruntime.InferenceSession(
path,
providers=["CPUExecutionProvider"],
sess_options=opts,
)
def get_initial_state(self, batch_size: int):
h = np.zeros((2, batch_size, 64), dtype=np.float32)
c = np.zeros((2, batch_size, 64), dtype=np.float32)
return h, c
@staticmethod
def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray:
"""Collects and concatenates audio chunks."""
if not chunks:
return np.array([], dtype=np.float32)
return np.concatenate(
[audio[chunk["start"] : chunk["end"]] for chunk in chunks]
)
def get_speech_timestamps(
self,
audio: np.ndarray,
vad_options: SileroVadOptions,
**kwargs,
) -> List[AudioChunk]:
"""This method is used for splitting long audios into speech chunks using silero VAD.
Args:
audio: One dimensional float array.
vad_options: Options for VAD processing.
kwargs: VAD options passed as keyword arguments for backward compatibility.
Returns:
List of dicts containing begin and end samples of each speech chunk.
"""
threshold = vad_options.threshold
min_speech_duration_ms = vad_options.min_speech_duration_ms
max_speech_duration_s = vad_options.max_speech_duration_s
min_silence_duration_ms = vad_options.min_silence_duration_ms
window_size_samples = vad_options.window_size_samples
speech_pad_ms = vad_options.speech_pad_ms
if window_size_samples not in [512, 1024, 1536]:
warnings.warn(
"Unusual window_size_samples! Supported window_size_samples:\n"
" - [512, 1024, 1536] for 16000 sampling_rate"
)
sampling_rate = 16000
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
max_speech_samples = (
sampling_rate * max_speech_duration_s
- window_size_samples
- 2 * speech_pad_samples
)
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
audio_length_samples = len(audio)
state = self.get_initial_state(batch_size=1)
speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[
current_start_sample : current_start_sample + window_size_samples
]
if len(chunk) < window_size_samples:
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob, state = self(chunk, state, sampling_rate)
speech_probs.append(speech_prob)
triggered = False
speeches = []
current_speech = {}
neg_threshold = threshold - 0.15
# to save potential segment end (and tolerate some silence)
temp_end = 0
# to save potential segment limits in case of maximum segment size reached
prev_end = next_start = 0
for i, speech_prob in enumerate(speech_probs):
if (speech_prob >= threshold) and temp_end:
temp_end = 0
if next_start < prev_end:
next_start = window_size_samples * i
if (speech_prob >= threshold) and not triggered:
triggered = True
current_speech["start"] = window_size_samples * i
continue
if (
triggered
and (window_size_samples * i) - current_speech["start"]
> max_speech_samples
):
if prev_end:
current_speech["end"] = prev_end
speeches.append(current_speech)
current_speech = {}
# previously reached silence (< neg_thres) and is still not speech (< thres)
if next_start < prev_end:
triggered = False
else:
current_speech["start"] = next_start
prev_end = next_start = temp_end = 0
else:
current_speech["end"] = window_size_samples * i
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if (speech_prob < neg_threshold) and triggered:
if not temp_end:
temp_end = window_size_samples * i
# condition to avoid cutting in very short silence
if (
window_size_samples * i
) - temp_end > min_silence_samples_at_max_speech:
prev_end = temp_end
if (window_size_samples * i) - temp_end < min_silence_samples:
continue
else:
current_speech["end"] = temp_end
if (
current_speech["end"] - current_speech["start"]
) > min_speech_samples:
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if (
current_speech
and (audio_length_samples - current_speech["start"]) > min_speech_samples
):
current_speech["end"] = audio_length_samples
speeches.append(current_speech)
for i, speech in enumerate(speeches):
if i == 0:
speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
if i != len(speeches) - 1:
silence_duration = speeches[i + 1]["start"] - speech["end"]
if silence_duration < 2 * speech_pad_samples:
speech["end"] += int(silence_duration // 2)
speeches[i + 1]["start"] = int(
max(0, speeches[i + 1]["start"] - silence_duration // 2)
)
else:
speech["end"] = int(
min(audio_length_samples, speech["end"] + speech_pad_samples)
)
speeches[i + 1]["start"] = int(
max(0, speeches[i + 1]["start"] - speech_pad_samples)
)
else:
speech["end"] = int(
min(audio_length_samples, speech["end"] + speech_pad_samples)
)
return speeches
@overload
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: Literal[True],
) -> tuple[float, List[AudioChunk]]: ...
@overload
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float: ...
def vad(
self,
audio_tuple: tuple[int, NDArray],
vad_parameters: None | SileroVadOptions,
return_chunks: bool = False,
) -> float | tuple[float, List[AudioChunk]]:
sampling_rate, audio = audio_tuple
logger.debug("VAD audio shape input: %s", audio.shape)
try:
if audio.dtype != np.float32:
audio = audio.astype(np.float32) / 32768.0
sr = 16000
if sr != sampling_rate:
try:
import librosa # type: ignore
except ImportError as e:
raise RuntimeError(
"Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz"
) from e
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
if not vad_parameters:
vad_parameters = SileroVadOptions()
speech_chunks = self.get_speech_timestamps(audio, vad_parameters)
logger.debug("VAD speech chunks: %s", speech_chunks)
audio = self.collect_chunks(audio, speech_chunks)
logger.debug("VAD audio shape: %s", audio.shape)
duration_after_vad = audio.shape[0] / sr
if return_chunks:
return duration_after_vad, speech_chunks
return duration_after_vad
except Exception as e:
import math
import traceback
logger.debug("VAD Exception: %s", str(e))
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
return math.inf
def __call__(self, x, state, sr: int):
if len(x.shape) == 1:
x = np.expand_dims(x, 0)
if len(x.shape) > 2:
raise ValueError(
f"Too many dimensions for input audio chunk {len(x.shape)}"
)
if sr / x.shape[1] > 31.25: # type: ignore
raise ValueError("Input audio chunk is too short")
h, c = state
ort_inputs = {
"input": x,
"h": h,
"c": c,
"sr": np.array(sr, dtype="int64"),
}
out, h, c = self.session.run(None, ort_inputs)
state = (h, c)
return out, state

View File

@@ -0,0 +1,188 @@
import asyncio
import inspect
from dataclasses import dataclass
from functools import lru_cache
from logging import getLogger
from threading import Event
from typing import Any, Callable, Generator, Literal, Union, cast
import numpy as np
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
from gradio_webrtc.webrtc import EmitType, StreamHandler
logger = getLogger(__name__)
counter = 0
@lru_cache
def get_vad_model() -> SileroVADModel:
"""Returns the VAD model instance."""
return SileroVADModel()
@dataclass
class AlgoOptions:
"""Algorithm options."""
audio_chunk_duration: float = 0.6
started_talking_threshold: float = 0.2
speech_threshold: float = 0.1
@dataclass
class AppState:
stream: np.ndarray | None = None
sampling_rate: int = 0
pause_detected: bool = False
started_talking: bool = False
responding: bool = False
stopped: bool = False
buffer: np.ndarray | None = None
ReplyFnGenerator = Union[
# For two arguments
Callable[
[tuple[int, np.ndarray], list[dict[Any, Any]]],
Generator[EmitType, None, None],
],
Callable[
[tuple[int, np.ndarray]],
Generator[EmitType, None, None],
],
]
async def iterate(generator: Generator) -> Any:
return next(generator)
class ReplyOnPause(StreamHandler):
def __init__(
self,
fn: ReplyFnGenerator,
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
input_sample_rate: int = 48000,
):
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=input_sample_rate,
)
self.expected_layout: Literal["mono", "stereo"] = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.model = get_vad_model()
self.fn = fn
self.is_async = inspect.isasyncgenfunction(fn)
self.event = Event()
self.state = AppState()
self.generator: Generator[EmitType, None, None] | None = None
self.model_options = model_options
self.algo_options = algo_options or AlgoOptions()
@property
def _needs_additional_inputs(self) -> bool:
return len(inspect.signature(self.fn).parameters) > 1
def copy(self):
return ReplyOnPause(
self.fn,
self.algo_options,
self.model_options,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
self.input_sample_rate,
)
def determine_pause(
self, audio: np.ndarray, sampling_rate: int, state: AppState
) -> bool:
"""Take in the stream, determine if a pause happened"""
duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad)
if (
dur_vad > self.algo_options.started_talking_threshold
and not state.started_talking
):
state.started_talking = True
logger.debug("Started talking")
if state.started_talking:
if state.stream is None:
state.stream = audio
else:
state.stream = np.concatenate((state.stream, audio))
state.buffer = None
if dur_vad < self.algo_options.speech_threshold and state.started_talking:
return True
return False
def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
frame_rate, array = audio
array = np.squeeze(array)
if not state.sampling_rate:
state.sampling_rate = frame_rate
if state.buffer is None:
state.buffer = array
else:
state.buffer = np.concatenate((state.buffer, array))
pause_detected = self.determine_pause(
state.buffer, state.sampling_rate, self.state
)
state.pause_detected = pause_detected
def receive(self, frame: tuple[int, np.ndarray]) -> None:
if self.state.responding:
return
self.process_audio(frame, self.state)
if self.state.pause_detected:
self.event.set()
def reset(self):
super().reset()
self.generator = None
self.event.clear()
self.state = AppState()
async def async_iterate(self, generator) -> EmitType:
return await anext(generator)
def emit(self):
if not self.event.is_set():
return None
else:
if not self.generator:
if self._needs_additional_inputs and not self.args_set.is_set():
asyncio.run_coroutine_threadsafe(
self.wait_for_args(), self.loop
).result()
logger.debug("Creating generator")
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
if self._needs_additional_inputs:
self.latest_args[0] = (self.state.sampling_rate, audio)
self.generator = self.fn(*self.latest_args)
else:
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
logger.debug("Latest args: %s", self.latest_args)
self.state.responding = True
try:
if self.is_async:
return asyncio.run_coroutine_threadsafe(
self.async_iterate(self.generator), self.loop
).result()
else:
return next(self.generator)
except (StopIteration, StopAsyncIteration):
self.reset()

View File

@@ -0,0 +1,147 @@
import asyncio
import logging
import re
from typing import Literal
import numpy as np
from .reply_on_pause import (
AlgoOptions,
AppState,
ReplyFnGenerator,
ReplyOnPause,
SileroVadOptions,
)
from .speech_to_text import get_stt_model, stt_for_chunks
from .utils import audio_to_float32
logger = logging.getLogger(__name__)
class ReplyOnStopWordsState(AppState):
stop_word_detected: bool = False
post_stop_word_buffer: np.ndarray | None = None
started_talking_pre_stop_word: bool = False
class ReplyOnStopWords(ReplyOnPause):
def __init__(
self,
fn: ReplyFnGenerator,
stop_words: list[str],
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
input_sample_rate: int = 48000,
):
super().__init__(
fn,
algo_options=algo_options,
model_options=model_options,
expected_layout=expected_layout,
output_sample_rate=output_sample_rate,
output_frame_size=output_frame_size,
input_sample_rate=input_sample_rate,
)
self.stop_words = stop_words
self.state = ReplyOnStopWordsState()
# Download Model
get_stt_model()
def stop_word_detected(self, text: str) -> bool:
for stop_word in self.stop_words:
stop_word = stop_word.lower().strip().split(" ")
if bool(
re.search(r"\b" + r"\s+".join(map(re.escape, stop_word)) + r"\b", text)
):
logger.debug("Stop word detected: %s", stop_word)
return True
return False
async def _send_stopword(
self,
):
if self.channel:
self.channel.send("stopword")
logger.debug("Sent stopword")
def send_stopword(self):
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
def determine_pause( # type: ignore
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
) -> bool:
"""Take in the stream, determine if a pause happened"""
import librosa
duration = len(audio) / sampling_rate
if duration >= self.algo_options.audio_chunk_duration:
if not state.stop_word_detected:
audio_f32 = audio_to_float32((sampling_rate, audio))
audio_rs = librosa.resample(
audio_f32, orig_sr=sampling_rate, target_sr=16000
)
if state.post_stop_word_buffer is None:
state.post_stop_word_buffer = audio_rs
else:
state.post_stop_word_buffer = np.concatenate(
(state.post_stop_word_buffer, audio_rs)
)
if len(state.post_stop_word_buffer) / 16000 > 2:
state.post_stop_word_buffer = state.post_stop_word_buffer[-32000:]
dur_vad, chunks = self.model.vad(
(16000, state.post_stop_word_buffer),
self.model_options,
return_chunks=True,
)
text = stt_for_chunks((16000, state.post_stop_word_buffer), chunks)
logger.debug(f"STT: {text}")
state.stop_word_detected = self.stop_word_detected(text)
if state.stop_word_detected:
logger.debug("Stop word detected")
self.send_stopword()
state.buffer = None
else:
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
logger.debug("VAD duration: %s", dur_vad)
if (
dur_vad > self.algo_options.started_talking_threshold
and not state.started_talking
and state.stop_word_detected
):
state.started_talking = True
logger.debug("Started talking")
if state.started_talking:
if state.stream is None:
state.stream = audio
else:
state.stream = np.concatenate((state.stream, audio))
state.buffer = None
if (
dur_vad < self.algo_options.speech_threshold
and state.started_talking
and state.stop_word_detected
):
return True
return False
def reset(self):
super().reset()
self.generator = None
self.event.clear()
self.state = ReplyOnStopWordsState()
def copy(self):
return ReplyOnStopWords(
self.fn,
self.stop_words,
self.algo_options,
self.model_options,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
self.input_sample_rate,
)

View File

@@ -0,0 +1,3 @@
from .stt_ import get_stt_model, stt, stt_for_chunks
__all__ = ["stt", "stt_for_chunks", "get_stt_model"]

View File

@@ -0,0 +1,53 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Callable
import numpy as np
from numpy.typing import NDArray
from ..utils import AudioChunk
@dataclass
class STTModel:
encoder: Callable
decoder: Callable
@lru_cache
def get_stt_model() -> STTModel:
from silero import silero_stt
model, decoder, _ = silero_stt(language="en", version="v6", jit_model="jit_xlarge")
return STTModel(model, decoder)
def stt(audio: tuple[int, NDArray[np.int16]]) -> str:
model = get_stt_model()
sr, audio_np = audio
if audio_np.dtype != np.float32:
print("converting")
audio_np = audio_np.astype(np.float32) / 32768.0
try:
import torch
except ImportError:
raise ImportError(
"PyTorch is required to run speech-to-text for stopword detection. Run `pip install torch`."
)
audio_torch = torch.tensor(audio_np, dtype=torch.float32)
if audio_torch.ndim == 1:
audio_torch = audio_torch.unsqueeze(0)
assert audio_torch.ndim == 2, "Audio must have a batch dimension"
print("before")
res = model.decoder(model.encoder(audio_torch)[0])
print("after")
return res
def stt_for_chunks(
audio: tuple[int, NDArray[np.int16]], chunks: list[AudioChunk]
) -> str:
sr, audio_np = audio
return " ".join(
[stt((sr, audio_np[chunk["start"] : chunk["end"]])) for chunk in chunks]
)

View File

@@ -0,0 +1,313 @@
import asyncio
import fractions
import io
import json
import logging
import tempfile
from contextvars import ContextVar
from typing import Any, Callable, Protocol, TypedDict, cast
import av
import numpy as np
from pydub import AudioSegment
logger = logging.getLogger(__name__)
AUDIO_PTIME = 0.020
class AudioChunk(TypedDict):
start: int
end: int
class AdditionalOutputs:
def __init__(self, *args) -> None:
self.args = args
class DataChannel(Protocol):
def send(self, message: str) -> None: ...
current_channel: ContextVar[DataChannel | None] = ContextVar(
"current_channel", default=None
)
def _send_log(message: str, type: str) -> None:
async def _send(channel: DataChannel) -> None:
channel.send(
json.dumps(
{
"type": type,
"message": message,
}
)
)
if channel := current_channel.get():
print("channel", channel)
try:
loop = asyncio.get_running_loop()
asyncio.run_coroutine_threadsafe(_send(channel), loop)
except RuntimeError:
asyncio.run(_send(channel))
def Warning( # noqa: N802
message: str = "Warning issued.",
):
"""
Send a warning message that is deplayed in the UI of the application.
Parameters
----------
audio : str
The warning message to send
Returns
-------
None
"""
_send_log(message, "warning")
class WebRTCError(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)
_send_log(message, "error")
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
if isinstance(data, AdditionalOutputs):
return None, data
if isinstance(data, tuple):
# handle the bare audio case
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
return data, None
if not len(data) == 2:
raise ValueError(
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
)
if not isinstance(data[-1], AdditionalOutputs):
raise ValueError(
"The last element of the tuple must be an instance of AdditionalOutputs."
)
return data[0], cast(AdditionalOutputs, data[1])
return data, None
async def player_worker_decode(
next_frame: Callable,
queue: asyncio.Queue,
thread_quit: asyncio.Event,
channel: Callable[[], DataChannel | None] | None,
set_additional_outputs: Callable | None,
quit_on_none: bool = False,
sample_rate: int = 48000,
frame_size: int = int(48000 * AUDIO_PTIME),
):
audio_samples = 0
audio_time_base = fractions.Fraction(1, sample_rate)
audio_resampler = av.AudioResampler( # type: ignore
format="s16",
layout="stereo",
rate=sample_rate,
frame_size=frame_size,
)
while not thread_quit.is_set():
try:
# Get next frame
frame, outputs = split_output(
await asyncio.wait_for(next_frame(), timeout=60)
)
if (
isinstance(outputs, AdditionalOutputs)
and set_additional_outputs
and channel
and channel()
):
set_additional_outputs(outputs)
cast(DataChannel, channel()).send("change")
if frame is None:
if quit_on_none:
await queue.put(None)
break
continue
if len(frame) == 2:
sample_rate, audio_array = frame
layout = "mono"
elif len(frame) == 3:
sample_rate, audio_array, layout = frame
logger.debug(
"received array with shape %s sample rate %s layout %s",
audio_array.shape, # type: ignore
sample_rate,
layout, # type: ignore
)
format = "s16" if audio_array.dtype == "int16" else "fltp" # type: ignore
# Convert to audio frame and resample
# This runs in the same timeout context
frame = av.AudioFrame.from_ndarray( # type: ignore
audio_array, # type: ignore
format=format,
layout=layout, # type: ignore
)
frame.sample_rate = sample_rate
for processed_frame in audio_resampler.resample(frame):
processed_frame.pts = audio_samples
processed_frame.time_base = audio_time_base
audio_samples += processed_frame.samples
await queue.put(processed_frame)
logger.debug("Queue size utils.py: %s", queue.qsize())
except (TimeoutError, asyncio.TimeoutError):
logger.warning(
"Timeout in frame processing cycle after %s seconds - resetting", 60
)
continue
except Exception as e:
import traceback
exec = traceback.format_exc()
logger.debug("traceback %s", exec)
logger.error("Error processing frame: %s", str(e))
continue
def audio_to_bytes(audio: tuple[int, np.ndarray]) -> bytes:
"""
Convert an audio tuple containing sample rate and numpy array data into bytes.
Parameters
----------
audio : tuple[int, np.ndarray]
A tuple containing:
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
Returns
-------
bytes
The audio data encoded as bytes, suitable for transmission or storage
Example
-------
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data)
>>> audio_bytes = audio_to_bytes(audio_tuple)
"""
audio_buffer = io.BytesIO()
segment = AudioSegment(
audio[1].tobytes(),
frame_rate=audio[0],
sample_width=audio[1].dtype.itemsize,
channels=1,
)
segment.export(audio_buffer, format="mp3")
return audio_buffer.getvalue()
def audio_to_file(audio: tuple[int, np.ndarray]) -> str:
"""
Save an audio tuple containing sample rate and numpy array data to a file.
Parameters
----------
audio : tuple[int, np.ndarray]
A tuple containing:
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
Returns
-------
str
The path to the saved audio file
Example
-------
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data)
>>> file_path = audio_to_file(audio_tuple)
>>> print(f"Audio saved to: {file_path}")
"""
bytes_ = audio_to_bytes(audio)
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
f.write(bytes_)
return f.name
def audio_to_float32(audio: tuple[int, np.ndarray]) -> np.ndarray:
"""
Convert an audio tuple containing sample rate (int16) and numpy array data to float32.
Parameters
----------
audio : tuple[int, np.ndarray]
A tuple containing:
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
Returns
-------
np.ndarray
The audio data as a numpy array with dtype float32
Example
-------
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data)
>>> audio_float32 = audio_to_float32(audio_tuple)
"""
return audio[1].astype(np.float32) / 32768.0
def aggregate_bytes_to_16bit(chunks_iterator):
leftover = b"" # Store incomplete bytes between chunks
for chunk in chunks_iterator:
# Combine with any leftover bytes from previous chunk
current_bytes = leftover + chunk
# Calculate complete samples
n_complete_samples = len(current_bytes) // 2 # int16 = 2 bytes
bytes_to_process = n_complete_samples * 2
# Split into complete samples and leftover
to_process = current_bytes[:bytes_to_process]
leftover = current_bytes[bytes_to_process:]
if to_process: # Only yield if we have complete samples
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
yield audio_array
async def async_aggregate_bytes_to_16bit(chunks_iterator):
leftover = b"" # Store incomplete bytes between chunks
async for chunk in chunks_iterator:
# Combine with any leftover bytes from previous chunk
current_bytes = leftover + chunk
# Calculate complete samples
n_complete_samples = len(current_bytes) // 2 # int16 = 2 bytes
bytes_to_process = n_complete_samples * 2
# Split into complete samples and leftover
to_process = current_bytes[:bytes_to_process]
leftover = current_bytes[bytes_to_process:]
if to_process: # Only yield if we have complete samples
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
yield audio_array

File diff suppressed because it is too large Load Diff

44
demo/README.md Normal file
View File

@@ -0,0 +1,44 @@
---
license: mit
tags:
- object-detection
- computer-vision
- yolov10
datasets:
- detection-datasets/coco
sdk: gradio
sdk_version: 5.0.0b1
---
### Model Description
[YOLOv10: Real-Time End-to-End Object Detection](https://arxiv.org/abs/2405.14458v1)
- arXiv: https://arxiv.org/abs/2405.14458v1
- github: https://github.com/THU-MIG/yolov10
### Installation
```
pip install supervision git+https://github.com/THU-MIG/yolov10.git
```
### Yolov10 Inference
```python
from ultralytics import YOLOv10
import supervision as sv
import cv2
IMAGE_PATH = 'dog.jpeg'
model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
model.predict(IMAGE_PATH, show=True)
```
### BibTeX Entry and Citation Info
```
@article{wang2024yolov10,
title={YOLOv10: Real-Time End-to-End Object Detection},
author={Wang, Ao and Chen, Hui and Liu, Lihao and Chen, Kai and Lin, Zijia and Han, Jungong and Ding, Guiguang},
journal={arXiv preprint arXiv:2405.14458},
year={2024}
}
```

0
demo/__init__.py Normal file
View File

105
demo/also_return_text.py Normal file
View File

@@ -0,0 +1,105 @@
import logging
import os
import gradio as gr
import numpy as np
from gradio_webrtc import AdditionalOutputs, WebRTC
from pydub import AudioSegment
from twilio.rest import Client
# Configure the root logger to WARNING to suppress debug messages from other libraries
logging.basicConfig(level=logging.WARNING)
# Create a console handler
console_handler = logging.FileHandler("gradio_webrtc.log")
console_handler.setLevel(logging.DEBUG)
# Create a formatter
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
# Configure the logger for your specific library
logger = logging.getLogger("gradio_webrtc")
logger.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
def generation(num_steps):
for i in range(num_steps):
segment = AudioSegment.from_file(
"/Users/freddy/sources/gradio/demo/scratch/audio-streaming/librispeech.mp3"
)
yield (
(
segment.frame_rate,
np.array(segment.get_array_of_samples()).reshape(1, -1),
),
AdditionalOutputs(
f"Hello, from step {i}!",
"/Users/freddy/sources/gradio/demo/scratch/audio-streaming/librispeech.mp3",
),
)
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Audio Streaming (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="audio",
)
num_steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=10,
step=1,
value=5,
)
button = gr.Button("Generate")
textbox = gr.Textbox(placeholder="Output will appear here.")
audio_file = gr.Audio()
audio.stream(
fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click
)
audio.on_additional_outputs(
fn=lambda t, a: (f"State changed to {t}.", a),
outputs=[textbox, audio_file],
)
if __name__ == "__main__":
demo.launch(
allowed_paths=[
"/Users/freddy/sources/gradio/demo/scratch/audio-streaming/librispeech.mp3"
]
)

367
demo/app.py Normal file
View File

@@ -0,0 +1,367 @@
import os
import gradio as gr
_docs = {
"WebRTC": {
"description": "Stream audio/video with WebRTC",
"members": {
"__init__": {
"rtc_configuration": {
"type": "dict[str, Any] | None",
"default": "None",
"description": "The configration dictionary to pass to the RTCPeerConnection constructor. If None, the default configuration is used.",
},
"height": {
"type": "int | str | None",
"default": "None",
"description": "The height of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
},
"width": {
"type": "int | str | None",
"default": "None",
"description": "The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
},
"label": {
"type": "str | None",
"default": "None",
"description": "the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.",
},
"show_label": {
"type": "bool | None",
"default": "None",
"description": "if True, will display label.",
},
"container": {
"type": "bool",
"default": "True",
"description": "if True, will place the component in a container - providing some extra padding around the border.",
},
"scale": {
"type": "int | None",
"default": "None",
"description": "relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.",
},
"min_width": {
"type": "int",
"default": "160",
"description": "minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.",
},
"interactive": {
"type": "bool | None",
"default": "None",
"description": "if True, will allow users to upload a video; if False, can only be used to display videos. If not provided, this is inferred based on whether the component is used as an input or output.",
},
"visible": {
"type": "bool",
"default": "True",
"description": "if False, component will be hidden.",
},
"elem_id": {
"type": "str | None",
"default": "None",
"description": "an optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.",
},
"elem_classes": {
"type": "list[str] | str | None",
"default": "None",
"description": "an optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.",
},
"render": {
"type": "bool",
"default": "True",
"description": "if False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.",
},
"key": {
"type": "int | str | None",
"default": "None",
"description": "if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.",
},
"mirror_webcam": {
"type": "bool",
"default": "True",
"description": "if True webcam will be mirrored. Default is True.",
},
},
"events": {"tick": {"type": None, "default": None, "description": ""}},
},
"__meta__": {"additional_interfaces": {}, "user_fn_refs": {"WebRTC": []}},
}
}
abs_path = os.path.join(os.path.dirname(__file__), "css.css")
with gr.Blocks(
css_paths=abs_path,
theme=gr.themes.Default(
font_mono=[
gr.themes.GoogleFont("Inconsolata"),
"monospace",
],
),
) as demo:
gr.Markdown(
"""
<h1 style='text-align: center; margin-bottom: 1rem'> Gradio WebRTC ⚡️ </h1>
<div style="display: flex; flex-direction: row; justify-content: center">
<img style="display: block; padding-right: 5px; height: 20px;" alt="Static Badge" src="https://img.shields.io/badge/version%20-%200.0.6%20-%20orange">
<a href="https://github.com/freddyaboulton/gradio-webrtc" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/github-white?logo=github&logoColor=black"></a>
</div>
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.Markdown(
"""
## Installation
```bash
pip install gradio_webrtc
```
## Examples:
1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷
2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥
3. [Text-to-Speech](https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc) 🗣️
4. [Conversational AI](https://huggingface.co/spaces/freddyaboulton/omni-mini-webrtc) 🤖🗣️
## Usage
The WebRTC component supports the following three use cases:
1. [Streaming video from the user webcam to the server and back](#h-streaming-video-from-the-user-webcam-to-the-server-and-back)
2. [Streaming Video from the server to the client](#h-streaming-video-from-the-server-to-the-client)
3. [Streaming Audio from the server to the client](#h-streaming-audio-from-the-server-to-the-client)
4. [Streaming Audio from the client to the server and back (conversational AI)](#h-conversational-ai)
## Streaming Video from the User Webcam to the Server and Back
```python
import gradio as gr
from gradio_webrtc import WebRTC
def detection(image, conf_threshold=0.3):
... your detection code here ...
with gr.Blocks() as demo:
image = WebRTC(label="Stream", mode="send-receive", modality="video")
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
image.stream(
fn=detection,
inputs=[image, conf_threshold],
outputs=[image], time_limit=10
)
if __name__ == "__main__":
demo.launch()
```
* Set the `mode` parameter to `send-receive` and `modality` to "video".
* The `stream` event's `fn` parameter is a function that receives the next frame from the webcam
as a **numpy array** and returns the processed frame also as a **numpy array**.
* Numpy arrays are in (height, width, 3) format where the color channels are in RGB format.
* The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
* The `time_limit` parameter is the maximum time in seconds the video stream will run. If the time limit is reached, the video stream will stop.
## Streaming Video from the server to the client
```python
import gradio as gr
from gradio_webrtc import WebRTC
import cv2
def generation():
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
cap = cv2.VideoCapture(url)
iterating = True
while iterating:
iterating, frame = cap.read()
yield frame
with gr.Blocks() as demo:
output_video = WebRTC(label="Video Stream", mode="receive", modality="video")
button = gr.Button("Start", variant="primary")
output_video.stream(
fn=generation, inputs=None, outputs=[output_video],
trigger=button.click
)
if __name__ == "__main__":
demo.launch()
```
* Set the "mode" parameter to "receive" and "modality" to "video".
* The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
* The only output allowed is the WebRTC component.
* The `trigger` parameter the gradio event that will trigger the webrtc connection. In this case, the button click event.
## Streaming Audio from the Server to the Client
```python
import gradio as gr
from pydub import AudioSegment
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1))
with gr.Blocks() as demo:
audio = WebRTC(label="Stream", mode="receive", modality="audio")
num_steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=10,
step=1,
value=5,
)
button = gr.Button("Generate")
audio.stream(
fn=generation, inputs=[num_steps], outputs=[audio],
trigger=button.click
)
```
* Set the "mode" parameter to "receive" and "modality" to "audio".
* The `stream` event's `fn` parameter is a generator function that yields the next audio segment as a tuple of (frame_rate, audio_samples).
* The numpy array should be of shape (1, num_samples).
* The `outputs` parameter should be a list with the WebRTC component as the only element.
## Conversational AI
```python
import gradio as gr
import numpy as np
from gradio_webrtc import WebRTC, StreamHandler
from queue import Queue
import time
class EchoHandler(StreamHandler):
def __init__(self) -> None:
super().__init__()
self.queue = Queue()
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
self.queue.put(frame)
def emit(self) -> None:
return self.queue.get()
with gr.Blocks() as demo:
with gr.Column():
with gr.Group():
audio = WebRTC(
label="Stream",
rtc_configuration=None,
mode="send-receive",
modality="audio",
)
audio.stream(fn=EchoHandler(), inputs=[audio], outputs=[audio], time_limit=15)
if __name__ == "__main__":
demo.launch()
```
* Instead of passing a function to the `stream` event's `fn` parameter, pass a `StreamHandler` implementation. The `StreamHandler` above simply echoes the audio back to the client.
* The `StreamHandler` class has two methods: `receive` and `emit`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client.
* An audio frame is represented as a tuple of (frame_rate, audio_samples) where `audio_samples` is a numpy array of shape (num_channels, num_samples).
* You can also specify the audio layout ("mono" or "stereo") in the emit method by retuning it as the third element of the tuple. If not specified, the default is "mono".
* The `time_limit` parameter is the maximum time in seconds the conversation will run. If the time limit is reached, the audio stream will stop.
* The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return None.
## Deployment
When deploying in a cloud environment (like Hugging Face Spaces, EC2, etc), you need to set up a TURN server to relay the WebRTC traffic.
The easiest way to do this is to use a service like Twilio.
```python
from twilio.rest import Client
import os
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
with gr.Blocks() as demo:
...
rtc = WebRTC(rtc_configuration=rtc_configuration, ...)
...
```
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.Markdown(
"""
##
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.ParamViewer(value=_docs["WebRTC"]["members"]["__init__"], linkify=[])
demo.load(
None,
js=r"""function() {
const refs = {};
const user_fn_refs = {
WebRTC: [], };
requestAnimationFrame(() => {
Object.entries(user_fn_refs).forEach(([key, refs]) => {
if (refs.length > 0) {
const el = document.querySelector(`.${key}-user-fn`);
if (!el) return;
refs.forEach(ref => {
el.innerHTML = el.innerHTML.replace(
new RegExp("\\b"+ref+"\\b", "g"),
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
);
})
}
})
Object.entries(refs).forEach(([key, refs]) => {
if (refs.length > 0) {
const el = document.querySelector(`.${key}`);
if (!el) return;
refs.forEach(ref => {
el.innerHTML = el.innerHTML.replace(
new RegExp("\\b"+ref+"\\b", "g"),
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
);
})
}
})
})
}
""",
)
demo.launch()

367
demo/app_.py Normal file
View File

@@ -0,0 +1,367 @@
import os
import gradio as gr
_docs = {
"WebRTC": {
"description": "Stream audio/video with WebRTC",
"members": {
"__init__": {
"rtc_configuration": {
"type": "dict[str, Any] | None",
"default": "None",
"description": "The configration dictionary to pass to the RTCPeerConnection constructor. If None, the default configuration is used.",
},
"height": {
"type": "int | str | None",
"default": "None",
"description": "The height of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
},
"width": {
"type": "int | str | None",
"default": "None",
"description": "The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
},
"label": {
"type": "str | None",
"default": "None",
"description": "the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.",
},
"show_label": {
"type": "bool | None",
"default": "None",
"description": "if True, will display label.",
},
"container": {
"type": "bool",
"default": "True",
"description": "if True, will place the component in a container - providing some extra padding around the border.",
},
"scale": {
"type": "int | None",
"default": "None",
"description": "relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.",
},
"min_width": {
"type": "int",
"default": "160",
"description": "minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.",
},
"interactive": {
"type": "bool | None",
"default": "None",
"description": "if True, will allow users to upload a video; if False, can only be used to display videos. If not provided, this is inferred based on whether the component is used as an input or output.",
},
"visible": {
"type": "bool",
"default": "True",
"description": "if False, component will be hidden.",
},
"elem_id": {
"type": "str | None",
"default": "None",
"description": "an optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.",
},
"elem_classes": {
"type": "list[str] | str | None",
"default": "None",
"description": "an optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.",
},
"render": {
"type": "bool",
"default": "True",
"description": "if False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.",
},
"key": {
"type": "int | str | None",
"default": "None",
"description": "if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.",
},
"mirror_webcam": {
"type": "bool",
"default": "True",
"description": "if True webcam will be mirrored. Default is True.",
},
},
"events": {"tick": {"type": None, "default": None, "description": ""}},
},
"__meta__": {"additional_interfaces": {}, "user_fn_refs": {"WebRTC": []}},
}
}
abs_path = os.path.join(os.path.dirname(__file__), "css.css")
with gr.Blocks(
css_paths=abs_path,
theme=gr.themes.Default(
font_mono=[
gr.themes.GoogleFont("Inconsolata"),
"monospace",
],
),
) as demo:
gr.Markdown(
"""
<h1 style='text-align: center; margin-bottom: 1rem'> Gradio WebRTC ⚡️ </h1>
<div style="display: flex; flex-direction: row; justify-content: center">
<img style="display: block; padding-right: 5px; height: 20px;" alt="Static Badge" src="https://img.shields.io/badge/version%20-%200.0.6%20-%20orange">
<a href="https://github.com/freddyaboulton/gradio-webrtc" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/github-white?logo=github&logoColor=black"></a>
</div>
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.Markdown(
"""
## Installation
```bash
pip install gradio_webrtc
```
## Examples:
1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷
2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥
3. [Text-to-Speech](https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc) 🗣️
4. [Conversational AI](https://huggingface.co/spaces/freddyaboulton/omni-mini-webrtc) 🤖🗣️
## Usage
The WebRTC component supports the following three use cases:
1. [Streaming video from the user webcam to the server and back](#h-streaming-video-from-the-user-webcam-to-the-server-and-back)
2. [Streaming Video from the server to the client](#h-streaming-video-from-the-server-to-the-client)
3. [Streaming Audio from the server to the client](#h-streaming-audio-from-the-server-to-the-client)
4. [Streaming Audio from the client to the server and back (conversational AI)](#h-conversational-ai)
## Streaming Video from the User Webcam to the Server and Back
```python
import gradio as gr
from gradio_webrtc import WebRTC
def detection(image, conf_threshold=0.3):
... your detection code here ...
with gr.Blocks() as demo:
image = WebRTC(label="Stream", mode="send-receive", modality="video")
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
image.stream(
fn=detection,
inputs=[image, conf_threshold],
outputs=[image], time_limit=10
)
if __name__ == "__main__":
demo.launch()
```
* Set the `mode` parameter to `send-receive` and `modality` to "video".
* The `stream` event's `fn` parameter is a function that receives the next frame from the webcam
as a **numpy array** and returns the processed frame also as a **numpy array**.
* Numpy arrays are in (height, width, 3) format where the color channels are in RGB format.
* The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
* The `time_limit` parameter is the maximum time in seconds the video stream will run. If the time limit is reached, the video stream will stop.
## Streaming Video from the server to the client
```python
import gradio as gr
from gradio_webrtc import WebRTC
import cv2
def generation():
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
cap = cv2.VideoCapture(url)
iterating = True
while iterating:
iterating, frame = cap.read()
yield frame
with gr.Blocks() as demo:
output_video = WebRTC(label="Video Stream", mode="receive", modality="video")
button = gr.Button("Start", variant="primary")
output_video.stream(
fn=generation, inputs=None, outputs=[output_video],
trigger=button.click
)
if __name__ == "__main__":
demo.launch()
```
* Set the "mode" parameter to "receive" and "modality" to "video".
* The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
* The only output allowed is the WebRTC component.
* The `trigger` parameter the gradio event that will trigger the webrtc connection. In this case, the button click event.
## Streaming Audio from the Server to the Client
```python
import gradio as gr
from pydub import AudioSegment
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1))
with gr.Blocks() as demo:
audio = WebRTC(label="Stream", mode="receive", modality="audio")
num_steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=10,
step=1,
value=5,
)
button = gr.Button("Generate")
audio.stream(
fn=generation, inputs=[num_steps], outputs=[audio],
trigger=button.click
)
```
* Set the "mode" parameter to "receive" and "modality" to "audio".
* The `stream` event's `fn` parameter is a generator function that yields the next audio segment as a tuple of (frame_rate, audio_samples).
* The numpy array should be of shape (1, num_samples).
* The `outputs` parameter should be a list with the WebRTC component as the only element.
## Conversational AI
```python
import gradio as gr
import numpy as np
from gradio_webrtc import WebRTC, StreamHandler
from queue import Queue
import time
class EchoHandler(StreamHandler):
def __init__(self) -> None:
super().__init__()
self.queue = Queue()
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
self.queue.put(frame)
def emit(self) -> None:
return self.queue.get()
with gr.Blocks() as demo:
with gr.Column():
with gr.Group():
audio = WebRTC(
label="Stream",
rtc_configuration=None,
mode="send-receive",
modality="audio",
)
audio.stream(fn=EchoHandler(), inputs=[audio], outputs=[audio], time_limit=15)
if __name__ == "__main__":
demo.launch()
```
* Instead of passing a function to the `stream` event's `fn` parameter, pass a `StreamHandler` implementation. The `StreamHandler` above simply echoes the audio back to the client.
* The `StreamHandler` class has two methods: `receive` and `emit`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client.
* An audio frame is represented as a tuple of (frame_rate, audio_samples) where `audio_samples` is a numpy array of shape (num_channels, num_samples).
* You can also specify the audio layout ("mono" or "stereo") in the emit method by retuning it as the third element of the tuple. If not specified, the default is "mono".
* The `time_limit` parameter is the maximum time in seconds the conversation will run. If the time limit is reached, the audio stream will stop.
* The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return None.
## Deployment
When deploying in a cloud environment (like Hugging Face Spaces, EC2, etc), you need to set up a TURN server to relay the WebRTC traffic.
The easiest way to do this is to use a service like Twilio.
```python
from twilio.rest import Client
import os
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
with gr.Blocks() as demo:
...
rtc = WebRTC(rtc_configuration=rtc_configuration, ...)
...
```
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.Markdown(
"""
##
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.ParamViewer(value=_docs["WebRTC"]["members"]["__init__"], linkify=[])
demo.load(
None,
js=r"""function() {
const refs = {};
const user_fn_refs = {
WebRTC: [], };
requestAnimationFrame(() => {
Object.entries(user_fn_refs).forEach(([key, refs]) => {
if (refs.length > 0) {
const el = document.querySelector(`.${key}-user-fn`);
if (!el) return;
refs.forEach(ref => {
el.innerHTML = el.innerHTML.replace(
new RegExp("\\b"+ref+"\\b", "g"),
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
);
})
}
})
Object.entries(refs).forEach(([key, refs]) => {
if (refs.length > 0) {
const el = document.querySelector(`.${key}`);
if (!el) return;
refs.forEach(ref => {
el.innerHTML = el.innerHTML.replace(
new RegExp("\\b"+ref+"\\b", "g"),
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
);
})
}
})
})
}
""",
)
demo.launch()

73
demo/app_orig.py Normal file
View File

@@ -0,0 +1,73 @@
import os
import cv2
import gradio as gr
from gradio_webrtc import WebRTC
from huggingface_hub import hf_hub_download
from inference import YOLOv10
from twilio.rest import Client
model_file = hf_hub_download(
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
)
model = YOLOv10(model_file)
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
def detection(image, conf_threshold=0.3):
image = cv2.resize(image, (model.input_width, model.input_height))
new_image = model.detect_objects(image, conf_threshold)
return cv2.resize(new_image, (500, 500))
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️)
</h1>
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
</h3>
"""
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
image = WebRTC(label="Stream", rtc_configuration=rtc_configuration)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
image.stream(
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10
)
if __name__ == "__main__":
demo.launch()

71
demo/audio_out.py Normal file
View File

@@ -0,0 +1,71 @@
import os
import gradio as gr
import numpy as np
from gradio_webrtc import WebRTC
from pydub import AudioSegment
from twilio.rest import Client
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file(
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
)
yield (
segment.frame_rate,
np.array(segment.get_array_of_samples()).reshape(1, -1),
)
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Audio Streaming (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="audio",
)
num_steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=10,
step=1,
value=5,
)
button = gr.Button("Generate")
audio.stream(
fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click
)
if __name__ == "__main__":
demo.launch()

64
demo/audio_out_2.py Normal file
View File

@@ -0,0 +1,64 @@
import os
import time
import gradio as gr
import numpy as np
from gradio_webrtc import WebRTC
from pydub import AudioSegment
from twilio.rest import Client
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file(
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
)
yield (
segment.frame_rate,
np.array(segment.get_array_of_samples()).reshape(1, -1),
)
time.sleep(3.5)
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Audio Streaming (Powered by WebRaTC ⚡️)
</h1>
"""
)
with gr.Row():
with gr.Column():
gr.Slider()
with gr.Column():
# audio = gr.Audio(interactive=False)
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="audio",
)
if __name__ == "__main__":
demo.launch()

161
demo/css.css Normal file
View File

@@ -0,0 +1,161 @@
html {
font-family: Inter;
font-size: 16px;
font-weight: 400;
line-height: 1.5;
-webkit-text-size-adjust: 100%;
background: #fff;
color: #323232;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
text-rendering: optimizeLegibility;
}
:root {
--space: 1;
--vspace: calc(var(--space) * 1rem);
--vspace-0: calc(3 * var(--space) * 1rem);
--vspace-1: calc(2 * var(--space) * 1rem);
--vspace-2: calc(1.5 * var(--space) * 1rem);
--vspace-3: calc(0.5 * var(--space) * 1rem);
}
.app {
max-width: 748px !important;
}
.prose p {
margin: var(--vspace) 0;
line-height: var(--vspace * 2);
font-size: 1rem;
}
code {
font-family: "Inconsolata", sans-serif;
font-size: 16px;
}
h1,
h1 code {
font-weight: 400;
line-height: calc(2.5 / var(--space) * var(--vspace));
}
h1 code {
background: none;
border: none;
letter-spacing: 0.05em;
padding-bottom: 5px;
position: relative;
padding: 0;
}
h2 {
margin: var(--vspace-1) 0 var(--vspace-2) 0;
line-height: 1em;
}
h3,
h3 code {
margin: var(--vspace-1) 0 var(--vspace-2) 0;
line-height: 1em;
}
h4,
h5,
h6 {
margin: var(--vspace-3) 0 var(--vspace-3) 0;
line-height: var(--vspace);
}
.bigtitle,
h1,
h1 code {
font-size: calc(8px * 4.5);
word-break: break-word;
}
.title,
h2,
h2 code {
font-size: calc(8px * 3.375);
font-weight: lighter;
word-break: break-word;
border: none;
background: none;
}
.subheading1,
h3,
h3 code {
font-size: calc(8px * 1.8);
font-weight: 600;
border: none;
background: none;
letter-spacing: 0.1em;
text-transform: uppercase;
}
h2 code {
padding: 0;
position: relative;
letter-spacing: 0.05em;
}
blockquote {
font-size: calc(8px * 1.1667);
font-style: italic;
line-height: calc(1.1667 * var(--vspace));
margin: var(--vspace-2) var(--vspace-2);
}
.subheading2,
h4 {
font-size: calc(8px * 1.4292);
text-transform: uppercase;
font-weight: 600;
}
.subheading3,
h5 {
font-size: calc(8px * 1.2917);
line-height: calc(1.2917 * var(--vspace));
font-weight: lighter;
text-transform: uppercase;
letter-spacing: 0.15em;
}
h6 {
font-size: calc(8px * 1.1667);
font-size: 1.1667em;
font-weight: normal;
font-style: italic;
font-family: "le-monde-livre-classic-byol", serif !important;
letter-spacing: 0px !important;
}
#start .md > *:first-child {
margin-top: 0;
}
h2 + h3 {
margin-top: 0;
}
.md hr {
border: none;
border-top: 1px solid var(--block-border-color);
margin: var(--vspace-2) 0 var(--vspace-2) 0;
}
.prose ul {
margin: var(--vspace-2) 0 var(--vspace-1) 0;
}
.gap {
gap: 0;
}
.md-custom {
overflow: hidden;
}

99
demo/docs.py Normal file
View File

@@ -0,0 +1,99 @@
_docs = {
"WebRTC": {
"description": "Stream audio/video with WebRTC",
"members": {
"__init__": {
"rtc_configuration": {
"type": "dict[str, Any] | None",
"default": "None",
"description": "The configration dictionary to pass to the RTCPeerConnection constructor. If None, the default configuration is used.",
},
"height": {
"type": "int | str | None",
"default": "None",
"description": "The height of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
},
"width": {
"type": "int | str | None",
"default": "None",
"description": "The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
},
"label": {
"type": "str | None",
"default": "None",
"description": "the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.",
},
"show_label": {
"type": "bool | None",
"default": "None",
"description": "if True, will display label.",
},
"container": {
"type": "bool",
"default": "True",
"description": "if True, will place the component in a container - providing some extra padding around the border.",
},
"scale": {
"type": "int | None",
"default": "None",
"description": "relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.",
},
"min_width": {
"type": "int",
"default": "160",
"description": "minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.",
},
"interactive": {
"type": "bool | None",
"default": "None",
"description": "if True, will allow users to upload a video; if False, can only be used to display videos. If not provided, this is inferred based on whether the component is used as an input or output.",
},
"visible": {
"type": "bool",
"default": "True",
"description": "if False, component will be hidden.",
},
"elem_id": {
"type": "str | None",
"default": "None",
"description": "an optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.",
},
"elem_classes": {
"type": "list[str] | str | None",
"default": "None",
"description": "an optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.",
},
"render": {
"type": "bool",
"default": "True",
"description": "if False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.",
},
"key": {
"type": "int | str | None",
"default": "None",
"description": "if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.",
},
"mirror_webcam": {
"type": "bool",
"default": "True",
"description": "if True webcam will be mirrored. Default is True.",
},
"postprocess": {
"value": {
"type": "typing.Any",
"description": "Expects a {str} or {pathlib.Path} filepath to a video which is displayed, or a {Tuple[str | pathlib.Path, str | pathlib.Path | None]} where the first element is a filepath to a video and the second element is an optional filepath to a subtitle file.",
}
},
"preprocess": {
"return": {
"type": "str",
"description": "Passes the uploaded video as a `str` filepath or URL whose extension can be modified by `format`.",
},
"value": None,
},
},
"events": {"tick": {"type": None, "default": None, "description": ""}},
},
"__meta__": {"additional_interfaces": {}, "user_fn_refs": {"WebRTC": []}},
}
}

61
demo/echo_conversation.py Normal file
View File

@@ -0,0 +1,61 @@
import logging
from queue import Queue
import gradio as gr
import numpy as np
from gradio_webrtc import StreamHandler, WebRTC
# Configure the root logger to WARNING to suppress debug messages from other libraries
logging.basicConfig(level=logging.WARNING)
# Create a console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
# Create a formatter
formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
# Configure the logger for your specific library
logger = logging.getLogger("gradio_webrtc")
logger.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
class EchoHandler(StreamHandler):
def __init__(self) -> None:
super().__init__()
self.queue = Queue()
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
self.queue.put(frame)
def emit(self) -> None:
return self.queue.get()
def copy(self) -> StreamHandler:
return EchoHandler()
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Conversational AI (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Column():
with gr.Group():
audio = WebRTC(
label="Stream",
rtc_configuration=None,
mode="send-receive",
modality="audio",
)
audio.stream(fn=EchoHandler(), inputs=[audio], outputs=[audio], time_limit=15)
if __name__ == "__main__":
demo.launch()

149
demo/inference.py Normal file
View File

@@ -0,0 +1,149 @@
import time
import cv2
import numpy as np
import onnxruntime
from utils import draw_detections
class YOLOv10:
def __init__(self, path):
# Initialize model
self.initialize_model(path)
def __call__(self, image):
return self.detect_objects(image)
def initialize_model(self, path):
self.session = onnxruntime.InferenceSession(
path, providers=onnxruntime.get_available_providers()
)
# Get model info
self.get_input_details()
self.get_output_details()
def detect_objects(self, image, conf_threshold=0.3):
input_tensor = self.prepare_input(image)
# Perform inference on the image
new_image = self.inference(image, input_tensor, conf_threshold)
return new_image
def prepare_input(self, image):
self.img_height, self.img_width = image.shape[:2]
input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Resize input image
input_img = cv2.resize(input_img, (self.input_width, self.input_height))
# Scale input pixel values to 0 to 1
input_img = input_img / 255.0
input_img = input_img.transpose(2, 0, 1)
input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
return input_tensor
def inference(self, image, input_tensor, conf_threshold=0.3):
start = time.perf_counter()
outputs = self.session.run(
self.output_names, {self.input_names[0]: input_tensor}
)
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
(
boxes,
scores,
class_ids,
) = self.process_output(outputs, conf_threshold)
return self.draw_detections(image, boxes, scores, class_ids)
def process_output(self, output, conf_threshold=0.3):
predictions = np.squeeze(output[0])
# Filter out object confidence scores below threshold
scores = predictions[:, 4]
predictions = predictions[scores > conf_threshold, :]
scores = scores[scores > conf_threshold]
if len(scores) == 0:
return [], [], []
# Get the class with the highest confidence
class_ids = np.argmax(predictions[:, 4:], axis=1)
# Get bounding boxes for each object
boxes = self.extract_boxes(predictions)
return boxes, scores, class_ids
def extract_boxes(self, predictions):
# Extract boxes from predictions
boxes = predictions[:, :4]
# Scale boxes to original image dimensions
boxes = self.rescale_boxes(boxes)
# Convert boxes to xyxy format
# boxes = xywh2xyxy(boxes)
return boxes
def rescale_boxes(self, boxes):
# Rescale boxes to original image dimensions
input_shape = np.array(
[self.input_width, self.input_height, self.input_width, self.input_height]
)
boxes = np.divide(boxes, input_shape, dtype=np.float32)
boxes *= np.array(
[self.img_width, self.img_height, self.img_width, self.img_height]
)
return boxes
def draw_detections(
self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
):
return draw_detections(image, boxes, scores, class_ids, mask_alpha)
def get_input_details(self):
model_inputs = self.session.get_inputs()
self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
self.input_shape = model_inputs[0].shape
self.input_height = self.input_shape[2]
self.input_width = self.input_shape[3]
def get_output_details(self):
model_outputs = self.session.get_outputs()
self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
if __name__ == "__main__":
import tempfile
import requests
from huggingface_hub import hf_hub_download
model_file = hf_hub_download(
repo_id="onnx-community/yolov10s", filename="onnx/model.onnx"
)
yolov8_detector = YOLOv10(model_file)
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
f.write(
requests.get(
"https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
).content
)
f.seek(0)
img = cv2.imread(f.name)
# # Detect Objects
combined_image = yolov8_detector.detect_objects(img)
# Draw detections
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
cv2.imshow("Output", combined_image)
cv2.waitKey(0)

74
demo/old_app.py Normal file
View File

@@ -0,0 +1,74 @@
import os
import cv2
import gradio as gr
from gradio_webrtc import WebRTC
from huggingface_hub import hf_hub_download
from inference import YOLOv10
from twilio.rest import Client
model_file = hf_hub_download(
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
)
model = YOLOv10(model_file)
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
def detection(frame, conf_threshold=0.3):
frame = cv2.flip(frame, 0)
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️)
</h1>
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
</h3>
"""
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
image = WebRTC(label="Stream", rtc_configuration=rtc_configuration)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
number = gr.Number()
image.stream(
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10
)
image.on_additional_outputs(lambda n: n, outputs=[number])
if __name__ == "__main__":
demo.launch()

6
demo/requirements.txt Normal file
View File

@@ -0,0 +1,6 @@
safetensors==0.4.3
opencv-python
twilio
https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/gradio-5.0.0b3-py3-none-any.whl
https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/gradio_webrtc-0.0.1-py3-none-any.whl
onnxruntime-gpu

321
demo/space.py Normal file
View File

@@ -0,0 +1,321 @@
import os
import gradio as gr
_docs = {
"WebRTC": {
"description": "Stream audio/video with WebRTC",
"members": {
"__init__": {
"rtc_configuration": {
"type": "dict[str, Any] | None",
"default": "None",
"description": "The configration dictionary to pass to the RTCPeerConnection constructor. If None, the default configuration is used.",
},
"height": {
"type": "int | str | None",
"default": "None",
"description": "The height of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
},
"width": {
"type": "int | str | None",
"default": "None",
"description": "The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
},
"label": {
"type": "str | None",
"default": "None",
"description": "the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.",
},
"show_label": {
"type": "bool | None",
"default": "None",
"description": "if True, will display label.",
},
"container": {
"type": "bool",
"default": "True",
"description": "if True, will place the component in a container - providing some extra padding around the border.",
},
"scale": {
"type": "int | None",
"default": "None",
"description": "relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.",
},
"min_width": {
"type": "int",
"default": "160",
"description": "minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.",
},
"interactive": {
"type": "bool | None",
"default": "None",
"description": "if True, will allow users to upload a video; if False, can only be used to display videos. If not provided, this is inferred based on whether the component is used as an input or output.",
},
"visible": {
"type": "bool",
"default": "True",
"description": "if False, component will be hidden.",
},
"elem_id": {
"type": "str | None",
"default": "None",
"description": "an optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.",
},
"elem_classes": {
"type": "list[str] | str | None",
"default": "None",
"description": "an optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.",
},
"render": {
"type": "bool",
"default": "True",
"description": "if False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.",
},
"key": {
"type": "int | str | None",
"default": "None",
"description": "if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.",
},
"mirror_webcam": {
"type": "bool",
"default": "True",
"description": "if True webcam will be mirrored. Default is True.",
},
},
"events": {"tick": {"type": None, "default": None, "description": ""}},
},
"__meta__": {"additional_interfaces": {}, "user_fn_refs": {"WebRTC": []}},
}
}
abs_path = os.path.join(os.path.dirname(__file__), "css.css")
with gr.Blocks(
css_paths=abs_path,
theme=gr.themes.Default(
font_mono=[
gr.themes.GoogleFont("Inconsolata"),
"monospace",
],
),
) as demo:
gr.Markdown(
"""
<h1 style='text-align: center; margin-bottom: 1rem'> Gradio WebRTC ⚡️ </h1>
<div style="display: flex; flex-direction: row; justify-content: center">
<img style="display: block; padding-right: 5px; height: 20px;" alt="Static Badge" src="https://img.shields.io/badge/version%20-%200.0.5%20-%20orange">
<a href="https://github.com/freddyaboulton/gradio-webrtc" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/github-white?logo=github&logoColor=black"></a>
</div>
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.Markdown(
"""
## Installation
```bash
pip install gradio_webrtc
```
## Examples:
1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷
2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥
3. [Text-to-Speech](https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc) 🗣️
## Usage
The WebRTC component supports the following three use cases:
1. Streaming video from the user webcam to the server and back
2. Streaming Video from the server to the client
3. Streaming Audio from the server to the client
Streaming Audio from client to the server and back (conversational AI) is not supported yet.
## Streaming Video from the User Webcam to the Server and Back
```python
import gradio as gr
from gradio_webrtc import WebRTC
def detection(image, conf_threshold=0.3):
... your detection code here ...
with gr.Blocks() as demo:
image = WebRTC(label="Stream", mode="send-receive", modality="video")
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
image.stream(
fn=detection,
inputs=[image, conf_threshold],
outputs=[image], time_limit=10
)
if __name__ == "__main__":
demo.launch()
```
* Set the `mode` parameter to `send-receive` and `modality` to "video".
* The `stream` event's `fn` parameter is a function that receives the next frame from the webcam
as a **numpy array** and returns the processed frame also as a **numpy array**.
* Numpy arrays are in (height, width, 3) format where the color channels are in RGB format.
* The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
* The `time_limit` parameter is the maximum time in seconds the video stream will run. If the time limit is reached, the video stream will stop.
## Streaming Video from the User Webcam to the Server and Back
```python
import gradio as gr
from gradio_webrtc import WebRTC
import cv2
def generation():
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
cap = cv2.VideoCapture(url)
iterating = True
while iterating:
iterating, frame = cap.read()
yield frame
with gr.Blocks() as demo:
output_video = WebRTC(label="Video Stream", mode="receive", modality="video")
button = gr.Button("Start", variant="primary")
output_video.stream(
fn=generation, inputs=None, outputs=[output_video],
trigger=button.click
)
if __name__ == "__main__":
demo.launch()
```
* Set the "mode" parameter to "receive" and "modality" to "video".
* The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
* The only output allowed is the WebRTC component.
* The `trigger` parameter the gradio event that will trigger the webrtc connection. In this case, the button click event.
## Streaming Audio from the Server to the Client
```python
import gradio as gr
from pydub import AudioSegment
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1))
with gr.Blocks() as demo:
audio = WebRTC(label="Stream", mode="receive", modality="audio")
num_steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=10,
step=1,
value=5,
)
button = gr.Button("Generate")
audio.stream(
fn=generation, inputs=[num_steps], outputs=[audio],
trigger=button.click
)
```
* Set the "mode" parameter to "receive" and "modality" to "audio".
* The `stream` event's `fn` parameter is a generator function that yields the next audio segment as a tuple of (frame_rate, audio_samples).
* The numpy array should be of shape (1, num_samples).
* The `outputs` parameter should be a list with the WebRTC component as the only element.
## Deployment
When deploying in a cloud environment (like Hugging Face Spaces, EC2, etc), you need to set up a TURN server to relay the WebRTC traffic.
The easiest way to do this is to use a service like Twilio.
```python
from twilio.rest import Client
import os
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
with gr.Blocks() as demo:
...
rtc = WebRTC(rtc_configuration=rtc_configuration, ...)
...
```
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.Markdown(
"""
##
""",
elem_classes=["md-custom"],
header_links=True,
)
gr.ParamViewer(value=_docs["WebRTC"]["members"]["__init__"], linkify=[])
demo.load(
None,
js=r"""function() {
const refs = {};
const user_fn_refs = {
WebRTC: [], };
requestAnimationFrame(() => {
Object.entries(user_fn_refs).forEach(([key, refs]) => {
if (refs.length > 0) {
const el = document.querySelector(`.${key}-user-fn`);
if (!el) return;
refs.forEach(ref => {
el.innerHTML = el.innerHTML.replace(
new RegExp("\\b"+ref+"\\b", "g"),
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
);
})
}
})
Object.entries(refs).forEach(([key, refs]) => {
if (refs.length > 0) {
const el = document.querySelector(`.${key}`);
if (!el) return;
refs.forEach(ref => {
el.innerHTML = el.innerHTML.replace(
new RegExp("\\b"+ref+"\\b", "g"),
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
);
})
}
})
})
}
""",
)
demo.launch()

53
demo/stream_whisper.py Normal file
View File

@@ -0,0 +1,53 @@
import tempfile
import gradio as gr
import numpy as np
from gradio_webrtc import AdditionalOutputs, ReplyOnPause, WebRTC
from openai import OpenAI
from pydub import AudioSegment
from dotenv import load_dotenv
load_dotenv()
client = OpenAI()
def transcribe(audio: tuple[int, np.ndarray], transcript: list[dict]):
print("audio", audio)
segment = AudioSegment(
audio[1].tobytes(),
frame_rate=audio[0],
sample_width=audio[1].dtype.itemsize,
channels=1,
)
transcript.append({"role": "user", "content": gr.Audio((audio[0], audio[1].squeeze()))})
with tempfile.NamedTemporaryFile(suffix=".mp3") as temp_audio:
segment.export(temp_audio.name, format="mp3")
next_chunk = client.audio.transcriptions.create(
model="whisper-1", file=open(temp_audio.name, "rb")
).text
transcript.append({"role": "assistant", "content": next_chunk})
yield AdditionalOutputs(transcript)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
audio = WebRTC(
label="Stream",
mode="send",
modality="audio",
)
with gr.Column():
transcript = gr.Chatbot(label="transcript", type="messages")
audio.stream(ReplyOnPause(transcribe), inputs=[audio, transcript], outputs=[audio],
time_limit=30)
audio.on_additional_outputs(lambda s: s, outputs=transcript)
if __name__ == "__main__":
demo.launch()

237
demo/utils.py Normal file
View File

@@ -0,0 +1,237 @@
import cv2
import numpy as np
class_names = [
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
# Create a list of colors for each class where each color is a tuple of 3 integer values
rng = np.random.default_rng(3)
colors = rng.uniform(0, 255, size=(len(class_names), 3))
def nms(boxes, scores, iou_threshold):
# Sort by score
sorted_indices = np.argsort(scores)[::-1]
keep_boxes = []
while sorted_indices.size > 0:
# Pick the last box
box_id = sorted_indices[0]
keep_boxes.append(box_id)
# Compute IoU of the picked box with the rest
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
# Remove boxes with IoU over the threshold
keep_indices = np.where(ious < iou_threshold)[0]
# print(keep_indices.shape, sorted_indices.shape)
sorted_indices = sorted_indices[keep_indices + 1]
return keep_boxes
def multiclass_nms(boxes, scores, class_ids, iou_threshold):
unique_class_ids = np.unique(class_ids)
keep_boxes = []
for class_id in unique_class_ids:
class_indices = np.where(class_ids == class_id)[0]
class_boxes = boxes[class_indices, :]
class_scores = scores[class_indices]
class_keep_boxes = nms(class_boxes, class_scores, iou_threshold)
keep_boxes.extend(class_indices[class_keep_boxes])
return keep_boxes
def compute_iou(box, boxes):
# Compute xmin, ymin, xmax, ymax for both boxes
xmin = np.maximum(box[0], boxes[:, 0])
ymin = np.maximum(box[1], boxes[:, 1])
xmax = np.minimum(box[2], boxes[:, 2])
ymax = np.minimum(box[3], boxes[:, 3])
# Compute intersection area
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
# Compute union area
box_area = (box[2] - box[0]) * (box[3] - box[1])
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
union_area = box_area + boxes_area - intersection_area
# Compute IoU
iou = intersection_area / union_area
return iou
def xywh2xyxy(x):
# Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
y = np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2
y[..., 1] = x[..., 1] - x[..., 3] / 2
y[..., 2] = x[..., 0] + x[..., 2] / 2
y[..., 3] = x[..., 1] + x[..., 3] / 2
return y
def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
det_img = image.copy()
img_height, img_width = image.shape[:2]
font_size = min([img_height, img_width]) * 0.0006
text_thickness = int(min([img_height, img_width]) * 0.001)
# det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
# Draw bounding boxes and labels of detections
for class_id, box, score in zip(class_ids, boxes, scores):
color = colors[class_id]
draw_box(det_img, box, color)
label = class_names[class_id]
caption = f"{label} {int(score * 100)}%"
draw_text(det_img, caption, box, color, font_size, text_thickness)
return det_img
def draw_box(
image: np.ndarray,
box: np.ndarray,
color: tuple[int, int, int] = (0, 0, 255),
thickness: int = 2,
) -> np.ndarray:
x1, y1, x2, y2 = box.astype(int)
return cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
def draw_text(
image: np.ndarray,
text: str,
box: np.ndarray,
color: tuple[int, int, int] = (0, 0, 255),
font_size: float = 0.001,
text_thickness: int = 2,
) -> np.ndarray:
x1, y1, x2, y2 = box.astype(int)
(tw, th), _ = cv2.getTextSize(
text=text,
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=font_size,
thickness=text_thickness,
)
th = int(th * 1.2)
cv2.rectangle(image, (x1, y1), (x1 + tw, y1 - th), color, -1)
return cv2.putText(
image,
text,
(x1, y1),
cv2.FONT_HERSHEY_SIMPLEX,
font_size,
(255, 255, 255),
text_thickness,
cv2.LINE_AA,
)
def draw_masks(
image: np.ndarray, boxes: np.ndarray, classes: np.ndarray, mask_alpha: float = 0.3
) -> np.ndarray:
mask_img = image.copy()
# Draw bounding boxes and labels of detections
for box, class_id in zip(boxes, classes):
color = colors[class_id]
x1, y1, x2, y2 = box.astype(int)
# Draw fill rectangle in mask image
cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1)
return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0)

65
demo/video_out.py Normal file
View File

@@ -0,0 +1,65 @@
import os
import cv2
import gradio as gr
from gradio_webrtc import WebRTC
from twilio.rest import Client
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
def generation(input_video):
cap = cv2.VideoCapture(input_video)
iterating = True
while iterating:
iterating, frame = cap.read()
# flip frame vertically
frame = cv2.flip(frame, 0)
display_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
yield display_frame
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Video Streaming (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Row():
with gr.Column():
input_video = gr.Video(sources="upload")
with gr.Column():
output_video = WebRTC(
label="Video Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="video",
)
output_video.stream(
fn=generation,
inputs=[input_video],
outputs=[output_video],
trigger=input_video.upload,
)
if __name__ == "__main__":
demo.launch()

54
demo/video_out_stream.py Normal file
View File

@@ -0,0 +1,54 @@
import os
import cv2
import gradio as gr
from gradio_webrtc import WebRTC
from twilio.rest import Client
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
def generation():
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
cap = cv2.VideoCapture(url)
iterating = True
while iterating:
iterating, frame = cap.read()
yield frame
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Video Streaming (Powered by WebRTC ⚡️)
</h1>
"""
)
output_video = WebRTC(
label="Video Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="video",
)
button = gr.Button("Start", variant="primary")
output_video.stream(
fn=generation, inputs=None, outputs=[output_video], trigger=button.click
)
if __name__ == "__main__":
demo.launch()

100
demo/video_send_output.py Normal file
View File

@@ -0,0 +1,100 @@
import logging
import os
import random
import cv2
import gradio as gr
from gradio_webrtc import AdditionalOutputs, WebRTC
from huggingface_hub import hf_hub_download
from inference import YOLOv10
from twilio.rest import Client
# Configure the root logger to WARNING to suppress debug messages from other libraries
logging.basicConfig(level=logging.WARNING)
# Create a console handler
console_handler = logging.FileHandler("gradio_webrtc.log")
console_handler.setLevel(logging.DEBUG)
# Create a formatter
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
# Configure the logger for your specific library
logger = logging.getLogger("gradio_webrtc")
logger.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
model_file = hf_hub_download(
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
)
model = YOLOv10(model_file)
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
def detection(frame, conf_threshold=0.3):
print("frame.shape", frame.shape)
frame = cv2.flip(frame, 0)
return AdditionalOutputs(1)
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️)
</h1>
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
</h3>
"""
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
image = WebRTC(
label="Stream", rtc_configuration=rtc_configuration,
mode="send",
track_constraints={"width": {"exact": 800},
"height": {"exact": 600},
"aspectRatio": {"exact": 1.33333}
},
rtp_params={"degradationPreference": "maintain-resolution"}
)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
number = gr.Number()
image.stream(
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10
)
image.on_additional_outputs(lambda n: n, outputs=number)
demo.launch()

View File

@@ -0,0 +1,160 @@
## Track Constraints
You can specify the `track_constraints` parameter to control how the data is streamed to the server. The full documentation on track constraints is [here](https://developer.mozilla.org/en-US/docs/Web/API/MediaTrackConstraints#constraints).
For example, you can control the size of the frames captured from the webcam like so:
```python
track_constraints = {
"width": {"exact": 500},
"height": {"exact": 500},
"frameRate": {"ideal": 30},
}
webrtc = WebRTC(track_constraints=track_constraints,
modality="video",
mode="send-receive")
```
!!! warning
WebRTC may not enforce your constaints. For example, it may rescale your video
(while keeping the same resolution) in order to maintain the desired (or reach a better) frame rate. If you
really want to enforce height, width and resolution constraints, use the `rtp_params` parameter as set `"degradationPreference": "maintain-resolution"`.
```python
image = WebRTC(
label="Stream",
mode="send",
track_constraints=track_constraints,
rtp_params={"degradationPreference": "maintain-resolution"}
)
```
## The RTC Configuration
You can configure how the connection is created on the client by passing an `rtc_configuration` parameter to the `WebRTC` component constructor.
See the list of available arguments [here](https://developer.mozilla.org/en-US/docs/Web/API/RTCPeerConnection/RTCPeerConnection#configuration).
When deploying on a remote server, an `rtc_configuration` parameter must be passed in. See [Deployment](/deployment).
## Reply on Pause Voice-Activity-Detection
The `ReplyOnPause` class runs a Voice Activity Detection (VAD) algorithm to determine when a user has stopped speaking.
1. First, the algorithm determines when the user has started speaking.
2. Then it groups the audio into chunks.
3. On each chunk, we determine the length of human speech in the chunk.
4. If the length of human speech is below a threshold, a pause is detected.
The following parameters control this argument:
```python
from gradio_webrtc import AlgoOptions, ReplyOnPause, WebRTC
options = AlgoOptions(audio_chunk_duration=0.6, # (1)
started_talking_threshold=0.2, # (2)
speech_threshold=0.1, # (3)
)
with gr.Blocks as demo:
audio = WebRTC(...)
audio.stream(ReplyOnPause(..., algo_options=algo_options)
)
demo.launch()
```
1. This is the length (in seconds) of audio chunks.
2. If the chunk has more than 0.2 seconds of speech, the user started talking.
3. If, after the user started speaking, there is a chunk with less than 0.1 seconds of speech, the user stopped speaking.
## Stream Handler Input Audio
You can configure the sampling rate of the audio passed to the `ReplyOnPause` or `StreamHandler` instance with the `input_sampling_rate` parameter. The current default is `48000`
```python
from gradio_webrtc import ReplyOnPause, WebRTC
with gr.Blocks as demo:
audio = WebRTC(...)
audio.stream(ReplyOnPause(..., input_sampling_rate=24000)
)
demo.launch()
```
## Stream Handler Output Audio
You can configure the output audio chunk size of `ReplyOnPause` (and any `StreamHandler`)
with the `output_sample_rate` and `output_frame_size` parameters.
The following code (which uses the default values of these parameters), states that each output chunk will be a frame of 960 samples at a frame rate of `24,000` hz. So it will correspond to `0.04` seconds.
```python
from gradio_webrtc import ReplyOnPause, WebRTC
with gr.Blocks as demo:
audio = WebRTC(...)
audio.stream(ReplyOnPause(..., output_sample_rate=24000, output_frame_size=960)
)
demo.launch()
```
!!! tip
In general it is best to leave these settings untouched. In some cases,
lowering the output_frame_size can yield smoother audio playback.
## Audio Icon
You can display an icon of your choice instead of the default wave animation for audio streaming.
Pass any local path or url to an image (svg, png, jpeg) to the components `icon` parameter. This will display the icon as a circular button. When audio is sent or recevied (depending on the `mode` parameter) a pulse animation will emanate from the button.
You can control the button color and pulse color with `icon_button_color` and `pulse_color` parameters. They can take any valid css color.
=== "Code"
``` python
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="audio",
icon="phone-solid.svg",
)
```
<img src="https://github.com/user-attachments/assets/fd2e70a3-1698-4805-a8cb-9b7b3bcf2198">
=== "Code Custom colors"
``` python
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="audio",
icon="phone-solid.svg",
icon_button_color="black",
pulse_color="black",
)
```
<img src="https://github.com/user-attachments/assets/39e9bb0b-53fb-448e-be44-d37f6785b4b6">
## Changing the Button Text
You can supply a `button_labels` dictionary to change the text displayed in the `Start`, `Stop` and `Waiting` buttons that are displayed in the UI.
The keys must be `"start"`, `"stop"`, and `"waiting"`.
``` python
webrtc = WebRTC(
label="Video Chat",
modality="audio-video",
mode="send-receive",
button_labels={"start": "Start Talking to Gemini"}
)
```
<img src="https://github.com/user-attachments/assets/04be0b95-189c-4b4b-b8cc-1eb598529dd3" />

1
docs/bolt.svg Normal file
View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e8eaed"><path d="m422-232 207-248H469l29-227-185 267h139l-30 208ZM320-80l40-280H160l360-520h80l-40 320h240L400-80h-80Zm151-390Z"/></svg>

After

Width:  |  Height:  |  Size: 235 B

172
docs/cookbook.md Normal file
View File

@@ -0,0 +1,172 @@
<div class="grid cards" markdown>
- :speaking_head:{ .lg .middle }:eyes:{ .lg .middle } __Gemini Audio Video Chat__
---
Stream BOTH your webcam video and audio feeds to Google Gemini. You can also upload images to augment your conversation!
<video width=98% src="https://github.com/user-attachments/assets/9636dc97-4fee-46bb-abb8-b92e69c08c71" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/gemini-audio-video-chat)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/gemini-audio-video-chat/blob/main/app.py)
- :speaking_head:{ .lg .middle } __Google Gemini Real Time Voice API__
---
Talk to Gemini in real time using Google's voice API.
<video width=98% src="https://github.com/user-attachments/assets/da8c8a2a-5d99-4ac7-8927-0f7812e4146f" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/gemini-voice)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/gemini-voice/blob/main/app.py)
- :speaking_head:{ .lg .middle } __OpenAI Real Time Voice API__
---
Talk to ChatGPT in real time using OpenAI's voice API.
<video width=98% src="https://github.com/user-attachments/assets/41a63376-43ec-496a-9b31-4f067d3903d6" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/openai-realtime-voice)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/openai-realtime-voice/blob/main/app.py)
- :speaking_head:{ .lg .middle } __Hello Llama: Stop Word Detection__
---
A code editor built with Llama 3.3 70b that is triggered by the phrase "Hello Llama".
Build a Siri-like coding assistant in 100 lines of code!
<video width=98% src="https://github.com/user-attachments/assets/3e10cb15-ff1b-4b17-b141-ff0ad852e613" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](hhttps://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor/blob/main/app.py)
- :robot:{ .lg .middle } __Llama Code Editor__
---
Create and edit HTML pages with just your voice! Powered by SambaNova systems.
<video width=98% src="https://github.com/user-attachments/assets/a09647f1-33e1-4154-a5a3-ffefda8a736a" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/llama-code-editor)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/llama-code-editor/blob/main/app.py)
- :speaking_head:{ .lg .middle } __Audio Input/Output with mini-omni2__
---
Build a GPT-4o like experience with mini-omni2, an audio-native LLM.
<video width=98% src="https://github.com/user-attachments/assets/58c06523-fc38-4f5f-a4ba-a02a28e7fa9e" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/mini-omni2-webrtc)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/mini-omni2-webrtc/blob/main/app.py)
- :speaking_head:{ .lg .middle } __Talk to Claude__
---
Use the Anthropic and Play.Ht APIs to have an audio conversation with Claude.
<video width=98% src="https://github.com/user-attachments/assets/650bc492-798e-4995-8cef-159e1cfc2185" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/talk-to-claude)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-claude/blob/main/app.py)
- :speaking_head:{ .lg .middle } __Kyutai Moshi__
---
Kyutai's moshi is a novel speech-to-speech model for modeling human conversations.
<video width=98% src="https://github.com/user-attachments/assets/becc7a13-9e89-4a19-9df2-5fb1467a0137" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/talk-to-moshi)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-moshi/blob/main/app.py)
- :speaking_head:{ .lg .middle } __Talk to Ultravox__
---
Talk to Fixie.AI's audio-native Ultravox LLM with the transformers library.
<video width=98% src="https://github.com/user-attachments/assets/e6e62482-518c-4021-9047-9da14cd82be1" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/talk-to-ultravox)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-ultravox/blob/main/app.py)
- :speaking_head:{ .lg .middle } __Talk to Llama 3.2 3b__
---
Use the Lepton API to make Llama 3.2 talk back to you!
<video width=98% src="https://github.com/user-attachments/assets/3ee37a6b-0892-45f5-b801-73188fdfad9a" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/llama-3.2-3b-voice-webrtc)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/llama-3.2-3b-voice-webrtc/blob/main/app.py)
- :robot:{ .lg .middle } __Talk to Qwen2-Audio__
---
Qwen2-Audio is a SOTA audio-to-text LLM developed by Alibaba.
<video width=98% src="https://github.com/user-attachments/assets/c821ad86-44cc-4d0c-8dc4-8c02ad1e5dc8" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/talk-to-qwen-webrtc)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-qwen-webrtc/blob/main/app.py)
- :camera:{ .lg .middle } __Yolov10 Object Detection__
---
Run the Yolov10 model on a user webcam stream in real time!
<video width=98% src="https://github.com/user-attachments/assets/c90d8c9d-d2d5-462e-9e9b-af969f2ea73c" controls style="text-align: center"></video>
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n/blob/main/app.py)
- :camera:{ .lg .middle } __Video Object Detection with RT-DETR__
---
Upload a video and stream out frames with detected objects (powered by RT-DETR) model.
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc/blob/main/app.py)
- :speaker:{ .lg .middle } __Text-to-Speech with Parler__
---
Stream out audio generated by Parler TTS!
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc)
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc/blob/main/app.py)
</div>

165
docs/deployment.md Normal file
View File

@@ -0,0 +1,165 @@
When deploying in a cloud environment (like Hugging Face Spaces, EC2, etc), you need to set up a TURN server to relay the WebRTC traffic.
## Community Server
Hugging Face graciously provides a TURN server for the community.
In order to use it, you need to first create a Hugging Face account by going to the [huggingface.co](https://huggingface.co/).
Then navigate to this [space](https://huggingface.co/spaces/freddyaboulton/turn-server-login) and follow the instructions on the page. You just have to click the "Log in" button and then the "Sign Up" button.
![turn_login](https://github.com/user-attachments/assets/d077c3a3-7059-45d6-8e50-eb3d8a4aa43f)
Then you can use the `get_hf_turn_credentials` helper to get your credentials:
```python
from gradio_webrtc import get_hf_turn_credentials, WebRTC
# Pass a valid access token for your Hugging Face account
# or set the HF_TOKEN environment variable
credentials = get_hf_turn_credentials(token=None)
with gr.Blcocks() as demo:
webrtc = WebRTC(rtc_configuration=credentials)
...
demo.launch()
```
!!! warning
This is a shared resource so we make no latency/availability guarantees.
For more robust options, see the Twilio and self-hosting options below.
## Twilio API
The easiest way to do this is to use a service like Twilio.
Create a **free** [account](https://login.twilio.com/u/signup) and the install the `twilio` package with pip (`pip install twilio`). You can then connect from the WebRTC component like so:
```python
from twilio.rest import Client
import os
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
with gr.Blocks() as demo:
...
rtc = WebRTC(rtc_configuration=rtc_configuration, ...)
...
```
!!! tip "Automatic Login"
You can log in automatically with the `get_twilio_turn_credentials` helper
```python
from gradio_webrtc import get_twilio_turn_credentials
# Will automatically read the TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN
# env variables but you can also pass in the tokens as parameters
rtc_configuration = get_twilio_turn_credentials()
```
## Self Hosting
We have developed a script that can automatically deploy a TURN server to Amazon Web Services (AWS). You can follow the instructions [here](https://github.com/freddyaboulton/turn-server-deploy) or this guide.
### Prerequisites
Clone the following [repository](https://github.com/freddyaboulton/turn-server-deploy) and install the `aws` cli if you have not done so already (`pip install awscli`).
Log into your AWS account and create an IAM user with the following permissions:
- [AWSCloudFormationFullAccess](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/policies/details/arn%3Aaws%3Aiam%3A%3Aaws%3Apolicy%2FAWSCloudFormationFullAccess)
- [AmazonEC2FullAccess](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/policies/details/arn%3Aaws%3Aiam%3A%3Aaws%3Apolicy%2FAmazonEC2FullAccess)
Create a key pair for this user and write down the "access key" and "secret access key". Then log into the aws cli with these credentials (`aws configure`).
Finally, create an ec2 keypair (replace `your-key-name` with the name you want to give it).
```
aws ec2 create-key-pair --key-name your-key-name --query 'KeyMaterial' --output text > your-key-name.pem
```
### Running the script
Open the `parameters.json` file and fill in the correct values for all the parameters:
- `KeyName`: The key file we just created, e.g. `your-key-name` (omit `.pem`).
- `TurnUserName`: The username needed to connect to the server.
- `TurnPassword`: The password needed to connect to the server.
- `InstanceType`: One of the following values `t3.micro`, `t3.small`, `t3.medium`, `c4.large`, `c5.large`.
Then run the deployment script:
```bash
aws cloudformation create-stack \
--stack-name turn-server \
--template-body file://deployment.yml \
--parameters file://parameters.json \
--capabilities CAPABILITY_IAM
```
You can then wait for the stack to come up with:
```bash
aws cloudformation wait stack-create-complete \
--stack-name turn-server
```
Next, grab your EC2 server's public ip with:
```
aws cloudformation describe-stacks \
--stack-name turn-server \
--query 'Stacks[0].Outputs' > server-info.json
```
The `server-info.json` file will have the server's public IP and public DNS:
```json
[
{
"OutputKey": "PublicIP",
"OutputValue": "35.173.254.80",
"Description": "Public IP address of the TURN server"
},
{
"OutputKey": "PublicDNS",
"OutputValue": "ec2-35-173-254-80.compute-1.amazonaws.com",
"Description": "Public DNS name of the TURN server"
}
]
```
Finally, you can connect to your EC2 server from the gradio WebRTC component via the `rtc_configuration` argument:
```python
import gradio as gr
from gradio_webrtc import WebRTC
rtc_configuration = {
"iceServers": [
{
"urls": "turn:35.173.254.80:80",
"username": "<my-username>",
"credential": "<my-password>"
},
]
}
with gr.Blocks() as demo:
webrtc = WebRTC(rtc_configuration=rtc_configuration)
```

67
docs/faq.md Normal file
View File

@@ -0,0 +1,67 @@
## Demo does not work when deploying to the cloud
Make sure you are using a TURN server. See [deployment](/deployment).
## Recorded input audio sounds muffled during output audio playback
By default, the microphone is [configured](https://github.com/freddyaboulton/gradio-webrtc/blob/903f1f70bd586f638ad3b5a3940c7a8ec70ad1f5/backend/gradio_webrtc/webrtc.py#L575) to do echoCancellation.
This is what's causing the recorded audio to sound muffled when the streamed audio starts playing.
You can disable this via the `track_constraints` (see [advanced configuration](./advanced-configuration])) with the following code:
```python
audio = WebRTC(
label="Stream",
track_constraints={
"echoCancellation": False,
"noiseSuppression": {"exact": True},
"autoGainControl": {"exact": True},
"sampleRate": {"ideal": 24000},
"sampleSize": {"ideal": 16},
"channelCount": {"exact": 1},
},
rtc_configuration=None,
mode="send-receive",
modality="audio",
)
```
## How to raise errors in the UI
You can raise `WebRTCError` in order for an error message to show up in the user's screen. This is similar to how `gr.Error` works.
Here is a simple example:
```python
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file(
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
)
yield (
segment.frame_rate,
np.array(segment.get_array_of_samples()).reshape(1, -1),
)
time.sleep(3.5)
raise WebRTCError("This is a test error")
with gr.Blocks() as demo:
audio = WebRTC(
label="Stream",
mode="receive",
modality="audio",
)
num_steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=10,
step=1,
value=5,
)
button = gr.Button("Generate")
audio.stream(
fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click
)
demo.launch()
```

30
docs/index.md Normal file
View File

@@ -0,0 +1,30 @@
<h1 style='text-align: center; margin-bottom: 1rem; color: white;'> Gradio WebRTC ⚡️ </h1>
<div style="display: flex; flex-direction: row; justify-content: center">
<img style="display: block; padding-right: 5px; height: 20px;" alt="Static Badge" src="https://img.shields.io/pypi/v/gradio_webrtc">
<a href="https://github.com/freddyaboulton/gradio-webrtc" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/github-white?logo=github&logoColor=black"></a>
</div>
<h3 style='text-align: center'>
Stream video and audio in real time with Gradio using WebRTC.
</h3>
## Installation
```bash
pip install gradio_webrtc
```
to use built-in pause detection (see [ReplyOnPause](/user-guide/#reply-on-pause)), install the `vad` extra:
```bash
pip install gradio_webrtc[vad]
```
For stop word detection (see [ReplyOnStopWords](/user-guide/#reply-on-stopwords)), install the `stopword` extra:
```bash
pip install gradio_webrtc[stopword]
```
## Examples
See the [cookbook](/cookbook)

505
docs/user-guide.md Normal file
View File

@@ -0,0 +1,505 @@
# User Guide
To get started with WebRTC streams, all that's needed is to import the `WebRTC` component from this package and implement its `stream` event.
This page will show how to do so with simple code examples.
For complete implementations of common tasks, see the [cookbook](/cookbook).
## Audio Streaming
### Reply on Pause
Typically, you want to run an AI model that generates audio when the user has stopped speaking. This can be done by wrapping a python generator with the `ReplyOnPause` class
and passing it to the `stream` event of the `WebRTC` component.
=== "Code"
``` py title="ReplyonPause"
import gradio as gr
from gradio_webrtc import WebRTC, ReplyOnPause
def response(audio: tuple[int, np.ndarray]): # (1)
"""This function must yield audio frames"""
...
for numpy_array in generated_audio:
yield (sampling_rate, numpy_array, "mono") # (2)
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Chat (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Column():
with gr.Group():
audio = WebRTC(
mode="send-receive", # (3)
modality="audio",
)
audio.stream(fn=ReplyOnPause(response),
inputs=[audio], outputs=[audio], # (4)
time_limit=60) # (5)
demo.launch()
```
1. The python generator will receive the **entire** audio up until the user stopped. It will be a tuple of the form (sampling_rate, numpy array of audio). The array will have a shape of (1, num_samples). You can also pass in additional input components.
2. The generator must yield audio chunks as a tuple of (sampling_rate, numpy audio array). Each numpy audio array must have a shape of (1, num_samples).
3. The `mode` and `modality` arguments must be set to `"send-receive"` and `"audio"`.
4. The `WebRTC` component must be the first input and output component.
5. Set a `time_limit` to control how long a conversation will last. If the `concurrency_count` is 1 (default), only one conversation will be handled at a time.
=== "Notes"
1. The python generator will receive the **entire** audio up until the user stopped. It will be a tuple of the form (sampling_rate, numpy array of audio). The array will have a shape of (1, num_samples). You can also pass in additional input components.
2. The generator must yield audio chunks as a tuple of (sampling_rate, numpy audio arrays). Each numpy audio array must have a shape of (1, num_samples).
3. The `mode` and `modality` arguments must be set to `"send-receive"` and `"audio"`.
4. The `WebRTC` component must be the first input and output component.
5. Set a `time_limit` to control how long a conversation will last. If the `concurrency_count` is 1 (default), only one conversation will be handled at a time.
### Reply On Stopwords
You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.
The API is similar to `ReplyOnPause` with the addition of a `stop_words` parameter.
=== "Code"
``` py title="ReplyonPause"
import gradio as gr
from gradio_webrtc import WebRTC, ReplyOnPause
def response(audio: tuple[int, np.ndarray]):
"""This function must yield audio frames"""
...
for numpy_array in generated_audio:
yield (sampling_rate, numpy_array, "mono")
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Chat (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Column():
with gr.Group():
audio = WebRTC(
mode="send",
modality="audio",
)
webrtc.stream(ReplyOnStopWords(generate,
input_sample_rate=16000,
stop_words=["computer"]), # (1)
inputs=[webrtc, history, code],
outputs=[webrtc], time_limit=90,
concurrency_limit=10)
demo.launch()
```
1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris".
=== "Notes"
1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris".
### Stream Handler
`ReplyOnPause` is an implementation of a `StreamHandler`. The `StreamHandler` is a low-level
abstraction that gives you arbitrary control over how the input audio stream and output audio stream are created. The following example echos back the user audio.
=== "Code"
``` py title="Stream Handler"
import gradio as gr
from gradio_webrtc import WebRTC, StreamHandler
from queue import Queue
class EchoHandler(StreamHandler):
def __init__(self) -> None:
super().__init__()
self.queue = Queue()
def receive(self, frame: tuple[int, np.ndarray]) -> None: # (1)
self.queue.put(frame)
def emit(self) -> None: # (2)
return self.queue.get()
def copy(self) -> StreamHandler:
return EchoHandler()
with gr.Blocks() as demo:
with gr.Column():
with gr.Group():
audio = WebRTC(
mode="send-receive",
modality="audio",
)
audio.stream(fn=EchoHandler(),
inputs=[audio], outputs=[audio],
time_limit=15)
demo.launch()
```
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
=== "Notes"
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
### Async Stream Handlers
It is also possible to create asynchronous stream handlers. This is very convenient for accessing async APIs from major LLM developers, like Google and OpenAI. The main difference is that `receive` and `emit` are now defined with `async def`.
Here is a complete example of using `AsyncStreamHandler` for using the Google Gemini real time API:
=== "Code"
``` py title="AsyncStreamHandler"
import asyncio
import base64
import logging
import os
import gradio as gr
import numpy as np
from google import genai
from gradio_webrtc import (
AsyncStreamHandler,
WebRTC,
async_aggregate_bytes_to_16bit,
get_twilio_turn_credentials,
)
class GeminiHandler(AsyncStreamHandler):
def __init__(
self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=16000,
)
self.client: genai.Client | None = None
self.input_queue = asyncio.Queue()
self.output_queue = asyncio.Queue()
self.quit = asyncio.Event()
def copy(self) -> "GeminiHandler":
return GeminiHandler(
expected_layout=self.expected_layout,
output_sample_rate=self.output_sample_rate,
output_frame_size=self.output_frame_size,
)
async def stream(self):
while not self.quit.is_set():
audio = await self.input_queue.get()
yield audio
async def connect(self, api_key: str):
client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
config = {"response_modalities": ["AUDIO"]}
async with client.aio.live.connect(
model="gemini-2.0-flash-exp", config=config
) as session:
async for audio in session.start_stream(
stream=self.stream(), mime_type="audio/pcm"
):
if audio.data:
yield audio.data
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
_, array = frame
array = array.squeeze()
audio_message = base64.b64encode(array.tobytes()).decode("UTF-8")
self.input_queue.put_nowait(audio_message)
async def generator(self):
async for audio_response in async_aggregate_bytes_to_16bit(
self.connect(api_key=self.latest_args[1])
):
self.output_queue.put_nowait(audio_response)
async def emit(self):
if not self.args_set.is_set():
await self.wait_for_args()
asyncio.create_task(self.generator())
array = await self.output_queue.get()
return (self.output_sample_rate, array)
def shutdown(self) -> None:
self.quit.set()
with gr.Blocks() as demo:
gr.HTML(
"""
<div style='text-align: center'>
<h1>Gen AI SDK Voice Chat</h1>
<p>Speak with Gemini using real-time audio streaming</p>
<p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
</div>
"""
)
with gr.Row() as api_key_row:
api_key = gr.Textbox(
label="API Key",
placeholder="Enter your API Key",
value=os.getenv("GOOGLE_API_KEY", ""),
type="password",
)
with gr.Row(visible=False) as row:
webrtc = WebRTC(
label="Audio",
modality="audio",
mode="send-receive",
rtc_configuration=get_twilio_turn_credentials(),
pulse_color="rgb(35, 157, 225)",
icon_button_color="rgb(35, 157, 225)",
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
)
webrtc.stream(
GeminiHandler(),
inputs=[webrtc, api_key],
outputs=[webrtc],
time_limit=90,
concurrency_limit=2,
)
api_key.submit(
lambda: (gr.update(visible=False), gr.update(visible=True)),
None,
[api_key_row, row],
)
demo.launch()
```
### Accessing Other Component Values from a StreamHandler
In the gemini demo above, you'll notice that we have the user input their google API key. This is stored in a `gr.Textbox` parameter.
We can access the value of this component via the `latest_args` prop of the `StreamHandler`. The `latest_args` is a list storing the values of each component in the WebRTC `stream` event `inputs` parameter. The value of the `WebRTC` component is the 0th index and it's always the dummy string `__webrtc_value__`.
In order to fetch the latest value from the user however, we `await self.wait_for_args()`. In a synchronous `StreamHandler`, we would call `self.wait_for_args_sync()`.
### Server-To-Client Only
To stream only from the server to the client, implement a python generator and pass it to the component's `stream` event. The stream event must also specify a `trigger` corresponding to a UI interaction that starts the stream. In this case, it's a button click.
=== "Code"
``` py title="Server-To-CLient"
import gradio as gr
from gradio_webrtc import WebRTC
from pydub import AudioSegment
def generation(num_steps):
for _ in range(num_steps):
segment = AudioSegment.from_file("audio_file.wav")
array = np.array(segment.get_array_of_samples()).reshape(1, -1)
yield (segment.frame_rate, array)
with gr.Blocks() as demo:
audio = WebRTC(label="Stream", mode="receive", # (1)
modality="audio")
num_steps = gr.Slider(label="Number of Steps", minimum=1,
maximum=10, step=1, value=5)
button = gr.Button("Generate")
audio.stream(
fn=generation, inputs=[num_steps], outputs=[audio],
trigger=button.click # (2)
)
```
1. Set `mode="receive"` to only receive audio from the server.
2. The `stream` event must take a `trigger` that corresponds to the gradio event that starts the stream. In this case, it's the button click.
=== "Notes"
1. Set `mode="receive"` to only receive audio from the server.
2. The `stream` event must take a `trigger` that corresponds to the gradio event that starts the stream. In this case, it's the button click.
## Video Streaming
### Input/Output Streaming
Set up a video Input/Output stream to continuosly receive webcam frames from the user and run an arbitrary python function to return a modified frame.
=== "Code"
``` py title="Input/Output Streaming"
import gradio as gr
from gradio_webrtc import WebRTC
def detection(image, conf_threshold=0.3): # (1)
... your detection code here ...
return modified_frame # (2)
with gr.Blocks() as demo:
image = WebRTC(label="Stream", mode="send-receive", modality="video") # (3)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
image.stream(
fn=detection,
inputs=[image, conf_threshold], # (4)
outputs=[image], time_limit=10
)
if __name__ == "__main__":
demo.launch()
```
1. The webcam frame will be represented as a numpy array of shape (height, width, RGB).
2. The function must return a numpy array. It can take arbitrary values from other components.
3. Set the `modality="video"` and `mode="send-receive"`
4. The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
=== "Notes"
1. The webcam frame will be represented as a numpy array of shape (height, width, RGB).
2. The function must return a numpy array. It can take arbitrary values from other components.
3. Set the `modality="video"` and `mode="send-receive"`
4. The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
### Server-to-Client Only
Set up a server-to-client stream to stream video from an arbitrary user interaction.
=== "Code"
``` py title="Server-To-Client"
import gradio as gr
from gradio_webrtc import WebRTC
import cv2
def generation():
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
cap = cv2.VideoCapture(url)
iterating = True
while iterating:
iterating, frame = cap.read()
yield frame # (1)
with gr.Blocks() as demo:
output_video = WebRTC(label="Video Stream", mode="receive", # (2)
modality="video")
button = gr.Button("Start", variant="primary")
output_video.stream(
fn=generation, inputs=None, outputs=[output_video],
trigger=button.click # (3)
)
demo.launch()
```
1. The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
2. Set `mode="receive"` to only receive audio from the server.
3. The `trigger` parameter the gradio event that will trigger the stream. In this case, the button click event.
=== "Notes"
1. The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
2. Set `mode="receive"` to only receive audio from the server.
3. The `trigger` parameter the gradio event that will trigger the stream. In this case, the button click event.
## Audio-Video Streaming
You can simultaneously stream audio and video simultaneously to/from a server using `AudioVideoStreamHandler` or `AsyncAudioVideoStreamHandler`.
They are identical to the audio `StreamHandlers` with the addition of `video_receive` and `video_emit` methods which take and return a `numpy` array, respectively.
Here is an example of the video handling functions for connecting with the Gemini multimodal API. In this case, we simply reflect the webcam feed back to the user but every second we'll send the latest webcam frame (and an additional image component) to the Gemini server.
Please see the "Gemini Audio Video Chat" example in the [cookbook](/cookbook) for the complete code.
``` python title="Async Gemini Video Handling"
async def video_receive(self, frame: np.ndarray):
"""Send video frames to the server"""
if self.session:
# send image every 1 second
# otherwise we flood the API
if time.time() - self.last_frame_time > 1:
self.last_frame_time = time.time()
await self.session.send(encode_image(frame))
if self.latest_args[2] is not None:
await self.session.send(encode_image(self.latest_args[2]))
self.video_queue.put_nowait(frame)
async def video_emit(self) -> VideoEmitType:
"""Return video frames to the client"""
return await self.video_queue.get()
```
## Additional Outputs
In order to modify other components from within the WebRTC stream, you must yield an instance of `AdditionalOutputs` and add an `on_additional_outputs` event to the `WebRTC` component.
This is common for displaying a multimodal text/audio conversation in a Chatbot UI.
=== "Code"
``` py title="Additional Outputs"
from gradio_webrtc import AdditionalOutputs, WebRTC
def transcribe(audio: tuple[int, np.ndarray],
transformers_convo: list[dict],
gradio_convo: list[dict]):
response = model.generate(**inputs, max_length=256)
transformers_convo.append({"role": "assistant", "content": response})
gradio_convo.append({"role": "assistant", "content": response})
yield AdditionalOutputs(transformers_convo, gradio_convo) # (1)
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Talk to Qwen2Audio (Powered by WebRTC ⚡️)
</h1>
"""
)
transformers_convo = gr.State(value=[])
with gr.Row():
with gr.Column():
audio = WebRTC(
label="Stream",
mode="send", # (2)
modality="audio",
)
with gr.Column():
transcript = gr.Chatbot(label="transcript", type="messages")
audio.stream(ReplyOnPause(transcribe),
inputs=[audio, transformers_convo, transcript],
outputs=[audio], time_limit=90)
audio.on_additional_outputs(lambda s,a: (s,a), # (3)
outputs=[transformers_convo, transcript],
queue=False, show_progress="hidden")
demo.launch()
```
1. Pass your data to `AdditionalOutputs` and yield it.
2. In this case, no audio is being returned, so we set `mode="send"`. However, if we set `mode="send-receive"`, we could also yield generated audio and `AdditionalOutputs`.
3. The `on_additional_outputs` event does not take `inputs`. It's common practice to not run this event on the queue since it is just a quick UI update.
=== "Notes"
1. Pass your data to `AdditionalOutputs` and yield it.
2. In this case, no audio is being returned, so we set `mode="send"`. However, if we set `mode="send-receive"`, we could also yield generated audio and `AdditionalOutputs`.
3. The `on_additional_outputs` event does not take `inputs`. It's common practice to not run this event on the queue since it is just a quick UI update.

54
docs/utils.md Normal file
View File

@@ -0,0 +1,54 @@
# Utils
## `audio_to_bytes`
Convert an audio tuple containing sample rate and numpy array data into bytes.
Useful for sending data to external APIs from `ReplyOnPause` handler.
Parameters
```
audio : tuple[int, np.ndarray]
A tuple containing:
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
```
Returns
```
bytes
The audio data encoded as bytes, suitable for transmission or storage
```
Example
```python
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data)
>>> audio_bytes = audio_to_bytes(audio_tuple)
```
## `audio_to_file`
Save an audio tuple containing sample rate and numpy array data to a file.
Parameters
```
audio : tuple[int, np.ndarray]
A tuple containing:
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
```
Returns
```
str
The path to the saved audio file
```
Example
```
```python
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data)
>>> file_path = audio_to_file(audio_tuple)
>>> print(f"Audio saved to: {file_path}")
```

73
frontend/Example.svelte Normal file
View File

@@ -0,0 +1,73 @@
<script lang="ts">
import { playable } from "./shared/utils";
import { type FileData } from "@gradio/client";
export let type: "gallery" | "table";
export let selected = false;
export let value: { video: FileData; subtitles: FileData | null } | null;
export let loop: boolean;
let video: HTMLVideoElement;
async function init(): Promise<void> {
video.muted = true;
video.playsInline = true;
video.controls = false;
video.setAttribute("muted", "");
await video.play();
video.pause();
}
</script>
{#if value}
{#if playable()}
<div
class="container"
class:table={type === "table"}
class:gallery={type === "gallery"}
class:selected
>
<video
bind:this={video}
on:loadeddata={init}
on:mouseover={video.play.bind(video)}
on:mouseout={video.pause.bind(video)}
src={value?.video.url}
/>
</div>
{:else}
<div>{value}</div>
{/if}
{/if}
<style>
.container {
flex: none;
max-width: none;
}
.container :global(video) {
width: var(--size-full);
height: var(--size-full);
object-fit: cover;
}
.container:hover,
.container.selected {
border-color: var(--border-color-accent);
}
.container.table {
margin: 0 auto;
border: 2px solid var(--border-color-primary);
border-radius: var(--radius-lg);
overflow: hidden;
width: var(--size-20);
height: var(--size-20);
object-fit: cover;
}
.container.gallery {
height: var(--size-20);
max-height: var(--size-20);
object-fit: cover;
}
</style>

158
frontend/Index.svelte Normal file
View File

@@ -0,0 +1,158 @@
<svelte:options accessors={true} />
<script lang="ts">
import { Block, UploadText } from "@gradio/atoms";
import Video from "./shared/InteractiveVideo.svelte";
import { StatusTracker } from "@gradio/statustracker";
import type { LoadingStatus } from "@gradio/statustracker";
import StaticVideo from "./shared/StaticVideo.svelte";
import StaticAudio from "./shared/StaticAudio.svelte";
import InteractiveAudio from "./shared/InteractiveAudio.svelte";
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
export let value: string = "__webrtc_value__";
export let button_labels: {start: string, stop: string, waiting: string};
export let label: string;
export let root: string;
export let show_label: boolean;
export let loading_status: LoadingStatus;
export let height: number | undefined;
export let width: number | undefined;
export let server: {
offer: (body: any) => Promise<any>;
};
export let container = false;
export let scale: number | null = null;
export let min_width: number | undefined = undefined;
export let gradio;
export let rtc_configuration: Object;
export let time_limit: number | null = null;
export let modality: "video" | "audio" | "audio-video" = "video";
export let mode: "send-receive" | "receive" | "send" = "send-receive";
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
export let track_constraints: MediaTrackConstraints = {};
export let icon: string | undefined = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
const on_change_cb = (msg: "change" | "tick" | any) => {
if (msg?.type === "info" || msg?.type === "warning" || msg?.type === "error") {
console.log("dispatching info", msg.message);
gradio.dispatch(msg?.type === "error"? "error": "warning", msg.message);
}
gradio.dispatch(msg === "change" ? "state_change" : "tick");
}
let dragging = false;
$: console.log("value", value);
</script>
<Block
{visible}
variant={"solid"}
border_mode={dragging ? "focus" : "base"}
padding={false}
{elem_id}
{elem_classes}
{height}
{width}
{container}
{scale}
{min_width}
allow_overflow={false}
>
<StatusTracker
autoscroll={gradio.autoscroll}
i18n={gradio.i18n}
{...loading_status}
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
/>
{#if mode == "receive" && modality === "video"}
<StaticVideo
bind:value={value}
{on_change_cb}
{label}
{show_label}
{server}
{rtc_configuration}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
{:else if mode == "receive" && modality === "audio"}
<StaticAudio
bind:value={value}
{on_change_cb}
{label}
{show_label}
{server}
{rtc_configuration}
{icon}
{icon_button_color}
{pulse_color}
i18n={gradio.i18n}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
/>
{:else if (mode === "send-receive" || mode == "send") && (modality === "video" || modality == "audio-video")}
<Video
bind:value={value}
{label}
{show_label}
active_source={"webcam"}
include_audio={modality === "audio-video"}
{server}
{rtc_configuration}
{time_limit}
{mode}
{track_constraints}
{rtp_params}
{on_change_cb}
{icon}
{icon_button_color}
{pulse_color}
{button_labels}
on:clear={() => gradio.dispatch("clear")}
on:play={() => gradio.dispatch("play")}
on:pause={() => gradio.dispatch("pause")}
on:upload={() => gradio.dispatch("upload")}
on:stop={() => gradio.dispatch("stop")}
on:end={() => gradio.dispatch("end")}
on:start_recording={() => gradio.dispatch("start_recording")}
on:stop_recording={() => gradio.dispatch("stop_recording")}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
i18n={gradio.i18n}
stream_handler={(...args) => gradio.client.stream(...args)}
>
<UploadText i18n={gradio.i18n} type="video" />
</Video>
{:else if (mode === "send-receive" || mode === "send") && modality === "audio"}
<InteractiveAudio
bind:value={value}
{on_change_cb}
{label}
{show_label}
{server}
{rtc_configuration}
{time_limit}
{track_constraints}
{mode}
{rtp_params}
i18n={gradio.i18n}
{icon}
{icon_button_color}
{pulse_color}
{button_labels}
on:tick={() => gradio.dispatch("tick")}
on:error={({ detail }) => gradio.dispatch("error", detail)}
on:warning={({ detail }) => gradio.dispatch("warning", detail)}
/>
{/if}
</Block>

View File

@@ -0,0 +1,9 @@
export default {
plugins: [],
svelte: {
preprocess: [],
},
build: {
target: "modules",
},
};

5
frontend/index.ts Normal file
View File

@@ -0,0 +1,5 @@
export { default as BaseInteractiveVideo } from "./shared/InteractiveVideo.svelte";
export { prettyBytes, playable, loaded } from "./shared/utils";
export { default as BaseExample } from "./Example.svelte";
import { default as Index } from "./Index.svelte";
export default Index;

5900
frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

51
frontend/package.json Normal file
View File

@@ -0,0 +1,51 @@
{
"name": "gradio_webrtc",
"version": "0.11.0-beta.3",
"description": "Gradio UI packages",
"type": "module",
"author": "",
"license": "ISC",
"private": false,
"dependencies": {
"@ffmpeg/ffmpeg": "^0.12.10",
"@ffmpeg/util": "^0.12.1",
"@gradio/atoms": "0.9.2",
"@gradio/client": "1.7.0",
"@gradio/icons": "0.8.0",
"@gradio/image": "0.16.4",
"@gradio/markdown": "^0.10.3",
"@gradio/statustracker": "0.9.1",
"@gradio/upload": "0.13.3",
"@gradio/utils": "0.7.0",
"@gradio/wasm": "0.14.2",
"hls.js": "^1.5.16",
"mrmime": "^2.0.0"
},
"devDependencies": {
"@gradio/preview": "0.12.0",
"prettier": "3.3.3"
},
"exports": {
"./package.json": "./package.json",
".": {
"gradio": "./index.ts",
"svelte": "./dist/index.js",
"types": "./dist/index.d.ts"
},
"./example": {
"gradio": "./Example.svelte",
"svelte": "./dist/Example.svelte",
"types": "./dist/Example.svelte.d.ts"
}
},
"peerDependencies": {
"svelte": "^4.0.0"
},
"main": "index.ts",
"main_changeset": true,
"repository": {
"type": "git",
"url": "git+https://github.com/gradio-app/gradio.git",
"directory": "js/video"
}
}

View File

@@ -0,0 +1,164 @@
<script lang="ts">
import { onDestroy } from 'svelte';
import type {ComponentType} from 'svelte';
import PulsingIcon from './PulsingIcon.svelte';
export let numBars = 16;
export let stream_state: "open" | "closed" | "waiting" = "closed";
export let audio_source_callback: () => MediaStream;
export let icon: string | undefined | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
let audioContext: AudioContext;
let analyser: AnalyserNode;
let dataArray: Uint8Array;
let animationId: number;
let pulseScale = 1;
$: containerWidth = icon
? "128px"
: `calc((var(--boxSize) + var(--gutter)) * ${numBars})`;
$: if(stream_state === "open") setupAudioContext();
onDestroy(() => {
if (animationId) {
cancelAnimationFrame(animationId);
}
if (audioContext) {
audioContext.close();
}
});
function setupAudioContext() {
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
const source = audioContext.createMediaStreamSource(audio_source_callback());
source.connect(analyser);
analyser.fftSize = 64;
analyser.smoothingTimeConstant = 0.8;
dataArray = new Uint8Array(analyser.frequencyBinCount);
updateVisualization();
}
function updateVisualization() {
analyser.getByteFrequencyData(dataArray);
// Update bars
const bars = document.querySelectorAll('.gradio-webrtc-waveContainer .gradio-webrtc-box');
for (let i = 0; i < bars.length; i++) {
const barHeight = (dataArray[i] / 255) * 2;
bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`;
}
animationId = requestAnimationFrame(updateVisualization);
}
</script>
<div class="gradio-webrtc-waveContainer">
{#if icon}
<div class="gradio-webrtc-icon-container">
<div
class="gradio-webrtc-icon"
style:transform={`scale(${pulseScale})`}
style:background={icon_button_color}
>
<PulsingIcon
{stream_state}
{pulse_color}
{icon}
{icon_button_color}
{audio_source_callback}/>
</div>
</div>
{:else}
<div class="gradio-webrtc-boxContainer" style:width={containerWidth}>
{#each Array(numBars) as _}
<div class="gradio-webrtc-box"></div>
{/each}
</div>
{/if}
</div>
<style>
.gradio-webrtc-waveContainer {
position: relative;
display: flex;
min-height: 100px;
max-height: 128px;
justify-content: center;
align-items: center;
}
.gradio-webrtc-boxContainer {
display: flex;
justify-content: space-between;
height: 64px;
--boxSize: 8px;
--gutter: 4px;
}
.gradio-webrtc-box {
height: 100%;
width: var(--boxSize);
background: var(--color-accent);
border-radius: 8px;
transition: transform 0.05s ease;
}
.gradio-webrtc-icon-container {
position: relative;
width: 128px;
height: 128px;
display: flex;
justify-content: center;
align-items: center;
}
.gradio-webrtc-icon {
position: relative;
width: 48px;
height: 48px;
border-radius: 50%;
transition: transform 0.1s ease;
display: flex;
justify-content: center;
align-items: center;
z-index: 2;
}
.icon-image {
width: 32px;
height: 32px;
object-fit: contain;
filter: brightness(0) invert(1);
}
.pulse-ring {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
width: 48px;
height: 48px;
border-radius: 50%;
animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
opacity: 0.5;
}
@keyframes pulse {
0% {
transform: translate(-50%, -50%) scale(1);
opacity: 0.5;
}
100% {
transform: translate(-50%, -50%) scale(var(--max-scale, 3));
opacity: 0;
}
}
</style>

View File

@@ -0,0 +1,454 @@
<script lang="ts">
import {
BlockLabel,
} from "@gradio/atoms";
import type { I18nFormatter } from "@gradio/utils";
import { createEventDispatcher } from "svelte";
import { onMount } from "svelte";
import { fade } from "svelte/transition";
import { StreamingBar } from "@gradio/statustracker";
import {
Circle,
Square,
Spinner,
Music,
DropdownArrow,
Microphone
} from "@gradio/icons";
import { start, stop } from "./webrtc_utils";
import { get_devices, set_available_devices } from "./stream_utils";
import AudioWave from "./AudioWave.svelte";
import WebcamPermissions from "./WebcamPermissions.svelte";
export let mode: "send-receive" | "send";
export let value: string | null = null;
export let label: string | undefined = undefined;
export let show_label = true;
export let rtc_configuration: Object | null = null;
export let i18n: I18nFormatter;
export let time_limit: number | null = null;
export let track_constraints: MediaTrackConstraints = {};
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
export let on_change_cb: (mg: "tick" | "change") => void;
export let icon: string | undefined = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
export let button_labels: {start: string, stop: string, waiting: string};
let stopword_recognized = false;
let notification_sound;
onMount(() => {
if (value === "__webrtc_value__") {
notification_sound = new Audio("https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/pop-sounds.mp3");
}
});
let _on_change_cb = (msg: "change" | "tick" | "stopword") => {
console.log("msg", msg);
if (msg === "stopword") {
console.log("stopword recognized");
stopword_recognized = true;
setTimeout(() => {
stopword_recognized = false;
}, 3000);
} else {
console.log("calling on_change_cb with msg", msg);
on_change_cb(msg);
}
};
let options_open = false;
let _time_limit: number | null = null;
export let server: {
offer: (body: any) => Promise<any>;
};
let stream_state: "open" | "closed" | "waiting" = "closed";
let audio_player: HTMLAudioElement;
let pc: RTCPeerConnection;
let _webrtc_id = null;
let stream: MediaStream;
let available_audio_devices: MediaDeviceInfo[];
let selected_device: MediaDeviceInfo | null = null;
let mic_accessed = false;
const audio_source_callback = () => {
console.log("stream in callback", stream);
if(mode==="send") return stream;
else return audio_player.srcObject as MediaStream
}
const dispatch = createEventDispatcher<{
tick: undefined;
state_change: undefined;
error: string
play: undefined;
stop: undefined;
}>();
async function access_mic(): Promise<void> {
try {
const constraints = selected_device ? { deviceId: { exact: selected_device.deviceId }, ...track_constraints } : track_constraints;
const stream_ = await navigator.mediaDevices.getUserMedia({ audio: constraints });
stream = stream_;
} catch (err) {
if (!navigator.mediaDevices) {
dispatch("error", i18n("audio.no_device_support"));
return;
}
if (err instanceof DOMException && err.name == "NotAllowedError") {
dispatch("error", i18n("audio.allow_recording_access"));
return;
}
throw err;
}
available_audio_devices = set_available_devices(await get_devices(), "audioinput");
mic_accessed = true;
const used_devices = stream
.getTracks()
.map((track) => track.getSettings()?.deviceId)[0];
selected_device = used_devices
? available_audio_devices.find((device) => device.deviceId === used_devices) ||
available_audio_devices[0]
: available_audio_devices[0];
}
async function start_stream(): Promise<void> {
if( stream_state === "open"){
stop(pc);
stream_state = "closed";
_time_limit = null;
await access_mic();
return;
}
_webrtc_id = Math.random().toString(36).substring(2);
value = _webrtc_id;
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
console.info("connected");
stream_state = "open";
_time_limit = time_limit;
break;
case "disconnected":
console.info("closed");
stream_state = "closed";
_time_limit = null;
stop(pc);
break;
default:
break;
}
}
)
stream_state = "waiting"
stream = null
try {
await access_mic();
} catch (err) {
if (!navigator.mediaDevices) {
dispatch("error", i18n("audio.no_device_support"));
return;
}
if (err instanceof DOMException && err.name == "NotAllowedError") {
dispatch("error", i18n("audio.allow_recording_access"));
return;
}
throw err;
}
if (stream == null) return;
start(stream, pc, mode === "send" ? null: audio_player, server.offer, _webrtc_id, "audio", _on_change_cb, rtp_params).then((connection) => {
pc = connection;
}).catch(() => {
console.info("catching")
dispatch("error", "Too many concurrent users. Come back later!");
});
}
function handle_click_outside(event: MouseEvent): void {
event.preventDefault();
event.stopPropagation();
options_open = false;
}
function click_outside(node: Node, cb: any): any {
const handle_click = (event: MouseEvent): void => {
if (
node &&
!node.contains(event.target as Node) &&
!event.defaultPrevented
) {
cb(event);
}
};
document.addEventListener("click", handle_click, true);
return {
destroy() {
document.removeEventListener("click", handle_click, true);
}
};
}
const handle_device_change = async (event: InputEvent): Promise<void> => {
const target = event.target as HTMLInputElement;
const device_id = target.value;
stream = await navigator.mediaDevices.getUserMedia({ audio: {deviceId: { exact: device_id }, ...track_constraints }});
selected_device =
available_audio_devices.find(
(device) => device.deviceId === device_id
) || null;
options_open = false;
};
$: if(stopword_recognized){
notification_sound.play();
}
</script>
<BlockLabel
{show_label}
Icon={Music}
float={false}
label={label || i18n("audio.audio")}
/>
<div class="audio-container">
<audio
class="standard-player"
class:hidden={value === "__webrtc_value__"}
on:load
bind:this={audio_player}
on:ended={() => dispatch("stop")}
on:play={() => dispatch("play")}
/>
{#if !mic_accessed}
<div
in:fade={{ delay: 100, duration: 200 }}
title="grant webcam access"
style="height: 100%"
>
<WebcamPermissions icon={Microphone} on:click={async () => access_mic()} />
</div>
{:else}
<AudioWave {audio_source_callback} {stream_state} {icon} {icon_button_color} {pulse_color}/>
<StreamingBar time_limit={_time_limit} />
<div class="button-wrap" class:pulse={stopword_recognized}>
<button
on:click={start_stream}
aria-label={"start stream"}
>
{#if stream_state === "waiting"}
<div class="icon-with-text">
<div class="icon color-primary" title="spinner">
<Spinner />
</div>
{button_labels.waiting || i18n("audio.waiting")}
</div>
{:else if stream_state === "open"}
<div class="icon-with-text">
<div class="icon color-primary" title="stop recording">
<Square />
</div>
{button_labels.stop || i18n("audio.stop")}
</div>
{:else}
<div class="icon-with-text">
<div class="icon color-primary" title="start recording">
<Circle />
</div>
{button_labels.start || i18n("audio.record")}
</div>
{/if}
</button>
{#if stream_state === "closed"}
<button
class="icon"
on:click={() => (options_open = true)}
aria-label="select input source"
>
<DropdownArrow />
</button>
{/if}
{#if options_open && selected_device}
<select
class="select-wrap"
aria-label="select source"
use:click_outside={handle_click_outside}
on:change={handle_device_change}
>
<button
class="inset-icon"
on:click|stopPropagation={() => (options_open = false)}
>
<DropdownArrow />
</button>
{#if available_audio_devices.length === 0}
<option value="">{i18n("common.no_devices")}</option>
{:else}
{#each available_audio_devices as device}
<option
value={device.deviceId}
selected={selected_device.deviceId === device.deviceId}
>
{device.label}
</option>
{/each}
{/if}
</select>
{/if}
</div>
{/if}
</div>
<style>
.audio-container {
display: flex;
height: 100%;
flex-direction: column;
justify-content: center;
align-items: center;
}
:global(::part(wrapper)) {
margin-bottom: var(--size-2);
}
.standard-player {
width: 100%;
padding: var(--size-2);
}
.hidden {
display: none;
}
.button-wrap {
margin-top: var(--size-2);
margin-bottom: var(--size-2);
background-color: var(--block-background-fill);
border: 1px solid var(--border-color-primary);
border-radius: var(--radius-xl);
padding: var(--size-1-5);
display: flex;
bottom: var(--size-2);
box-shadow: var(--shadow-drop-lg);
border-radius: var(--radius-xl);
line-height: var(--size-3);
color: var(--button-secondary-text-color);
}
@keyframes pulse {
0% {
transform: scale(1);
box-shadow: 0 0 0 0 rgba(var(--primary-500-rgb), 0.7);
}
70% {
transform: scale(1.25);
box-shadow: 0 0 0 10px rgba(var(--primary-500-rgb), 0);
}
100% {
transform: scale(1);
box-shadow: 0 0 0 0 rgba(var(--primary-500-rgb), 0);
}
}
.pulse {
animation: pulse 1s infinite;
}
.icon-with-text {
min-width: var(--size-16);
align-items: center;
margin: 0 var(--spacing-xl);
display: flex;
justify-content: space-evenly;
gap: var(--size-2);
}
@media (--screen-md) {
button {
bottom: var(--size-4);
}
}
@media (--screen-xl) {
button {
bottom: var(--size-8);
}
}
.icon {
width: 18px;
height: 18px;
display: flex;
justify-content: space-between;
align-items: center;
}
.color-primary {
fill: var(--primary-600);
stroke: var(--primary-600);
color: var(--primary-600);
}
.select-wrap {
-webkit-appearance: none;
-moz-appearance: none;
appearance: none;
color: var(--button-secondary-text-color);
background-color: transparent;
width: 95%;
font-size: var(--text-md);
position: absolute;
bottom: var(--size-2);
background-color: var(--block-background-fill);
box-shadow: var(--shadow-drop-lg);
border-radius: var(--radius-xl);
z-index: var(--layer-top);
border: 1px solid var(--border-color-primary);
text-align: left;
line-height: var(--size-4);
white-space: nowrap;
text-overflow: ellipsis;
left: 50%;
transform: translate(-50%, 0);
max-width: var(--size-52);
}
.select-wrap > option {
padding: 0.25rem 0.5rem;
border-bottom: 1px solid var(--border-color-accent);
padding-right: var(--size-8);
text-overflow: ellipsis;
overflow: hidden;
}
.select-wrap > option:hover {
background-color: var(--color-accent);
}
.select-wrap > option:last-child {
border: none;
}
</style>

View File

@@ -0,0 +1,89 @@
<script lang="ts">
import { createEventDispatcher } from "svelte";
import type { ComponentType } from "svelte";
import type { FileData, Client } from "@gradio/client";
import { BlockLabel } from "@gradio/atoms";
import Webcam from "./Webcam.svelte";
import { Video } from "@gradio/icons";
import type { I18nFormatter } from "@gradio/utils";
export let value: string = null;
export let label: string | undefined = undefined;
export let show_label = true;
export let include_audio: boolean;
export let i18n: I18nFormatter;
export let active_source: "webcam" | "upload" = "webcam";
export let handle_reset_value: () => void = () => {};
export let stream_handler: Client["stream"];
export let time_limit: number | null = null;
export let button_labels: {start: string, stop: string, waiting: string};
export let server: {
offer: (body: any) => Promise<any>;
};
export let rtc_configuration: Object;
export let track_constraints: MediaTrackConstraints = {};
export let mode: "send" | "send-receive";
export let on_change_cb: (msg: "change" | "tick") => void;
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
export let icon: string | undefined | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
const dispatch = createEventDispatcher<{
change: FileData | null;
clear?: never;
play?: never;
pause?: never;
end?: never;
drag: boolean;
error: string;
upload: FileData;
start_recording?: never;
stop_recording?: never;
tick: never;
}>();
let dragging = false;
$: dispatch("drag", dragging);
$: console.log("value", value)
</script>
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />
<div data-testid="video" class="video-container">
<Webcam
{rtc_configuration}
{include_audio}
{time_limit}
{track_constraints}
{mode}
{rtp_params}
{on_change_cb}
{icon}
{icon_button_color}
{pulse_color}
{button_labels}
on:error
on:start_recording
on:stop_recording
on:tick
{i18n}
stream_every={0.5}
{server}
bind:webrtc_id={value}
/>
<!-- <SelectSource {sources} bind:active_source /> -->
</div>
<style>
.video-container {
display: flex;
height: 100%;
flex-direction: column;
justify-content: center;
align-items: center;
}
</style>

View File

@@ -0,0 +1,151 @@
<script lang="ts">
import { onDestroy } from 'svelte';
import type {ComponentType} from 'svelte';
export let stream_state: "open" | "closed" | "waiting" = "closed";
export let audio_source_callback: () => MediaStream;
export let icon: string | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
let audioContext: AudioContext;
let analyser: AnalyserNode;
let dataArray: Uint8Array;
let animationId: number;
let pulseScale = 1;
let pulseIntensity = 0;
$: if(stream_state === "open") setupAudioContext();
onDestroy(() => {
if (animationId) {
cancelAnimationFrame(animationId);
}
if (audioContext) {
audioContext.close();
}
});
function setupAudioContext() {
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
const source = audioContext.createMediaStreamSource(audio_source_callback());
source.connect(analyser);
analyser.fftSize = 64;
analyser.smoothingTimeConstant = 0.8;
dataArray = new Uint8Array(analyser.frequencyBinCount);
updateVisualization();
}
function updateVisualization() {
analyser.getByteFrequencyData(dataArray);
// Calculate average amplitude for pulse effect
const average = Array.from(dataArray).reduce((a, b) => a + b, 0) / dataArray.length;
const normalizedAverage = average / 255;
pulseScale = 1 + (normalizedAverage * 0.15);
pulseIntensity = normalizedAverage;
animationId = requestAnimationFrame(updateVisualization);
}
$: maxPulseScale = 1 + (pulseIntensity * 10); // Scale from 1x to 3x based on intensity
</script>
<div class="gradio-webrtc-icon-wrapper">
<div class="gradio-webrtc-pulsing-icon-container">
{#if pulseIntensity > 0}
{#each Array(3) as _, i}
<div
class="pulse-ring"
style:background={pulse_color}
style:animation-delay={`${i * 0.4}s`}
style:--max-scale={maxPulseScale}
style:opacity={0.5 * pulseIntensity}
/>
{/each}
{/if}
<div
class="gradio-webrtc-pulsing-icon"
style:transform={`scale(${pulseScale})`}
style:background={icon_button_color}
>
{#if typeof icon === "string"}
<img
src={icon}
alt="Audio visualization icon"
class="icon-image"
/>
{:else}
<svelte:component this={icon} />
{/if}
</div>
</div>
</div>
<style>
.gradio-webrtc-icon-wrapper {
position: relative;
display: flex;
max-height: 128px;
justify-content: center;
align-items: center;
}
.gradio-webrtc-pulsing-icon-container {
position: relative;
width: 100%;
height: 100%;
display: flex;
justify-content: center;
align-items: center;
}
.gradio-webrtc-pulsing-icon {
position: relative;
width: 100%;
height: 100%;
border-radius: 50%;
transition: transform 0.1s ease;
display: flex;
justify-content: center;
align-items: center;
z-index: 2;
}
.icon-image {
width: 100%;
height: 100%;
object-fit: contain;
filter: brightness(0) invert(1);
}
.pulse-ring {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
width: 100%;
height: 100%;
border-radius: 50%;
animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
opacity: 0.5;
}
@keyframes pulse {
0% {
transform: translate(-50%, -50%) scale(1);
opacity: 0.5;
}
100% {
transform: translate(-50%, -50%) scale(var(--max-scale, 3));
opacity: 0;
}
}

View File

@@ -0,0 +1,135 @@
<script lang="ts">
import { Empty } from "@gradio/atoms";
import {
BlockLabel,
} from "@gradio/atoms";
import { Music } from "@gradio/icons";
import type { I18nFormatter } from "@gradio/utils";
import { createEventDispatcher } from "svelte";
import { onMount } from "svelte";
import { start, stop } from "./webrtc_utils";
import AudioWave from "./AudioWave.svelte";
export let value: string | null = null;
export let label: string | undefined = undefined;
export let show_label = true;
export let rtc_configuration: Object | null = null;
export let i18n: I18nFormatter;
export let on_change_cb: (msg: "change" | "tick") => void;
export let icon: string | undefined = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
export let server: {
offer: (body: any) => Promise<any>;
};
let stream_state: "open" | "closed" | "waiting" = "closed";
let audio_player: HTMLAudioElement;
let pc: RTCPeerConnection;
let _webrtc_id = Math.random().toString(36).substring(2);
const dispatch = createEventDispatcher<{
tick: undefined;
error: string
play: undefined;
stop: undefined;
}>();
onMount(() => {
window.setInterval(() => {
if (stream_state == "open") {
dispatch("tick");
}
}, 1000);
}
)
async function start_stream(value: string): Promise<string> {
if( value === "start_webrtc_stream") {
stream_state = "waiting";
_webrtc_id = Math.random().toString(36).substring(2)
value = _webrtc_id;
console.log("set value to ", value);
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
console.info("connected");
stream_state = "open";
break;
case "disconnected":
console.info("closed");
stop(pc);
break;
default:
break;
}
}
)
let stream = null;
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => {
pc = connection;
}).catch(() => {
console.info("catching")
dispatch("error", "Too many concurrent users. Come back later!");
});
}
return value;
}
$: start_stream(value).then((val) => {
value = val;
});
</script>
<BlockLabel
{show_label}
Icon={Music}
float={false}
label={label || i18n("audio.audio")}
/>
<audio
class="standard-player"
class:hidden={true}
on:load
bind:this={audio_player}
on:ended={() => dispatch("stop")}
on:play={() => dispatch("play")}
/>
{#if value !== "__webrtc_value__"}
<div class="audio-container">
<AudioWave audio_source_callback={() => audio_player.srcObject} {stream_state} {icon} {icon_button_color} {pulse_color}/>
</div>
{/if}
{#if value === "__webrtc_value__"}
<Empty size="small">
<Music />
</Empty>
{/if}
<style>
.audio-container {
display: flex;
height: 100%;
flex-direction: column;
justify-content: center;
align-items: center;
}
.standard-player {
width: 100%;
}
.hidden {
display: none;
}
</style>

View File

@@ -0,0 +1,119 @@
<script lang="ts">
import { createEventDispatcher, onMount} from "svelte";
import {
BlockLabel,
Empty
} from "@gradio/atoms";
import { Video } from "@gradio/icons";
import { start, stop } from "./webrtc_utils";
export let value: string | null = null;
export let label: string | undefined = undefined;
export let show_label = true;
export let rtc_configuration: Object | null = null;
export let on_change_cb: (msg: "change" | "tick") => void;
export let server: {
offer: (body: any) => Promise<any>;
};
let video_element: HTMLVideoElement;
let _webrtc_id = Math.random().toString(36).substring(2);
let pc: RTCPeerConnection;
const dispatch = createEventDispatcher<{
error: string;
tick: undefined;
}>();
let stream_state = "closed";
onMount(() => {
window.setInterval(() => {
if (stream_state == "open") {
dispatch("tick");
}
}, 1000);
}
)
$: if( value === "start_webrtc_stream") {
_webrtc_id = Math.random().toString(36).substring(2);
value = _webrtc_id;
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
console.log("connected");
stream_state = "open";
break;
case "disconnected":
console.log("closed");
stop(pc);
break;
default:
break;
}
}
)
start(null, pc, video_element, server.offer, _webrtc_id, "video", on_change_cb).then((connection) => {
pc = connection;
}).catch(() => {
console.log("catching")
dispatch("error", "Too many concurrent users. Come back later!");
});
}
</script>
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />
{#if value === "__webrtc_value__"}
<Empty unpadded_box={true} size="large"><Video /></Empty>
{/if}
<div class="wrap">
<video
class:hidden={value === "__webrtc_value__"}
bind:this={video_element}
autoplay={true}
on:loadeddata={dispatch.bind(null, "loadeddata")}
on:click={dispatch.bind(null, "click")}
on:play={dispatch.bind(null, "play")}
on:pause={dispatch.bind(null, "pause")}
on:ended={dispatch.bind(null, "ended")}
on:mouseover={dispatch.bind(null, "mouseover")}
on:mouseout={dispatch.bind(null, "mouseout")}
on:focus={dispatch.bind(null, "focus")}
on:blur={dispatch.bind(null, "blur")}
on:load
data-testid={$$props["data-testid"]}
crossorigin="anonymous"
>
<track kind="captions" />
</video>
</div>
<style>
.hidden {
display: none;
}
.wrap {
position: relative;
background-color: var(--background-fill-secondary);
height: var(--size-full);
width: var(--size-full);
border-radius: var(--radius-xl);
}
.wrap :global(video) {
height: var(--size-full);
width: var(--size-full);
}
</style>

View File

@@ -0,0 +1,434 @@
<script lang="ts">
import { createEventDispatcher, onMount } from "svelte";
import type { ComponentType } from "svelte";
import {
Circle,
Square,
DropdownArrow,
Spinner,
Microphone as Mic
} from "@gradio/icons";
import type { I18nFormatter } from "@gradio/utils";
import { StreamingBar } from "@gradio/statustracker";
import WebcamPermissions from "./WebcamPermissions.svelte";
import { fade } from "svelte/transition";
import {
get_devices,
get_video_stream,
set_available_devices
} from "./stream_utils";
import { start, stop } from "./webrtc_utils";
import PulsingIcon from "./PulsingIcon.svelte";
let video_source: HTMLVideoElement;
let available_video_devices: MediaDeviceInfo[] = [];
let selected_device: MediaDeviceInfo | null = null;
let _time_limit: number | null = null;
export let time_limit: number | null = null;
let stream_state: "open" | "waiting" | "closed" = "closed";
export let on_change_cb: (msg: "tick" | "change") => void;
export let mode: "send-receive" | "send";
const _webrtc_id = Math.random().toString(36).substring(2);
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
export let icon: string | undefined | ComponentType = undefined;
export let icon_button_color: string = "var(--color-accent)";
export let pulse_color: string = "var(--color-accent)";
export let button_labels: {start: string, stop: string, waiting: string};
export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
state: "open" | "closed" | "waiting"
) => {
if (state === "closed") {
_time_limit = null;
stream_state = "closed";
} else if (state === "waiting") {
stream_state = "waiting";
} else {
stream_state = "open";
}
};
let canvas: HTMLCanvasElement;
export let track_constraints: MediaTrackConstraints | null = null;
export let rtc_configuration: Object;
export let stream_every = 1;
export let server: {
offer: (body: any) => Promise<any>;
};
export let include_audio: boolean;
export let i18n: I18nFormatter;
const dispatch = createEventDispatcher<{
tick: undefined;
error: string;
start_recording: undefined;
stop_recording: undefined;
close_stream: undefined;
}>();
onMount(() => (canvas = document.createElement("canvas")));
const handle_device_change = async (event: InputEvent): Promise<void> => {
const target = event.target as HTMLInputElement;
const device_id = target.value;
await get_video_stream(include_audio, video_source, device_id, track_constraints).then(
async (local_stream) => {
stream = local_stream;
selected_device =
available_video_devices.find(
(device) => device.deviceId === device_id
) || null;
options_open = false;
}
);
};
async function access_webcam(): Promise<void> {
try {
get_video_stream(include_audio, video_source, null, track_constraints)
.then(async (local_stream) => {
webcam_accessed = true;
available_video_devices = await get_devices();
stream = local_stream;
})
.then(() => set_available_devices(available_video_devices))
.then((devices) => {
available_video_devices = devices;
const used_devices = stream
.getTracks()
.map((track) => track.getSettings()?.deviceId)[0];
selected_device = used_devices
? devices.find((device) => device.deviceId === used_devices) ||
available_video_devices[0]
: available_video_devices[0];
});
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
dispatch("error", i18n("image.no_webcam_support"));
}
} catch (err) {
if (err instanceof DOMException && err.name == "NotAllowedError") {
dispatch("error", i18n("image.allow_webcam_access"));
} else {
throw err;
}
}
}
let recording = false;
let stream: MediaStream;
let webcam_accessed = false;
let pc: RTCPeerConnection;
export let webrtc_id;
async function start_webrtc(): Promise<void> {
if (stream_state === 'closed') {
pc = new RTCPeerConnection(rtc_configuration);
pc.addEventListener("connectionstatechange",
async (event) => {
switch(pc.connectionState) {
case "connected":
stream_state = "open";
_time_limit = time_limit;
break;
case "disconnected":
stream_state = "closed";
_time_limit = null;
stop(pc);
await access_webcam();
break;
default:
break;
}
}
)
stream_state = "waiting"
webrtc_id = Math.random().toString(36).substring(2);
start(stream, pc, mode === "send" ? null: video_source, server.offer, webrtc_id, "video", on_change_cb, rtp_params).then((connection) => {
pc = connection;
}).catch(() => {
console.info("catching")
stream_state = "closed";
dispatch("error", "Too many concurrent users. Come back later!");
});
} else {
stop(pc);
stream_state = "closed";
_time_limit = null;
await access_webcam();
}
}
let options_open = false;
export function click_outside(node: Node, cb: any): any {
const handle_click = (event: MouseEvent): void => {
if (
node &&
!node.contains(event.target as Node) &&
!event.defaultPrevented
) {
cb(event);
}
};
document.addEventListener("click", handle_click, true);
return {
destroy() {
document.removeEventListener("click", handle_click, true);
}
};
}
function handle_click_outside(event: MouseEvent): void {
event.preventDefault();
event.stopPropagation();
options_open = false;
}
const audio_source_callback = () => video_source.srcObject as MediaStream;
</script>
<div class="wrap">
<StreamingBar time_limit={_time_limit} />
{#if stream_state === "open" && include_audio}
<div class="audio-indicator">
<PulsingIcon
stream_state={stream_state}
audio_source_callback={audio_source_callback}
icon={icon || Mic}
icon_button_color={icon_button_color}
pulse_color={pulse_color}
/>
</div>
{/if}
<!-- svelte-ignore a11y-media-has-caption -->
<!-- need to suppress for video streaming https://github.com/sveltejs/svelte/issues/5967 -->
<video
bind:this={video_source}
class:hide={!webcam_accessed}
class:flip={(stream_state != "open") || (stream_state === "open" && include_audio)}
autoplay={true}
playsinline={true}
/>
<!-- svelte-ignore a11y-missing-attribute -->
{#if !webcam_accessed}
<div
in:fade={{ delay: 100, duration: 200 }}
title="grant webcam access"
style="height: 100%"
>
<WebcamPermissions on:click={async () => access_webcam()} />
</div>
{:else}
<div class="button-wrap">
<button
on:click={start_webrtc}
aria-label={"start stream"}
>
{#if stream_state === "waiting"}
<div class="icon-with-text">
<div class="icon color-primary" title="spinner">
<Spinner />
</div>
{button_labels.waiting || i18n("audio.waiting")}
</div>
{:else if stream_state === "open"}
<div class="icon-with-text">
<div class="icon color-primary" title="stop recording">
<Square />
</div>
{button_labels.stop || i18n("audio.stop")}
</div>
{:else}
<div class="icon-with-text">
<div class="icon color-primary" title="start recording">
<Circle />
</div>
{button_labels.start || i18n("audio.record")}
</div>
{/if}
</button>
{#if !recording}
<button
class="icon"
on:click={() => (options_open = true)}
aria-label="select input source"
>
<DropdownArrow />
</button>
{/if}
</div>
{#if options_open && selected_device}
<select
class="select-wrap"
aria-label="select source"
use:click_outside={handle_click_outside}
on:change={handle_device_change}
>
<button
class="inset-icon"
on:click|stopPropagation={() => (options_open = false)}
>
<DropdownArrow />
</button>
{#if available_video_devices.length === 0}
<option value="">{i18n("common.no_devices")}</option>
{:else}
{#each available_video_devices as device}
<option
value={device.deviceId}
selected={selected_device.deviceId === device.deviceId}
>
{device.label}
</option>
{/each}
{/if}
</select>
{/if}
{/if}
</div>
<style>
.wrap {
position: relative;
width: var(--size-full);
height: var(--size-full);
}
.hide {
display: none;
}
video {
width: var(--size-full);
height: var(--size-full);
object-fit: cover;
}
.button-wrap {
position: absolute;
background-color: var(--block-background-fill);
border: 1px solid var(--border-color-primary);
border-radius: var(--radius-xl);
padding: var(--size-1-5);
display: flex;
bottom: var(--size-2);
left: 50%;
transform: translate(-50%, 0);
box-shadow: var(--shadow-drop-lg);
border-radius: var(--radius-xl);
line-height: var(--size-3);
color: var(--button-secondary-text-color);
}
.icon-with-text {
min-width: var(--size-16);
align-items: center;
margin: 0 var(--spacing-xl);
display: flex;
justify-content: space-evenly;
/* Add gap between icon and text */
gap: var(--size-2);
}
.audio-indicator {
position: absolute;
top: var(--size-2);
right: var(--size-2);
z-index: var(--layer-2);
height: var(--size-5);
width: var(--size-5);
}
@media (--screen-md) {
button {
bottom: var(--size-4);
}
}
@media (--screen-xl) {
button {
bottom: var(--size-8);
}
}
.icon {
width: 18px;
height: 18px;
display: flex;
justify-content: space-between;
align-items: center;
}
.color-primary {
fill: var(--primary-600);
stroke: var(--primary-600);
color: var(--primary-600);
}
.flip {
transform: scaleX(-1);
}
.select-wrap {
-webkit-appearance: none;
-moz-appearance: none;
appearance: none;
color: var(--button-secondary-text-color);
background-color: transparent;
width: 95%;
font-size: var(--text-md);
position: absolute;
bottom: var(--size-2);
background-color: var(--block-background-fill);
box-shadow: var(--shadow-drop-lg);
border-radius: var(--radius-xl);
z-index: var(--layer-top);
border: 1px solid var(--border-color-primary);
text-align: left;
line-height: var(--size-4);
white-space: nowrap;
text-overflow: ellipsis;
left: 50%;
transform: translate(-50%, 0);
max-width: var(--size-52);
}
.select-wrap > option {
padding: 0.25rem 0.5rem;
border-bottom: 1px solid var(--border-color-accent);
padding-right: var(--size-8);
text-overflow: ellipsis;
overflow: hidden;
}
.select-wrap > option:hover {
background-color: var(--color-accent);
}
.select-wrap > option:last-child {
border: none;
}
.inset-icon {
position: absolute;
top: 5px;
right: -6.5px;
width: var(--size-10);
height: var(--size-5);
opacity: 0.8;
}
@media (--screen-md) {
.wrap {
font-size: var(--text-lg);
}
}
</style>

View File

@@ -0,0 +1,49 @@
<script lang="ts">
import { Webcam } from "@gradio/icons";
import { createEventDispatcher } from "svelte";
export let icon = Webcam;
$: text = icon === Webcam ? "Click to Access Webcam" : "Click to Access Microphone";
const dispatch = createEventDispatcher<{
click: undefined;
}>();
</script>
<button style:height="100%" on:click={() => dispatch("click")}>
<div class="wrap">
<span class="icon-wrap">
<svelte:component this={icon} />
</span>
{text}
</div>
</button>
<style>
button {
cursor: pointer;
width: var(--size-full);
}
.wrap {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
min-height: var(--size-60);
color: var(--block-label-text-color);
height: 100%;
padding-top: var(--size-3);
}
.icon-wrap {
width: 30px;
margin-bottom: var(--spacing-lg);
}
@media (--screen-md) {
.wrap {
font-size: var(--text-lg);
}
}
</style>

1
frontend/shared/index.ts Normal file
View File

@@ -0,0 +1 @@
export { default as Video } from "./Video.svelte";

View File

@@ -0,0 +1,53 @@
export function get_devices(): Promise<MediaDeviceInfo[]> {
return navigator.mediaDevices.enumerateDevices();
}
export function handle_error(error: string): void {
throw new Error(error);
}
export function set_local_stream(
local_stream: MediaStream | null,
video_source: HTMLVideoElement,
): void {
video_source.srcObject = local_stream;
video_source.muted = true;
video_source.play();
}
export async function get_video_stream(
include_audio: boolean,
video_source: HTMLVideoElement,
device_id?: string,
track_constraints?: MediaTrackConstraints,
): Promise<MediaStream> {
const fallback_constraints = track_constraints || {
width: { ideal: 500 },
height: { ideal: 500 },
};
const constraints = {
video: device_id
? { deviceId: { exact: device_id }, ...fallback_constraints }
: fallback_constraints,
audio: include_audio,
};
return navigator.mediaDevices
.getUserMedia(constraints)
.then((local_stream: MediaStream) => {
set_local_stream(local_stream, video_source);
return local_stream;
});
}
export function set_available_devices(
devices: MediaDeviceInfo[],
kind: "videoinput" | "audioinput" = "videoinput",
): MediaDeviceInfo[] {
const cameras = devices.filter(
(device: MediaDeviceInfo) => device.kind === kind,
);
return cameras;
}

146
frontend/shared/utils.ts Normal file
View File

@@ -0,0 +1,146 @@
import { toBlobURL } from "@ffmpeg/util";
import { FFmpeg } from "@ffmpeg/ffmpeg";
import { lookup } from "mrmime";
export const prettyBytes = (bytes: number): string => {
let units = ["B", "KB", "MB", "GB", "PB"];
let i = 0;
while (bytes > 1024) {
bytes /= 1024;
i++;
}
let unit = units[i];
return bytes.toFixed(1) + " " + unit;
};
export const playable = (): boolean => {
// TODO: Fix this
// let video_element = document.createElement("video");
// let mime_type = mime.lookup(filename);
// return video_element.canPlayType(mime_type) != "";
return true; // FIX BEFORE COMMIT - mime import causing issues
};
export function loaded(
node: HTMLVideoElement,
{ autoplay }: { autoplay: boolean },
): any {
async function handle_playback(): Promise<void> {
if (!autoplay) return;
await node.play();
}
node.addEventListener("loadeddata", handle_playback);
return {
destroy(): void {
node.removeEventListener("loadeddata", handle_playback);
},
};
}
export default async function loadFfmpeg(): Promise<FFmpeg> {
const ffmpeg = new FFmpeg();
const baseURL = "https://unpkg.com/@ffmpeg/core@0.12.4/dist/esm";
await ffmpeg.load({
coreURL: await toBlobURL(`${baseURL}/ffmpeg-core.js`, "text/javascript"),
wasmURL: await toBlobURL(`${baseURL}/ffmpeg-core.wasm`, "application/wasm"),
});
return ffmpeg;
}
export function blob_to_data_url(blob: Blob): Promise<string> {
return new Promise((fulfill, reject) => {
let reader = new FileReader();
reader.onerror = reject;
reader.onload = () => fulfill(reader.result as string);
reader.readAsDataURL(blob);
});
}
export async function trimVideo(
ffmpeg: FFmpeg,
startTime: number,
endTime: number,
videoElement: HTMLVideoElement,
): Promise<any> {
const videoUrl = videoElement.src;
const mimeType = lookup(videoElement.src) || "video/mp4";
const blobUrl = await toBlobURL(videoUrl, mimeType);
const response = await fetch(blobUrl);
const vidBlob = await response.blob();
const type = getVideoExtensionFromMimeType(mimeType) || "mp4";
const inputName = `input.${type}`;
const outputName = `output.${type}`;
try {
if (startTime === 0 && endTime === 0) {
return vidBlob;
}
await ffmpeg.writeFile(
inputName,
new Uint8Array(await vidBlob.arrayBuffer()),
);
let command = [
"-i",
inputName,
...(startTime !== 0 ? ["-ss", startTime.toString()] : []),
...(endTime !== 0 ? ["-to", endTime.toString()] : []),
"-c:a",
"copy",
outputName,
];
await ffmpeg.exec(command);
const outputData = await ffmpeg.readFile(outputName);
const outputBlob = new Blob([outputData], {
type: `video/${type}`,
});
return outputBlob;
} catch (error) {
console.error("Error initializing FFmpeg:", error);
return vidBlob;
}
}
const getVideoExtensionFromMimeType = (mimeType: string): string | null => {
const videoMimeToExtensionMap: { [key: string]: string } = {
"video/mp4": "mp4",
"video/webm": "webm",
"video/ogg": "ogv",
"video/quicktime": "mov",
"video/x-msvideo": "avi",
"video/x-matroska": "mkv",
"video/mpeg": "mpeg",
"video/3gpp": "3gp",
"video/3gpp2": "3g2",
"video/h261": "h261",
"video/h263": "h263",
"video/h264": "h264",
"video/jpeg": "jpgv",
"video/jpm": "jpm",
"video/mj2": "mj2",
"video/mpv": "mpv",
"video/vnd.ms-playready.media.pyv": "pyv",
"video/vnd.uvvu.mp4": "uvu",
"video/vnd.vivo": "viv",
"video/x-f4v": "f4v",
"video/x-fli": "fli",
"video/x-flv": "flv",
"video/x-m4v": "m4v",
"video/x-ms-asf": "asf",
"video/x-ms-wm": "wm",
"video/x-ms-wmv": "wmv",
"video/x-ms-wmx": "wmx",
"video/x-ms-wvx": "wvx",
"video/x-sgi-movie": "movie",
"video/x-smv": "smv",
};
return videoMimeToExtensionMap[mimeType] || null;
};

View File

@@ -0,0 +1,184 @@
export function createPeerConnection(pc, node) {
// register some listeners to help debugging
pc.addEventListener(
"icegatheringstatechange",
() => {
console.debug(pc.iceGatheringState);
},
false,
);
pc.addEventListener(
"iceconnectionstatechange",
() => {
console.debug(pc.iceConnectionState);
},
false,
);
pc.addEventListener(
"signalingstatechange",
() => {
console.debug(pc.signalingState);
},
false,
);
// connect audio / video from server to local
pc.addEventListener("track", (evt) => {
console.debug("track event listener");
if (node && node.srcObject !== evt.streams[0]) {
console.debug("streams", evt.streams);
node.srcObject = evt.streams[0];
console.debug("node.srcOject", node.srcObject);
if (evt.track.kind === "audio") {
node.volume = 1.0; // Ensure volume is up
node.muted = false;
node.autoplay = true;
// Attempt to play (needed for some browsers)
node.play().catch((e) => console.debug("Autoplay failed:", e));
}
}
});
return pc;
}
export async function start(
stream,
pc: RTCPeerConnection,
node,
server_fn,
webrtc_id,
modality: "video" | "audio" = "video",
on_change_cb: (msg: "change" | "tick") => void = () => {},
rtp_params = {},
) {
pc = createPeerConnection(pc, node);
const data_channel = pc.createDataChannel("text");
data_channel.onopen = () => {
console.debug("Data channel is open");
data_channel.send("handshake");
};
data_channel.onmessage = (event) => {
console.debug("Received message:", event.data);
let event_json;
try {
event_json = JSON.parse(event.data);
} catch (e) {
console.debug("Error parsing JSON");
}
console.log("event_json", event_json);
if (
event.data === "change" ||
event.data === "tick" ||
event.data === "stopword" ||
event_json?.type === "warning" ||
event_json?.type === "error"
) {
console.debug(`${event.data} event received`);
on_change_cb(event_json ?? event.data);
}
};
if (stream) {
stream.getTracks().forEach(async (track) => {
console.debug("Track stream callback", track);
const sender = pc.addTrack(track, stream);
const params = sender.getParameters();
const updated_params = { ...params, ...rtp_params };
await sender.setParameters(updated_params);
console.debug("sender params", sender.getParameters());
});
} else {
console.debug("Creating transceiver!");
pc.addTransceiver(modality, { direction: "recvonly" });
}
await negotiate(pc, server_fn, webrtc_id);
return pc;
}
function make_offer(server_fn: any, body): Promise<object> {
return new Promise((resolve, reject) => {
server_fn(body).then((data) => {
console.debug("data", data);
if (data?.status === "failed") {
console.debug("rejecting");
reject("error");
}
resolve(data);
});
});
}
async function negotiate(
pc: RTCPeerConnection,
server_fn: any,
webrtc_id: string,
): Promise<void> {
return pc
.createOffer()
.then((offer) => {
return pc.setLocalDescription(offer);
})
.then(() => {
// wait for ICE gathering to complete
return new Promise<void>((resolve) => {
console.debug("ice gathering state", pc.iceGatheringState);
if (pc.iceGatheringState === "complete") {
resolve();
} else {
const checkState = () => {
if (pc.iceGatheringState === "complete") {
console.debug("ice complete");
pc.removeEventListener("icegatheringstatechange", checkState);
resolve();
}
};
pc.addEventListener("icegatheringstatechange", checkState);
}
});
})
.then(() => {
var offer = pc.localDescription;
return make_offer(server_fn, {
sdp: offer.sdp,
type: offer.type,
webrtc_id: webrtc_id,
});
})
.then((response) => {
return response;
})
.then((answer) => {
return pc.setRemoteDescription(answer);
});
}
export function stop(pc: RTCPeerConnection) {
console.debug("Stopping peer connection");
// close transceivers
if (pc.getTransceivers) {
pc.getTransceivers().forEach((transceiver) => {
if (transceiver.stop) {
transceiver.stop();
}
});
}
// close local audio / video
if (pc.getSenders()) {
pc.getSenders().forEach((sender) => {
console.log("sender", sender);
if (sender.track && sender.track.stop) sender.track.stop();
});
}
// close peer connection
setTimeout(() => {
pc.close();
}, 500);
}

40
mkdocs.yml Normal file
View File

@@ -0,0 +1,40 @@
site_name: Gradio WebRTC
site_url: https://sitename.example
repo_name: gradio-webrtc
repo_url: https://github.com/freddyaboulton/gradio-webrtc
theme:
name: material
palette:
scheme: slate
primary: black
accent: yellow
features:
- content.code.copy
- content.code.annotate
logo: bolt.svg
favicon: bolt.svg
nav:
- Home: index.md
- User Guide: user-guide.md
- Cookbook: cookbook.md
- Deployment: deployment.md
- Advanced Configuration: advanced-configuration.md
- Utils: utils.md
- Frequently Asked Questions: faq.md
markdown_extensions:
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
- pymdownx.tabbed:
alternate_style: true
- attr_list
- md_in_html
- pymdownx.emoji:
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
- admonition
- pymdownx.details

56
pyproject.toml Normal file
View File

@@ -0,0 +1,56 @@
[build-system]
requires = [
"hatchling",
"hatch-requirements-txt",
"hatch-fancy-pypi-readme>=22.5.0",
]
build-backend = "hatchling.build"
[project]
name = "gradio_webrtc"
version = "0.0.29"
description = "Stream images in realtime with webrtc"
readme = "README.md"
license = "apache-2.0"
requires-python = ">=3.10"
authors = [{ name = "Freddy Boulton", email = "YOUREMAIL@domain.com" }]
keywords = ["gradio-custom-component", "gradio-template-Video", "streaming", "webrtc", "realtime"]
# Add dependencies here
dependencies = ["gradio>=4.0,<6.0", "aiortc"]
classifiers = [
'Development Status :: 3 - Alpha',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Scientific/Engineering :: Visualization',
]
# The repository and space URLs are optional, but recommended.
# Adding a repository URL will create a badge in the auto-generated README that links to the repository.
# Adding a space URL will create a badge in the auto-generated README that links to the space.
# This will make it easy for people to find your deployed demo or source code when they
# encounter your project in the wild.
# [project.urls]
# repository = "your github repository"
# space = "your space url"
[project.optional-dependencies]
dev = ["build", "twine"]
vad = ["librosa", "onnxruntime"]
stopword = ["silero", "librosa", "onnxruntime"]
[tool.hatch.build]
artifacts = ["/backend/gradio_webrtc/templates", "*.pyi"]
[tool.hatch.build.targets.wheel]
packages = ["/backend/gradio_webrtc"]
[tool.ruff]
target-version = "py310"