mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 09:29:23 +08:00
t :# 请为您的变更输入提交说明。以 '#' 开始的行将被忽略,而一个空的提交
This commit is contained in:
28
.github/workflows/docs.yml
vendored
Normal file
28
.github/workflows/docs.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: docs
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
permissions:
|
||||
contents: write
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Configure Git Credentials
|
||||
run: |
|
||||
git config user.name github-actions[bot]
|
||||
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.x
|
||||
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
key: mkdocs-material-${{ env.cache_id }}
|
||||
path: .cache
|
||||
restore-keys: |
|
||||
mkdocs-material-
|
||||
- run: pip install mkdocs-material
|
||||
- run: mkdocs gh-deploy --force
|
||||
17
.gitignore
vendored
Normal file
17
.gitignore
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
.eggs/
|
||||
dist/
|
||||
*.pyc
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
__tmp/*
|
||||
*.pyi
|
||||
.mypycache
|
||||
.ruff_cache
|
||||
node_modules
|
||||
backend/**/templates/
|
||||
demo/MobileNetSSD_deploy.caffemodel
|
||||
demo/MobileNetSSD_deploy.prototxt.txt
|
||||
.DS_Store
|
||||
test/
|
||||
.env
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Freddy Boulton
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
441
README.md
Normal file
441
README.md
Normal file
@@ -0,0 +1,441 @@
|
||||
<h1 style='text-align: center; margin-bottom: 1rem'> Gradio WebRTC ⚡️ </h1>
|
||||
|
||||
<div style="display: flex; flex-direction: row; justify-content: center">
|
||||
<img style="display: block; padding-right: 5px; height: 20px;" alt="Static Badge" src="https://img.shields.io/pypi/v/gradio_webrtc">
|
||||
<a href="https://github.com/freddyaboulton/gradio-webrtc" target="_blank"><img alt="Static Badge" style="display: block; padding-right: 5px; height: 20px;" src="https://img.shields.io/badge/github-white?logo=github&logoColor=black"></a>
|
||||
<a href="https://freddyaboulton.github.io/gradio-webrtc/" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/Docs-ffcf40"></a>
|
||||
</div>
|
||||
|
||||
<h3 style='text-align: center'>
|
||||
Stream video and audio in real time with Gradio using WebRTC.
|
||||
</h3>
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install gradio_webrtc
|
||||
```
|
||||
|
||||
to use built-in pause detection (see [ReplyOnPause](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-pause)), install the `vad` extra:
|
||||
|
||||
```bash
|
||||
pip install gradio_webrtc[vad]
|
||||
```
|
||||
|
||||
For stop word detection (see [ReplyOnStopWords](https://freddyaboulton.github.io/gradio-webrtc//user-guide/#reply-on-stopwords)), install the `stopword` extra:
|
||||
|
||||
```bash
|
||||
pip install gradio_webrtc[stopword]
|
||||
```
|
||||
|
||||
## Docs
|
||||
|
||||
https://freddyaboulton.github.io/gradio-webrtc/
|
||||
|
||||
## Examples
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="50%">
|
||||
<h3>🗣️ Audio Input/Output with mini-omni2</h3>
|
||||
<p>Build a GPT-4o like experience with mini-omni2, an audio-native LLM.</p>
|
||||
<video width="100%" src="https://github.com/user-attachments/assets/58c06523-fc38-4f5f-a4ba-a02a28e7fa9e" controls></video>
|
||||
<p>
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/mini-omni2-webrtc">Demo</a> |
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/mini-omni2-webrtc/blob/main/app.py">Code</a>
|
||||
</p>
|
||||
</td>
|
||||
<td width="50%">
|
||||
<h3>🗣️ Talk to Claude</h3>
|
||||
<p>Use the Anthropic and Play.Ht APIs to have an audio conversation with Claude.</p>
|
||||
<video width="100%" src="https://github.com/user-attachments/assets/650bc492-798e-4995-8cef-159e1cfc2185" controls></video>
|
||||
<p>
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-claude">Demo</a> |
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-claude/blob/main/app.py">Code</a>
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="50%">
|
||||
<h3>🗣️ Kyutai Moshi</h3>
|
||||
<p>Kyutai's moshi is a novel speech-to-speech model for modeling human conversations.</p>
|
||||
<video width="100%" src="https://github.com/user-attachments/assets/becc7a13-9e89-4a19-9df2-5fb1467a0137" controls></video>
|
||||
<p>
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-moshi">Demo</a> |
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-moshi/blob/main/app.py">Code</a>
|
||||
</p>
|
||||
</td>
|
||||
<td width="50%">
|
||||
<h3>🗣️ Hello Llama: Stop Word Detection</h3>
|
||||
<p>A code editor built with Llama 3.3 70b that is triggered by the phrase "Hello Llama". Build a Siri-like coding assistant in 100 lines of code!</p>
|
||||
<video width="100%" src="https://github.com/user-attachments/assets/3e10cb15-ff1b-4b17-b141-ff0ad852e613" controls></video>
|
||||
<p>
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor">Demo</a> |
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor/blob/main/app.py">Code</a>
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="50%">
|
||||
<h3>🤖 Llama Code Editor</h3>
|
||||
<p>Create and edit HTML pages with just your voice! Powered by SambaNova systems.</p>
|
||||
<video width="100%" src="https://github.com/user-attachments/assets/a09647f1-33e1-4154-a5a3-ffefda8a736a" controls></video>
|
||||
<p>
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/llama-code-editor">Demo</a> |
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/llama-code-editor/blob/main/app.py">Code</a>
|
||||
</p>
|
||||
</td>
|
||||
<td width="50%">
|
||||
<h3>🗣️ Talk to Ultravox</h3>
|
||||
<p>Talk to Fixie.AI's audio-native Ultravox LLM with the transformers library.</p>
|
||||
<video width="100%" src="https://github.com/user-attachments/assets/e6e62482-518c-4021-9047-9da14cd82be1" controls></video>
|
||||
<p>
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-ultravox">Demo</a> |
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-ultravox/blob/main/app.py">Code</a>
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="50%">
|
||||
<h3>🗣️ Talk to Llama 3.2 3b</h3>
|
||||
<p>Use the Lepton API to make Llama 3.2 talk back to you!</p>
|
||||
<video width="100%" src="https://github.com/user-attachments/assets/3ee37a6b-0892-45f5-b801-73188fdfad9a" controls></video>
|
||||
<p>
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/llama-3.2-3b-voice-webrtc">Demo</a> |
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/llama-3.2-3b-voice-webrtc/blob/main/app.py">Code</a>
|
||||
</p>
|
||||
</td>
|
||||
<td width="50%">
|
||||
<h3>🤖 Talk to Qwen2-Audio</h3>
|
||||
<p>Qwen2-Audio is a SOTA audio-to-text LLM developed by Alibaba.</p>
|
||||
<video width="100%" src="https://github.com/user-attachments/assets/c821ad86-44cc-4d0c-8dc4-8c02ad1e5dc8" controls></video>
|
||||
<p>
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-qwen-webrtc">Demo</a> |
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/talk-to-qwen-webrtc/blob/main/app.py">Code</a>
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="50%">
|
||||
<h3>📷 Yolov10 Object Detection</h3>
|
||||
<p>Run the Yolov10 model on a user webcam stream in real time!</p>
|
||||
<video width="100%" src="https://github.com/user-attachments/assets/c90d8c9d-d2d5-462e-9e9b-af969f2ea73c" controls></video>
|
||||
<p>
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n">Demo</a> |
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n/blob/main/app.py">Code</a>
|
||||
</p>
|
||||
</td>
|
||||
<td width="50%">
|
||||
<h3>📷 Video Object Detection with RT-DETR</h3>
|
||||
<p>Upload a video and stream out frames with detected objects (powered by RT-DETR) model.</p>
|
||||
<p>
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc">Demo</a> |
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc/blob/main/app.py">Code</a>
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="50%">
|
||||
<h3>🔊 Text-to-Speech with Parler</h3>
|
||||
<p>Stream out audio generated by Parler TTS!</p>
|
||||
<p>
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc">Demo</a> |
|
||||
<a href="https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc/blob/main/app.py">Code</a>
|
||||
</p>
|
||||
</td>
|
||||
<td width="50%">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Usage
|
||||
|
||||
This is an shortened version of the official [usage guide](https://freddyaboulton.github.io/gradio-webrtc/user-guide/).
|
||||
|
||||
To get started with WebRTC streams, all that's needed is to import the `WebRTC` component from this package and implement its `stream` event.
|
||||
|
||||
### Reply on Pause
|
||||
|
||||
Typically, you want to run an AI model that generates audio when the user has stopped speaking. This can be done by wrapping a python generator with the `ReplyOnPause` class
|
||||
and passing it to the `stream` event of the `WebRTC` component.
|
||||
|
||||
```py
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC, ReplyOnPause
|
||||
|
||||
def response(audio: tuple[int, np.ndarray]): # (1)
|
||||
"""This function must yield audio frames"""
|
||||
...
|
||||
for numpy_array in generated_audio:
|
||||
yield (sampling_rate, numpy_array, "mono") # (2)
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Chat (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
audio = WebRTC(
|
||||
mode="send-receive", # (3)
|
||||
modality="audio",
|
||||
)
|
||||
audio.stream(fn=ReplyOnPause(response),
|
||||
inputs=[audio], outputs=[audio], # (4)
|
||||
time_limit=60) # (5)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. The python generator will receive the **entire** audio up until the user stopped. It will be a tuple of the form (sampling_rate, numpy array of audio). The array will have a shape of (1, num_samples). You can also pass in additional input components.
|
||||
|
||||
2. The generator must yield audio chunks as a tuple of (sampling_rate, numpy audio array). Each numpy audio array must have a shape of (1, num_samples).
|
||||
|
||||
3. The `mode` and `modality` arguments must be set to `"send-receive"` and `"audio"`.
|
||||
|
||||
4. The `WebRTC` component must be the first input and output component.
|
||||
|
||||
5. Set a `time_limit` to control how long a conversation will last. If the `concurrency_count` is 1 (default), only one conversation will be handled at a time.
|
||||
|
||||
|
||||
### Reply On Stopwords
|
||||
|
||||
You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.
|
||||
|
||||
The API is similar to `ReplyOnPause` with the addition of a `stop_words` parameter.
|
||||
|
||||
|
||||
```py
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC, ReplyOnPause
|
||||
|
||||
def response(audio: tuple[int, np.ndarray]):
|
||||
"""This function must yield audio frames"""
|
||||
...
|
||||
for numpy_array in generated_audio:
|
||||
yield (sampling_rate, numpy_array, "mono")
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Chat (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
audio = WebRTC(
|
||||
mode="send",
|
||||
modality="audio",
|
||||
)
|
||||
webrtc.stream(ReplyOnStopWords(generate,
|
||||
input_sample_rate=16000,
|
||||
stop_words=["computer"]), # (1)
|
||||
inputs=[webrtc, history, code],
|
||||
outputs=[webrtc], time_limit=90,
|
||||
concurrency_limit=10)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris".
|
||||
|
||||
|
||||
### Audio Server-To-Clien
|
||||
|
||||
To stream only from the server to the client, implement a python generator and pass it to the component's `stream` event. The stream event must also specify a `trigger` corresponding to a UI interaction that starts the stream. In this case, it's a button click.
|
||||
|
||||
|
||||
|
||||
```py
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
from pydub import AudioSegment
|
||||
|
||||
def generation(num_steps):
|
||||
for _ in range(num_steps):
|
||||
segment = AudioSegment.from_file("audio_file.wav")
|
||||
array = np.array(segment.get_array_of_samples()).reshape(1, -1)
|
||||
yield (segment.frame_rate, array)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
audio = WebRTC(label="Stream", mode="receive", # (1)
|
||||
modality="audio")
|
||||
num_steps = gr.Slider(label="Number of Steps", minimum=1,
|
||||
maximum=10, step=1, value=5)
|
||||
button = gr.Button("Generate")
|
||||
|
||||
audio.stream(
|
||||
fn=generation, inputs=[num_steps], outputs=[audio],
|
||||
trigger=button.click # (2)
|
||||
)
|
||||
```
|
||||
|
||||
1. Set `mode="receive"` to only receive audio from the server.
|
||||
2. The `stream` event must take a `trigger` that corresponds to the gradio event that starts the stream. In this case, it's the button click.
|
||||
|
||||
|
||||
### Video Input/Output Streaming
|
||||
Set up a video Input/Output stream to continuosly receive webcam frames from the user and run an arbitrary python function to return a modified frame.
|
||||
|
||||
```py
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
|
||||
|
||||
def detection(image, conf_threshold=0.3): # (1)
|
||||
... your detection code here ...
|
||||
return modified_frame # (2)
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
image = WebRTC(label="Stream", mode="send-receive", modality="video") # (3)
|
||||
conf_threshold = gr.Slider(
|
||||
label="Confidence Threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
step=0.05,
|
||||
value=0.30,
|
||||
)
|
||||
image.stream(
|
||||
fn=detection,
|
||||
inputs=[image, conf_threshold], # (4)
|
||||
outputs=[image], time_limit=10
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. The webcam frame will be represented as a numpy array of shape (height, width, RGB).
|
||||
2. The function must return a numpy array. It can take arbitrary values from other components.
|
||||
3. Set the `modality="video"` and `mode="send-receive"`
|
||||
4. The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
|
||||
|
||||
### Server-to-Client Only
|
||||
|
||||
Set up a server-to-client stream to stream video from an arbitrary user interaction.
|
||||
|
||||
```py
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
import cv2
|
||||
|
||||
def generation():
|
||||
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
|
||||
cap = cv2.VideoCapture(url)
|
||||
iterating = True
|
||||
while iterating:
|
||||
iterating, frame = cap.read()
|
||||
yield frame # (1)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
output_video = WebRTC(label="Video Stream", mode="receive", # (2)
|
||||
modality="video")
|
||||
button = gr.Button("Start", variant="primary")
|
||||
output_video.stream(
|
||||
fn=generation, inputs=None, outputs=[output_video],
|
||||
trigger=button.click # (3)
|
||||
)
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
|
||||
2. Set `mode="receive"` to only receive audio from the server.
|
||||
3. The `trigger` parameter the gradio event that will trigger the stream. In this case, the button click event.
|
||||
|
||||
|
||||
### Additional Outputs
|
||||
|
||||
In order to modify other components from within the WebRTC stream, you must yield an instance of `AdditionalOutputs` and add an `on_additional_outputs` event to the `WebRTC` component.
|
||||
|
||||
This is common for displaying a multimodal text/audio conversation in a Chatbot UI.
|
||||
|
||||
|
||||
|
||||
``` py title="Additional Outputs"
|
||||
from gradio_webrtc import AdditionalOutputs, WebRTC
|
||||
|
||||
def transcribe(audio: tuple[int, np.ndarray],
|
||||
transformers_convo: list[dict],
|
||||
gradio_convo: list[dict]):
|
||||
response = model.generate(**inputs, max_length=256)
|
||||
transformers_convo.append({"role": "assistant", "content": response})
|
||||
gradio_convo.append({"role": "assistant", "content": response})
|
||||
yield AdditionalOutputs(transformers_convo, gradio_convo) # (1)
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Talk to Qwen2Audio (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
transformers_convo = gr.State(value=[])
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
mode="send", # (2)
|
||||
modality="audio",
|
||||
)
|
||||
with gr.Column():
|
||||
transcript = gr.Chatbot(label="transcript", type="messages")
|
||||
|
||||
audio.stream(ReplyOnPause(transcribe),
|
||||
inputs=[audio, transformers_convo, transcript],
|
||||
outputs=[audio], time_limit=90)
|
||||
audio.on_additional_outputs(lambda s,a: (s,a), # (3)
|
||||
outputs=[transformers_convo, transcript],
|
||||
queue=False, show_progress="hidden")
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. Pass your data to `AdditionalOutputs` and yield it.
|
||||
2. In this case, no audio is being returned, so we set `mode="send"`. However, if we set `mode="send-receive"`, we could also yield generated audio and `AdditionalOutputs`.
|
||||
3. The `on_additional_outputs` event does not take `inputs`. It's common practice to not run this event on the queue since it is just a quick UI update.
|
||||
=== "Notes"
|
||||
1. Pass your data to `AdditionalOutputs` and yield it.
|
||||
2. In this case, no audio is being returned, so we set `mode="send"`. However, if we set `mode="send-receive"`, we could also yield generated audio and `AdditionalOutputs`.
|
||||
3. The `on_additional_outputs` event does not take `inputs`. It's common practice to not run this event on the queue since it is just a quick UI update.
|
||||
|
||||
|
||||
## Deployment
|
||||
|
||||
When deploying in a cloud environment (like Hugging Face Spaces, EC2, etc), you need to set up a TURN server to relay the WebRTC traffic.
|
||||
The easiest way to do this is to use a service like Twilio.
|
||||
|
||||
```python
|
||||
from twilio.rest import Client
|
||||
import os
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
...
|
||||
rtc = WebRTC(rtc_configuration=rtc_configuration, ...)
|
||||
...
|
||||
```
|
||||
54
backend/gradio_webrtc/__init__.py
Normal file
54
backend/gradio_webrtc/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from .credentials import (
|
||||
get_hf_turn_credentials,
|
||||
get_turn_credentials,
|
||||
get_twilio_turn_credentials,
|
||||
)
|
||||
from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions
|
||||
from .reply_on_stopwords import ReplyOnStopWords
|
||||
from .speech_to_text import stt, stt_for_chunks
|
||||
from .utils import (
|
||||
AdditionalOutputs,
|
||||
Warning,
|
||||
WebRTCError,
|
||||
aggregate_bytes_to_16bit,
|
||||
async_aggregate_bytes_to_16bit,
|
||||
audio_to_bytes,
|
||||
audio_to_file,
|
||||
audio_to_float32,
|
||||
)
|
||||
from .webrtc import (
|
||||
AsyncAudioVideoStreamHandler,
|
||||
AsyncStreamHandler,
|
||||
AudioVideoStreamHandler,
|
||||
StreamHandler,
|
||||
WebRTC,
|
||||
VideoEmitType,
|
||||
AudioEmitType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AsyncStreamHandler",
|
||||
"AudioVideoStreamHandler",
|
||||
"AudioEmitType",
|
||||
"AsyncAudioVideoStreamHandler",
|
||||
"AlgoOptions",
|
||||
"AdditionalOutputs",
|
||||
"aggregate_bytes_to_16bit",
|
||||
"async_aggregate_bytes_to_16bit",
|
||||
"audio_to_bytes",
|
||||
"audio_to_file",
|
||||
"audio_to_float32",
|
||||
"get_hf_turn_credentials",
|
||||
"get_twilio_turn_credentials",
|
||||
"get_turn_credentials",
|
||||
"ReplyOnPause",
|
||||
"ReplyOnStopWords",
|
||||
"SileroVadOptions",
|
||||
"stt",
|
||||
"stt_for_chunks",
|
||||
"StreamHandler",
|
||||
"VideoEmitType",
|
||||
"WebRTC",
|
||||
"WebRTCError",
|
||||
"Warning",
|
||||
]
|
||||
52
backend/gradio_webrtc/credentials.py
Normal file
52
backend/gradio_webrtc/credentials.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def get_hf_turn_credentials(token=None):
|
||||
if token is None:
|
||||
token = os.getenv("HF_TOKEN")
|
||||
credentials = requests.get(
|
||||
"https://freddyaboulton-turn-server-login.hf.space/credentials",
|
||||
headers={"X-HF-Access-Token": token},
|
||||
)
|
||||
if not credentials.status_code == 200:
|
||||
raise ValueError("Failed to get credentials from HF turn server")
|
||||
return {
|
||||
"iceServers": [
|
||||
{
|
||||
"urls": "turn:gradio-turn.com:80",
|
||||
**credentials.json(),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def get_twilio_turn_credentials(twilio_sid=None, twilio_token=None):
|
||||
try:
|
||||
from twilio.rest import Client
|
||||
except ImportError:
|
||||
raise ImportError("Please install twilio with `pip install twilio`")
|
||||
|
||||
if not twilio_sid and not twilio_token:
|
||||
twilio_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
twilio_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
client = Client(twilio_sid, twilio_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
return {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
|
||||
|
||||
def get_turn_credentials(method: Literal["hf", "twilio"] = "hf", **kwargs):
|
||||
if method == "hf":
|
||||
return get_hf_turn_credentials(**kwargs)
|
||||
elif method == "twilio":
|
||||
return get_twilio_turn_credentials(**kwargs)
|
||||
else:
|
||||
raise ValueError("Invalid method. Must be 'hf' or 'twilio'")
|
||||
3
backend/gradio_webrtc/pause_detection/__init__.py
Normal file
3
backend/gradio_webrtc/pause_detection/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .vad import SileroVADModel, SileroVadOptions
|
||||
|
||||
__all__ = ["SileroVADModel", "SileroVadOptions"]
|
||||
320
backend/gradio_webrtc/pause_detection/vad.py
Normal file
320
backend/gradio_webrtc/pause_detection/vad.py
Normal file
@@ -0,0 +1,320 @@
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, overload
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..utils import AudioChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The code below is adapted from https://github.com/snakers4/silero-vad.
|
||||
# The code below is adapted from https://github.com/gpt-omni/mini-omni/blob/main/utils/vad.py
|
||||
|
||||
|
||||
@dataclass
|
||||
class SileroVadOptions:
|
||||
"""VAD options.
|
||||
|
||||
Attributes:
|
||||
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
|
||||
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
||||
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
|
||||
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
|
||||
than max_speech_duration_s will be split at the timestamp of the last silence that
|
||||
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
|
||||
split aggressively just before max_speech_duration_s.
|
||||
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
|
||||
before separating it
|
||||
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
|
||||
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
|
||||
Values other than these may affect model performance!!
|
||||
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
||||
speech_duration: If the length of the speech is less than this value, a pause will be detected.
|
||||
"""
|
||||
|
||||
threshold: float = 0.5
|
||||
min_speech_duration_ms: int = 250
|
||||
max_speech_duration_s: float = float("inf")
|
||||
min_silence_duration_ms: int = 2000
|
||||
window_size_samples: int = 1024
|
||||
speech_pad_ms: int = 400
|
||||
|
||||
|
||||
class SileroVADModel:
|
||||
@staticmethod
|
||||
def download_model() -> str:
|
||||
return hf_hub_download(
|
||||
repo_id="freddyaboulton/silero-vad", filename="silero_vad.onnx"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
import onnxruntime
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Applying the VAD filter requires the onnxruntime package"
|
||||
) from e
|
||||
|
||||
path = self.download_model()
|
||||
|
||||
opts = onnxruntime.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 1
|
||||
opts.log_severity_level = 4
|
||||
|
||||
self.session = onnxruntime.InferenceSession(
|
||||
path,
|
||||
providers=["CPUExecutionProvider"],
|
||||
sess_options=opts,
|
||||
)
|
||||
|
||||
def get_initial_state(self, batch_size: int):
|
||||
h = np.zeros((2, batch_size, 64), dtype=np.float32)
|
||||
c = np.zeros((2, batch_size, 64), dtype=np.float32)
|
||||
return h, c
|
||||
|
||||
@staticmethod
|
||||
def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray:
|
||||
"""Collects and concatenates audio chunks."""
|
||||
if not chunks:
|
||||
return np.array([], dtype=np.float32)
|
||||
|
||||
return np.concatenate(
|
||||
[audio[chunk["start"] : chunk["end"]] for chunk in chunks]
|
||||
)
|
||||
|
||||
def get_speech_timestamps(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
vad_options: SileroVadOptions,
|
||||
**kwargs,
|
||||
) -> List[AudioChunk]:
|
||||
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
||||
|
||||
Args:
|
||||
audio: One dimensional float array.
|
||||
vad_options: Options for VAD processing.
|
||||
kwargs: VAD options passed as keyword arguments for backward compatibility.
|
||||
|
||||
Returns:
|
||||
List of dicts containing begin and end samples of each speech chunk.
|
||||
"""
|
||||
|
||||
threshold = vad_options.threshold
|
||||
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
||||
max_speech_duration_s = vad_options.max_speech_duration_s
|
||||
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
||||
window_size_samples = vad_options.window_size_samples
|
||||
speech_pad_ms = vad_options.speech_pad_ms
|
||||
|
||||
if window_size_samples not in [512, 1024, 1536]:
|
||||
warnings.warn(
|
||||
"Unusual window_size_samples! Supported window_size_samples:\n"
|
||||
" - [512, 1024, 1536] for 16000 sampling_rate"
|
||||
)
|
||||
|
||||
sampling_rate = 16000
|
||||
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
||||
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||
max_speech_samples = (
|
||||
sampling_rate * max_speech_duration_s
|
||||
- window_size_samples
|
||||
- 2 * speech_pad_samples
|
||||
)
|
||||
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
|
||||
|
||||
audio_length_samples = len(audio)
|
||||
|
||||
state = self.get_initial_state(batch_size=1)
|
||||
|
||||
speech_probs = []
|
||||
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
||||
chunk = audio[
|
||||
current_start_sample : current_start_sample + window_size_samples
|
||||
]
|
||||
if len(chunk) < window_size_samples:
|
||||
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||
speech_prob, state = self(chunk, state, sampling_rate)
|
||||
speech_probs.append(speech_prob)
|
||||
|
||||
triggered = False
|
||||
speeches = []
|
||||
current_speech = {}
|
||||
neg_threshold = threshold - 0.15
|
||||
|
||||
# to save potential segment end (and tolerate some silence)
|
||||
temp_end = 0
|
||||
# to save potential segment limits in case of maximum segment size reached
|
||||
prev_end = next_start = 0
|
||||
|
||||
for i, speech_prob in enumerate(speech_probs):
|
||||
if (speech_prob >= threshold) and temp_end:
|
||||
temp_end = 0
|
||||
if next_start < prev_end:
|
||||
next_start = window_size_samples * i
|
||||
|
||||
if (speech_prob >= threshold) and not triggered:
|
||||
triggered = True
|
||||
current_speech["start"] = window_size_samples * i
|
||||
continue
|
||||
|
||||
if (
|
||||
triggered
|
||||
and (window_size_samples * i) - current_speech["start"]
|
||||
> max_speech_samples
|
||||
):
|
||||
if prev_end:
|
||||
current_speech["end"] = prev_end
|
||||
speeches.append(current_speech)
|
||||
current_speech = {}
|
||||
# previously reached silence (< neg_thres) and is still not speech (< thres)
|
||||
if next_start < prev_end:
|
||||
triggered = False
|
||||
else:
|
||||
current_speech["start"] = next_start
|
||||
prev_end = next_start = temp_end = 0
|
||||
else:
|
||||
current_speech["end"] = window_size_samples * i
|
||||
speeches.append(current_speech)
|
||||
current_speech = {}
|
||||
prev_end = next_start = temp_end = 0
|
||||
triggered = False
|
||||
continue
|
||||
|
||||
if (speech_prob < neg_threshold) and triggered:
|
||||
if not temp_end:
|
||||
temp_end = window_size_samples * i
|
||||
# condition to avoid cutting in very short silence
|
||||
if (
|
||||
window_size_samples * i
|
||||
) - temp_end > min_silence_samples_at_max_speech:
|
||||
prev_end = temp_end
|
||||
if (window_size_samples * i) - temp_end < min_silence_samples:
|
||||
continue
|
||||
else:
|
||||
current_speech["end"] = temp_end
|
||||
if (
|
||||
current_speech["end"] - current_speech["start"]
|
||||
) > min_speech_samples:
|
||||
speeches.append(current_speech)
|
||||
current_speech = {}
|
||||
prev_end = next_start = temp_end = 0
|
||||
triggered = False
|
||||
continue
|
||||
|
||||
if (
|
||||
current_speech
|
||||
and (audio_length_samples - current_speech["start"]) > min_speech_samples
|
||||
):
|
||||
current_speech["end"] = audio_length_samples
|
||||
speeches.append(current_speech)
|
||||
|
||||
for i, speech in enumerate(speeches):
|
||||
if i == 0:
|
||||
speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
|
||||
if i != len(speeches) - 1:
|
||||
silence_duration = speeches[i + 1]["start"] - speech["end"]
|
||||
if silence_duration < 2 * speech_pad_samples:
|
||||
speech["end"] += int(silence_duration // 2)
|
||||
speeches[i + 1]["start"] = int(
|
||||
max(0, speeches[i + 1]["start"] - silence_duration // 2)
|
||||
)
|
||||
else:
|
||||
speech["end"] = int(
|
||||
min(audio_length_samples, speech["end"] + speech_pad_samples)
|
||||
)
|
||||
speeches[i + 1]["start"] = int(
|
||||
max(0, speeches[i + 1]["start"] - speech_pad_samples)
|
||||
)
|
||||
else:
|
||||
speech["end"] = int(
|
||||
min(audio_length_samples, speech["end"] + speech_pad_samples)
|
||||
)
|
||||
|
||||
return speeches
|
||||
|
||||
@overload
|
||||
def vad(
|
||||
self,
|
||||
audio_tuple: tuple[int, NDArray],
|
||||
vad_parameters: None | SileroVadOptions,
|
||||
return_chunks: Literal[True],
|
||||
) -> tuple[float, List[AudioChunk]]: ...
|
||||
|
||||
@overload
|
||||
def vad(
|
||||
self,
|
||||
audio_tuple: tuple[int, NDArray],
|
||||
vad_parameters: None | SileroVadOptions,
|
||||
return_chunks: bool = False,
|
||||
) -> float: ...
|
||||
|
||||
def vad(
|
||||
self,
|
||||
audio_tuple: tuple[int, NDArray],
|
||||
vad_parameters: None | SileroVadOptions,
|
||||
return_chunks: bool = False,
|
||||
) -> float | tuple[float, List[AudioChunk]]:
|
||||
sampling_rate, audio = audio_tuple
|
||||
logger.debug("VAD audio shape input: %s", audio.shape)
|
||||
try:
|
||||
if audio.dtype != np.float32:
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
sr = 16000
|
||||
if sr != sampling_rate:
|
||||
try:
|
||||
import librosa # type: ignore
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Applying the VAD filter requires the librosa if the input sampling rate is not 16000hz"
|
||||
) from e
|
||||
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
|
||||
|
||||
if not vad_parameters:
|
||||
vad_parameters = SileroVadOptions()
|
||||
speech_chunks = self.get_speech_timestamps(audio, vad_parameters)
|
||||
logger.debug("VAD speech chunks: %s", speech_chunks)
|
||||
audio = self.collect_chunks(audio, speech_chunks)
|
||||
logger.debug("VAD audio shape: %s", audio.shape)
|
||||
duration_after_vad = audio.shape[0] / sr
|
||||
if return_chunks:
|
||||
return duration_after_vad, speech_chunks
|
||||
return duration_after_vad
|
||||
except Exception as e:
|
||||
import math
|
||||
import traceback
|
||||
|
||||
logger.debug("VAD Exception: %s", str(e))
|
||||
exec = traceback.format_exc()
|
||||
logger.debug("traceback %s", exec)
|
||||
return math.inf
|
||||
|
||||
def __call__(self, x, state, sr: int):
|
||||
if len(x.shape) == 1:
|
||||
x = np.expand_dims(x, 0)
|
||||
if len(x.shape) > 2:
|
||||
raise ValueError(
|
||||
f"Too many dimensions for input audio chunk {len(x.shape)}"
|
||||
)
|
||||
if sr / x.shape[1] > 31.25: # type: ignore
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
h, c = state
|
||||
|
||||
ort_inputs = {
|
||||
"input": x,
|
||||
"h": h,
|
||||
"c": c,
|
||||
"sr": np.array(sr, dtype="int64"),
|
||||
}
|
||||
|
||||
out, h, c = self.session.run(None, ort_inputs)
|
||||
state = (h, c)
|
||||
|
||||
return out, state
|
||||
188
backend/gradio_webrtc/reply_on_pause.py
Normal file
188
backend/gradio_webrtc/reply_on_pause.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from logging import getLogger
|
||||
from threading import Event
|
||||
from typing import Any, Callable, Generator, Literal, Union, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gradio_webrtc.pause_detection import SileroVADModel, SileroVadOptions
|
||||
from gradio_webrtc.webrtc import EmitType, StreamHandler
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
counter = 0
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_vad_model() -> SileroVADModel:
|
||||
"""Returns the VAD model instance."""
|
||||
return SileroVADModel()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlgoOptions:
|
||||
"""Algorithm options."""
|
||||
|
||||
audio_chunk_duration: float = 0.6
|
||||
started_talking_threshold: float = 0.2
|
||||
speech_threshold: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppState:
|
||||
stream: np.ndarray | None = None
|
||||
sampling_rate: int = 0
|
||||
pause_detected: bool = False
|
||||
started_talking: bool = False
|
||||
responding: bool = False
|
||||
stopped: bool = False
|
||||
buffer: np.ndarray | None = None
|
||||
|
||||
|
||||
ReplyFnGenerator = Union[
|
||||
# For two arguments
|
||||
Callable[
|
||||
[tuple[int, np.ndarray], list[dict[Any, Any]]],
|
||||
Generator[EmitType, None, None],
|
||||
],
|
||||
Callable[
|
||||
[tuple[int, np.ndarray]],
|
||||
Generator[EmitType, None, None],
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
async def iterate(generator: Generator) -> Any:
|
||||
return next(generator)
|
||||
|
||||
|
||||
class ReplyOnPause(StreamHandler):
|
||||
def __init__(
|
||||
self,
|
||||
fn: ReplyFnGenerator,
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: SileroVadOptions | None = None,
|
||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 480,
|
||||
input_sample_rate: int = 48000,
|
||||
):
|
||||
super().__init__(
|
||||
expected_layout,
|
||||
output_sample_rate,
|
||||
output_frame_size,
|
||||
input_sample_rate=input_sample_rate,
|
||||
)
|
||||
self.expected_layout: Literal["mono", "stereo"] = expected_layout
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_frame_size = output_frame_size
|
||||
self.model = get_vad_model()
|
||||
self.fn = fn
|
||||
self.is_async = inspect.isasyncgenfunction(fn)
|
||||
self.event = Event()
|
||||
self.state = AppState()
|
||||
self.generator: Generator[EmitType, None, None] | None = None
|
||||
self.model_options = model_options
|
||||
self.algo_options = algo_options or AlgoOptions()
|
||||
|
||||
@property
|
||||
def _needs_additional_inputs(self) -> bool:
|
||||
return len(inspect.signature(self.fn).parameters) > 1
|
||||
|
||||
def copy(self):
|
||||
return ReplyOnPause(
|
||||
self.fn,
|
||||
self.algo_options,
|
||||
self.model_options,
|
||||
self.expected_layout,
|
||||
self.output_sample_rate,
|
||||
self.output_frame_size,
|
||||
self.input_sample_rate,
|
||||
)
|
||||
|
||||
def determine_pause(
|
||||
self, audio: np.ndarray, sampling_rate: int, state: AppState
|
||||
) -> bool:
|
||||
"""Take in the stream, determine if a pause happened"""
|
||||
duration = len(audio) / sampling_rate
|
||||
|
||||
if duration >= self.algo_options.audio_chunk_duration:
|
||||
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
|
||||
logger.debug("VAD duration: %s", dur_vad)
|
||||
if (
|
||||
dur_vad > self.algo_options.started_talking_threshold
|
||||
and not state.started_talking
|
||||
):
|
||||
state.started_talking = True
|
||||
logger.debug("Started talking")
|
||||
if state.started_talking:
|
||||
if state.stream is None:
|
||||
state.stream = audio
|
||||
else:
|
||||
state.stream = np.concatenate((state.stream, audio))
|
||||
state.buffer = None
|
||||
if dur_vad < self.algo_options.speech_threshold and state.started_talking:
|
||||
return True
|
||||
return False
|
||||
|
||||
def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
|
||||
frame_rate, array = audio
|
||||
array = np.squeeze(array)
|
||||
if not state.sampling_rate:
|
||||
state.sampling_rate = frame_rate
|
||||
if state.buffer is None:
|
||||
state.buffer = array
|
||||
else:
|
||||
state.buffer = np.concatenate((state.buffer, array))
|
||||
|
||||
pause_detected = self.determine_pause(
|
||||
state.buffer, state.sampling_rate, self.state
|
||||
)
|
||||
state.pause_detected = pause_detected
|
||||
|
||||
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
if self.state.responding:
|
||||
return
|
||||
self.process_audio(frame, self.state)
|
||||
if self.state.pause_detected:
|
||||
self.event.set()
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self.generator = None
|
||||
self.event.clear()
|
||||
self.state = AppState()
|
||||
|
||||
async def async_iterate(self, generator) -> EmitType:
|
||||
return await anext(generator)
|
||||
|
||||
def emit(self):
|
||||
if not self.event.is_set():
|
||||
return None
|
||||
else:
|
||||
if not self.generator:
|
||||
if self._needs_additional_inputs and not self.args_set.is_set():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.wait_for_args(), self.loop
|
||||
).result()
|
||||
logger.debug("Creating generator")
|
||||
audio = cast(np.ndarray, self.state.stream).reshape(1, -1)
|
||||
if self._needs_additional_inputs:
|
||||
self.latest_args[0] = (self.state.sampling_rate, audio)
|
||||
self.generator = self.fn(*self.latest_args)
|
||||
else:
|
||||
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
|
||||
logger.debug("Latest args: %s", self.latest_args)
|
||||
self.state.responding = True
|
||||
try:
|
||||
if self.is_async:
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self.async_iterate(self.generator), self.loop
|
||||
).result()
|
||||
else:
|
||||
return next(self.generator)
|
||||
except (StopIteration, StopAsyncIteration):
|
||||
self.reset()
|
||||
147
backend/gradio_webrtc/reply_on_stopwords.py
Normal file
147
backend/gradio_webrtc/reply_on_stopwords.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .reply_on_pause import (
|
||||
AlgoOptions,
|
||||
AppState,
|
||||
ReplyFnGenerator,
|
||||
ReplyOnPause,
|
||||
SileroVadOptions,
|
||||
)
|
||||
from .speech_to_text import get_stt_model, stt_for_chunks
|
||||
from .utils import audio_to_float32
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplyOnStopWordsState(AppState):
|
||||
stop_word_detected: bool = False
|
||||
post_stop_word_buffer: np.ndarray | None = None
|
||||
started_talking_pre_stop_word: bool = False
|
||||
|
||||
|
||||
class ReplyOnStopWords(ReplyOnPause):
|
||||
def __init__(
|
||||
self,
|
||||
fn: ReplyFnGenerator,
|
||||
stop_words: list[str],
|
||||
algo_options: AlgoOptions | None = None,
|
||||
model_options: SileroVadOptions | None = None,
|
||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 480,
|
||||
input_sample_rate: int = 48000,
|
||||
):
|
||||
super().__init__(
|
||||
fn,
|
||||
algo_options=algo_options,
|
||||
model_options=model_options,
|
||||
expected_layout=expected_layout,
|
||||
output_sample_rate=output_sample_rate,
|
||||
output_frame_size=output_frame_size,
|
||||
input_sample_rate=input_sample_rate,
|
||||
)
|
||||
self.stop_words = stop_words
|
||||
self.state = ReplyOnStopWordsState()
|
||||
# Download Model
|
||||
get_stt_model()
|
||||
|
||||
def stop_word_detected(self, text: str) -> bool:
|
||||
for stop_word in self.stop_words:
|
||||
stop_word = stop_word.lower().strip().split(" ")
|
||||
if bool(
|
||||
re.search(r"\b" + r"\s+".join(map(re.escape, stop_word)) + r"\b", text)
|
||||
):
|
||||
logger.debug("Stop word detected: %s", stop_word)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _send_stopword(
|
||||
self,
|
||||
):
|
||||
if self.channel:
|
||||
self.channel.send("stopword")
|
||||
logger.debug("Sent stopword")
|
||||
|
||||
def send_stopword(self):
|
||||
asyncio.run_coroutine_threadsafe(self._send_stopword(), self.loop)
|
||||
|
||||
def determine_pause( # type: ignore
|
||||
self, audio: np.ndarray, sampling_rate: int, state: ReplyOnStopWordsState
|
||||
) -> bool:
|
||||
"""Take in the stream, determine if a pause happened"""
|
||||
import librosa
|
||||
|
||||
duration = len(audio) / sampling_rate
|
||||
|
||||
if duration >= self.algo_options.audio_chunk_duration:
|
||||
if not state.stop_word_detected:
|
||||
audio_f32 = audio_to_float32((sampling_rate, audio))
|
||||
audio_rs = librosa.resample(
|
||||
audio_f32, orig_sr=sampling_rate, target_sr=16000
|
||||
)
|
||||
if state.post_stop_word_buffer is None:
|
||||
state.post_stop_word_buffer = audio_rs
|
||||
else:
|
||||
state.post_stop_word_buffer = np.concatenate(
|
||||
(state.post_stop_word_buffer, audio_rs)
|
||||
)
|
||||
if len(state.post_stop_word_buffer) / 16000 > 2:
|
||||
state.post_stop_word_buffer = state.post_stop_word_buffer[-32000:]
|
||||
dur_vad, chunks = self.model.vad(
|
||||
(16000, state.post_stop_word_buffer),
|
||||
self.model_options,
|
||||
return_chunks=True,
|
||||
)
|
||||
text = stt_for_chunks((16000, state.post_stop_word_buffer), chunks)
|
||||
logger.debug(f"STT: {text}")
|
||||
state.stop_word_detected = self.stop_word_detected(text)
|
||||
if state.stop_word_detected:
|
||||
logger.debug("Stop word detected")
|
||||
self.send_stopword()
|
||||
state.buffer = None
|
||||
else:
|
||||
dur_vad = self.model.vad((sampling_rate, audio), self.model_options)
|
||||
logger.debug("VAD duration: %s", dur_vad)
|
||||
if (
|
||||
dur_vad > self.algo_options.started_talking_threshold
|
||||
and not state.started_talking
|
||||
and state.stop_word_detected
|
||||
):
|
||||
state.started_talking = True
|
||||
logger.debug("Started talking")
|
||||
if state.started_talking:
|
||||
if state.stream is None:
|
||||
state.stream = audio
|
||||
else:
|
||||
state.stream = np.concatenate((state.stream, audio))
|
||||
state.buffer = None
|
||||
if (
|
||||
dur_vad < self.algo_options.speech_threshold
|
||||
and state.started_talking
|
||||
and state.stop_word_detected
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self.generator = None
|
||||
self.event.clear()
|
||||
self.state = ReplyOnStopWordsState()
|
||||
|
||||
def copy(self):
|
||||
return ReplyOnStopWords(
|
||||
self.fn,
|
||||
self.stop_words,
|
||||
self.algo_options,
|
||||
self.model_options,
|
||||
self.expected_layout,
|
||||
self.output_sample_rate,
|
||||
self.output_frame_size,
|
||||
self.input_sample_rate,
|
||||
)
|
||||
3
backend/gradio_webrtc/speech_to_text/__init__.py
Normal file
3
backend/gradio_webrtc/speech_to_text/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .stt_ import get_stt_model, stt, stt_for_chunks
|
||||
|
||||
__all__ = ["stt", "stt_for_chunks", "get_stt_model"]
|
||||
53
backend/gradio_webrtc/speech_to_text/stt_.py
Normal file
53
backend/gradio_webrtc/speech_to_text/stt_.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..utils import AudioChunk
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTModel:
|
||||
encoder: Callable
|
||||
decoder: Callable
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_stt_model() -> STTModel:
|
||||
from silero import silero_stt
|
||||
|
||||
model, decoder, _ = silero_stt(language="en", version="v6", jit_model="jit_xlarge")
|
||||
return STTModel(model, decoder)
|
||||
|
||||
|
||||
def stt(audio: tuple[int, NDArray[np.int16]]) -> str:
|
||||
model = get_stt_model()
|
||||
sr, audio_np = audio
|
||||
if audio_np.dtype != np.float32:
|
||||
print("converting")
|
||||
audio_np = audio_np.astype(np.float32) / 32768.0
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PyTorch is required to run speech-to-text for stopword detection. Run `pip install torch`."
|
||||
)
|
||||
audio_torch = torch.tensor(audio_np, dtype=torch.float32)
|
||||
if audio_torch.ndim == 1:
|
||||
audio_torch = audio_torch.unsqueeze(0)
|
||||
assert audio_torch.ndim == 2, "Audio must have a batch dimension"
|
||||
print("before")
|
||||
res = model.decoder(model.encoder(audio_torch)[0])
|
||||
print("after")
|
||||
return res
|
||||
|
||||
|
||||
def stt_for_chunks(
|
||||
audio: tuple[int, NDArray[np.int16]], chunks: list[AudioChunk]
|
||||
) -> str:
|
||||
sr, audio_np = audio
|
||||
return " ".join(
|
||||
[stt((sr, audio_np[chunk["start"] : chunk["end"]])) for chunk in chunks]
|
||||
)
|
||||
313
backend/gradio_webrtc/utils.py
Normal file
313
backend/gradio_webrtc/utils.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import asyncio
|
||||
import fractions
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Protocol, TypedDict, cast
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
from pydub import AudioSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
AUDIO_PTIME = 0.020
|
||||
|
||||
|
||||
class AudioChunk(TypedDict):
|
||||
start: int
|
||||
end: int
|
||||
|
||||
|
||||
class AdditionalOutputs:
|
||||
def __init__(self, *args) -> None:
|
||||
self.args = args
|
||||
|
||||
|
||||
class DataChannel(Protocol):
|
||||
def send(self, message: str) -> None: ...
|
||||
|
||||
|
||||
current_channel: ContextVar[DataChannel | None] = ContextVar(
|
||||
"current_channel", default=None
|
||||
)
|
||||
|
||||
|
||||
def _send_log(message: str, type: str) -> None:
|
||||
async def _send(channel: DataChannel) -> None:
|
||||
channel.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": type,
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if channel := current_channel.get():
|
||||
print("channel", channel)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
asyncio.run_coroutine_threadsafe(_send(channel), loop)
|
||||
except RuntimeError:
|
||||
asyncio.run(_send(channel))
|
||||
|
||||
|
||||
def Warning( # noqa: N802
|
||||
message: str = "Warning issued.",
|
||||
):
|
||||
"""
|
||||
Send a warning message that is deplayed in the UI of the application.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : str
|
||||
The warning message to send
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
_send_log(message, "warning")
|
||||
|
||||
|
||||
class WebRTCError(Exception):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
_send_log(message, "error")
|
||||
|
||||
|
||||
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||
if isinstance(data, AdditionalOutputs):
|
||||
return None, data
|
||||
if isinstance(data, tuple):
|
||||
# handle the bare audio case
|
||||
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
|
||||
return data, None
|
||||
if not len(data) == 2:
|
||||
raise ValueError(
|
||||
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
||||
)
|
||||
if not isinstance(data[-1], AdditionalOutputs):
|
||||
raise ValueError(
|
||||
"The last element of the tuple must be an instance of AdditionalOutputs."
|
||||
)
|
||||
return data[0], cast(AdditionalOutputs, data[1])
|
||||
return data, None
|
||||
|
||||
|
||||
async def player_worker_decode(
|
||||
next_frame: Callable,
|
||||
queue: asyncio.Queue,
|
||||
thread_quit: asyncio.Event,
|
||||
channel: Callable[[], DataChannel | None] | None,
|
||||
set_additional_outputs: Callable | None,
|
||||
quit_on_none: bool = False,
|
||||
sample_rate: int = 48000,
|
||||
frame_size: int = int(48000 * AUDIO_PTIME),
|
||||
):
|
||||
audio_samples = 0
|
||||
audio_time_base = fractions.Fraction(1, sample_rate)
|
||||
audio_resampler = av.AudioResampler( # type: ignore
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
rate=sample_rate,
|
||||
frame_size=frame_size,
|
||||
)
|
||||
|
||||
while not thread_quit.is_set():
|
||||
try:
|
||||
# Get next frame
|
||||
frame, outputs = split_output(
|
||||
await asyncio.wait_for(next_frame(), timeout=60)
|
||||
)
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
and set_additional_outputs
|
||||
and channel
|
||||
and channel()
|
||||
):
|
||||
set_additional_outputs(outputs)
|
||||
cast(DataChannel, channel()).send("change")
|
||||
|
||||
if frame is None:
|
||||
if quit_on_none:
|
||||
await queue.put(None)
|
||||
break
|
||||
continue
|
||||
|
||||
if len(frame) == 2:
|
||||
sample_rate, audio_array = frame
|
||||
layout = "mono"
|
||||
elif len(frame) == 3:
|
||||
sample_rate, audio_array, layout = frame
|
||||
|
||||
logger.debug(
|
||||
"received array with shape %s sample rate %s layout %s",
|
||||
audio_array.shape, # type: ignore
|
||||
sample_rate,
|
||||
layout, # type: ignore
|
||||
)
|
||||
format = "s16" if audio_array.dtype == "int16" else "fltp" # type: ignore
|
||||
|
||||
# Convert to audio frame and resample
|
||||
# This runs in the same timeout context
|
||||
frame = av.AudioFrame.from_ndarray( # type: ignore
|
||||
audio_array, # type: ignore
|
||||
format=format,
|
||||
layout=layout, # type: ignore
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
|
||||
for processed_frame in audio_resampler.resample(frame):
|
||||
processed_frame.pts = audio_samples
|
||||
processed_frame.time_base = audio_time_base
|
||||
audio_samples += processed_frame.samples
|
||||
await queue.put(processed_frame)
|
||||
logger.debug("Queue size utils.py: %s", queue.qsize())
|
||||
|
||||
except (TimeoutError, asyncio.TimeoutError):
|
||||
logger.warning(
|
||||
"Timeout in frame processing cycle after %s seconds - resetting", 60
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
exec = traceback.format_exc()
|
||||
logger.debug("traceback %s", exec)
|
||||
logger.error("Error processing frame: %s", str(e))
|
||||
continue
|
||||
|
||||
|
||||
def audio_to_bytes(audio: tuple[int, np.ndarray]) -> bytes:
|
||||
"""
|
||||
Convert an audio tuple containing sample rate and numpy array data into bytes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : tuple[int, np.ndarray]
|
||||
A tuple containing:
|
||||
- sample_rate (int): The audio sample rate in Hz
|
||||
- data (np.ndarray): The audio data as a numpy array
|
||||
|
||||
Returns
|
||||
-------
|
||||
bytes
|
||||
The audio data encoded as bytes, suitable for transmission or storage
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> sample_rate = 44100
|
||||
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
|
||||
>>> audio_tuple = (sample_rate, audio_data)
|
||||
>>> audio_bytes = audio_to_bytes(audio_tuple)
|
||||
"""
|
||||
audio_buffer = io.BytesIO()
|
||||
segment = AudioSegment(
|
||||
audio[1].tobytes(),
|
||||
frame_rate=audio[0],
|
||||
sample_width=audio[1].dtype.itemsize,
|
||||
channels=1,
|
||||
)
|
||||
segment.export(audio_buffer, format="mp3")
|
||||
return audio_buffer.getvalue()
|
||||
|
||||
|
||||
def audio_to_file(audio: tuple[int, np.ndarray]) -> str:
|
||||
"""
|
||||
Save an audio tuple containing sample rate and numpy array data to a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : tuple[int, np.ndarray]
|
||||
A tuple containing:
|
||||
- sample_rate (int): The audio sample rate in Hz
|
||||
- data (np.ndarray): The audio data as a numpy array
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The path to the saved audio file
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> sample_rate = 44100
|
||||
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
|
||||
>>> audio_tuple = (sample_rate, audio_data)
|
||||
>>> file_path = audio_to_file(audio_tuple)
|
||||
>>> print(f"Audio saved to: {file_path}")
|
||||
"""
|
||||
bytes_ = audio_to_bytes(audio)
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
|
||||
f.write(bytes_)
|
||||
return f.name
|
||||
|
||||
|
||||
def audio_to_float32(audio: tuple[int, np.ndarray]) -> np.ndarray:
|
||||
"""
|
||||
Convert an audio tuple containing sample rate (int16) and numpy array data to float32.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : tuple[int, np.ndarray]
|
||||
A tuple containing:
|
||||
- sample_rate (int): The audio sample rate in Hz
|
||||
- data (np.ndarray): The audio data as a numpy array
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The audio data as a numpy array with dtype float32
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> sample_rate = 44100
|
||||
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
|
||||
>>> audio_tuple = (sample_rate, audio_data)
|
||||
>>> audio_float32 = audio_to_float32(audio_tuple)
|
||||
"""
|
||||
return audio[1].astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def aggregate_bytes_to_16bit(chunks_iterator):
|
||||
leftover = b"" # Store incomplete bytes between chunks
|
||||
|
||||
for chunk in chunks_iterator:
|
||||
# Combine with any leftover bytes from previous chunk
|
||||
current_bytes = leftover + chunk
|
||||
|
||||
# Calculate complete samples
|
||||
n_complete_samples = len(current_bytes) // 2 # int16 = 2 bytes
|
||||
bytes_to_process = n_complete_samples * 2
|
||||
|
||||
# Split into complete samples and leftover
|
||||
to_process = current_bytes[:bytes_to_process]
|
||||
leftover = current_bytes[bytes_to_process:]
|
||||
|
||||
if to_process: # Only yield if we have complete samples
|
||||
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
|
||||
yield audio_array
|
||||
|
||||
|
||||
async def async_aggregate_bytes_to_16bit(chunks_iterator):
|
||||
leftover = b"" # Store incomplete bytes between chunks
|
||||
|
||||
async for chunk in chunks_iterator:
|
||||
# Combine with any leftover bytes from previous chunk
|
||||
current_bytes = leftover + chunk
|
||||
|
||||
# Calculate complete samples
|
||||
n_complete_samples = len(current_bytes) // 2 # int16 = 2 bytes
|
||||
bytes_to_process = n_complete_samples * 2
|
||||
|
||||
# Split into complete samples and leftover
|
||||
to_process = current_bytes[:bytes_to_process]
|
||||
leftover = current_bytes[bytes_to_process:]
|
||||
|
||||
if to_process: # Only yield if we have complete samples
|
||||
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
|
||||
yield audio_array
|
||||
1151
backend/gradio_webrtc/webrtc.py
Normal file
1151
backend/gradio_webrtc/webrtc.py
Normal file
File diff suppressed because it is too large
Load Diff
44
demo/README.md
Normal file
44
demo/README.md
Normal file
@@ -0,0 +1,44 @@
|
||||
---
|
||||
license: mit
|
||||
tags:
|
||||
- object-detection
|
||||
- computer-vision
|
||||
- yolov10
|
||||
datasets:
|
||||
- detection-datasets/coco
|
||||
sdk: gradio
|
||||
sdk_version: 5.0.0b1
|
||||
---
|
||||
|
||||
### Model Description
|
||||
[YOLOv10: Real-Time End-to-End Object Detection](https://arxiv.org/abs/2405.14458v1)
|
||||
|
||||
- arXiv: https://arxiv.org/abs/2405.14458v1
|
||||
- github: https://github.com/THU-MIG/yolov10
|
||||
|
||||
### Installation
|
||||
```
|
||||
pip install supervision git+https://github.com/THU-MIG/yolov10.git
|
||||
```
|
||||
|
||||
### Yolov10 Inference
|
||||
```python
|
||||
from ultralytics import YOLOv10
|
||||
import supervision as sv
|
||||
import cv2
|
||||
|
||||
IMAGE_PATH = 'dog.jpeg'
|
||||
|
||||
model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||
model.predict(IMAGE_PATH, show=True)
|
||||
```
|
||||
|
||||
### BibTeX Entry and Citation Info
|
||||
```
|
||||
@article{wang2024yolov10,
|
||||
title={YOLOv10: Real-Time End-to-End Object Detection},
|
||||
author={Wang, Ao and Chen, Hui and Liu, Lihao and Chen, Kai and Lin, Zijia and Han, Jungong and Ding, Guiguang},
|
||||
journal={arXiv preprint arXiv:2405.14458},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
0
demo/__init__.py
Normal file
0
demo/__init__.py
Normal file
105
demo/also_return_text.py
Normal file
105
demo/also_return_text.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from gradio_webrtc import AdditionalOutputs, WebRTC
|
||||
from pydub import AudioSegment
|
||||
from twilio.rest import Client
|
||||
|
||||
# Configure the root logger to WARNING to suppress debug messages from other libraries
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
# Create a console handler
|
||||
console_handler = logging.FileHandler("gradio_webrtc.log")
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
|
||||
# Create a formatter
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Configure the logger for your specific library
|
||||
logger = logging.getLogger("gradio_webrtc")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
if account_sid and auth_token:
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
else:
|
||||
rtc_configuration = None
|
||||
|
||||
|
||||
def generation(num_steps):
|
||||
for i in range(num_steps):
|
||||
segment = AudioSegment.from_file(
|
||||
"/Users/freddy/sources/gradio/demo/scratch/audio-streaming/librispeech.mp3"
|
||||
)
|
||||
yield (
|
||||
(
|
||||
segment.frame_rate,
|
||||
np.array(segment.get_array_of_samples()).reshape(1, -1),
|
||||
),
|
||||
AdditionalOutputs(
|
||||
f"Hello, from step {i}!",
|
||||
"/Users/freddy/sources/gradio/demo/scratch/audio-streaming/librispeech.mp3",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
|
||||
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Audio Streaming (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
with gr.Column(elem_classes=["my-column"]):
|
||||
with gr.Group(elem_classes=["my-group"]):
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
rtc_configuration=rtc_configuration,
|
||||
mode="receive",
|
||||
modality="audio",
|
||||
)
|
||||
num_steps = gr.Slider(
|
||||
label="Number of Steps",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=1,
|
||||
value=5,
|
||||
)
|
||||
button = gr.Button("Generate")
|
||||
textbox = gr.Textbox(placeholder="Output will appear here.")
|
||||
audio_file = gr.Audio()
|
||||
|
||||
audio.stream(
|
||||
fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click
|
||||
)
|
||||
audio.on_additional_outputs(
|
||||
fn=lambda t, a: (f"State changed to {t}.", a),
|
||||
outputs=[textbox, audio_file],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(
|
||||
allowed_paths=[
|
||||
"/Users/freddy/sources/gradio/demo/scratch/audio-streaming/librispeech.mp3"
|
||||
]
|
||||
)
|
||||
367
demo/app.py
Normal file
367
demo/app.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
_docs = {
|
||||
"WebRTC": {
|
||||
"description": "Stream audio/video with WebRTC",
|
||||
"members": {
|
||||
"__init__": {
|
||||
"rtc_configuration": {
|
||||
"type": "dict[str, Any] | None",
|
||||
"default": "None",
|
||||
"description": "The configration dictionary to pass to the RTCPeerConnection constructor. If None, the default configuration is used.",
|
||||
},
|
||||
"height": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "The height of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
|
||||
},
|
||||
"width": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
|
||||
},
|
||||
"label": {
|
||||
"type": "str | None",
|
||||
"default": "None",
|
||||
"description": "the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.",
|
||||
},
|
||||
"show_label": {
|
||||
"type": "bool | None",
|
||||
"default": "None",
|
||||
"description": "if True, will display label.",
|
||||
},
|
||||
"container": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if True, will place the component in a container - providing some extra padding around the border.",
|
||||
},
|
||||
"scale": {
|
||||
"type": "int | None",
|
||||
"default": "None",
|
||||
"description": "relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.",
|
||||
},
|
||||
"min_width": {
|
||||
"type": "int",
|
||||
"default": "160",
|
||||
"description": "minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.",
|
||||
},
|
||||
"interactive": {
|
||||
"type": "bool | None",
|
||||
"default": "None",
|
||||
"description": "if True, will allow users to upload a video; if False, can only be used to display videos. If not provided, this is inferred based on whether the component is used as an input or output.",
|
||||
},
|
||||
"visible": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if False, component will be hidden.",
|
||||
},
|
||||
"elem_id": {
|
||||
"type": "str | None",
|
||||
"default": "None",
|
||||
"description": "an optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.",
|
||||
},
|
||||
"elem_classes": {
|
||||
"type": "list[str] | str | None",
|
||||
"default": "None",
|
||||
"description": "an optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.",
|
||||
},
|
||||
"render": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.",
|
||||
},
|
||||
"key": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.",
|
||||
},
|
||||
"mirror_webcam": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if True webcam will be mirrored. Default is True.",
|
||||
},
|
||||
},
|
||||
"events": {"tick": {"type": None, "default": None, "description": ""}},
|
||||
},
|
||||
"__meta__": {"additional_interfaces": {}, "user_fn_refs": {"WebRTC": []}},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
abs_path = os.path.join(os.path.dirname(__file__), "css.css")
|
||||
|
||||
with gr.Blocks(
|
||||
css_paths=abs_path,
|
||||
theme=gr.themes.Default(
|
||||
font_mono=[
|
||||
gr.themes.GoogleFont("Inconsolata"),
|
||||
"monospace",
|
||||
],
|
||||
),
|
||||
) as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
<h1 style='text-align: center; margin-bottom: 1rem'> Gradio WebRTC ⚡️ </h1>
|
||||
|
||||
<div style="display: flex; flex-direction: row; justify-content: center">
|
||||
<img style="display: block; padding-right: 5px; height: 20px;" alt="Static Badge" src="https://img.shields.io/badge/version%20-%200.0.6%20-%20orange">
|
||||
<a href="https://github.com/freddyaboulton/gradio-webrtc" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/github-white?logo=github&logoColor=black"></a>
|
||||
</div>
|
||||
""",
|
||||
elem_classes=["md-custom"],
|
||||
header_links=True,
|
||||
)
|
||||
gr.Markdown(
|
||||
"""
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install gradio_webrtc
|
||||
```
|
||||
|
||||
## Examples:
|
||||
1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷
|
||||
2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥
|
||||
3. [Text-to-Speech](https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc) 🗣️
|
||||
4. [Conversational AI](https://huggingface.co/spaces/freddyaboulton/omni-mini-webrtc) 🤖🗣️
|
||||
|
||||
## Usage
|
||||
|
||||
The WebRTC component supports the following three use cases:
|
||||
1. [Streaming video from the user webcam to the server and back](#h-streaming-video-from-the-user-webcam-to-the-server-and-back)
|
||||
2. [Streaming Video from the server to the client](#h-streaming-video-from-the-server-to-the-client)
|
||||
3. [Streaming Audio from the server to the client](#h-streaming-audio-from-the-server-to-the-client)
|
||||
4. [Streaming Audio from the client to the server and back (conversational AI)](#h-conversational-ai)
|
||||
|
||||
|
||||
## Streaming Video from the User Webcam to the Server and Back
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
|
||||
|
||||
def detection(image, conf_threshold=0.3):
|
||||
... your detection code here ...
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
image = WebRTC(label="Stream", mode="send-receive", modality="video")
|
||||
conf_threshold = gr.Slider(
|
||||
label="Confidence Threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
step=0.05,
|
||||
value=0.30,
|
||||
)
|
||||
image.stream(
|
||||
fn=detection,
|
||||
inputs=[image, conf_threshold],
|
||||
outputs=[image], time_limit=10
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
||||
```
|
||||
* Set the `mode` parameter to `send-receive` and `modality` to "video".
|
||||
* The `stream` event's `fn` parameter is a function that receives the next frame from the webcam
|
||||
as a **numpy array** and returns the processed frame also as a **numpy array**.
|
||||
* Numpy arrays are in (height, width, 3) format where the color channels are in RGB format.
|
||||
* The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
|
||||
* The `time_limit` parameter is the maximum time in seconds the video stream will run. If the time limit is reached, the video stream will stop.
|
||||
|
||||
## Streaming Video from the server to the client
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
import cv2
|
||||
|
||||
def generation():
|
||||
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
|
||||
cap = cv2.VideoCapture(url)
|
||||
iterating = True
|
||||
while iterating:
|
||||
iterating, frame = cap.read()
|
||||
yield frame
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
output_video = WebRTC(label="Video Stream", mode="receive", modality="video")
|
||||
button = gr.Button("Start", variant="primary")
|
||||
output_video.stream(
|
||||
fn=generation, inputs=None, outputs=[output_video],
|
||||
trigger=button.click
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
* Set the "mode" parameter to "receive" and "modality" to "video".
|
||||
* The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
|
||||
* The only output allowed is the WebRTC component.
|
||||
* The `trigger` parameter the gradio event that will trigger the webrtc connection. In this case, the button click event.
|
||||
|
||||
## Streaming Audio from the Server to the Client
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from pydub import AudioSegment
|
||||
|
||||
def generation(num_steps):
|
||||
for _ in range(num_steps):
|
||||
segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
|
||||
yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1))
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
audio = WebRTC(label="Stream", mode="receive", modality="audio")
|
||||
num_steps = gr.Slider(
|
||||
label="Number of Steps",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=1,
|
||||
value=5,
|
||||
)
|
||||
button = gr.Button("Generate")
|
||||
|
||||
audio.stream(
|
||||
fn=generation, inputs=[num_steps], outputs=[audio],
|
||||
trigger=button.click
|
||||
)
|
||||
```
|
||||
|
||||
* Set the "mode" parameter to "receive" and "modality" to "audio".
|
||||
* The `stream` event's `fn` parameter is a generator function that yields the next audio segment as a tuple of (frame_rate, audio_samples).
|
||||
* The numpy array should be of shape (1, num_samples).
|
||||
* The `outputs` parameter should be a list with the WebRTC component as the only element.
|
||||
|
||||
## Conversational AI
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from gradio_webrtc import WebRTC, StreamHandler
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
|
||||
class EchoHandler(StreamHandler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.queue = Queue()
|
||||
|
||||
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
|
||||
self.queue.put(frame)
|
||||
|
||||
def emit(self) -> None:
|
||||
return self.queue.get()
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
rtc_configuration=None,
|
||||
mode="send-receive",
|
||||
modality="audio",
|
||||
)
|
||||
|
||||
audio.stream(fn=EchoHandler(), inputs=[audio], outputs=[audio], time_limit=15)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
* Instead of passing a function to the `stream` event's `fn` parameter, pass a `StreamHandler` implementation. The `StreamHandler` above simply echoes the audio back to the client.
|
||||
* The `StreamHandler` class has two methods: `receive` and `emit`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client.
|
||||
* An audio frame is represented as a tuple of (frame_rate, audio_samples) where `audio_samples` is a numpy array of shape (num_channels, num_samples).
|
||||
* You can also specify the audio layout ("mono" or "stereo") in the emit method by retuning it as the third element of the tuple. If not specified, the default is "mono".
|
||||
* The `time_limit` parameter is the maximum time in seconds the conversation will run. If the time limit is reached, the audio stream will stop.
|
||||
* The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return None.
|
||||
|
||||
## Deployment
|
||||
|
||||
When deploying in a cloud environment (like Hugging Face Spaces, EC2, etc), you need to set up a TURN server to relay the WebRTC traffic.
|
||||
The easiest way to do this is to use a service like Twilio.
|
||||
|
||||
```python
|
||||
from twilio.rest import Client
|
||||
import os
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
...
|
||||
rtc = WebRTC(rtc_configuration=rtc_configuration, ...)
|
||||
...
|
||||
```
|
||||
""",
|
||||
elem_classes=["md-custom"],
|
||||
header_links=True,
|
||||
)
|
||||
|
||||
gr.Markdown(
|
||||
"""
|
||||
##
|
||||
""",
|
||||
elem_classes=["md-custom"],
|
||||
header_links=True,
|
||||
)
|
||||
|
||||
gr.ParamViewer(value=_docs["WebRTC"]["members"]["__init__"], linkify=[])
|
||||
|
||||
demo.load(
|
||||
None,
|
||||
js=r"""function() {
|
||||
const refs = {};
|
||||
const user_fn_refs = {
|
||||
WebRTC: [], };
|
||||
requestAnimationFrame(() => {
|
||||
|
||||
Object.entries(user_fn_refs).forEach(([key, refs]) => {
|
||||
if (refs.length > 0) {
|
||||
const el = document.querySelector(`.${key}-user-fn`);
|
||||
if (!el) return;
|
||||
refs.forEach(ref => {
|
||||
el.innerHTML = el.innerHTML.replace(
|
||||
new RegExp("\\b"+ref+"\\b", "g"),
|
||||
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
Object.entries(refs).forEach(([key, refs]) => {
|
||||
if (refs.length > 0) {
|
||||
const el = document.querySelector(`.${key}`);
|
||||
if (!el) return;
|
||||
refs.forEach(ref => {
|
||||
el.innerHTML = el.innerHTML.replace(
|
||||
new RegExp("\\b"+ref+"\\b", "g"),
|
||||
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
""",
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
367
demo/app_.py
Normal file
367
demo/app_.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
_docs = {
|
||||
"WebRTC": {
|
||||
"description": "Stream audio/video with WebRTC",
|
||||
"members": {
|
||||
"__init__": {
|
||||
"rtc_configuration": {
|
||||
"type": "dict[str, Any] | None",
|
||||
"default": "None",
|
||||
"description": "The configration dictionary to pass to the RTCPeerConnection constructor. If None, the default configuration is used.",
|
||||
},
|
||||
"height": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "The height of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
|
||||
},
|
||||
"width": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
|
||||
},
|
||||
"label": {
|
||||
"type": "str | None",
|
||||
"default": "None",
|
||||
"description": "the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.",
|
||||
},
|
||||
"show_label": {
|
||||
"type": "bool | None",
|
||||
"default": "None",
|
||||
"description": "if True, will display label.",
|
||||
},
|
||||
"container": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if True, will place the component in a container - providing some extra padding around the border.",
|
||||
},
|
||||
"scale": {
|
||||
"type": "int | None",
|
||||
"default": "None",
|
||||
"description": "relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.",
|
||||
},
|
||||
"min_width": {
|
||||
"type": "int",
|
||||
"default": "160",
|
||||
"description": "minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.",
|
||||
},
|
||||
"interactive": {
|
||||
"type": "bool | None",
|
||||
"default": "None",
|
||||
"description": "if True, will allow users to upload a video; if False, can only be used to display videos. If not provided, this is inferred based on whether the component is used as an input or output.",
|
||||
},
|
||||
"visible": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if False, component will be hidden.",
|
||||
},
|
||||
"elem_id": {
|
||||
"type": "str | None",
|
||||
"default": "None",
|
||||
"description": "an optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.",
|
||||
},
|
||||
"elem_classes": {
|
||||
"type": "list[str] | str | None",
|
||||
"default": "None",
|
||||
"description": "an optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.",
|
||||
},
|
||||
"render": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.",
|
||||
},
|
||||
"key": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.",
|
||||
},
|
||||
"mirror_webcam": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if True webcam will be mirrored. Default is True.",
|
||||
},
|
||||
},
|
||||
"events": {"tick": {"type": None, "default": None, "description": ""}},
|
||||
},
|
||||
"__meta__": {"additional_interfaces": {}, "user_fn_refs": {"WebRTC": []}},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
abs_path = os.path.join(os.path.dirname(__file__), "css.css")
|
||||
|
||||
with gr.Blocks(
|
||||
css_paths=abs_path,
|
||||
theme=gr.themes.Default(
|
||||
font_mono=[
|
||||
gr.themes.GoogleFont("Inconsolata"),
|
||||
"monospace",
|
||||
],
|
||||
),
|
||||
) as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
<h1 style='text-align: center; margin-bottom: 1rem'> Gradio WebRTC ⚡️ </h1>
|
||||
|
||||
<div style="display: flex; flex-direction: row; justify-content: center">
|
||||
<img style="display: block; padding-right: 5px; height: 20px;" alt="Static Badge" src="https://img.shields.io/badge/version%20-%200.0.6%20-%20orange">
|
||||
<a href="https://github.com/freddyaboulton/gradio-webrtc" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/github-white?logo=github&logoColor=black"></a>
|
||||
</div>
|
||||
""",
|
||||
elem_classes=["md-custom"],
|
||||
header_links=True,
|
||||
)
|
||||
gr.Markdown(
|
||||
"""
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install gradio_webrtc
|
||||
```
|
||||
|
||||
## Examples:
|
||||
1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷
|
||||
2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥
|
||||
3. [Text-to-Speech](https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc) 🗣️
|
||||
4. [Conversational AI](https://huggingface.co/spaces/freddyaboulton/omni-mini-webrtc) 🤖🗣️
|
||||
|
||||
## Usage
|
||||
|
||||
The WebRTC component supports the following three use cases:
|
||||
1. [Streaming video from the user webcam to the server and back](#h-streaming-video-from-the-user-webcam-to-the-server-and-back)
|
||||
2. [Streaming Video from the server to the client](#h-streaming-video-from-the-server-to-the-client)
|
||||
3. [Streaming Audio from the server to the client](#h-streaming-audio-from-the-server-to-the-client)
|
||||
4. [Streaming Audio from the client to the server and back (conversational AI)](#h-conversational-ai)
|
||||
|
||||
|
||||
## Streaming Video from the User Webcam to the Server and Back
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
|
||||
|
||||
def detection(image, conf_threshold=0.3):
|
||||
... your detection code here ...
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
image = WebRTC(label="Stream", mode="send-receive", modality="video")
|
||||
conf_threshold = gr.Slider(
|
||||
label="Confidence Threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
step=0.05,
|
||||
value=0.30,
|
||||
)
|
||||
image.stream(
|
||||
fn=detection,
|
||||
inputs=[image, conf_threshold],
|
||||
outputs=[image], time_limit=10
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
||||
```
|
||||
* Set the `mode` parameter to `send-receive` and `modality` to "video".
|
||||
* The `stream` event's `fn` parameter is a function that receives the next frame from the webcam
|
||||
as a **numpy array** and returns the processed frame also as a **numpy array**.
|
||||
* Numpy arrays are in (height, width, 3) format where the color channels are in RGB format.
|
||||
* The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
|
||||
* The `time_limit` parameter is the maximum time in seconds the video stream will run. If the time limit is reached, the video stream will stop.
|
||||
|
||||
## Streaming Video from the server to the client
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
import cv2
|
||||
|
||||
def generation():
|
||||
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
|
||||
cap = cv2.VideoCapture(url)
|
||||
iterating = True
|
||||
while iterating:
|
||||
iterating, frame = cap.read()
|
||||
yield frame
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
output_video = WebRTC(label="Video Stream", mode="receive", modality="video")
|
||||
button = gr.Button("Start", variant="primary")
|
||||
output_video.stream(
|
||||
fn=generation, inputs=None, outputs=[output_video],
|
||||
trigger=button.click
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
* Set the "mode" parameter to "receive" and "modality" to "video".
|
||||
* The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
|
||||
* The only output allowed is the WebRTC component.
|
||||
* The `trigger` parameter the gradio event that will trigger the webrtc connection. In this case, the button click event.
|
||||
|
||||
## Streaming Audio from the Server to the Client
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from pydub import AudioSegment
|
||||
|
||||
def generation(num_steps):
|
||||
for _ in range(num_steps):
|
||||
segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
|
||||
yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1))
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
audio = WebRTC(label="Stream", mode="receive", modality="audio")
|
||||
num_steps = gr.Slider(
|
||||
label="Number of Steps",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=1,
|
||||
value=5,
|
||||
)
|
||||
button = gr.Button("Generate")
|
||||
|
||||
audio.stream(
|
||||
fn=generation, inputs=[num_steps], outputs=[audio],
|
||||
trigger=button.click
|
||||
)
|
||||
```
|
||||
|
||||
* Set the "mode" parameter to "receive" and "modality" to "audio".
|
||||
* The `stream` event's `fn` parameter is a generator function that yields the next audio segment as a tuple of (frame_rate, audio_samples).
|
||||
* The numpy array should be of shape (1, num_samples).
|
||||
* The `outputs` parameter should be a list with the WebRTC component as the only element.
|
||||
|
||||
## Conversational AI
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from gradio_webrtc import WebRTC, StreamHandler
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
|
||||
class EchoHandler(StreamHandler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.queue = Queue()
|
||||
|
||||
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
|
||||
self.queue.put(frame)
|
||||
|
||||
def emit(self) -> None:
|
||||
return self.queue.get()
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
rtc_configuration=None,
|
||||
mode="send-receive",
|
||||
modality="audio",
|
||||
)
|
||||
|
||||
audio.stream(fn=EchoHandler(), inputs=[audio], outputs=[audio], time_limit=15)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
* Instead of passing a function to the `stream` event's `fn` parameter, pass a `StreamHandler` implementation. The `StreamHandler` above simply echoes the audio back to the client.
|
||||
* The `StreamHandler` class has two methods: `receive` and `emit`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client.
|
||||
* An audio frame is represented as a tuple of (frame_rate, audio_samples) where `audio_samples` is a numpy array of shape (num_channels, num_samples).
|
||||
* You can also specify the audio layout ("mono" or "stereo") in the emit method by retuning it as the third element of the tuple. If not specified, the default is "mono".
|
||||
* The `time_limit` parameter is the maximum time in seconds the conversation will run. If the time limit is reached, the audio stream will stop.
|
||||
* The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return None.
|
||||
|
||||
## Deployment
|
||||
|
||||
When deploying in a cloud environment (like Hugging Face Spaces, EC2, etc), you need to set up a TURN server to relay the WebRTC traffic.
|
||||
The easiest way to do this is to use a service like Twilio.
|
||||
|
||||
```python
|
||||
from twilio.rest import Client
|
||||
import os
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
...
|
||||
rtc = WebRTC(rtc_configuration=rtc_configuration, ...)
|
||||
...
|
||||
```
|
||||
""",
|
||||
elem_classes=["md-custom"],
|
||||
header_links=True,
|
||||
)
|
||||
|
||||
gr.Markdown(
|
||||
"""
|
||||
##
|
||||
""",
|
||||
elem_classes=["md-custom"],
|
||||
header_links=True,
|
||||
)
|
||||
|
||||
gr.ParamViewer(value=_docs["WebRTC"]["members"]["__init__"], linkify=[])
|
||||
|
||||
demo.load(
|
||||
None,
|
||||
js=r"""function() {
|
||||
const refs = {};
|
||||
const user_fn_refs = {
|
||||
WebRTC: [], };
|
||||
requestAnimationFrame(() => {
|
||||
|
||||
Object.entries(user_fn_refs).forEach(([key, refs]) => {
|
||||
if (refs.length > 0) {
|
||||
const el = document.querySelector(`.${key}-user-fn`);
|
||||
if (!el) return;
|
||||
refs.forEach(ref => {
|
||||
el.innerHTML = el.innerHTML.replace(
|
||||
new RegExp("\\b"+ref+"\\b", "g"),
|
||||
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
Object.entries(refs).forEach(([key, refs]) => {
|
||||
if (refs.length > 0) {
|
||||
const el = document.querySelector(`.${key}`);
|
||||
if (!el) return;
|
||||
refs.forEach(ref => {
|
||||
el.innerHTML = el.innerHTML.replace(
|
||||
new RegExp("\\b"+ref+"\\b", "g"),
|
||||
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
""",
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
73
demo/app_orig.py
Normal file
73
demo/app_orig.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
from huggingface_hub import hf_hub_download
|
||||
from inference import YOLOv10
|
||||
from twilio.rest import Client
|
||||
|
||||
model_file = hf_hub_download(
|
||||
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
|
||||
)
|
||||
|
||||
model = YOLOv10(model_file)
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
if account_sid and auth_token:
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
else:
|
||||
rtc_configuration = None
|
||||
|
||||
|
||||
def detection(image, conf_threshold=0.3):
|
||||
image = cv2.resize(image, (model.input_width, model.input_height))
|
||||
new_image = model.detect_objects(image, conf_threshold)
|
||||
return cv2.resize(new_image, (500, 500))
|
||||
|
||||
|
||||
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
|
||||
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
|
||||
|
||||
|
||||
with gr.Blocks(css=css) as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
gr.HTML(
|
||||
"""
|
||||
<h3 style='text-align: center'>
|
||||
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
|
||||
</h3>
|
||||
"""
|
||||
)
|
||||
with gr.Column(elem_classes=["my-column"]):
|
||||
with gr.Group(elem_classes=["my-group"]):
|
||||
image = WebRTC(label="Stream", rtc_configuration=rtc_configuration)
|
||||
conf_threshold = gr.Slider(
|
||||
label="Confidence Threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
step=0.05,
|
||||
value=0.30,
|
||||
)
|
||||
|
||||
image.stream(
|
||||
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
71
demo/audio_out.py
Normal file
71
demo/audio_out.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from gradio_webrtc import WebRTC
|
||||
from pydub import AudioSegment
|
||||
from twilio.rest import Client
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
if account_sid and auth_token:
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
else:
|
||||
rtc_configuration = None
|
||||
|
||||
|
||||
def generation(num_steps):
|
||||
for _ in range(num_steps):
|
||||
segment = AudioSegment.from_file(
|
||||
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
|
||||
)
|
||||
yield (
|
||||
segment.frame_rate,
|
||||
np.array(segment.get_array_of_samples()).reshape(1, -1),
|
||||
)
|
||||
|
||||
|
||||
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
|
||||
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Audio Streaming (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
with gr.Column(elem_classes=["my-column"]):
|
||||
with gr.Group(elem_classes=["my-group"]):
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
rtc_configuration=rtc_configuration,
|
||||
mode="receive",
|
||||
modality="audio",
|
||||
)
|
||||
num_steps = gr.Slider(
|
||||
label="Number of Steps",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=1,
|
||||
value=5,
|
||||
)
|
||||
button = gr.Button("Generate")
|
||||
|
||||
audio.stream(
|
||||
fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
64
demo/audio_out_2.py
Normal file
64
demo/audio_out_2.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from gradio_webrtc import WebRTC
|
||||
from pydub import AudioSegment
|
||||
from twilio.rest import Client
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
if account_sid and auth_token:
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
else:
|
||||
rtc_configuration = None
|
||||
|
||||
|
||||
def generation(num_steps):
|
||||
for _ in range(num_steps):
|
||||
segment = AudioSegment.from_file(
|
||||
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
|
||||
)
|
||||
yield (
|
||||
segment.frame_rate,
|
||||
np.array(segment.get_array_of_samples()).reshape(1, -1),
|
||||
)
|
||||
time.sleep(3.5)
|
||||
|
||||
|
||||
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
|
||||
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Audio Streaming (Powered by WebRaTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Slider()
|
||||
with gr.Column():
|
||||
# audio = gr.Audio(interactive=False)
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
rtc_configuration=rtc_configuration,
|
||||
mode="receive",
|
||||
modality="audio",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
161
demo/css.css
Normal file
161
demo/css.css
Normal file
@@ -0,0 +1,161 @@
|
||||
html {
|
||||
font-family: Inter;
|
||||
font-size: 16px;
|
||||
font-weight: 400;
|
||||
line-height: 1.5;
|
||||
-webkit-text-size-adjust: 100%;
|
||||
background: #fff;
|
||||
color: #323232;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-moz-osx-font-smoothing: grayscale;
|
||||
text-rendering: optimizeLegibility;
|
||||
}
|
||||
|
||||
:root {
|
||||
--space: 1;
|
||||
--vspace: calc(var(--space) * 1rem);
|
||||
--vspace-0: calc(3 * var(--space) * 1rem);
|
||||
--vspace-1: calc(2 * var(--space) * 1rem);
|
||||
--vspace-2: calc(1.5 * var(--space) * 1rem);
|
||||
--vspace-3: calc(0.5 * var(--space) * 1rem);
|
||||
}
|
||||
|
||||
.app {
|
||||
max-width: 748px !important;
|
||||
}
|
||||
|
||||
.prose p {
|
||||
margin: var(--vspace) 0;
|
||||
line-height: var(--vspace * 2);
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
code {
|
||||
font-family: "Inconsolata", sans-serif;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
h1,
|
||||
h1 code {
|
||||
font-weight: 400;
|
||||
line-height: calc(2.5 / var(--space) * var(--vspace));
|
||||
}
|
||||
|
||||
h1 code {
|
||||
background: none;
|
||||
border: none;
|
||||
letter-spacing: 0.05em;
|
||||
padding-bottom: 5px;
|
||||
position: relative;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
h2 {
|
||||
margin: var(--vspace-1) 0 var(--vspace-2) 0;
|
||||
line-height: 1em;
|
||||
}
|
||||
|
||||
h3,
|
||||
h3 code {
|
||||
margin: var(--vspace-1) 0 var(--vspace-2) 0;
|
||||
line-height: 1em;
|
||||
}
|
||||
|
||||
h4,
|
||||
h5,
|
||||
h6 {
|
||||
margin: var(--vspace-3) 0 var(--vspace-3) 0;
|
||||
line-height: var(--vspace);
|
||||
}
|
||||
|
||||
.bigtitle,
|
||||
h1,
|
||||
h1 code {
|
||||
font-size: calc(8px * 4.5);
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.title,
|
||||
h2,
|
||||
h2 code {
|
||||
font-size: calc(8px * 3.375);
|
||||
font-weight: lighter;
|
||||
word-break: break-word;
|
||||
border: none;
|
||||
background: none;
|
||||
}
|
||||
|
||||
.subheading1,
|
||||
h3,
|
||||
h3 code {
|
||||
font-size: calc(8px * 1.8);
|
||||
font-weight: 600;
|
||||
border: none;
|
||||
background: none;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
|
||||
h2 code {
|
||||
padding: 0;
|
||||
position: relative;
|
||||
letter-spacing: 0.05em;
|
||||
}
|
||||
|
||||
blockquote {
|
||||
font-size: calc(8px * 1.1667);
|
||||
font-style: italic;
|
||||
line-height: calc(1.1667 * var(--vspace));
|
||||
margin: var(--vspace-2) var(--vspace-2);
|
||||
}
|
||||
|
||||
.subheading2,
|
||||
h4 {
|
||||
font-size: calc(8px * 1.4292);
|
||||
text-transform: uppercase;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.subheading3,
|
||||
h5 {
|
||||
font-size: calc(8px * 1.2917);
|
||||
line-height: calc(1.2917 * var(--vspace));
|
||||
|
||||
font-weight: lighter;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.15em;
|
||||
}
|
||||
|
||||
h6 {
|
||||
font-size: calc(8px * 1.1667);
|
||||
font-size: 1.1667em;
|
||||
font-weight: normal;
|
||||
font-style: italic;
|
||||
font-family: "le-monde-livre-classic-byol", serif !important;
|
||||
letter-spacing: 0px !important;
|
||||
}
|
||||
|
||||
#start .md > *:first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
h2 + h3 {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.md hr {
|
||||
border: none;
|
||||
border-top: 1px solid var(--block-border-color);
|
||||
margin: var(--vspace-2) 0 var(--vspace-2) 0;
|
||||
}
|
||||
.prose ul {
|
||||
margin: var(--vspace-2) 0 var(--vspace-1) 0;
|
||||
}
|
||||
|
||||
.gap {
|
||||
gap: 0;
|
||||
}
|
||||
|
||||
.md-custom {
|
||||
overflow: hidden;
|
||||
}
|
||||
99
demo/docs.py
Normal file
99
demo/docs.py
Normal file
@@ -0,0 +1,99 @@
|
||||
_docs = {
|
||||
"WebRTC": {
|
||||
"description": "Stream audio/video with WebRTC",
|
||||
"members": {
|
||||
"__init__": {
|
||||
"rtc_configuration": {
|
||||
"type": "dict[str, Any] | None",
|
||||
"default": "None",
|
||||
"description": "The configration dictionary to pass to the RTCPeerConnection constructor. If None, the default configuration is used.",
|
||||
},
|
||||
"height": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "The height of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
|
||||
},
|
||||
"width": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
|
||||
},
|
||||
"label": {
|
||||
"type": "str | None",
|
||||
"default": "None",
|
||||
"description": "the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.",
|
||||
},
|
||||
"show_label": {
|
||||
"type": "bool | None",
|
||||
"default": "None",
|
||||
"description": "if True, will display label.",
|
||||
},
|
||||
"container": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if True, will place the component in a container - providing some extra padding around the border.",
|
||||
},
|
||||
"scale": {
|
||||
"type": "int | None",
|
||||
"default": "None",
|
||||
"description": "relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.",
|
||||
},
|
||||
"min_width": {
|
||||
"type": "int",
|
||||
"default": "160",
|
||||
"description": "minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.",
|
||||
},
|
||||
"interactive": {
|
||||
"type": "bool | None",
|
||||
"default": "None",
|
||||
"description": "if True, will allow users to upload a video; if False, can only be used to display videos. If not provided, this is inferred based on whether the component is used as an input or output.",
|
||||
},
|
||||
"visible": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if False, component will be hidden.",
|
||||
},
|
||||
"elem_id": {
|
||||
"type": "str | None",
|
||||
"default": "None",
|
||||
"description": "an optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.",
|
||||
},
|
||||
"elem_classes": {
|
||||
"type": "list[str] | str | None",
|
||||
"default": "None",
|
||||
"description": "an optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.",
|
||||
},
|
||||
"render": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.",
|
||||
},
|
||||
"key": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.",
|
||||
},
|
||||
"mirror_webcam": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if True webcam will be mirrored. Default is True.",
|
||||
},
|
||||
"postprocess": {
|
||||
"value": {
|
||||
"type": "typing.Any",
|
||||
"description": "Expects a {str} or {pathlib.Path} filepath to a video which is displayed, or a {Tuple[str | pathlib.Path, str | pathlib.Path | None]} where the first element is a filepath to a video and the second element is an optional filepath to a subtitle file.",
|
||||
}
|
||||
},
|
||||
"preprocess": {
|
||||
"return": {
|
||||
"type": "str",
|
||||
"description": "Passes the uploaded video as a `str` filepath or URL whose extension can be modified by `format`.",
|
||||
},
|
||||
"value": None,
|
||||
},
|
||||
},
|
||||
"events": {"tick": {"type": None, "default": None, "description": ""}},
|
||||
},
|
||||
"__meta__": {"additional_interfaces": {}, "user_fn_refs": {"WebRTC": []}},
|
||||
}
|
||||
}
|
||||
61
demo/echo_conversation.py
Normal file
61
demo/echo_conversation.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from queue import Queue
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from gradio_webrtc import StreamHandler, WebRTC
|
||||
|
||||
# Configure the root logger to WARNING to suppress debug messages from other libraries
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
# Create a console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
|
||||
# Create a formatter
|
||||
formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Configure the logger for your specific library
|
||||
logger = logging.getLogger("gradio_webrtc")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
class EchoHandler(StreamHandler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.queue = Queue()
|
||||
|
||||
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
|
||||
self.queue.put(frame)
|
||||
|
||||
def emit(self) -> None:
|
||||
return self.queue.get()
|
||||
|
||||
def copy(self) -> StreamHandler:
|
||||
return EchoHandler()
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Conversational AI (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
rtc_configuration=None,
|
||||
mode="send-receive",
|
||||
modality="audio",
|
||||
)
|
||||
|
||||
audio.stream(fn=EchoHandler(), inputs=[audio], outputs=[audio], time_limit=15)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
149
demo/inference.py
Normal file
149
demo/inference.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
from utils import draw_detections
|
||||
|
||||
|
||||
class YOLOv10:
|
||||
def __init__(self, path):
|
||||
# Initialize model
|
||||
self.initialize_model(path)
|
||||
|
||||
def __call__(self, image):
|
||||
return self.detect_objects(image)
|
||||
|
||||
def initialize_model(self, path):
|
||||
self.session = onnxruntime.InferenceSession(
|
||||
path, providers=onnxruntime.get_available_providers()
|
||||
)
|
||||
# Get model info
|
||||
self.get_input_details()
|
||||
self.get_output_details()
|
||||
|
||||
def detect_objects(self, image, conf_threshold=0.3):
|
||||
input_tensor = self.prepare_input(image)
|
||||
|
||||
# Perform inference on the image
|
||||
new_image = self.inference(image, input_tensor, conf_threshold)
|
||||
|
||||
return new_image
|
||||
|
||||
def prepare_input(self, image):
|
||||
self.img_height, self.img_width = image.shape[:2]
|
||||
|
||||
input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize input image
|
||||
input_img = cv2.resize(input_img, (self.input_width, self.input_height))
|
||||
|
||||
# Scale input pixel values to 0 to 1
|
||||
input_img = input_img / 255.0
|
||||
input_img = input_img.transpose(2, 0, 1)
|
||||
input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
|
||||
|
||||
return input_tensor
|
||||
|
||||
def inference(self, image, input_tensor, conf_threshold=0.3):
|
||||
start = time.perf_counter()
|
||||
outputs = self.session.run(
|
||||
self.output_names, {self.input_names[0]: input_tensor}
|
||||
)
|
||||
|
||||
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
|
||||
(
|
||||
boxes,
|
||||
scores,
|
||||
class_ids,
|
||||
) = self.process_output(outputs, conf_threshold)
|
||||
return self.draw_detections(image, boxes, scores, class_ids)
|
||||
|
||||
def process_output(self, output, conf_threshold=0.3):
|
||||
predictions = np.squeeze(output[0])
|
||||
|
||||
# Filter out object confidence scores below threshold
|
||||
scores = predictions[:, 4]
|
||||
predictions = predictions[scores > conf_threshold, :]
|
||||
scores = scores[scores > conf_threshold]
|
||||
|
||||
if len(scores) == 0:
|
||||
return [], [], []
|
||||
|
||||
# Get the class with the highest confidence
|
||||
class_ids = np.argmax(predictions[:, 4:], axis=1)
|
||||
|
||||
# Get bounding boxes for each object
|
||||
boxes = self.extract_boxes(predictions)
|
||||
|
||||
return boxes, scores, class_ids
|
||||
|
||||
def extract_boxes(self, predictions):
|
||||
# Extract boxes from predictions
|
||||
boxes = predictions[:, :4]
|
||||
|
||||
# Scale boxes to original image dimensions
|
||||
boxes = self.rescale_boxes(boxes)
|
||||
|
||||
# Convert boxes to xyxy format
|
||||
# boxes = xywh2xyxy(boxes)
|
||||
|
||||
return boxes
|
||||
|
||||
def rescale_boxes(self, boxes):
|
||||
# Rescale boxes to original image dimensions
|
||||
input_shape = np.array(
|
||||
[self.input_width, self.input_height, self.input_width, self.input_height]
|
||||
)
|
||||
boxes = np.divide(boxes, input_shape, dtype=np.float32)
|
||||
boxes *= np.array(
|
||||
[self.img_width, self.img_height, self.img_width, self.img_height]
|
||||
)
|
||||
return boxes
|
||||
|
||||
def draw_detections(
|
||||
self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
|
||||
):
|
||||
return draw_detections(image, boxes, scores, class_ids, mask_alpha)
|
||||
|
||||
def get_input_details(self):
|
||||
model_inputs = self.session.get_inputs()
|
||||
self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
||||
|
||||
self.input_shape = model_inputs[0].shape
|
||||
self.input_height = self.input_shape[2]
|
||||
self.input_width = self.input_shape[3]
|
||||
|
||||
def get_output_details(self):
|
||||
model_outputs = self.session.get_outputs()
|
||||
self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import tempfile
|
||||
|
||||
import requests
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
model_file = hf_hub_download(
|
||||
repo_id="onnx-community/yolov10s", filename="onnx/model.onnx"
|
||||
)
|
||||
|
||||
yolov8_detector = YOLOv10(model_file)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(
|
||||
requests.get(
|
||||
"https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
|
||||
).content
|
||||
)
|
||||
f.seek(0)
|
||||
img = cv2.imread(f.name)
|
||||
|
||||
# # Detect Objects
|
||||
combined_image = yolov8_detector.detect_objects(img)
|
||||
|
||||
# Draw detections
|
||||
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
|
||||
cv2.imshow("Output", combined_image)
|
||||
cv2.waitKey(0)
|
||||
74
demo/old_app.py
Normal file
74
demo/old_app.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
from huggingface_hub import hf_hub_download
|
||||
from inference import YOLOv10
|
||||
from twilio.rest import Client
|
||||
|
||||
model_file = hf_hub_download(
|
||||
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
|
||||
)
|
||||
|
||||
model = YOLOv10(model_file)
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
if account_sid and auth_token:
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
else:
|
||||
rtc_configuration = None
|
||||
|
||||
|
||||
def detection(frame, conf_threshold=0.3):
|
||||
frame = cv2.flip(frame, 0)
|
||||
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
|
||||
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
|
||||
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
|
||||
|
||||
|
||||
with gr.Blocks(css=css) as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
gr.HTML(
|
||||
"""
|
||||
<h3 style='text-align: center'>
|
||||
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
|
||||
</h3>
|
||||
"""
|
||||
)
|
||||
with gr.Column(elem_classes=["my-column"]):
|
||||
with gr.Group(elem_classes=["my-group"]):
|
||||
image = WebRTC(label="Stream", rtc_configuration=rtc_configuration)
|
||||
conf_threshold = gr.Slider(
|
||||
label="Confidence Threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
step=0.05,
|
||||
value=0.30,
|
||||
)
|
||||
number = gr.Number()
|
||||
|
||||
image.stream(
|
||||
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10
|
||||
)
|
||||
image.on_additional_outputs(lambda n: n, outputs=[number])
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
6
demo/requirements.txt
Normal file
6
demo/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
safetensors==0.4.3
|
||||
opencv-python
|
||||
twilio
|
||||
https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/gradio-5.0.0b3-py3-none-any.whl
|
||||
https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/gradio_webrtc-0.0.1-py3-none-any.whl
|
||||
onnxruntime-gpu
|
||||
321
demo/space.py
Normal file
321
demo/space.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
_docs = {
|
||||
"WebRTC": {
|
||||
"description": "Stream audio/video with WebRTC",
|
||||
"members": {
|
||||
"__init__": {
|
||||
"rtc_configuration": {
|
||||
"type": "dict[str, Any] | None",
|
||||
"default": "None",
|
||||
"description": "The configration dictionary to pass to the RTCPeerConnection constructor. If None, the default configuration is used.",
|
||||
},
|
||||
"height": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "The height of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
|
||||
},
|
||||
"width": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed video file, but will affect the displayed video.",
|
||||
},
|
||||
"label": {
|
||||
"type": "str | None",
|
||||
"default": "None",
|
||||
"description": "the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.",
|
||||
},
|
||||
"show_label": {
|
||||
"type": "bool | None",
|
||||
"default": "None",
|
||||
"description": "if True, will display label.",
|
||||
},
|
||||
"container": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if True, will place the component in a container - providing some extra padding around the border.",
|
||||
},
|
||||
"scale": {
|
||||
"type": "int | None",
|
||||
"default": "None",
|
||||
"description": "relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.",
|
||||
},
|
||||
"min_width": {
|
||||
"type": "int",
|
||||
"default": "160",
|
||||
"description": "minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.",
|
||||
},
|
||||
"interactive": {
|
||||
"type": "bool | None",
|
||||
"default": "None",
|
||||
"description": "if True, will allow users to upload a video; if False, can only be used to display videos. If not provided, this is inferred based on whether the component is used as an input or output.",
|
||||
},
|
||||
"visible": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if False, component will be hidden.",
|
||||
},
|
||||
"elem_id": {
|
||||
"type": "str | None",
|
||||
"default": "None",
|
||||
"description": "an optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.",
|
||||
},
|
||||
"elem_classes": {
|
||||
"type": "list[str] | str | None",
|
||||
"default": "None",
|
||||
"description": "an optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.",
|
||||
},
|
||||
"render": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.",
|
||||
},
|
||||
"key": {
|
||||
"type": "int | str | None",
|
||||
"default": "None",
|
||||
"description": "if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.",
|
||||
},
|
||||
"mirror_webcam": {
|
||||
"type": "bool",
|
||||
"default": "True",
|
||||
"description": "if True webcam will be mirrored. Default is True.",
|
||||
},
|
||||
},
|
||||
"events": {"tick": {"type": None, "default": None, "description": ""}},
|
||||
},
|
||||
"__meta__": {"additional_interfaces": {}, "user_fn_refs": {"WebRTC": []}},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
abs_path = os.path.join(os.path.dirname(__file__), "css.css")
|
||||
|
||||
with gr.Blocks(
|
||||
css_paths=abs_path,
|
||||
theme=gr.themes.Default(
|
||||
font_mono=[
|
||||
gr.themes.GoogleFont("Inconsolata"),
|
||||
"monospace",
|
||||
],
|
||||
),
|
||||
) as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
<h1 style='text-align: center; margin-bottom: 1rem'> Gradio WebRTC ⚡️ </h1>
|
||||
|
||||
<div style="display: flex; flex-direction: row; justify-content: center">
|
||||
<img style="display: block; padding-right: 5px; height: 20px;" alt="Static Badge" src="https://img.shields.io/badge/version%20-%200.0.5%20-%20orange">
|
||||
<a href="https://github.com/freddyaboulton/gradio-webrtc" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/github-white?logo=github&logoColor=black"></a>
|
||||
</div>
|
||||
""",
|
||||
elem_classes=["md-custom"],
|
||||
header_links=True,
|
||||
)
|
||||
gr.Markdown(
|
||||
"""
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install gradio_webrtc
|
||||
```
|
||||
|
||||
## Examples:
|
||||
1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷
|
||||
2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥
|
||||
3. [Text-to-Speech](https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc) 🗣️
|
||||
|
||||
## Usage
|
||||
|
||||
The WebRTC component supports the following three use cases:
|
||||
1. Streaming video from the user webcam to the server and back
|
||||
2. Streaming Video from the server to the client
|
||||
3. Streaming Audio from the server to the client
|
||||
|
||||
Streaming Audio from client to the server and back (conversational AI) is not supported yet.
|
||||
|
||||
|
||||
## Streaming Video from the User Webcam to the Server and Back
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
|
||||
|
||||
def detection(image, conf_threshold=0.3):
|
||||
... your detection code here ...
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
image = WebRTC(label="Stream", mode="send-receive", modality="video")
|
||||
conf_threshold = gr.Slider(
|
||||
label="Confidence Threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
step=0.05,
|
||||
value=0.30,
|
||||
)
|
||||
image.stream(
|
||||
fn=detection,
|
||||
inputs=[image, conf_threshold],
|
||||
outputs=[image], time_limit=10
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
||||
```
|
||||
* Set the `mode` parameter to `send-receive` and `modality` to "video".
|
||||
* The `stream` event's `fn` parameter is a function that receives the next frame from the webcam
|
||||
as a **numpy array** and returns the processed frame also as a **numpy array**.
|
||||
* Numpy arrays are in (height, width, 3) format where the color channels are in RGB format.
|
||||
* The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
|
||||
* The `time_limit` parameter is the maximum time in seconds the video stream will run. If the time limit is reached, the video stream will stop.
|
||||
|
||||
## Streaming Video from the User Webcam to the Server and Back
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
import cv2
|
||||
|
||||
def generation():
|
||||
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
|
||||
cap = cv2.VideoCapture(url)
|
||||
iterating = True
|
||||
while iterating:
|
||||
iterating, frame = cap.read()
|
||||
yield frame
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
output_video = WebRTC(label="Video Stream", mode="receive", modality="video")
|
||||
button = gr.Button("Start", variant="primary")
|
||||
output_video.stream(
|
||||
fn=generation, inputs=None, outputs=[output_video],
|
||||
trigger=button.click
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
* Set the "mode" parameter to "receive" and "modality" to "video".
|
||||
* The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
|
||||
* The only output allowed is the WebRTC component.
|
||||
* The `trigger` parameter the gradio event that will trigger the webrtc connection. In this case, the button click event.
|
||||
|
||||
## Streaming Audio from the Server to the Client
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from pydub import AudioSegment
|
||||
|
||||
def generation(num_steps):
|
||||
for _ in range(num_steps):
|
||||
segment = AudioSegment.from_file("/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav")
|
||||
yield (segment.frame_rate, np.array(segment.get_array_of_samples()).reshape(1, -1))
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
audio = WebRTC(label="Stream", mode="receive", modality="audio")
|
||||
num_steps = gr.Slider(
|
||||
label="Number of Steps",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=1,
|
||||
value=5,
|
||||
)
|
||||
button = gr.Button("Generate")
|
||||
|
||||
audio.stream(
|
||||
fn=generation, inputs=[num_steps], outputs=[audio],
|
||||
trigger=button.click
|
||||
)
|
||||
```
|
||||
|
||||
* Set the "mode" parameter to "receive" and "modality" to "audio".
|
||||
* The `stream` event's `fn` parameter is a generator function that yields the next audio segment as a tuple of (frame_rate, audio_samples).
|
||||
* The numpy array should be of shape (1, num_samples).
|
||||
* The `outputs` parameter should be a list with the WebRTC component as the only element.
|
||||
|
||||
## Deployment
|
||||
|
||||
When deploying in a cloud environment (like Hugging Face Spaces, EC2, etc), you need to set up a TURN server to relay the WebRTC traffic.
|
||||
The easiest way to do this is to use a service like Twilio.
|
||||
|
||||
```python
|
||||
from twilio.rest import Client
|
||||
import os
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
...
|
||||
rtc = WebRTC(rtc_configuration=rtc_configuration, ...)
|
||||
...
|
||||
```
|
||||
""",
|
||||
elem_classes=["md-custom"],
|
||||
header_links=True,
|
||||
)
|
||||
|
||||
gr.Markdown(
|
||||
"""
|
||||
##
|
||||
""",
|
||||
elem_classes=["md-custom"],
|
||||
header_links=True,
|
||||
)
|
||||
|
||||
gr.ParamViewer(value=_docs["WebRTC"]["members"]["__init__"], linkify=[])
|
||||
|
||||
demo.load(
|
||||
None,
|
||||
js=r"""function() {
|
||||
const refs = {};
|
||||
const user_fn_refs = {
|
||||
WebRTC: [], };
|
||||
requestAnimationFrame(() => {
|
||||
|
||||
Object.entries(user_fn_refs).forEach(([key, refs]) => {
|
||||
if (refs.length > 0) {
|
||||
const el = document.querySelector(`.${key}-user-fn`);
|
||||
if (!el) return;
|
||||
refs.forEach(ref => {
|
||||
el.innerHTML = el.innerHTML.replace(
|
||||
new RegExp("\\b"+ref+"\\b", "g"),
|
||||
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
Object.entries(refs).forEach(([key, refs]) => {
|
||||
if (refs.length > 0) {
|
||||
const el = document.querySelector(`.${key}`);
|
||||
if (!el) return;
|
||||
refs.forEach(ref => {
|
||||
el.innerHTML = el.innerHTML.replace(
|
||||
new RegExp("\\b"+ref+"\\b", "g"),
|
||||
`<a href="#h-${ref.toLowerCase()}">${ref}</a>`
|
||||
);
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
""",
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
53
demo/stream_whisper.py
Normal file
53
demo/stream_whisper.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import tempfile
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from gradio_webrtc import AdditionalOutputs, ReplyOnPause, WebRTC
|
||||
from openai import OpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
def transcribe(audio: tuple[int, np.ndarray], transcript: list[dict]):
|
||||
print("audio", audio)
|
||||
segment = AudioSegment(
|
||||
audio[1].tobytes(),
|
||||
frame_rate=audio[0],
|
||||
sample_width=audio[1].dtype.itemsize,
|
||||
channels=1,
|
||||
)
|
||||
|
||||
transcript.append({"role": "user", "content": gr.Audio((audio[0], audio[1].squeeze()))})
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp3") as temp_audio:
|
||||
segment.export(temp_audio.name, format="mp3")
|
||||
next_chunk = client.audio.transcriptions.create(
|
||||
model="whisper-1", file=open(temp_audio.name, "rb")
|
||||
).text
|
||||
transcript.append({"role": "assistant", "content": next_chunk})
|
||||
yield AdditionalOutputs(transcript)
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
mode="send",
|
||||
modality="audio",
|
||||
)
|
||||
with gr.Column():
|
||||
transcript = gr.Chatbot(label="transcript", type="messages")
|
||||
|
||||
audio.stream(ReplyOnPause(transcribe), inputs=[audio, transcript], outputs=[audio],
|
||||
time_limit=30)
|
||||
audio.on_additional_outputs(lambda s: s, outputs=transcript)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
237
demo/utils.py
Normal file
237
demo/utils.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
class_names = [
|
||||
"person",
|
||||
"bicycle",
|
||||
"car",
|
||||
"motorcycle",
|
||||
"airplane",
|
||||
"bus",
|
||||
"train",
|
||||
"truck",
|
||||
"boat",
|
||||
"traffic light",
|
||||
"fire hydrant",
|
||||
"stop sign",
|
||||
"parking meter",
|
||||
"bench",
|
||||
"bird",
|
||||
"cat",
|
||||
"dog",
|
||||
"horse",
|
||||
"sheep",
|
||||
"cow",
|
||||
"elephant",
|
||||
"bear",
|
||||
"zebra",
|
||||
"giraffe",
|
||||
"backpack",
|
||||
"umbrella",
|
||||
"handbag",
|
||||
"tie",
|
||||
"suitcase",
|
||||
"frisbee",
|
||||
"skis",
|
||||
"snowboard",
|
||||
"sports ball",
|
||||
"kite",
|
||||
"baseball bat",
|
||||
"baseball glove",
|
||||
"skateboard",
|
||||
"surfboard",
|
||||
"tennis racket",
|
||||
"bottle",
|
||||
"wine glass",
|
||||
"cup",
|
||||
"fork",
|
||||
"knife",
|
||||
"spoon",
|
||||
"bowl",
|
||||
"banana",
|
||||
"apple",
|
||||
"sandwich",
|
||||
"orange",
|
||||
"broccoli",
|
||||
"carrot",
|
||||
"hot dog",
|
||||
"pizza",
|
||||
"donut",
|
||||
"cake",
|
||||
"chair",
|
||||
"couch",
|
||||
"potted plant",
|
||||
"bed",
|
||||
"dining table",
|
||||
"toilet",
|
||||
"tv",
|
||||
"laptop",
|
||||
"mouse",
|
||||
"remote",
|
||||
"keyboard",
|
||||
"cell phone",
|
||||
"microwave",
|
||||
"oven",
|
||||
"toaster",
|
||||
"sink",
|
||||
"refrigerator",
|
||||
"book",
|
||||
"clock",
|
||||
"vase",
|
||||
"scissors",
|
||||
"teddy bear",
|
||||
"hair drier",
|
||||
"toothbrush",
|
||||
]
|
||||
|
||||
# Create a list of colors for each class where each color is a tuple of 3 integer values
|
||||
rng = np.random.default_rng(3)
|
||||
colors = rng.uniform(0, 255, size=(len(class_names), 3))
|
||||
|
||||
|
||||
def nms(boxes, scores, iou_threshold):
|
||||
# Sort by score
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
|
||||
keep_boxes = []
|
||||
while sorted_indices.size > 0:
|
||||
# Pick the last box
|
||||
box_id = sorted_indices[0]
|
||||
keep_boxes.append(box_id)
|
||||
|
||||
# Compute IoU of the picked box with the rest
|
||||
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
|
||||
|
||||
# Remove boxes with IoU over the threshold
|
||||
keep_indices = np.where(ious < iou_threshold)[0]
|
||||
|
||||
# print(keep_indices.shape, sorted_indices.shape)
|
||||
sorted_indices = sorted_indices[keep_indices + 1]
|
||||
|
||||
return keep_boxes
|
||||
|
||||
|
||||
def multiclass_nms(boxes, scores, class_ids, iou_threshold):
|
||||
unique_class_ids = np.unique(class_ids)
|
||||
|
||||
keep_boxes = []
|
||||
for class_id in unique_class_ids:
|
||||
class_indices = np.where(class_ids == class_id)[0]
|
||||
class_boxes = boxes[class_indices, :]
|
||||
class_scores = scores[class_indices]
|
||||
|
||||
class_keep_boxes = nms(class_boxes, class_scores, iou_threshold)
|
||||
keep_boxes.extend(class_indices[class_keep_boxes])
|
||||
|
||||
return keep_boxes
|
||||
|
||||
|
||||
def compute_iou(box, boxes):
|
||||
# Compute xmin, ymin, xmax, ymax for both boxes
|
||||
xmin = np.maximum(box[0], boxes[:, 0])
|
||||
ymin = np.maximum(box[1], boxes[:, 1])
|
||||
xmax = np.minimum(box[2], boxes[:, 2])
|
||||
ymax = np.minimum(box[3], boxes[:, 3])
|
||||
|
||||
# Compute intersection area
|
||||
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
|
||||
|
||||
# Compute union area
|
||||
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
||||
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
union_area = box_area + boxes_area - intersection_area
|
||||
|
||||
# Compute IoU
|
||||
iou = intersection_area / union_area
|
||||
|
||||
return iou
|
||||
|
||||
|
||||
def xywh2xyxy(x):
|
||||
# Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
|
||||
y = np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2
|
||||
return y
|
||||
|
||||
|
||||
def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
|
||||
det_img = image.copy()
|
||||
|
||||
img_height, img_width = image.shape[:2]
|
||||
font_size = min([img_height, img_width]) * 0.0006
|
||||
text_thickness = int(min([img_height, img_width]) * 0.001)
|
||||
|
||||
# det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
|
||||
|
||||
# Draw bounding boxes and labels of detections
|
||||
for class_id, box, score in zip(class_ids, boxes, scores):
|
||||
color = colors[class_id]
|
||||
|
||||
draw_box(det_img, box, color)
|
||||
|
||||
label = class_names[class_id]
|
||||
caption = f"{label} {int(score * 100)}%"
|
||||
draw_text(det_img, caption, box, color, font_size, text_thickness)
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
def draw_box(
|
||||
image: np.ndarray,
|
||||
box: np.ndarray,
|
||||
color: tuple[int, int, int] = (0, 0, 255),
|
||||
thickness: int = 2,
|
||||
) -> np.ndarray:
|
||||
x1, y1, x2, y2 = box.astype(int)
|
||||
return cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
|
||||
|
||||
|
||||
def draw_text(
|
||||
image: np.ndarray,
|
||||
text: str,
|
||||
box: np.ndarray,
|
||||
color: tuple[int, int, int] = (0, 0, 255),
|
||||
font_size: float = 0.001,
|
||||
text_thickness: int = 2,
|
||||
) -> np.ndarray:
|
||||
x1, y1, x2, y2 = box.astype(int)
|
||||
(tw, th), _ = cv2.getTextSize(
|
||||
text=text,
|
||||
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
||||
fontScale=font_size,
|
||||
thickness=text_thickness,
|
||||
)
|
||||
th = int(th * 1.2)
|
||||
|
||||
cv2.rectangle(image, (x1, y1), (x1 + tw, y1 - th), color, -1)
|
||||
|
||||
return cv2.putText(
|
||||
image,
|
||||
text,
|
||||
(x1, y1),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
font_size,
|
||||
(255, 255, 255),
|
||||
text_thickness,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
|
||||
def draw_masks(
|
||||
image: np.ndarray, boxes: np.ndarray, classes: np.ndarray, mask_alpha: float = 0.3
|
||||
) -> np.ndarray:
|
||||
mask_img = image.copy()
|
||||
|
||||
# Draw bounding boxes and labels of detections
|
||||
for box, class_id in zip(boxes, classes):
|
||||
color = colors[class_id]
|
||||
|
||||
x1, y1, x2, y2 = box.astype(int)
|
||||
|
||||
# Draw fill rectangle in mask image
|
||||
cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1)
|
||||
|
||||
return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0)
|
||||
65
demo/video_out.py
Normal file
65
demo/video_out.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
from twilio.rest import Client
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
if account_sid and auth_token:
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
else:
|
||||
rtc_configuration = None
|
||||
|
||||
|
||||
def generation(input_video):
|
||||
cap = cv2.VideoCapture(input_video)
|
||||
|
||||
iterating = True
|
||||
|
||||
while iterating:
|
||||
iterating, frame = cap.read()
|
||||
|
||||
# flip frame vertically
|
||||
frame = cv2.flip(frame, 0)
|
||||
display_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
yield display_frame
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Video Streaming (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_video = gr.Video(sources="upload")
|
||||
with gr.Column():
|
||||
output_video = WebRTC(
|
||||
label="Video Stream",
|
||||
rtc_configuration=rtc_configuration,
|
||||
mode="receive",
|
||||
modality="video",
|
||||
)
|
||||
output_video.stream(
|
||||
fn=generation,
|
||||
inputs=[input_video],
|
||||
outputs=[output_video],
|
||||
trigger=input_video.upload,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
54
demo/video_out_stream.py
Normal file
54
demo/video_out_stream.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
from twilio.rest import Client
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
if account_sid and auth_token:
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
else:
|
||||
rtc_configuration = None
|
||||
|
||||
|
||||
def generation():
|
||||
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
|
||||
cap = cv2.VideoCapture(url)
|
||||
iterating = True
|
||||
while iterating:
|
||||
iterating, frame = cap.read()
|
||||
yield frame
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Video Streaming (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
output_video = WebRTC(
|
||||
label="Video Stream",
|
||||
rtc_configuration=rtc_configuration,
|
||||
mode="receive",
|
||||
modality="video",
|
||||
)
|
||||
button = gr.Button("Start", variant="primary")
|
||||
output_video.stream(
|
||||
fn=generation, inputs=None, outputs=[output_video], trigger=button.click
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
100
demo/video_send_output.py
Normal file
100
demo/video_send_output.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import gradio as gr
|
||||
from gradio_webrtc import AdditionalOutputs, WebRTC
|
||||
from huggingface_hub import hf_hub_download
|
||||
from inference import YOLOv10
|
||||
from twilio.rest import Client
|
||||
|
||||
# Configure the root logger to WARNING to suppress debug messages from other libraries
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
# Create a console handler
|
||||
console_handler = logging.FileHandler("gradio_webrtc.log")
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
|
||||
# Create a formatter
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Configure the logger for your specific library
|
||||
logger = logging.getLogger("gradio_webrtc")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
model_file = hf_hub_download(
|
||||
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
|
||||
)
|
||||
|
||||
model = YOLOv10(model_file)
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
if account_sid and auth_token:
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
else:
|
||||
rtc_configuration = None
|
||||
|
||||
|
||||
def detection(frame, conf_threshold=0.3):
|
||||
print("frame.shape", frame.shape)
|
||||
frame = cv2.flip(frame, 0)
|
||||
return AdditionalOutputs(1)
|
||||
|
||||
|
||||
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
|
||||
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
|
||||
|
||||
with gr.Blocks(css=css) as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
YOLOv10 Webcam Stream (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
gr.HTML(
|
||||
"""
|
||||
<h3 style='text-align: center'>
|
||||
<a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
|
||||
</h3>
|
||||
"""
|
||||
)
|
||||
with gr.Column(elem_classes=["my-column"]):
|
||||
with gr.Group(elem_classes=["my-group"]):
|
||||
image = WebRTC(
|
||||
label="Stream", rtc_configuration=rtc_configuration,
|
||||
mode="send",
|
||||
track_constraints={"width": {"exact": 800},
|
||||
"height": {"exact": 600},
|
||||
"aspectRatio": {"exact": 1.33333}
|
||||
},
|
||||
rtp_params={"degradationPreference": "maintain-resolution"}
|
||||
)
|
||||
conf_threshold = gr.Slider(
|
||||
label="Confidence Threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
step=0.05,
|
||||
value=0.30,
|
||||
)
|
||||
number = gr.Number()
|
||||
|
||||
image.stream(
|
||||
fn=detection, inputs=[image, conf_threshold], outputs=[image], time_limit=10
|
||||
)
|
||||
image.on_additional_outputs(lambda n: n, outputs=number)
|
||||
|
||||
demo.launch()
|
||||
160
docs/advanced-configuration.md
Normal file
160
docs/advanced-configuration.md
Normal file
@@ -0,0 +1,160 @@
|
||||
## Track Constraints
|
||||
|
||||
You can specify the `track_constraints` parameter to control how the data is streamed to the server. The full documentation on track constraints is [here](https://developer.mozilla.org/en-US/docs/Web/API/MediaTrackConstraints#constraints).
|
||||
|
||||
For example, you can control the size of the frames captured from the webcam like so:
|
||||
|
||||
```python
|
||||
track_constraints = {
|
||||
"width": {"exact": 500},
|
||||
"height": {"exact": 500},
|
||||
"frameRate": {"ideal": 30},
|
||||
}
|
||||
webrtc = WebRTC(track_constraints=track_constraints,
|
||||
modality="video",
|
||||
mode="send-receive")
|
||||
```
|
||||
|
||||
|
||||
!!! warning
|
||||
|
||||
WebRTC may not enforce your constaints. For example, it may rescale your video
|
||||
(while keeping the same resolution) in order to maintain the desired (or reach a better) frame rate. If you
|
||||
really want to enforce height, width and resolution constraints, use the `rtp_params` parameter as set `"degradationPreference": "maintain-resolution"`.
|
||||
|
||||
```python
|
||||
image = WebRTC(
|
||||
label="Stream",
|
||||
mode="send",
|
||||
track_constraints=track_constraints,
|
||||
rtp_params={"degradationPreference": "maintain-resolution"}
|
||||
)
|
||||
```
|
||||
|
||||
## The RTC Configuration
|
||||
|
||||
You can configure how the connection is created on the client by passing an `rtc_configuration` parameter to the `WebRTC` component constructor.
|
||||
See the list of available arguments [here](https://developer.mozilla.org/en-US/docs/Web/API/RTCPeerConnection/RTCPeerConnection#configuration).
|
||||
|
||||
When deploying on a remote server, an `rtc_configuration` parameter must be passed in. See [Deployment](/deployment).
|
||||
|
||||
## Reply on Pause Voice-Activity-Detection
|
||||
|
||||
The `ReplyOnPause` class runs a Voice Activity Detection (VAD) algorithm to determine when a user has stopped speaking.
|
||||
|
||||
1. First, the algorithm determines when the user has started speaking.
|
||||
2. Then it groups the audio into chunks.
|
||||
3. On each chunk, we determine the length of human speech in the chunk.
|
||||
4. If the length of human speech is below a threshold, a pause is detected.
|
||||
|
||||
The following parameters control this argument:
|
||||
|
||||
```python
|
||||
from gradio_webrtc import AlgoOptions, ReplyOnPause, WebRTC
|
||||
|
||||
options = AlgoOptions(audio_chunk_duration=0.6, # (1)
|
||||
started_talking_threshold=0.2, # (2)
|
||||
speech_threshold=0.1, # (3)
|
||||
)
|
||||
|
||||
with gr.Blocks as demo:
|
||||
audio = WebRTC(...)
|
||||
audio.stream(ReplyOnPause(..., algo_options=algo_options)
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. This is the length (in seconds) of audio chunks.
|
||||
2. If the chunk has more than 0.2 seconds of speech, the user started talking.
|
||||
3. If, after the user started speaking, there is a chunk with less than 0.1 seconds of speech, the user stopped speaking.
|
||||
|
||||
|
||||
## Stream Handler Input Audio
|
||||
|
||||
You can configure the sampling rate of the audio passed to the `ReplyOnPause` or `StreamHandler` instance with the `input_sampling_rate` parameter. The current default is `48000`
|
||||
|
||||
```python
|
||||
from gradio_webrtc import ReplyOnPause, WebRTC
|
||||
|
||||
with gr.Blocks as demo:
|
||||
audio = WebRTC(...)
|
||||
audio.stream(ReplyOnPause(..., input_sampling_rate=24000)
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
|
||||
## Stream Handler Output Audio
|
||||
|
||||
You can configure the output audio chunk size of `ReplyOnPause` (and any `StreamHandler`)
|
||||
with the `output_sample_rate` and `output_frame_size` parameters.
|
||||
|
||||
The following code (which uses the default values of these parameters), states that each output chunk will be a frame of 960 samples at a frame rate of `24,000` hz. So it will correspond to `0.04` seconds.
|
||||
|
||||
```python
|
||||
from gradio_webrtc import ReplyOnPause, WebRTC
|
||||
|
||||
with gr.Blocks as demo:
|
||||
audio = WebRTC(...)
|
||||
audio.stream(ReplyOnPause(..., output_sample_rate=24000, output_frame_size=960)
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
!!! tip
|
||||
|
||||
In general it is best to leave these settings untouched. In some cases,
|
||||
lowering the output_frame_size can yield smoother audio playback.
|
||||
|
||||
|
||||
## Audio Icon
|
||||
|
||||
You can display an icon of your choice instead of the default wave animation for audio streaming.
|
||||
Pass any local path or url to an image (svg, png, jpeg) to the components `icon` parameter. This will display the icon as a circular button. When audio is sent or recevied (depending on the `mode` parameter) a pulse animation will emanate from the button.
|
||||
|
||||
You can control the button color and pulse color with `icon_button_color` and `pulse_color` parameters. They can take any valid css color.
|
||||
|
||||
=== "Code"
|
||||
``` python
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
rtc_configuration=rtc_configuration,
|
||||
mode="receive",
|
||||
modality="audio",
|
||||
icon="phone-solid.svg",
|
||||
)
|
||||
```
|
||||
<img src="https://github.com/user-attachments/assets/fd2e70a3-1698-4805-a8cb-9b7b3bcf2198">
|
||||
=== "Code Custom colors"
|
||||
``` python
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
rtc_configuration=rtc_configuration,
|
||||
mode="receive",
|
||||
modality="audio",
|
||||
icon="phone-solid.svg",
|
||||
icon_button_color="black",
|
||||
pulse_color="black",
|
||||
)
|
||||
```
|
||||
<img src="https://github.com/user-attachments/assets/39e9bb0b-53fb-448e-be44-d37f6785b4b6">
|
||||
|
||||
|
||||
## Changing the Button Text
|
||||
|
||||
You can supply a `button_labels` dictionary to change the text displayed in the `Start`, `Stop` and `Waiting` buttons that are displayed in the UI.
|
||||
The keys must be `"start"`, `"stop"`, and `"waiting"`.
|
||||
|
||||
``` python
|
||||
webrtc = WebRTC(
|
||||
label="Video Chat",
|
||||
modality="audio-video",
|
||||
mode="send-receive",
|
||||
button_labels={"start": "Start Talking to Gemini"}
|
||||
)
|
||||
```
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/04be0b95-189c-4b4b-b8cc-1eb598529dd3" />
|
||||
1
docs/bolt.svg
Normal file
1
docs/bolt.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e8eaed"><path d="m422-232 207-248H469l29-227-185 267h139l-30 208ZM320-80l40-280H160l360-520h80l-40 320h240L400-80h-80Zm151-390Z"/></svg>
|
||||
|
After Width: | Height: | Size: 235 B |
172
docs/cookbook.md
Normal file
172
docs/cookbook.md
Normal file
@@ -0,0 +1,172 @@
|
||||
<div class="grid cards" markdown>
|
||||
|
||||
- :speaking_head:{ .lg .middle }:eyes:{ .lg .middle } __Gemini Audio Video Chat__
|
||||
|
||||
---
|
||||
|
||||
Stream BOTH your webcam video and audio feeds to Google Gemini. You can also upload images to augment your conversation!
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/9636dc97-4fee-46bb-abb8-b92e69c08c71" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/gemini-audio-video-chat)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/gemini-audio-video-chat/blob/main/app.py)
|
||||
|
||||
- :speaking_head:{ .lg .middle } __Google Gemini Real Time Voice API__
|
||||
|
||||
---
|
||||
|
||||
Talk to Gemini in real time using Google's voice API.
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/da8c8a2a-5d99-4ac7-8927-0f7812e4146f" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/gemini-voice)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/gemini-voice/blob/main/app.py)
|
||||
|
||||
- :speaking_head:{ .lg .middle } __OpenAI Real Time Voice API__
|
||||
|
||||
---
|
||||
|
||||
Talk to ChatGPT in real time using OpenAI's voice API.
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/41a63376-43ec-496a-9b31-4f067d3903d6" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/openai-realtime-voice)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/openai-realtime-voice/blob/main/app.py)
|
||||
|
||||
- :speaking_head:{ .lg .middle } __Hello Llama: Stop Word Detection__
|
||||
|
||||
---
|
||||
|
||||
A code editor built with Llama 3.3 70b that is triggered by the phrase "Hello Llama".
|
||||
Build a Siri-like coding assistant in 100 lines of code!
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/3e10cb15-ff1b-4b17-b141-ff0ad852e613" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](hhttps://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/hey-llama-code-editor/blob/main/app.py)
|
||||
|
||||
- :robot:{ .lg .middle } __Llama Code Editor__
|
||||
|
||||
---
|
||||
|
||||
Create and edit HTML pages with just your voice! Powered by SambaNova systems.
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/a09647f1-33e1-4154-a5a3-ffefda8a736a" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/llama-code-editor)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/llama-code-editor/blob/main/app.py)
|
||||
|
||||
- :speaking_head:{ .lg .middle } __Audio Input/Output with mini-omni2__
|
||||
|
||||
---
|
||||
|
||||
Build a GPT-4o like experience with mini-omni2, an audio-native LLM.
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/58c06523-fc38-4f5f-a4ba-a02a28e7fa9e" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/mini-omni2-webrtc)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/mini-omni2-webrtc/blob/main/app.py)
|
||||
|
||||
- :speaking_head:{ .lg .middle } __Talk to Claude__
|
||||
|
||||
---
|
||||
|
||||
Use the Anthropic and Play.Ht APIs to have an audio conversation with Claude.
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/650bc492-798e-4995-8cef-159e1cfc2185" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/talk-to-claude)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-claude/blob/main/app.py)
|
||||
|
||||
- :speaking_head:{ .lg .middle } __Kyutai Moshi__
|
||||
|
||||
---
|
||||
|
||||
Kyutai's moshi is a novel speech-to-speech model for modeling human conversations.
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/becc7a13-9e89-4a19-9df2-5fb1467a0137" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/talk-to-moshi)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-moshi/blob/main/app.py)
|
||||
|
||||
- :speaking_head:{ .lg .middle } __Talk to Ultravox__
|
||||
|
||||
---
|
||||
|
||||
Talk to Fixie.AI's audio-native Ultravox LLM with the transformers library.
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/e6e62482-518c-4021-9047-9da14cd82be1" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/talk-to-ultravox)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-ultravox/blob/main/app.py)
|
||||
|
||||
|
||||
- :speaking_head:{ .lg .middle } __Talk to Llama 3.2 3b__
|
||||
|
||||
---
|
||||
|
||||
Use the Lepton API to make Llama 3.2 talk back to you!
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/3ee37a6b-0892-45f5-b801-73188fdfad9a" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/llama-3.2-3b-voice-webrtc)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/llama-3.2-3b-voice-webrtc/blob/main/app.py)
|
||||
|
||||
|
||||
- :robot:{ .lg .middle } __Talk to Qwen2-Audio__
|
||||
|
||||
---
|
||||
|
||||
Qwen2-Audio is a SOTA audio-to-text LLM developed by Alibaba.
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/c821ad86-44cc-4d0c-8dc4-8c02ad1e5dc8" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/talk-to-qwen-webrtc)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-qwen-webrtc/blob/main/app.py)
|
||||
|
||||
|
||||
- :camera:{ .lg .middle } __Yolov10 Object Detection__
|
||||
|
||||
---
|
||||
|
||||
Run the Yolov10 model on a user webcam stream in real time!
|
||||
|
||||
<video width=98% src="https://github.com/user-attachments/assets/c90d8c9d-d2d5-462e-9e9b-af969f2ea73c" controls style="text-align: center"></video>
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n/blob/main/app.py)
|
||||
|
||||
- :camera:{ .lg .middle } __Video Object Detection with RT-DETR__
|
||||
|
||||
---
|
||||
|
||||
Upload a video and stream out frames with detected objects (powered by RT-DETR) model.
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc/blob/main/app.py)
|
||||
|
||||
- :speaker:{ .lg .middle } __Text-to-Speech with Parler__
|
||||
|
||||
---
|
||||
|
||||
Stream out audio generated by Parler TTS!
|
||||
|
||||
[:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc)
|
||||
|
||||
[:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc/blob/main/app.py)
|
||||
|
||||
|
||||
</div>
|
||||
165
docs/deployment.md
Normal file
165
docs/deployment.md
Normal file
@@ -0,0 +1,165 @@
|
||||
When deploying in a cloud environment (like Hugging Face Spaces, EC2, etc), you need to set up a TURN server to relay the WebRTC traffic.
|
||||
|
||||
## Community Server
|
||||
|
||||
Hugging Face graciously provides a TURN server for the community.
|
||||
In order to use it, you need to first create a Hugging Face account by going to the [huggingface.co](https://huggingface.co/).
|
||||
|
||||
Then navigate to this [space](https://huggingface.co/spaces/freddyaboulton/turn-server-login) and follow the instructions on the page. You just have to click the "Log in" button and then the "Sign Up" button.
|
||||
|
||||

|
||||
|
||||
Then you can use the `get_hf_turn_credentials` helper to get your credentials:
|
||||
|
||||
```python
|
||||
from gradio_webrtc import get_hf_turn_credentials, WebRTC
|
||||
|
||||
# Pass a valid access token for your Hugging Face account
|
||||
# or set the HF_TOKEN environment variable
|
||||
credentials = get_hf_turn_credentials(token=None)
|
||||
|
||||
with gr.Blcocks() as demo:
|
||||
webrtc = WebRTC(rtc_configuration=credentials)
|
||||
...
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
!!! warning
|
||||
|
||||
This is a shared resource so we make no latency/availability guarantees.
|
||||
For more robust options, see the Twilio and self-hosting options below.
|
||||
|
||||
|
||||
## Twilio API
|
||||
|
||||
The easiest way to do this is to use a service like Twilio.
|
||||
|
||||
Create a **free** [account](https://login.twilio.com/u/signup) and the install the `twilio` package with pip (`pip install twilio`). You can then connect from the WebRTC component like so:
|
||||
|
||||
```python
|
||||
from twilio.rest import Client
|
||||
import os
|
||||
|
||||
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
|
||||
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
|
||||
|
||||
client = Client(account_sid, auth_token)
|
||||
|
||||
token = client.tokens.create()
|
||||
|
||||
rtc_configuration = {
|
||||
"iceServers": token.ice_servers,
|
||||
"iceTransportPolicy": "relay",
|
||||
}
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
...
|
||||
rtc = WebRTC(rtc_configuration=rtc_configuration, ...)
|
||||
...
|
||||
```
|
||||
|
||||
!!! tip "Automatic Login"
|
||||
|
||||
You can log in automatically with the `get_twilio_turn_credentials` helper
|
||||
|
||||
```python
|
||||
from gradio_webrtc import get_twilio_turn_credentials
|
||||
|
||||
# Will automatically read the TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN
|
||||
# env variables but you can also pass in the tokens as parameters
|
||||
rtc_configuration = get_twilio_turn_credentials()
|
||||
```
|
||||
|
||||
## Self Hosting
|
||||
|
||||
We have developed a script that can automatically deploy a TURN server to Amazon Web Services (AWS). You can follow the instructions [here](https://github.com/freddyaboulton/turn-server-deploy) or this guide.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Clone the following [repository](https://github.com/freddyaboulton/turn-server-deploy) and install the `aws` cli if you have not done so already (`pip install awscli`).
|
||||
|
||||
Log into your AWS account and create an IAM user with the following permissions:
|
||||
|
||||
- [AWSCloudFormationFullAccess](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/policies/details/arn%3Aaws%3Aiam%3A%3Aaws%3Apolicy%2FAWSCloudFormationFullAccess)
|
||||
- [AmazonEC2FullAccess](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/policies/details/arn%3Aaws%3Aiam%3A%3Aaws%3Apolicy%2FAmazonEC2FullAccess)
|
||||
|
||||
|
||||
Create a key pair for this user and write down the "access key" and "secret access key". Then log into the aws cli with these credentials (`aws configure`).
|
||||
|
||||
Finally, create an ec2 keypair (replace `your-key-name` with the name you want to give it).
|
||||
|
||||
```
|
||||
aws ec2 create-key-pair --key-name your-key-name --query 'KeyMaterial' --output text > your-key-name.pem
|
||||
```
|
||||
|
||||
### Running the script
|
||||
|
||||
Open the `parameters.json` file and fill in the correct values for all the parameters:
|
||||
|
||||
- `KeyName`: The key file we just created, e.g. `your-key-name` (omit `.pem`).
|
||||
- `TurnUserName`: The username needed to connect to the server.
|
||||
- `TurnPassword`: The password needed to connect to the server.
|
||||
- `InstanceType`: One of the following values `t3.micro`, `t3.small`, `t3.medium`, `c4.large`, `c5.large`.
|
||||
|
||||
|
||||
Then run the deployment script:
|
||||
|
||||
```bash
|
||||
aws cloudformation create-stack \
|
||||
--stack-name turn-server \
|
||||
--template-body file://deployment.yml \
|
||||
--parameters file://parameters.json \
|
||||
--capabilities CAPABILITY_IAM
|
||||
```
|
||||
|
||||
You can then wait for the stack to come up with:
|
||||
|
||||
```bash
|
||||
aws cloudformation wait stack-create-complete \
|
||||
--stack-name turn-server
|
||||
```
|
||||
|
||||
Next, grab your EC2 server's public ip with:
|
||||
|
||||
```
|
||||
aws cloudformation describe-stacks \
|
||||
--stack-name turn-server \
|
||||
--query 'Stacks[0].Outputs' > server-info.json
|
||||
```
|
||||
|
||||
The `server-info.json` file will have the server's public IP and public DNS:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"OutputKey": "PublicIP",
|
||||
"OutputValue": "35.173.254.80",
|
||||
"Description": "Public IP address of the TURN server"
|
||||
},
|
||||
{
|
||||
"OutputKey": "PublicDNS",
|
||||
"OutputValue": "ec2-35-173-254-80.compute-1.amazonaws.com",
|
||||
"Description": "Public DNS name of the TURN server"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Finally, you can connect to your EC2 server from the gradio WebRTC component via the `rtc_configuration` argument:
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
rtc_configuration = {
|
||||
"iceServers": [
|
||||
{
|
||||
"urls": "turn:35.173.254.80:80",
|
||||
"username": "<my-username>",
|
||||
"credential": "<my-password>"
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
webrtc = WebRTC(rtc_configuration=rtc_configuration)
|
||||
```
|
||||
67
docs/faq.md
Normal file
67
docs/faq.md
Normal file
@@ -0,0 +1,67 @@
|
||||
## Demo does not work when deploying to the cloud
|
||||
|
||||
Make sure you are using a TURN server. See [deployment](/deployment).
|
||||
|
||||
## Recorded input audio sounds muffled during output audio playback
|
||||
|
||||
By default, the microphone is [configured](https://github.com/freddyaboulton/gradio-webrtc/blob/903f1f70bd586f638ad3b5a3940c7a8ec70ad1f5/backend/gradio_webrtc/webrtc.py#L575) to do echoCancellation.
|
||||
This is what's causing the recorded audio to sound muffled when the streamed audio starts playing.
|
||||
You can disable this via the `track_constraints` (see [advanced configuration](./advanced-configuration])) with the following code:
|
||||
|
||||
```python
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
track_constraints={
|
||||
"echoCancellation": False,
|
||||
"noiseSuppression": {"exact": True},
|
||||
"autoGainControl": {"exact": True},
|
||||
"sampleRate": {"ideal": 24000},
|
||||
"sampleSize": {"ideal": 16},
|
||||
"channelCount": {"exact": 1},
|
||||
},
|
||||
rtc_configuration=None,
|
||||
mode="send-receive",
|
||||
modality="audio",
|
||||
)
|
||||
```
|
||||
|
||||
## How to raise errors in the UI
|
||||
|
||||
You can raise `WebRTCError` in order for an error message to show up in the user's screen. This is similar to how `gr.Error` works.
|
||||
|
||||
Here is a simple example:
|
||||
|
||||
```python
|
||||
def generation(num_steps):
|
||||
for _ in range(num_steps):
|
||||
segment = AudioSegment.from_file(
|
||||
"/Users/freddy/sources/gradio/demo/audio_debugger/cantina.wav"
|
||||
)
|
||||
yield (
|
||||
segment.frame_rate,
|
||||
np.array(segment.get_array_of_samples()).reshape(1, -1),
|
||||
)
|
||||
time.sleep(3.5)
|
||||
raise WebRTCError("This is a test error")
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
mode="receive",
|
||||
modality="audio",
|
||||
)
|
||||
num_steps = gr.Slider(
|
||||
label="Number of Steps",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=1,
|
||||
value=5,
|
||||
)
|
||||
button = gr.Button("Generate")
|
||||
|
||||
audio.stream(
|
||||
fn=generation, inputs=[num_steps], outputs=[audio], trigger=button.click
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
30
docs/index.md
Normal file
30
docs/index.md
Normal file
@@ -0,0 +1,30 @@
|
||||
<h1 style='text-align: center; margin-bottom: 1rem; color: white;'> Gradio WebRTC ⚡️ </h1>
|
||||
|
||||
<div style="display: flex; flex-direction: row; justify-content: center">
|
||||
<img style="display: block; padding-right: 5px; height: 20px;" alt="Static Badge" src="https://img.shields.io/pypi/v/gradio_webrtc">
|
||||
<a href="https://github.com/freddyaboulton/gradio-webrtc" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/github-white?logo=github&logoColor=black"></a>
|
||||
</div>
|
||||
|
||||
<h3 style='text-align: center'>
|
||||
Stream video and audio in real time with Gradio using WebRTC.
|
||||
</h3>
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install gradio_webrtc
|
||||
```
|
||||
|
||||
to use built-in pause detection (see [ReplyOnPause](/user-guide/#reply-on-pause)), install the `vad` extra:
|
||||
|
||||
```bash
|
||||
pip install gradio_webrtc[vad]
|
||||
```
|
||||
|
||||
For stop word detection (see [ReplyOnStopWords](/user-guide/#reply-on-stopwords)), install the `stopword` extra:
|
||||
```bash
|
||||
pip install gradio_webrtc[stopword]
|
||||
```
|
||||
|
||||
## Examples
|
||||
See the [cookbook](/cookbook)
|
||||
505
docs/user-guide.md
Normal file
505
docs/user-guide.md
Normal file
@@ -0,0 +1,505 @@
|
||||
# User Guide
|
||||
|
||||
To get started with WebRTC streams, all that's needed is to import the `WebRTC` component from this package and implement its `stream` event.
|
||||
|
||||
This page will show how to do so with simple code examples.
|
||||
For complete implementations of common tasks, see the [cookbook](/cookbook).
|
||||
|
||||
## Audio Streaming
|
||||
|
||||
### Reply on Pause
|
||||
|
||||
Typically, you want to run an AI model that generates audio when the user has stopped speaking. This can be done by wrapping a python generator with the `ReplyOnPause` class
|
||||
and passing it to the `stream` event of the `WebRTC` component.
|
||||
|
||||
=== "Code"
|
||||
``` py title="ReplyonPause"
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC, ReplyOnPause
|
||||
|
||||
def response(audio: tuple[int, np.ndarray]): # (1)
|
||||
"""This function must yield audio frames"""
|
||||
...
|
||||
for numpy_array in generated_audio:
|
||||
yield (sampling_rate, numpy_array, "mono") # (2)
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Chat (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
audio = WebRTC(
|
||||
mode="send-receive", # (3)
|
||||
modality="audio",
|
||||
)
|
||||
audio.stream(fn=ReplyOnPause(response),
|
||||
inputs=[audio], outputs=[audio], # (4)
|
||||
time_limit=60) # (5)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. The python generator will receive the **entire** audio up until the user stopped. It will be a tuple of the form (sampling_rate, numpy array of audio). The array will have a shape of (1, num_samples). You can also pass in additional input components.
|
||||
|
||||
2. The generator must yield audio chunks as a tuple of (sampling_rate, numpy audio array). Each numpy audio array must have a shape of (1, num_samples).
|
||||
|
||||
3. The `mode` and `modality` arguments must be set to `"send-receive"` and `"audio"`.
|
||||
|
||||
4. The `WebRTC` component must be the first input and output component.
|
||||
|
||||
5. Set a `time_limit` to control how long a conversation will last. If the `concurrency_count` is 1 (default), only one conversation will be handled at a time.
|
||||
=== "Notes"
|
||||
1. The python generator will receive the **entire** audio up until the user stopped. It will be a tuple of the form (sampling_rate, numpy array of audio). The array will have a shape of (1, num_samples). You can also pass in additional input components.
|
||||
|
||||
2. The generator must yield audio chunks as a tuple of (sampling_rate, numpy audio arrays). Each numpy audio array must have a shape of (1, num_samples).
|
||||
|
||||
3. The `mode` and `modality` arguments must be set to `"send-receive"` and `"audio"`.
|
||||
|
||||
4. The `WebRTC` component must be the first input and output component.
|
||||
|
||||
5. Set a `time_limit` to control how long a conversation will last. If the `concurrency_count` is 1 (default), only one conversation will be handled at a time.
|
||||
|
||||
|
||||
### Reply On Stopwords
|
||||
|
||||
You can configure your AI model to run whenever a set of "stop words" are detected, like "Hey Siri" or "computer", with the `ReplyOnStopWords` class.
|
||||
|
||||
The API is similar to `ReplyOnPause` with the addition of a `stop_words` parameter.
|
||||
|
||||
=== "Code"
|
||||
``` py title="ReplyonPause"
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC, ReplyOnPause
|
||||
|
||||
def response(audio: tuple[int, np.ndarray]):
|
||||
"""This function must yield audio frames"""
|
||||
...
|
||||
for numpy_array in generated_audio:
|
||||
yield (sampling_rate, numpy_array, "mono")
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Chat (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
audio = WebRTC(
|
||||
mode="send",
|
||||
modality="audio",
|
||||
)
|
||||
webrtc.stream(ReplyOnStopWords(generate,
|
||||
input_sample_rate=16000,
|
||||
stop_words=["computer"]), # (1)
|
||||
inputs=[webrtc, history, code],
|
||||
outputs=[webrtc], time_limit=90,
|
||||
concurrency_limit=10)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris".
|
||||
|
||||
=== "Notes"
|
||||
1. The `stop_words` can be single words or pairs of words. Be sure to include common misspellings of your word for more robust detection, e.g. "llama", "lamma". In my experience, it's best to use two very distinct words like "ok computer" or "hello iris".
|
||||
|
||||
### Stream Handler
|
||||
|
||||
`ReplyOnPause` is an implementation of a `StreamHandler`. The `StreamHandler` is a low-level
|
||||
abstraction that gives you arbitrary control over how the input audio stream and output audio stream are created. The following example echos back the user audio.
|
||||
|
||||
=== "Code"
|
||||
``` py title="Stream Handler"
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC, StreamHandler
|
||||
from queue import Queue
|
||||
|
||||
class EchoHandler(StreamHandler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.queue = Queue()
|
||||
|
||||
def receive(self, frame: tuple[int, np.ndarray]) -> None: # (1)
|
||||
self.queue.put(frame)
|
||||
|
||||
def emit(self) -> None: # (2)
|
||||
return self.queue.get()
|
||||
|
||||
def copy(self) -> StreamHandler:
|
||||
return EchoHandler()
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
audio = WebRTC(
|
||||
mode="send-receive",
|
||||
modality="audio",
|
||||
)
|
||||
|
||||
audio.stream(fn=EchoHandler(),
|
||||
inputs=[audio], outputs=[audio],
|
||||
time_limit=15)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
|
||||
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
|
||||
|
||||
=== "Notes"
|
||||
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
|
||||
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
|
||||
|
||||
|
||||
### Async Stream Handlers
|
||||
|
||||
It is also possible to create asynchronous stream handlers. This is very convenient for accessing async APIs from major LLM developers, like Google and OpenAI. The main difference is that `receive` and `emit` are now defined with `async def`.
|
||||
|
||||
Here is a complete example of using `AsyncStreamHandler` for using the Google Gemini real time API:
|
||||
|
||||
=== "Code"
|
||||
``` py title="AsyncStreamHandler"
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from google import genai
|
||||
from gradio_webrtc import (
|
||||
AsyncStreamHandler,
|
||||
WebRTC,
|
||||
async_aggregate_bytes_to_16bit,
|
||||
get_twilio_turn_credentials,
|
||||
)
|
||||
|
||||
class GeminiHandler(AsyncStreamHandler):
|
||||
def __init__(
|
||||
self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
|
||||
) -> None:
|
||||
super().__init__(
|
||||
expected_layout,
|
||||
output_sample_rate,
|
||||
output_frame_size,
|
||||
input_sample_rate=16000,
|
||||
)
|
||||
self.client: genai.Client | None = None
|
||||
self.input_queue = asyncio.Queue()
|
||||
self.output_queue = asyncio.Queue()
|
||||
self.quit = asyncio.Event()
|
||||
|
||||
def copy(self) -> "GeminiHandler":
|
||||
return GeminiHandler(
|
||||
expected_layout=self.expected_layout,
|
||||
output_sample_rate=self.output_sample_rate,
|
||||
output_frame_size=self.output_frame_size,
|
||||
)
|
||||
|
||||
async def stream(self):
|
||||
while not self.quit.is_set():
|
||||
audio = await self.input_queue.get()
|
||||
yield audio
|
||||
|
||||
async def connect(self, api_key: str):
|
||||
client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
|
||||
config = {"response_modalities": ["AUDIO"]}
|
||||
async with client.aio.live.connect(
|
||||
model="gemini-2.0-flash-exp", config=config
|
||||
) as session:
|
||||
async for audio in session.start_stream(
|
||||
stream=self.stream(), mime_type="audio/pcm"
|
||||
):
|
||||
if audio.data:
|
||||
yield audio.data
|
||||
|
||||
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
_, array = frame
|
||||
array = array.squeeze()
|
||||
audio_message = base64.b64encode(array.tobytes()).decode("UTF-8")
|
||||
self.input_queue.put_nowait(audio_message)
|
||||
|
||||
async def generator(self):
|
||||
async for audio_response in async_aggregate_bytes_to_16bit(
|
||||
self.connect(api_key=self.latest_args[1])
|
||||
):
|
||||
self.output_queue.put_nowait(audio_response)
|
||||
|
||||
async def emit(self):
|
||||
if not self.args_set.is_set():
|
||||
await self.wait_for_args()
|
||||
asyncio.create_task(self.generator())
|
||||
|
||||
array = await self.output_queue.get()
|
||||
return (self.output_sample_rate, array)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.quit.set()
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<div style='text-align: center'>
|
||||
<h1>Gen AI SDK Voice Chat</h1>
|
||||
<p>Speak with Gemini using real-time audio streaming</p>
|
||||
<p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
|
||||
</div>
|
||||
"""
|
||||
)
|
||||
with gr.Row() as api_key_row:
|
||||
api_key = gr.Textbox(
|
||||
label="API Key",
|
||||
placeholder="Enter your API Key",
|
||||
value=os.getenv("GOOGLE_API_KEY", ""),
|
||||
type="password",
|
||||
)
|
||||
with gr.Row(visible=False) as row:
|
||||
webrtc = WebRTC(
|
||||
label="Audio",
|
||||
modality="audio",
|
||||
mode="send-receive",
|
||||
rtc_configuration=get_twilio_turn_credentials(),
|
||||
pulse_color="rgb(35, 157, 225)",
|
||||
icon_button_color="rgb(35, 157, 225)",
|
||||
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
|
||||
)
|
||||
|
||||
webrtc.stream(
|
||||
GeminiHandler(),
|
||||
inputs=[webrtc, api_key],
|
||||
outputs=[webrtc],
|
||||
time_limit=90,
|
||||
concurrency_limit=2,
|
||||
)
|
||||
api_key.submit(
|
||||
lambda: (gr.update(visible=False), gr.update(visible=True)),
|
||||
None,
|
||||
[api_key_row, row],
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
### Accessing Other Component Values from a StreamHandler
|
||||
|
||||
In the gemini demo above, you'll notice that we have the user input their google API key. This is stored in a `gr.Textbox` parameter.
|
||||
We can access the value of this component via the `latest_args` prop of the `StreamHandler`. The `latest_args` is a list storing the values of each component in the WebRTC `stream` event `inputs` parameter. The value of the `WebRTC` component is the 0th index and it's always the dummy string `__webrtc_value__`.
|
||||
|
||||
In order to fetch the latest value from the user however, we `await self.wait_for_args()`. In a synchronous `StreamHandler`, we would call `self.wait_for_args_sync()`.
|
||||
|
||||
|
||||
### Server-To-Client Only
|
||||
|
||||
To stream only from the server to the client, implement a python generator and pass it to the component's `stream` event. The stream event must also specify a `trigger` corresponding to a UI interaction that starts the stream. In this case, it's a button click.
|
||||
|
||||
=== "Code"
|
||||
|
||||
``` py title="Server-To-CLient"
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
from pydub import AudioSegment
|
||||
|
||||
def generation(num_steps):
|
||||
for _ in range(num_steps):
|
||||
segment = AudioSegment.from_file("audio_file.wav")
|
||||
array = np.array(segment.get_array_of_samples()).reshape(1, -1)
|
||||
yield (segment.frame_rate, array)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
audio = WebRTC(label="Stream", mode="receive", # (1)
|
||||
modality="audio")
|
||||
num_steps = gr.Slider(label="Number of Steps", minimum=1,
|
||||
maximum=10, step=1, value=5)
|
||||
button = gr.Button("Generate")
|
||||
|
||||
audio.stream(
|
||||
fn=generation, inputs=[num_steps], outputs=[audio],
|
||||
trigger=button.click # (2)
|
||||
)
|
||||
```
|
||||
|
||||
1. Set `mode="receive"` to only receive audio from the server.
|
||||
2. The `stream` event must take a `trigger` that corresponds to the gradio event that starts the stream. In this case, it's the button click.
|
||||
=== "Notes"
|
||||
1. Set `mode="receive"` to only receive audio from the server.
|
||||
2. The `stream` event must take a `trigger` that corresponds to the gradio event that starts the stream. In this case, it's the button click.
|
||||
|
||||
## Video Streaming
|
||||
|
||||
### Input/Output Streaming
|
||||
Set up a video Input/Output stream to continuosly receive webcam frames from the user and run an arbitrary python function to return a modified frame.
|
||||
|
||||
=== "Code"
|
||||
|
||||
``` py title="Input/Output Streaming"
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
|
||||
|
||||
def detection(image, conf_threshold=0.3): # (1)
|
||||
... your detection code here ...
|
||||
return modified_frame # (2)
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
image = WebRTC(label="Stream", mode="send-receive", modality="video") # (3)
|
||||
conf_threshold = gr.Slider(
|
||||
label="Confidence Threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
step=0.05,
|
||||
value=0.30,
|
||||
)
|
||||
image.stream(
|
||||
fn=detection,
|
||||
inputs=[image, conf_threshold], # (4)
|
||||
outputs=[image], time_limit=10
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. The webcam frame will be represented as a numpy array of shape (height, width, RGB).
|
||||
2. The function must return a numpy array. It can take arbitrary values from other components.
|
||||
3. Set the `modality="video"` and `mode="send-receive"`
|
||||
4. The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
|
||||
=== "Notes"
|
||||
1. The webcam frame will be represented as a numpy array of shape (height, width, RGB).
|
||||
2. The function must return a numpy array. It can take arbitrary values from other components.
|
||||
3. Set the `modality="video"` and `mode="send-receive"`
|
||||
4. The `inputs` parameter should be a list where the first element is the WebRTC component. The only output allowed is the WebRTC component.
|
||||
|
||||
### Server-to-Client Only
|
||||
|
||||
Set up a server-to-client stream to stream video from an arbitrary user interaction.
|
||||
|
||||
=== "Code"
|
||||
``` py title="Server-To-Client"
|
||||
import gradio as gr
|
||||
from gradio_webrtc import WebRTC
|
||||
import cv2
|
||||
|
||||
def generation():
|
||||
url = "https://download.tsi.telecom-paristech.fr/gpac/dataset/dash/uhd/mux_sources/hevcds_720p30_2M.mp4"
|
||||
cap = cv2.VideoCapture(url)
|
||||
iterating = True
|
||||
while iterating:
|
||||
iterating, frame = cap.read()
|
||||
yield frame # (1)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
output_video = WebRTC(label="Video Stream", mode="receive", # (2)
|
||||
modality="video")
|
||||
button = gr.Button("Start", variant="primary")
|
||||
output_video.stream(
|
||||
fn=generation, inputs=None, outputs=[output_video],
|
||||
trigger=button.click # (3)
|
||||
)
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
|
||||
2. Set `mode="receive"` to only receive audio from the server.
|
||||
3. The `trigger` parameter the gradio event that will trigger the stream. In this case, the button click event.
|
||||
=== "Notes"
|
||||
1. The `stream` event's `fn` parameter is a generator function that yields the next frame from the video as a **numpy array**.
|
||||
2. Set `mode="receive"` to only receive audio from the server.
|
||||
3. The `trigger` parameter the gradio event that will trigger the stream. In this case, the button click event.
|
||||
|
||||
## Audio-Video Streaming
|
||||
|
||||
You can simultaneously stream audio and video simultaneously to/from a server using `AudioVideoStreamHandler` or `AsyncAudioVideoStreamHandler`.
|
||||
They are identical to the audio `StreamHandlers` with the addition of `video_receive` and `video_emit` methods which take and return a `numpy` array, respectively.
|
||||
|
||||
Here is an example of the video handling functions for connecting with the Gemini multimodal API. In this case, we simply reflect the webcam feed back to the user but every second we'll send the latest webcam frame (and an additional image component) to the Gemini server.
|
||||
|
||||
Please see the "Gemini Audio Video Chat" example in the [cookbook](/cookbook) for the complete code.
|
||||
|
||||
``` python title="Async Gemini Video Handling"
|
||||
|
||||
async def video_receive(self, frame: np.ndarray):
|
||||
"""Send video frames to the server"""
|
||||
if self.session:
|
||||
# send image every 1 second
|
||||
# otherwise we flood the API
|
||||
if time.time() - self.last_frame_time > 1:
|
||||
self.last_frame_time = time.time()
|
||||
await self.session.send(encode_image(frame))
|
||||
if self.latest_args[2] is not None:
|
||||
await self.session.send(encode_image(self.latest_args[2]))
|
||||
self.video_queue.put_nowait(frame)
|
||||
|
||||
async def video_emit(self) -> VideoEmitType:
|
||||
"""Return video frames to the client"""
|
||||
return await self.video_queue.get()
|
||||
```
|
||||
|
||||
|
||||
## Additional Outputs
|
||||
|
||||
In order to modify other components from within the WebRTC stream, you must yield an instance of `AdditionalOutputs` and add an `on_additional_outputs` event to the `WebRTC` component.
|
||||
|
||||
This is common for displaying a multimodal text/audio conversation in a Chatbot UI.
|
||||
|
||||
=== "Code"
|
||||
|
||||
``` py title="Additional Outputs"
|
||||
from gradio_webrtc import AdditionalOutputs, WebRTC
|
||||
|
||||
def transcribe(audio: tuple[int, np.ndarray],
|
||||
transformers_convo: list[dict],
|
||||
gradio_convo: list[dict]):
|
||||
response = model.generate(**inputs, max_length=256)
|
||||
transformers_convo.append({"role": "assistant", "content": response})
|
||||
gradio_convo.append({"role": "assistant", "content": response})
|
||||
yield AdditionalOutputs(transformers_convo, gradio_convo) # (1)
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center'>
|
||||
Talk to Qwen2Audio (Powered by WebRTC ⚡️)
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
transformers_convo = gr.State(value=[])
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
audio = WebRTC(
|
||||
label="Stream",
|
||||
mode="send", # (2)
|
||||
modality="audio",
|
||||
)
|
||||
with gr.Column():
|
||||
transcript = gr.Chatbot(label="transcript", type="messages")
|
||||
|
||||
audio.stream(ReplyOnPause(transcribe),
|
||||
inputs=[audio, transformers_convo, transcript],
|
||||
outputs=[audio], time_limit=90)
|
||||
audio.on_additional_outputs(lambda s,a: (s,a), # (3)
|
||||
outputs=[transformers_convo, transcript],
|
||||
queue=False, show_progress="hidden")
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
1. Pass your data to `AdditionalOutputs` and yield it.
|
||||
2. In this case, no audio is being returned, so we set `mode="send"`. However, if we set `mode="send-receive"`, we could also yield generated audio and `AdditionalOutputs`.
|
||||
3. The `on_additional_outputs` event does not take `inputs`. It's common practice to not run this event on the queue since it is just a quick UI update.
|
||||
=== "Notes"
|
||||
1. Pass your data to `AdditionalOutputs` and yield it.
|
||||
2. In this case, no audio is being returned, so we set `mode="send"`. However, if we set `mode="send-receive"`, we could also yield generated audio and `AdditionalOutputs`.
|
||||
3. The `on_additional_outputs` event does not take `inputs`. It's common practice to not run this event on the queue since it is just a quick UI update.
|
||||
54
docs/utils.md
Normal file
54
docs/utils.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Utils
|
||||
|
||||
## `audio_to_bytes`
|
||||
|
||||
Convert an audio tuple containing sample rate and numpy array data into bytes.
|
||||
Useful for sending data to external APIs from `ReplyOnPause` handler.
|
||||
|
||||
Parameters
|
||||
```
|
||||
audio : tuple[int, np.ndarray]
|
||||
A tuple containing:
|
||||
- sample_rate (int): The audio sample rate in Hz
|
||||
- data (np.ndarray): The audio data as a numpy array
|
||||
```
|
||||
|
||||
Returns
|
||||
```
|
||||
bytes
|
||||
The audio data encoded as bytes, suitable for transmission or storage
|
||||
```
|
||||
|
||||
Example
|
||||
```python
|
||||
>>> sample_rate = 44100
|
||||
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
|
||||
>>> audio_tuple = (sample_rate, audio_data)
|
||||
>>> audio_bytes = audio_to_bytes(audio_tuple)
|
||||
```
|
||||
|
||||
## `audio_to_file`
|
||||
|
||||
Save an audio tuple containing sample rate and numpy array data to a file.
|
||||
|
||||
Parameters
|
||||
```
|
||||
audio : tuple[int, np.ndarray]
|
||||
A tuple containing:
|
||||
- sample_rate (int): The audio sample rate in Hz
|
||||
- data (np.ndarray): The audio data as a numpy array
|
||||
```
|
||||
Returns
|
||||
```
|
||||
str
|
||||
The path to the saved audio file
|
||||
```
|
||||
Example
|
||||
```
|
||||
```python
|
||||
>>> sample_rate = 44100
|
||||
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
|
||||
>>> audio_tuple = (sample_rate, audio_data)
|
||||
>>> file_path = audio_to_file(audio_tuple)
|
||||
>>> print(f"Audio saved to: {file_path}")
|
||||
```
|
||||
73
frontend/Example.svelte
Normal file
73
frontend/Example.svelte
Normal file
@@ -0,0 +1,73 @@
|
||||
<script lang="ts">
|
||||
import { playable } from "./shared/utils";
|
||||
import { type FileData } from "@gradio/client";
|
||||
|
||||
export let type: "gallery" | "table";
|
||||
export let selected = false;
|
||||
export let value: { video: FileData; subtitles: FileData | null } | null;
|
||||
export let loop: boolean;
|
||||
let video: HTMLVideoElement;
|
||||
|
||||
async function init(): Promise<void> {
|
||||
video.muted = true;
|
||||
video.playsInline = true;
|
||||
video.controls = false;
|
||||
video.setAttribute("muted", "");
|
||||
|
||||
await video.play();
|
||||
video.pause();
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if value}
|
||||
{#if playable()}
|
||||
<div
|
||||
class="container"
|
||||
class:table={type === "table"}
|
||||
class:gallery={type === "gallery"}
|
||||
class:selected
|
||||
>
|
||||
<video
|
||||
bind:this={video}
|
||||
on:loadeddata={init}
|
||||
on:mouseover={video.play.bind(video)}
|
||||
on:mouseout={video.pause.bind(video)}
|
||||
src={value?.video.url}
|
||||
/>
|
||||
</div>
|
||||
{:else}
|
||||
<div>{value}</div>
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.container {
|
||||
flex: none;
|
||||
max-width: none;
|
||||
}
|
||||
.container :global(video) {
|
||||
width: var(--size-full);
|
||||
height: var(--size-full);
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.container:hover,
|
||||
.container.selected {
|
||||
border-color: var(--border-color-accent);
|
||||
}
|
||||
.container.table {
|
||||
margin: 0 auto;
|
||||
border: 2px solid var(--border-color-primary);
|
||||
border-radius: var(--radius-lg);
|
||||
overflow: hidden;
|
||||
width: var(--size-20);
|
||||
height: var(--size-20);
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.container.gallery {
|
||||
height: var(--size-20);
|
||||
max-height: var(--size-20);
|
||||
object-fit: cover;
|
||||
}
|
||||
</style>
|
||||
158
frontend/Index.svelte
Normal file
158
frontend/Index.svelte
Normal file
@@ -0,0 +1,158 @@
|
||||
<svelte:options accessors={true} />
|
||||
|
||||
<script lang="ts">
|
||||
import { Block, UploadText } from "@gradio/atoms";
|
||||
import Video from "./shared/InteractiveVideo.svelte";
|
||||
import { StatusTracker } from "@gradio/statustracker";
|
||||
import type { LoadingStatus } from "@gradio/statustracker";
|
||||
import StaticVideo from "./shared/StaticVideo.svelte";
|
||||
import StaticAudio from "./shared/StaticAudio.svelte";
|
||||
import InteractiveAudio from "./shared/InteractiveAudio.svelte";
|
||||
|
||||
export let elem_id = "";
|
||||
export let elem_classes: string[] = [];
|
||||
export let visible = true;
|
||||
export let value: string = "__webrtc_value__";
|
||||
export let button_labels: {start: string, stop: string, waiting: string};
|
||||
|
||||
export let label: string;
|
||||
export let root: string;
|
||||
export let show_label: boolean;
|
||||
export let loading_status: LoadingStatus;
|
||||
export let height: number | undefined;
|
||||
export let width: number | undefined;
|
||||
export let server: {
|
||||
offer: (body: any) => Promise<any>;
|
||||
};
|
||||
|
||||
export let container = false;
|
||||
export let scale: number | null = null;
|
||||
export let min_width: number | undefined = undefined;
|
||||
export let gradio;
|
||||
export let rtc_configuration: Object;
|
||||
export let time_limit: number | null = null;
|
||||
export let modality: "video" | "audio" | "audio-video" = "video";
|
||||
export let mode: "send-receive" | "receive" | "send" = "send-receive";
|
||||
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
||||
export let track_constraints: MediaTrackConstraints = {};
|
||||
export let icon: string | undefined = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
|
||||
const on_change_cb = (msg: "change" | "tick" | any) => {
|
||||
if (msg?.type === "info" || msg?.type === "warning" || msg?.type === "error") {
|
||||
console.log("dispatching info", msg.message);
|
||||
gradio.dispatch(msg?.type === "error"? "error": "warning", msg.message);
|
||||
}
|
||||
gradio.dispatch(msg === "change" ? "state_change" : "tick");
|
||||
}
|
||||
|
||||
let dragging = false;
|
||||
|
||||
$: console.log("value", value);
|
||||
</script>
|
||||
|
||||
<Block
|
||||
{visible}
|
||||
variant={"solid"}
|
||||
border_mode={dragging ? "focus" : "base"}
|
||||
padding={false}
|
||||
{elem_id}
|
||||
{elem_classes}
|
||||
{height}
|
||||
{width}
|
||||
{container}
|
||||
{scale}
|
||||
{min_width}
|
||||
allow_overflow={false}
|
||||
>
|
||||
<StatusTracker
|
||||
autoscroll={gradio.autoscroll}
|
||||
i18n={gradio.i18n}
|
||||
{...loading_status}
|
||||
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
|
||||
/>
|
||||
|
||||
{#if mode == "receive" && modality === "video"}
|
||||
<StaticVideo
|
||||
bind:value={value}
|
||||
{on_change_cb}
|
||||
{label}
|
||||
{show_label}
|
||||
{server}
|
||||
{rtc_configuration}
|
||||
on:tick={() => gradio.dispatch("tick")}
|
||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||
/>
|
||||
{:else if mode == "receive" && modality === "audio"}
|
||||
<StaticAudio
|
||||
bind:value={value}
|
||||
{on_change_cb}
|
||||
{label}
|
||||
{show_label}
|
||||
{server}
|
||||
{rtc_configuration}
|
||||
{icon}
|
||||
{icon_button_color}
|
||||
{pulse_color}
|
||||
i18n={gradio.i18n}
|
||||
on:tick={() => gradio.dispatch("tick")}
|
||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||
|
||||
/>
|
||||
{:else if (mode === "send-receive" || mode == "send") && (modality === "video" || modality == "audio-video")}
|
||||
<Video
|
||||
bind:value={value}
|
||||
{label}
|
||||
{show_label}
|
||||
active_source={"webcam"}
|
||||
include_audio={modality === "audio-video"}
|
||||
{server}
|
||||
{rtc_configuration}
|
||||
{time_limit}
|
||||
{mode}
|
||||
{track_constraints}
|
||||
{rtp_params}
|
||||
{on_change_cb}
|
||||
{icon}
|
||||
{icon_button_color}
|
||||
{pulse_color}
|
||||
{button_labels}
|
||||
on:clear={() => gradio.dispatch("clear")}
|
||||
on:play={() => gradio.dispatch("play")}
|
||||
on:pause={() => gradio.dispatch("pause")}
|
||||
on:upload={() => gradio.dispatch("upload")}
|
||||
on:stop={() => gradio.dispatch("stop")}
|
||||
on:end={() => gradio.dispatch("end")}
|
||||
on:start_recording={() => gradio.dispatch("start_recording")}
|
||||
on:stop_recording={() => gradio.dispatch("stop_recording")}
|
||||
on:tick={() => gradio.dispatch("tick")}
|
||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||
i18n={gradio.i18n}
|
||||
stream_handler={(...args) => gradio.client.stream(...args)}
|
||||
>
|
||||
<UploadText i18n={gradio.i18n} type="video" />
|
||||
</Video>
|
||||
{:else if (mode === "send-receive" || mode === "send") && modality === "audio"}
|
||||
<InteractiveAudio
|
||||
bind:value={value}
|
||||
{on_change_cb}
|
||||
{label}
|
||||
{show_label}
|
||||
{server}
|
||||
{rtc_configuration}
|
||||
{time_limit}
|
||||
{track_constraints}
|
||||
{mode}
|
||||
{rtp_params}
|
||||
i18n={gradio.i18n}
|
||||
{icon}
|
||||
{icon_button_color}
|
||||
{pulse_color}
|
||||
{button_labels}
|
||||
on:tick={() => gradio.dispatch("tick")}
|
||||
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
||||
on:warning={({ detail }) => gradio.dispatch("warning", detail)}
|
||||
/>
|
||||
{/if}
|
||||
</Block>
|
||||
9
frontend/gradio.config.js
Normal file
9
frontend/gradio.config.js
Normal file
@@ -0,0 +1,9 @@
|
||||
export default {
|
||||
plugins: [],
|
||||
svelte: {
|
||||
preprocess: [],
|
||||
},
|
||||
build: {
|
||||
target: "modules",
|
||||
},
|
||||
};
|
||||
5
frontend/index.ts
Normal file
5
frontend/index.ts
Normal file
@@ -0,0 +1,5 @@
|
||||
export { default as BaseInteractiveVideo } from "./shared/InteractiveVideo.svelte";
|
||||
export { prettyBytes, playable, loaded } from "./shared/utils";
|
||||
export { default as BaseExample } from "./Example.svelte";
|
||||
import { default as Index } from "./Index.svelte";
|
||||
export default Index;
|
||||
5900
frontend/package-lock.json
generated
Normal file
5900
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
51
frontend/package.json
Normal file
51
frontend/package.json
Normal file
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"name": "gradio_webrtc",
|
||||
"version": "0.11.0-beta.3",
|
||||
"description": "Gradio UI packages",
|
||||
"type": "module",
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
"private": false,
|
||||
"dependencies": {
|
||||
"@ffmpeg/ffmpeg": "^0.12.10",
|
||||
"@ffmpeg/util": "^0.12.1",
|
||||
"@gradio/atoms": "0.9.2",
|
||||
"@gradio/client": "1.7.0",
|
||||
"@gradio/icons": "0.8.0",
|
||||
"@gradio/image": "0.16.4",
|
||||
"@gradio/markdown": "^0.10.3",
|
||||
"@gradio/statustracker": "0.9.1",
|
||||
"@gradio/upload": "0.13.3",
|
||||
"@gradio/utils": "0.7.0",
|
||||
"@gradio/wasm": "0.14.2",
|
||||
"hls.js": "^1.5.16",
|
||||
"mrmime": "^2.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@gradio/preview": "0.12.0",
|
||||
"prettier": "3.3.3"
|
||||
},
|
||||
"exports": {
|
||||
"./package.json": "./package.json",
|
||||
".": {
|
||||
"gradio": "./index.ts",
|
||||
"svelte": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts"
|
||||
},
|
||||
"./example": {
|
||||
"gradio": "./Example.svelte",
|
||||
"svelte": "./dist/Example.svelte",
|
||||
"types": "./dist/Example.svelte.d.ts"
|
||||
}
|
||||
},
|
||||
"peerDependencies": {
|
||||
"svelte": "^4.0.0"
|
||||
},
|
||||
"main": "index.ts",
|
||||
"main_changeset": true,
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/gradio-app/gradio.git",
|
||||
"directory": "js/video"
|
||||
}
|
||||
}
|
||||
164
frontend/shared/AudioWave.svelte
Normal file
164
frontend/shared/AudioWave.svelte
Normal file
@@ -0,0 +1,164 @@
|
||||
<script lang="ts">
|
||||
import { onDestroy } from 'svelte';
|
||||
import type {ComponentType} from 'svelte';
|
||||
|
||||
import PulsingIcon from './PulsingIcon.svelte';
|
||||
|
||||
export let numBars = 16;
|
||||
export let stream_state: "open" | "closed" | "waiting" = "closed";
|
||||
export let audio_source_callback: () => MediaStream;
|
||||
export let icon: string | undefined | ComponentType = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
|
||||
let audioContext: AudioContext;
|
||||
let analyser: AnalyserNode;
|
||||
let dataArray: Uint8Array;
|
||||
let animationId: number;
|
||||
let pulseScale = 1;
|
||||
|
||||
$: containerWidth = icon
|
||||
? "128px"
|
||||
: `calc((var(--boxSize) + var(--gutter)) * ${numBars})`;
|
||||
|
||||
$: if(stream_state === "open") setupAudioContext();
|
||||
|
||||
onDestroy(() => {
|
||||
if (animationId) {
|
||||
cancelAnimationFrame(animationId);
|
||||
}
|
||||
if (audioContext) {
|
||||
audioContext.close();
|
||||
}
|
||||
});
|
||||
|
||||
function setupAudioContext() {
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
analyser = audioContext.createAnalyser();
|
||||
const source = audioContext.createMediaStreamSource(audio_source_callback());
|
||||
|
||||
source.connect(analyser);
|
||||
|
||||
analyser.fftSize = 64;
|
||||
analyser.smoothingTimeConstant = 0.8;
|
||||
dataArray = new Uint8Array(analyser.frequencyBinCount);
|
||||
|
||||
updateVisualization();
|
||||
}
|
||||
|
||||
function updateVisualization() {
|
||||
analyser.getByteFrequencyData(dataArray);
|
||||
|
||||
// Update bars
|
||||
const bars = document.querySelectorAll('.gradio-webrtc-waveContainer .gradio-webrtc-box');
|
||||
for (let i = 0; i < bars.length; i++) {
|
||||
const barHeight = (dataArray[i] / 255) * 2;
|
||||
bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`;
|
||||
}
|
||||
|
||||
animationId = requestAnimationFrame(updateVisualization);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="gradio-webrtc-waveContainer">
|
||||
{#if icon}
|
||||
<div class="gradio-webrtc-icon-container">
|
||||
<div
|
||||
class="gradio-webrtc-icon"
|
||||
style:transform={`scale(${pulseScale})`}
|
||||
style:background={icon_button_color}
|
||||
>
|
||||
<PulsingIcon
|
||||
{stream_state}
|
||||
{pulse_color}
|
||||
{icon}
|
||||
{icon_button_color}
|
||||
{audio_source_callback}/>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="gradio-webrtc-boxContainer" style:width={containerWidth}>
|
||||
{#each Array(numBars) as _}
|
||||
<div class="gradio-webrtc-box"></div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.gradio-webrtc-waveContainer {
|
||||
position: relative;
|
||||
display: flex;
|
||||
min-height: 100px;
|
||||
max-height: 128px;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.gradio-webrtc-boxContainer {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
height: 64px;
|
||||
--boxSize: 8px;
|
||||
--gutter: 4px;
|
||||
}
|
||||
|
||||
.gradio-webrtc-box {
|
||||
height: 100%;
|
||||
width: var(--boxSize);
|
||||
background: var(--color-accent);
|
||||
border-radius: 8px;
|
||||
transition: transform 0.05s ease;
|
||||
}
|
||||
|
||||
.gradio-webrtc-icon-container {
|
||||
position: relative;
|
||||
width: 128px;
|
||||
height: 128px;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.gradio-webrtc-icon {
|
||||
position: relative;
|
||||
width: 48px;
|
||||
height: 48px;
|
||||
border-radius: 50%;
|
||||
transition: transform 0.1s ease;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
z-index: 2;
|
||||
}
|
||||
|
||||
.icon-image {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
object-fit: contain;
|
||||
filter: brightness(0) invert(1);
|
||||
}
|
||||
|
||||
.pulse-ring {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
width: 48px;
|
||||
height: 48px;
|
||||
border-radius: 50%;
|
||||
animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% {
|
||||
transform: translate(-50%, -50%) scale(1);
|
||||
opacity: 0.5;
|
||||
}
|
||||
100% {
|
||||
transform: translate(-50%, -50%) scale(var(--max-scale, 3));
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
454
frontend/shared/InteractiveAudio.svelte
Normal file
454
frontend/shared/InteractiveAudio.svelte
Normal file
@@ -0,0 +1,454 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
BlockLabel,
|
||||
} from "@gradio/atoms";
|
||||
import type { I18nFormatter } from "@gradio/utils";
|
||||
import { createEventDispatcher } from "svelte";
|
||||
import { onMount } from "svelte";
|
||||
import { fade } from "svelte/transition";
|
||||
import { StreamingBar } from "@gradio/statustracker";
|
||||
import {
|
||||
Circle,
|
||||
Square,
|
||||
Spinner,
|
||||
Music,
|
||||
DropdownArrow,
|
||||
Microphone
|
||||
} from "@gradio/icons";
|
||||
|
||||
import { start, stop } from "./webrtc_utils";
|
||||
import { get_devices, set_available_devices } from "./stream_utils";
|
||||
import AudioWave from "./AudioWave.svelte";
|
||||
import WebcamPermissions from "./WebcamPermissions.svelte";
|
||||
|
||||
export let mode: "send-receive" | "send";
|
||||
export let value: string | null = null;
|
||||
export let label: string | undefined = undefined;
|
||||
export let show_label = true;
|
||||
export let rtc_configuration: Object | null = null;
|
||||
export let i18n: I18nFormatter;
|
||||
export let time_limit: number | null = null;
|
||||
export let track_constraints: MediaTrackConstraints = {};
|
||||
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
||||
export let on_change_cb: (mg: "tick" | "change") => void;
|
||||
export let icon: string | undefined = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
export let button_labels: {start: string, stop: string, waiting: string};
|
||||
|
||||
let stopword_recognized = false;
|
||||
|
||||
let notification_sound;
|
||||
|
||||
onMount(() => {
|
||||
if (value === "__webrtc_value__") {
|
||||
notification_sound = new Audio("https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/pop-sounds.mp3");
|
||||
}
|
||||
});
|
||||
|
||||
let _on_change_cb = (msg: "change" | "tick" | "stopword") => {
|
||||
console.log("msg", msg);
|
||||
if (msg === "stopword") {
|
||||
console.log("stopword recognized");
|
||||
stopword_recognized = true;
|
||||
setTimeout(() => {
|
||||
stopword_recognized = false;
|
||||
}, 3000);
|
||||
} else {
|
||||
console.log("calling on_change_cb with msg", msg);
|
||||
on_change_cb(msg);
|
||||
}
|
||||
};
|
||||
|
||||
let options_open = false;
|
||||
|
||||
let _time_limit: number | null = null;
|
||||
|
||||
export let server: {
|
||||
offer: (body: any) => Promise<any>;
|
||||
};
|
||||
|
||||
let stream_state: "open" | "closed" | "waiting" = "closed";
|
||||
let audio_player: HTMLAudioElement;
|
||||
let pc: RTCPeerConnection;
|
||||
let _webrtc_id = null;
|
||||
let stream: MediaStream;
|
||||
let available_audio_devices: MediaDeviceInfo[];
|
||||
let selected_device: MediaDeviceInfo | null = null;
|
||||
let mic_accessed = false;
|
||||
|
||||
const audio_source_callback = () => {
|
||||
console.log("stream in callback", stream);
|
||||
if(mode==="send") return stream;
|
||||
else return audio_player.srcObject as MediaStream
|
||||
}
|
||||
|
||||
|
||||
const dispatch = createEventDispatcher<{
|
||||
tick: undefined;
|
||||
state_change: undefined;
|
||||
error: string
|
||||
play: undefined;
|
||||
stop: undefined;
|
||||
}>();
|
||||
|
||||
|
||||
async function access_mic(): Promise<void> {
|
||||
|
||||
try {
|
||||
const constraints = selected_device ? { deviceId: { exact: selected_device.deviceId }, ...track_constraints } : track_constraints;
|
||||
const stream_ = await navigator.mediaDevices.getUserMedia({ audio: constraints });
|
||||
stream = stream_;
|
||||
} catch (err) {
|
||||
if (!navigator.mediaDevices) {
|
||||
dispatch("error", i18n("audio.no_device_support"));
|
||||
return;
|
||||
}
|
||||
if (err instanceof DOMException && err.name == "NotAllowedError") {
|
||||
dispatch("error", i18n("audio.allow_recording_access"));
|
||||
return;
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
available_audio_devices = set_available_devices(await get_devices(), "audioinput");
|
||||
mic_accessed = true;
|
||||
const used_devices = stream
|
||||
.getTracks()
|
||||
.map((track) => track.getSettings()?.deviceId)[0];
|
||||
|
||||
selected_device = used_devices
|
||||
? available_audio_devices.find((device) => device.deviceId === used_devices) ||
|
||||
available_audio_devices[0]
|
||||
: available_audio_devices[0];
|
||||
}
|
||||
|
||||
async function start_stream(): Promise<void> {
|
||||
if( stream_state === "open"){
|
||||
stop(pc);
|
||||
stream_state = "closed";
|
||||
_time_limit = null;
|
||||
await access_mic();
|
||||
return;
|
||||
}
|
||||
_webrtc_id = Math.random().toString(36).substring(2);
|
||||
value = _webrtc_id;
|
||||
pc = new RTCPeerConnection(rtc_configuration);
|
||||
pc.addEventListener("connectionstatechange",
|
||||
async (event) => {
|
||||
switch(pc.connectionState) {
|
||||
case "connected":
|
||||
console.info("connected");
|
||||
stream_state = "open";
|
||||
_time_limit = time_limit;
|
||||
break;
|
||||
case "disconnected":
|
||||
console.info("closed");
|
||||
stream_state = "closed";
|
||||
_time_limit = null;
|
||||
stop(pc);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
)
|
||||
stream_state = "waiting"
|
||||
stream = null
|
||||
|
||||
try {
|
||||
await access_mic();
|
||||
} catch (err) {
|
||||
if (!navigator.mediaDevices) {
|
||||
dispatch("error", i18n("audio.no_device_support"));
|
||||
return;
|
||||
}
|
||||
if (err instanceof DOMException && err.name == "NotAllowedError") {
|
||||
dispatch("error", i18n("audio.allow_recording_access"));
|
||||
return;
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
if (stream == null) return;
|
||||
|
||||
start(stream, pc, mode === "send" ? null: audio_player, server.offer, _webrtc_id, "audio", _on_change_cb, rtp_params).then((connection) => {
|
||||
pc = connection;
|
||||
}).catch(() => {
|
||||
console.info("catching")
|
||||
dispatch("error", "Too many concurrent users. Come back later!");
|
||||
});
|
||||
}
|
||||
|
||||
function handle_click_outside(event: MouseEvent): void {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
options_open = false;
|
||||
}
|
||||
|
||||
function click_outside(node: Node, cb: any): any {
|
||||
const handle_click = (event: MouseEvent): void => {
|
||||
if (
|
||||
node &&
|
||||
!node.contains(event.target as Node) &&
|
||||
!event.defaultPrevented
|
||||
) {
|
||||
cb(event);
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("click", handle_click, true);
|
||||
|
||||
return {
|
||||
destroy() {
|
||||
document.removeEventListener("click", handle_click, true);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const handle_device_change = async (event: InputEvent): Promise<void> => {
|
||||
const target = event.target as HTMLInputElement;
|
||||
const device_id = target.value;
|
||||
|
||||
stream = await navigator.mediaDevices.getUserMedia({ audio: {deviceId: { exact: device_id }, ...track_constraints }});
|
||||
selected_device =
|
||||
available_audio_devices.find(
|
||||
(device) => device.deviceId === device_id
|
||||
) || null;
|
||||
options_open = false;
|
||||
};
|
||||
|
||||
$: if(stopword_recognized){
|
||||
notification_sound.play();
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<BlockLabel
|
||||
{show_label}
|
||||
Icon={Music}
|
||||
float={false}
|
||||
label={label || i18n("audio.audio")}
|
||||
/>
|
||||
<div class="audio-container">
|
||||
<audio
|
||||
class="standard-player"
|
||||
class:hidden={value === "__webrtc_value__"}
|
||||
on:load
|
||||
bind:this={audio_player}
|
||||
on:ended={() => dispatch("stop")}
|
||||
on:play={() => dispatch("play")}
|
||||
/>
|
||||
{#if !mic_accessed}
|
||||
<div
|
||||
in:fade={{ delay: 100, duration: 200 }}
|
||||
title="grant webcam access"
|
||||
style="height: 100%"
|
||||
>
|
||||
<WebcamPermissions icon={Microphone} on:click={async () => access_mic()} />
|
||||
</div>
|
||||
{:else}
|
||||
<AudioWave {audio_source_callback} {stream_state} {icon} {icon_button_color} {pulse_color}/>
|
||||
<StreamingBar time_limit={_time_limit} />
|
||||
<div class="button-wrap" class:pulse={stopword_recognized}>
|
||||
<button
|
||||
on:click={start_stream}
|
||||
aria-label={"start stream"}
|
||||
>
|
||||
{#if stream_state === "waiting"}
|
||||
<div class="icon-with-text">
|
||||
<div class="icon color-primary" title="spinner">
|
||||
<Spinner />
|
||||
</div>
|
||||
{button_labels.waiting || i18n("audio.waiting")}
|
||||
</div>
|
||||
{:else if stream_state === "open"}
|
||||
<div class="icon-with-text">
|
||||
<div class="icon color-primary" title="stop recording">
|
||||
<Square />
|
||||
</div>
|
||||
{button_labels.stop || i18n("audio.stop")}
|
||||
</div>
|
||||
{:else}
|
||||
<div class="icon-with-text">
|
||||
<div class="icon color-primary" title="start recording">
|
||||
<Circle />
|
||||
</div>
|
||||
{button_labels.start || i18n("audio.record")}
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
{#if stream_state === "closed"}
|
||||
<button
|
||||
class="icon"
|
||||
on:click={() => (options_open = true)}
|
||||
aria-label="select input source"
|
||||
>
|
||||
<DropdownArrow />
|
||||
</button>
|
||||
{/if}
|
||||
{#if options_open && selected_device}
|
||||
<select
|
||||
class="select-wrap"
|
||||
aria-label="select source"
|
||||
use:click_outside={handle_click_outside}
|
||||
on:change={handle_device_change}
|
||||
>
|
||||
<button
|
||||
class="inset-icon"
|
||||
on:click|stopPropagation={() => (options_open = false)}
|
||||
>
|
||||
<DropdownArrow />
|
||||
</button>
|
||||
{#if available_audio_devices.length === 0}
|
||||
<option value="">{i18n("common.no_devices")}</option>
|
||||
{:else}
|
||||
{#each available_audio_devices as device}
|
||||
<option
|
||||
value={device.deviceId}
|
||||
selected={selected_device.deviceId === device.deviceId}
|
||||
>
|
||||
{device.label}
|
||||
</option>
|
||||
{/each}
|
||||
{/if}
|
||||
</select>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
|
||||
.audio-container {
|
||||
display: flex;
|
||||
height: 100%;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
|
||||
:global(::part(wrapper)) {
|
||||
margin-bottom: var(--size-2);
|
||||
}
|
||||
|
||||
.standard-player {
|
||||
width: 100%;
|
||||
padding: var(--size-2);
|
||||
}
|
||||
|
||||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
|
||||
|
||||
.button-wrap {
|
||||
margin-top: var(--size-2);
|
||||
margin-bottom: var(--size-2);
|
||||
background-color: var(--block-background-fill);
|
||||
border: 1px solid var(--border-color-primary);
|
||||
border-radius: var(--radius-xl);
|
||||
padding: var(--size-1-5);
|
||||
display: flex;
|
||||
bottom: var(--size-2);
|
||||
box-shadow: var(--shadow-drop-lg);
|
||||
border-radius: var(--radius-xl);
|
||||
line-height: var(--size-3);
|
||||
color: var(--button-secondary-text-color);
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% {
|
||||
transform: scale(1);
|
||||
box-shadow: 0 0 0 0 rgba(var(--primary-500-rgb), 0.7);
|
||||
}
|
||||
|
||||
70% {
|
||||
transform: scale(1.25);
|
||||
box-shadow: 0 0 0 10px rgba(var(--primary-500-rgb), 0);
|
||||
}
|
||||
|
||||
100% {
|
||||
transform: scale(1);
|
||||
box-shadow: 0 0 0 0 rgba(var(--primary-500-rgb), 0);
|
||||
}
|
||||
}
|
||||
|
||||
.pulse {
|
||||
animation: pulse 1s infinite;
|
||||
}
|
||||
|
||||
.icon-with-text {
|
||||
min-width: var(--size-16);
|
||||
align-items: center;
|
||||
margin: 0 var(--spacing-xl);
|
||||
display: flex;
|
||||
justify-content: space-evenly;
|
||||
gap: var(--size-2);
|
||||
}
|
||||
|
||||
@media (--screen-md) {
|
||||
button {
|
||||
bottom: var(--size-4);
|
||||
}
|
||||
}
|
||||
|
||||
@media (--screen-xl) {
|
||||
button {
|
||||
bottom: var(--size-8);
|
||||
}
|
||||
}
|
||||
|
||||
.icon {
|
||||
width: 18px;
|
||||
height: 18px;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.color-primary {
|
||||
fill: var(--primary-600);
|
||||
stroke: var(--primary-600);
|
||||
color: var(--primary-600);
|
||||
}
|
||||
|
||||
.select-wrap {
|
||||
-webkit-appearance: none;
|
||||
-moz-appearance: none;
|
||||
appearance: none;
|
||||
color: var(--button-secondary-text-color);
|
||||
background-color: transparent;
|
||||
width: 95%;
|
||||
font-size: var(--text-md);
|
||||
position: absolute;
|
||||
bottom: var(--size-2);
|
||||
background-color: var(--block-background-fill);
|
||||
box-shadow: var(--shadow-drop-lg);
|
||||
border-radius: var(--radius-xl);
|
||||
z-index: var(--layer-top);
|
||||
border: 1px solid var(--border-color-primary);
|
||||
text-align: left;
|
||||
line-height: var(--size-4);
|
||||
white-space: nowrap;
|
||||
text-overflow: ellipsis;
|
||||
left: 50%;
|
||||
transform: translate(-50%, 0);
|
||||
max-width: var(--size-52);
|
||||
}
|
||||
|
||||
.select-wrap > option {
|
||||
padding: 0.25rem 0.5rem;
|
||||
border-bottom: 1px solid var(--border-color-accent);
|
||||
padding-right: var(--size-8);
|
||||
text-overflow: ellipsis;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.select-wrap > option:hover {
|
||||
background-color: var(--color-accent);
|
||||
}
|
||||
|
||||
.select-wrap > option:last-child {
|
||||
border: none;
|
||||
}
|
||||
</style>
|
||||
89
frontend/shared/InteractiveVideo.svelte
Normal file
89
frontend/shared/InteractiveVideo.svelte
Normal file
@@ -0,0 +1,89 @@
|
||||
<script lang="ts">
|
||||
import { createEventDispatcher } from "svelte";
|
||||
import type { ComponentType } from "svelte";
|
||||
import type { FileData, Client } from "@gradio/client";
|
||||
import { BlockLabel } from "@gradio/atoms";
|
||||
import Webcam from "./Webcam.svelte";
|
||||
import { Video } from "@gradio/icons";
|
||||
|
||||
import type { I18nFormatter } from "@gradio/utils";
|
||||
|
||||
export let value: string = null;
|
||||
export let label: string | undefined = undefined;
|
||||
export let show_label = true;
|
||||
export let include_audio: boolean;
|
||||
export let i18n: I18nFormatter;
|
||||
export let active_source: "webcam" | "upload" = "webcam";
|
||||
export let handle_reset_value: () => void = () => {};
|
||||
export let stream_handler: Client["stream"];
|
||||
export let time_limit: number | null = null;
|
||||
export let button_labels: {start: string, stop: string, waiting: string};
|
||||
export let server: {
|
||||
offer: (body: any) => Promise<any>;
|
||||
};
|
||||
export let rtc_configuration: Object;
|
||||
export let track_constraints: MediaTrackConstraints = {};
|
||||
export let mode: "send" | "send-receive";
|
||||
export let on_change_cb: (msg: "change" | "tick") => void;
|
||||
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
||||
export let icon: string | undefined | ComponentType = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
|
||||
const dispatch = createEventDispatcher<{
|
||||
change: FileData | null;
|
||||
clear?: never;
|
||||
play?: never;
|
||||
pause?: never;
|
||||
end?: never;
|
||||
drag: boolean;
|
||||
error: string;
|
||||
upload: FileData;
|
||||
start_recording?: never;
|
||||
stop_recording?: never;
|
||||
tick: never;
|
||||
}>();
|
||||
|
||||
let dragging = false;
|
||||
$: dispatch("drag", dragging);
|
||||
|
||||
$: console.log("value", value)
|
||||
|
||||
</script>
|
||||
|
||||
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />
|
||||
<div data-testid="video" class="video-container">
|
||||
<Webcam
|
||||
{rtc_configuration}
|
||||
{include_audio}
|
||||
{time_limit}
|
||||
{track_constraints}
|
||||
{mode}
|
||||
{rtp_params}
|
||||
{on_change_cb}
|
||||
{icon}
|
||||
{icon_button_color}
|
||||
{pulse_color}
|
||||
{button_labels}
|
||||
on:error
|
||||
on:start_recording
|
||||
on:stop_recording
|
||||
on:tick
|
||||
{i18n}
|
||||
stream_every={0.5}
|
||||
{server}
|
||||
bind:webrtc_id={value}
|
||||
/>
|
||||
|
||||
<!-- <SelectSource {sources} bind:active_source /> -->
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.video-container {
|
||||
display: flex;
|
||||
height: 100%;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
</style>
|
||||
151
frontend/shared/PulsingIcon.svelte
Normal file
151
frontend/shared/PulsingIcon.svelte
Normal file
@@ -0,0 +1,151 @@
|
||||
<script lang="ts">
|
||||
import { onDestroy } from 'svelte';
|
||||
import type {ComponentType} from 'svelte';
|
||||
|
||||
export let stream_state: "open" | "closed" | "waiting" = "closed";
|
||||
export let audio_source_callback: () => MediaStream;
|
||||
export let icon: string | ComponentType = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
|
||||
let audioContext: AudioContext;
|
||||
let analyser: AnalyserNode;
|
||||
let dataArray: Uint8Array;
|
||||
let animationId: number;
|
||||
let pulseScale = 1;
|
||||
let pulseIntensity = 0;
|
||||
|
||||
$: if(stream_state === "open") setupAudioContext();
|
||||
|
||||
onDestroy(() => {
|
||||
if (animationId) {
|
||||
cancelAnimationFrame(animationId);
|
||||
}
|
||||
if (audioContext) {
|
||||
audioContext.close();
|
||||
}
|
||||
});
|
||||
|
||||
function setupAudioContext() {
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
analyser = audioContext.createAnalyser();
|
||||
const source = audioContext.createMediaStreamSource(audio_source_callback());
|
||||
|
||||
source.connect(analyser);
|
||||
|
||||
analyser.fftSize = 64;
|
||||
analyser.smoothingTimeConstant = 0.8;
|
||||
dataArray = new Uint8Array(analyser.frequencyBinCount);
|
||||
|
||||
updateVisualization();
|
||||
}
|
||||
|
||||
function updateVisualization() {
|
||||
|
||||
analyser.getByteFrequencyData(dataArray);
|
||||
|
||||
// Calculate average amplitude for pulse effect
|
||||
const average = Array.from(dataArray).reduce((a, b) => a + b, 0) / dataArray.length;
|
||||
const normalizedAverage = average / 255;
|
||||
pulseScale = 1 + (normalizedAverage * 0.15);
|
||||
pulseIntensity = normalizedAverage;
|
||||
animationId = requestAnimationFrame(updateVisualization);
|
||||
|
||||
}
|
||||
|
||||
$: maxPulseScale = 1 + (pulseIntensity * 10); // Scale from 1x to 3x based on intensity
|
||||
|
||||
|
||||
</script>
|
||||
|
||||
<div class="gradio-webrtc-icon-wrapper">
|
||||
<div class="gradio-webrtc-pulsing-icon-container">
|
||||
{#if pulseIntensity > 0}
|
||||
{#each Array(3) as _, i}
|
||||
<div
|
||||
class="pulse-ring"
|
||||
style:background={pulse_color}
|
||||
style:animation-delay={`${i * 0.4}s`}
|
||||
style:--max-scale={maxPulseScale}
|
||||
style:opacity={0.5 * pulseIntensity}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
|
||||
<div
|
||||
class="gradio-webrtc-pulsing-icon"
|
||||
style:transform={`scale(${pulseScale})`}
|
||||
style:background={icon_button_color}
|
||||
>
|
||||
{#if typeof icon === "string"}
|
||||
<img
|
||||
src={icon}
|
||||
alt="Audio visualization icon"
|
||||
class="icon-image"
|
||||
/>
|
||||
{:else}
|
||||
<svelte:component this={icon} />
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.gradio-webrtc-icon-wrapper {
|
||||
position: relative;
|
||||
display: flex;
|
||||
max-height: 128px;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.gradio-webrtc-pulsing-icon-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.gradio-webrtc-pulsing-icon {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
border-radius: 50%;
|
||||
transition: transform 0.1s ease;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
z-index: 2;
|
||||
}
|
||||
|
||||
.icon-image {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: contain;
|
||||
filter: brightness(0) invert(1);
|
||||
}
|
||||
|
||||
.pulse-ring {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
border-radius: 50%;
|
||||
animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% {
|
||||
transform: translate(-50%, -50%) scale(1);
|
||||
opacity: 0.5;
|
||||
}
|
||||
100% {
|
||||
transform: translate(-50%, -50%) scale(var(--max-scale, 3));
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
||||
135
frontend/shared/StaticAudio.svelte
Normal file
135
frontend/shared/StaticAudio.svelte
Normal file
@@ -0,0 +1,135 @@
|
||||
<script lang="ts">
|
||||
import { Empty } from "@gradio/atoms";
|
||||
import {
|
||||
BlockLabel,
|
||||
} from "@gradio/atoms";
|
||||
import { Music } from "@gradio/icons";
|
||||
import type { I18nFormatter } from "@gradio/utils";
|
||||
import { createEventDispatcher } from "svelte";
|
||||
import { onMount } from "svelte";
|
||||
|
||||
import { start, stop } from "./webrtc_utils";
|
||||
import AudioWave from "./AudioWave.svelte";
|
||||
|
||||
|
||||
export let value: string | null = null;
|
||||
export let label: string | undefined = undefined;
|
||||
export let show_label = true;
|
||||
export let rtc_configuration: Object | null = null;
|
||||
export let i18n: I18nFormatter;
|
||||
export let on_change_cb: (msg: "change" | "tick") => void;
|
||||
export let icon: string | undefined = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
|
||||
export let server: {
|
||||
offer: (body: any) => Promise<any>;
|
||||
};
|
||||
|
||||
let stream_state: "open" | "closed" | "waiting" = "closed";
|
||||
let audio_player: HTMLAudioElement;
|
||||
let pc: RTCPeerConnection;
|
||||
let _webrtc_id = Math.random().toString(36).substring(2);
|
||||
|
||||
|
||||
const dispatch = createEventDispatcher<{
|
||||
tick: undefined;
|
||||
error: string
|
||||
play: undefined;
|
||||
stop: undefined;
|
||||
}>();
|
||||
|
||||
onMount(() => {
|
||||
window.setInterval(() => {
|
||||
if (stream_state == "open") {
|
||||
dispatch("tick");
|
||||
}
|
||||
}, 1000);
|
||||
}
|
||||
)
|
||||
|
||||
async function start_stream(value: string): Promise<string> {
|
||||
if( value === "start_webrtc_stream") {
|
||||
stream_state = "waiting";
|
||||
_webrtc_id = Math.random().toString(36).substring(2)
|
||||
value = _webrtc_id;
|
||||
console.log("set value to ", value);
|
||||
pc = new RTCPeerConnection(rtc_configuration);
|
||||
pc.addEventListener("connectionstatechange",
|
||||
async (event) => {
|
||||
switch(pc.connectionState) {
|
||||
case "connected":
|
||||
console.info("connected");
|
||||
stream_state = "open";
|
||||
break;
|
||||
case "disconnected":
|
||||
console.info("closed");
|
||||
stop(pc);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
)
|
||||
let stream = null;
|
||||
start(stream, pc, audio_player, server.offer, _webrtc_id, "audio", on_change_cb).then((connection) => {
|
||||
pc = connection;
|
||||
}).catch(() => {
|
||||
console.info("catching")
|
||||
dispatch("error", "Too many concurrent users. Come back later!");
|
||||
});
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
$: start_stream(value).then((val) => {
|
||||
value = val;
|
||||
});
|
||||
|
||||
|
||||
|
||||
</script>
|
||||
|
||||
<BlockLabel
|
||||
{show_label}
|
||||
Icon={Music}
|
||||
float={false}
|
||||
label={label || i18n("audio.audio")}
|
||||
/>
|
||||
<audio
|
||||
class="standard-player"
|
||||
class:hidden={true}
|
||||
on:load
|
||||
bind:this={audio_player}
|
||||
on:ended={() => dispatch("stop")}
|
||||
on:play={() => dispatch("play")}
|
||||
/>
|
||||
{#if value !== "__webrtc_value__"}
|
||||
<div class="audio-container">
|
||||
<AudioWave audio_source_callback={() => audio_player.srcObject} {stream_state} {icon} {icon_button_color} {pulse_color}/>
|
||||
</div>
|
||||
{/if}
|
||||
{#if value === "__webrtc_value__"}
|
||||
<Empty size="small">
|
||||
<Music />
|
||||
</Empty>
|
||||
{/if}
|
||||
|
||||
|
||||
<style>
|
||||
.audio-container {
|
||||
display: flex;
|
||||
height: 100%;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.standard-player {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
</style>
|
||||
119
frontend/shared/StaticVideo.svelte
Normal file
119
frontend/shared/StaticVideo.svelte
Normal file
@@ -0,0 +1,119 @@
|
||||
<script lang="ts">
|
||||
import { createEventDispatcher, onMount} from "svelte";
|
||||
import {
|
||||
BlockLabel,
|
||||
Empty
|
||||
} from "@gradio/atoms";
|
||||
import { Video } from "@gradio/icons";
|
||||
|
||||
import { start, stop } from "./webrtc_utils";
|
||||
|
||||
|
||||
export let value: string | null = null;
|
||||
export let label: string | undefined = undefined;
|
||||
export let show_label = true;
|
||||
export let rtc_configuration: Object | null = null;
|
||||
export let on_change_cb: (msg: "change" | "tick") => void;
|
||||
export let server: {
|
||||
offer: (body: any) => Promise<any>;
|
||||
};
|
||||
|
||||
let video_element: HTMLVideoElement;
|
||||
|
||||
let _webrtc_id = Math.random().toString(36).substring(2);
|
||||
|
||||
let pc: RTCPeerConnection;
|
||||
|
||||
const dispatch = createEventDispatcher<{
|
||||
error: string;
|
||||
tick: undefined;
|
||||
}>();
|
||||
|
||||
let stream_state = "closed";
|
||||
|
||||
onMount(() => {
|
||||
window.setInterval(() => {
|
||||
if (stream_state == "open") {
|
||||
dispatch("tick");
|
||||
}
|
||||
}, 1000);
|
||||
}
|
||||
)
|
||||
|
||||
$: if( value === "start_webrtc_stream") {
|
||||
_webrtc_id = Math.random().toString(36).substring(2);
|
||||
value = _webrtc_id;
|
||||
pc = new RTCPeerConnection(rtc_configuration);
|
||||
pc.addEventListener("connectionstatechange",
|
||||
async (event) => {
|
||||
switch(pc.connectionState) {
|
||||
case "connected":
|
||||
console.log("connected");
|
||||
stream_state = "open";
|
||||
break;
|
||||
case "disconnected":
|
||||
console.log("closed");
|
||||
stop(pc);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
)
|
||||
start(null, pc, video_element, server.offer, _webrtc_id, "video", on_change_cb).then((connection) => {
|
||||
pc = connection;
|
||||
}).catch(() => {
|
||||
console.log("catching")
|
||||
dispatch("error", "Too many concurrent users. Come back later!");
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
</script>
|
||||
|
||||
<BlockLabel {show_label} Icon={Video} label={label || "Video"} />
|
||||
|
||||
{#if value === "__webrtc_value__"}
|
||||
<Empty unpadded_box={true} size="large"><Video /></Empty>
|
||||
{/if}
|
||||
<div class="wrap">
|
||||
<video
|
||||
class:hidden={value === "__webrtc_value__"}
|
||||
bind:this={video_element}
|
||||
autoplay={true}
|
||||
on:loadeddata={dispatch.bind(null, "loadeddata")}
|
||||
on:click={dispatch.bind(null, "click")}
|
||||
on:play={dispatch.bind(null, "play")}
|
||||
on:pause={dispatch.bind(null, "pause")}
|
||||
on:ended={dispatch.bind(null, "ended")}
|
||||
on:mouseover={dispatch.bind(null, "mouseover")}
|
||||
on:mouseout={dispatch.bind(null, "mouseout")}
|
||||
on:focus={dispatch.bind(null, "focus")}
|
||||
on:blur={dispatch.bind(null, "blur")}
|
||||
on:load
|
||||
data-testid={$$props["data-testid"]}
|
||||
crossorigin="anonymous"
|
||||
>
|
||||
<track kind="captions" />
|
||||
</video>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
<style>
|
||||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.wrap {
|
||||
position: relative;
|
||||
background-color: var(--background-fill-secondary);
|
||||
height: var(--size-full);
|
||||
width: var(--size-full);
|
||||
border-radius: var(--radius-xl);
|
||||
}
|
||||
.wrap :global(video) {
|
||||
height: var(--size-full);
|
||||
width: var(--size-full);
|
||||
}
|
||||
</style>
|
||||
434
frontend/shared/Webcam.svelte
Normal file
434
frontend/shared/Webcam.svelte
Normal file
@@ -0,0 +1,434 @@
|
||||
<script lang="ts">
|
||||
import { createEventDispatcher, onMount } from "svelte";
|
||||
import type { ComponentType } from "svelte";
|
||||
import {
|
||||
Circle,
|
||||
Square,
|
||||
DropdownArrow,
|
||||
Spinner,
|
||||
Microphone as Mic
|
||||
} from "@gradio/icons";
|
||||
import type { I18nFormatter } from "@gradio/utils";
|
||||
import { StreamingBar } from "@gradio/statustracker";
|
||||
import WebcamPermissions from "./WebcamPermissions.svelte";
|
||||
import { fade } from "svelte/transition";
|
||||
import {
|
||||
get_devices,
|
||||
get_video_stream,
|
||||
set_available_devices
|
||||
} from "./stream_utils";
|
||||
import { start, stop } from "./webrtc_utils";
|
||||
import PulsingIcon from "./PulsingIcon.svelte";
|
||||
|
||||
let video_source: HTMLVideoElement;
|
||||
let available_video_devices: MediaDeviceInfo[] = [];
|
||||
let selected_device: MediaDeviceInfo | null = null;
|
||||
let _time_limit: number | null = null;
|
||||
export let time_limit: number | null = null;
|
||||
let stream_state: "open" | "waiting" | "closed" = "closed";
|
||||
export let on_change_cb: (msg: "tick" | "change") => void;
|
||||
export let mode: "send-receive" | "send";
|
||||
const _webrtc_id = Math.random().toString(36).substring(2);
|
||||
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
||||
export let icon: string | undefined | ComponentType = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
export let button_labels: {start: string, stop: string, waiting: string};
|
||||
|
||||
export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
|
||||
state: "open" | "closed" | "waiting"
|
||||
) => {
|
||||
if (state === "closed") {
|
||||
_time_limit = null;
|
||||
stream_state = "closed";
|
||||
} else if (state === "waiting") {
|
||||
stream_state = "waiting";
|
||||
} else {
|
||||
stream_state = "open";
|
||||
}
|
||||
};
|
||||
|
||||
let canvas: HTMLCanvasElement;
|
||||
export let track_constraints: MediaTrackConstraints | null = null;
|
||||
export let rtc_configuration: Object;
|
||||
export let stream_every = 1;
|
||||
export let server: {
|
||||
offer: (body: any) => Promise<any>;
|
||||
};
|
||||
|
||||
export let include_audio: boolean;
|
||||
export let i18n: I18nFormatter;
|
||||
|
||||
const dispatch = createEventDispatcher<{
|
||||
tick: undefined;
|
||||
error: string;
|
||||
start_recording: undefined;
|
||||
stop_recording: undefined;
|
||||
close_stream: undefined;
|
||||
}>();
|
||||
|
||||
onMount(() => (canvas = document.createElement("canvas")));
|
||||
|
||||
const handle_device_change = async (event: InputEvent): Promise<void> => {
|
||||
const target = event.target as HTMLInputElement;
|
||||
const device_id = target.value;
|
||||
|
||||
await get_video_stream(include_audio, video_source, device_id, track_constraints).then(
|
||||
async (local_stream) => {
|
||||
stream = local_stream;
|
||||
selected_device =
|
||||
available_video_devices.find(
|
||||
(device) => device.deviceId === device_id
|
||||
) || null;
|
||||
options_open = false;
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
async function access_webcam(): Promise<void> {
|
||||
try {
|
||||
get_video_stream(include_audio, video_source, null, track_constraints)
|
||||
.then(async (local_stream) => {
|
||||
webcam_accessed = true;
|
||||
available_video_devices = await get_devices();
|
||||
stream = local_stream;
|
||||
})
|
||||
.then(() => set_available_devices(available_video_devices))
|
||||
.then((devices) => {
|
||||
available_video_devices = devices;
|
||||
|
||||
const used_devices = stream
|
||||
.getTracks()
|
||||
.map((track) => track.getSettings()?.deviceId)[0];
|
||||
|
||||
selected_device = used_devices
|
||||
? devices.find((device) => device.deviceId === used_devices) ||
|
||||
available_video_devices[0]
|
||||
: available_video_devices[0];
|
||||
});
|
||||
|
||||
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
|
||||
dispatch("error", i18n("image.no_webcam_support"));
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof DOMException && err.name == "NotAllowedError") {
|
||||
dispatch("error", i18n("image.allow_webcam_access"));
|
||||
} else {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let recording = false;
|
||||
let stream: MediaStream;
|
||||
|
||||
let webcam_accessed = false;
|
||||
let pc: RTCPeerConnection;
|
||||
export let webrtc_id;
|
||||
|
||||
async function start_webrtc(): Promise<void> {
|
||||
if (stream_state === 'closed') {
|
||||
pc = new RTCPeerConnection(rtc_configuration);
|
||||
pc.addEventListener("connectionstatechange",
|
||||
async (event) => {
|
||||
switch(pc.connectionState) {
|
||||
case "connected":
|
||||
stream_state = "open";
|
||||
_time_limit = time_limit;
|
||||
break;
|
||||
case "disconnected":
|
||||
stream_state = "closed";
|
||||
_time_limit = null;
|
||||
stop(pc);
|
||||
await access_webcam();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
)
|
||||
stream_state = "waiting"
|
||||
webrtc_id = Math.random().toString(36).substring(2);
|
||||
start(stream, pc, mode === "send" ? null: video_source, server.offer, webrtc_id, "video", on_change_cb, rtp_params).then((connection) => {
|
||||
pc = connection;
|
||||
}).catch(() => {
|
||||
console.info("catching")
|
||||
stream_state = "closed";
|
||||
dispatch("error", "Too many concurrent users. Come back later!");
|
||||
});
|
||||
} else {
|
||||
stop(pc);
|
||||
stream_state = "closed";
|
||||
_time_limit = null;
|
||||
await access_webcam();
|
||||
}
|
||||
}
|
||||
|
||||
let options_open = false;
|
||||
|
||||
export function click_outside(node: Node, cb: any): any {
|
||||
const handle_click = (event: MouseEvent): void => {
|
||||
if (
|
||||
node &&
|
||||
!node.contains(event.target as Node) &&
|
||||
!event.defaultPrevented
|
||||
) {
|
||||
cb(event);
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("click", handle_click, true);
|
||||
|
||||
return {
|
||||
destroy() {
|
||||
document.removeEventListener("click", handle_click, true);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function handle_click_outside(event: MouseEvent): void {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
options_open = false;
|
||||
}
|
||||
|
||||
const audio_source_callback = () => video_source.srcObject as MediaStream;
|
||||
</script>
|
||||
|
||||
<div class="wrap">
|
||||
<StreamingBar time_limit={_time_limit} />
|
||||
{#if stream_state === "open" && include_audio}
|
||||
<div class="audio-indicator">
|
||||
<PulsingIcon
|
||||
stream_state={stream_state}
|
||||
audio_source_callback={audio_source_callback}
|
||||
icon={icon || Mic}
|
||||
icon_button_color={icon_button_color}
|
||||
pulse_color={pulse_color}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
<!-- svelte-ignore a11y-media-has-caption -->
|
||||
<!-- need to suppress for video streaming https://github.com/sveltejs/svelte/issues/5967 -->
|
||||
<video
|
||||
bind:this={video_source}
|
||||
class:hide={!webcam_accessed}
|
||||
class:flip={(stream_state != "open") || (stream_state === "open" && include_audio)}
|
||||
autoplay={true}
|
||||
playsinline={true}
|
||||
/>
|
||||
<!-- svelte-ignore a11y-missing-attribute -->
|
||||
{#if !webcam_accessed}
|
||||
<div
|
||||
in:fade={{ delay: 100, duration: 200 }}
|
||||
title="grant webcam access"
|
||||
style="height: 100%"
|
||||
>
|
||||
<WebcamPermissions on:click={async () => access_webcam()} />
|
||||
</div>
|
||||
{:else}
|
||||
<div class="button-wrap">
|
||||
<button
|
||||
on:click={start_webrtc}
|
||||
aria-label={"start stream"}
|
||||
>
|
||||
{#if stream_state === "waiting"}
|
||||
<div class="icon-with-text">
|
||||
<div class="icon color-primary" title="spinner">
|
||||
<Spinner />
|
||||
</div>
|
||||
{button_labels.waiting || i18n("audio.waiting")}
|
||||
</div>
|
||||
{:else if stream_state === "open"}
|
||||
<div class="icon-with-text">
|
||||
<div class="icon color-primary" title="stop recording">
|
||||
<Square />
|
||||
</div>
|
||||
{button_labels.stop || i18n("audio.stop")}
|
||||
</div>
|
||||
{:else}
|
||||
<div class="icon-with-text">
|
||||
<div class="icon color-primary" title="start recording">
|
||||
<Circle />
|
||||
</div>
|
||||
{button_labels.start || i18n("audio.record")}
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
{#if !recording}
|
||||
<button
|
||||
class="icon"
|
||||
on:click={() => (options_open = true)}
|
||||
aria-label="select input source"
|
||||
>
|
||||
<DropdownArrow />
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
{#if options_open && selected_device}
|
||||
<select
|
||||
class="select-wrap"
|
||||
aria-label="select source"
|
||||
use:click_outside={handle_click_outside}
|
||||
on:change={handle_device_change}
|
||||
>
|
||||
<button
|
||||
class="inset-icon"
|
||||
on:click|stopPropagation={() => (options_open = false)}
|
||||
>
|
||||
<DropdownArrow />
|
||||
</button>
|
||||
{#if available_video_devices.length === 0}
|
||||
<option value="">{i18n("common.no_devices")}</option>
|
||||
{:else}
|
||||
{#each available_video_devices as device}
|
||||
<option
|
||||
value={device.deviceId}
|
||||
selected={selected_device.deviceId === device.deviceId}
|
||||
>
|
||||
{device.label}
|
||||
</option>
|
||||
{/each}
|
||||
{/if}
|
||||
</select>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.wrap {
|
||||
position: relative;
|
||||
width: var(--size-full);
|
||||
height: var(--size-full);
|
||||
}
|
||||
|
||||
.hide {
|
||||
display: none;
|
||||
}
|
||||
|
||||
video {
|
||||
width: var(--size-full);
|
||||
height: var(--size-full);
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.button-wrap {
|
||||
position: absolute;
|
||||
background-color: var(--block-background-fill);
|
||||
border: 1px solid var(--border-color-primary);
|
||||
border-radius: var(--radius-xl);
|
||||
padding: var(--size-1-5);
|
||||
display: flex;
|
||||
bottom: var(--size-2);
|
||||
left: 50%;
|
||||
transform: translate(-50%, 0);
|
||||
box-shadow: var(--shadow-drop-lg);
|
||||
border-radius: var(--radius-xl);
|
||||
line-height: var(--size-3);
|
||||
color: var(--button-secondary-text-color);
|
||||
}
|
||||
|
||||
.icon-with-text {
|
||||
min-width: var(--size-16);
|
||||
align-items: center;
|
||||
margin: 0 var(--spacing-xl);
|
||||
display: flex;
|
||||
justify-content: space-evenly;
|
||||
/* Add gap between icon and text */
|
||||
gap: var(--size-2);
|
||||
}
|
||||
|
||||
.audio-indicator {
|
||||
position: absolute;
|
||||
top: var(--size-2);
|
||||
right: var(--size-2);
|
||||
z-index: var(--layer-2);
|
||||
height: var(--size-5);
|
||||
width: var(--size-5);
|
||||
}
|
||||
|
||||
@media (--screen-md) {
|
||||
button {
|
||||
bottom: var(--size-4);
|
||||
}
|
||||
}
|
||||
|
||||
@media (--screen-xl) {
|
||||
button {
|
||||
bottom: var(--size-8);
|
||||
}
|
||||
}
|
||||
|
||||
.icon {
|
||||
width: 18px;
|
||||
height: 18px;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.color-primary {
|
||||
fill: var(--primary-600);
|
||||
stroke: var(--primary-600);
|
||||
color: var(--primary-600);
|
||||
}
|
||||
|
||||
.flip {
|
||||
transform: scaleX(-1);
|
||||
}
|
||||
|
||||
.select-wrap {
|
||||
-webkit-appearance: none;
|
||||
-moz-appearance: none;
|
||||
appearance: none;
|
||||
color: var(--button-secondary-text-color);
|
||||
background-color: transparent;
|
||||
width: 95%;
|
||||
font-size: var(--text-md);
|
||||
position: absolute;
|
||||
bottom: var(--size-2);
|
||||
background-color: var(--block-background-fill);
|
||||
box-shadow: var(--shadow-drop-lg);
|
||||
border-radius: var(--radius-xl);
|
||||
z-index: var(--layer-top);
|
||||
border: 1px solid var(--border-color-primary);
|
||||
text-align: left;
|
||||
line-height: var(--size-4);
|
||||
white-space: nowrap;
|
||||
text-overflow: ellipsis;
|
||||
left: 50%;
|
||||
transform: translate(-50%, 0);
|
||||
max-width: var(--size-52);
|
||||
}
|
||||
|
||||
.select-wrap > option {
|
||||
padding: 0.25rem 0.5rem;
|
||||
border-bottom: 1px solid var(--border-color-accent);
|
||||
padding-right: var(--size-8);
|
||||
text-overflow: ellipsis;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.select-wrap > option:hover {
|
||||
background-color: var(--color-accent);
|
||||
}
|
||||
|
||||
.select-wrap > option:last-child {
|
||||
border: none;
|
||||
}
|
||||
|
||||
.inset-icon {
|
||||
position: absolute;
|
||||
top: 5px;
|
||||
right: -6.5px;
|
||||
width: var(--size-10);
|
||||
height: var(--size-5);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
@media (--screen-md) {
|
||||
.wrap {
|
||||
font-size: var(--text-lg);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
49
frontend/shared/WebcamPermissions.svelte
Normal file
49
frontend/shared/WebcamPermissions.svelte
Normal file
@@ -0,0 +1,49 @@
|
||||
<script lang="ts">
|
||||
import { Webcam } from "@gradio/icons";
|
||||
import { createEventDispatcher } from "svelte";
|
||||
|
||||
export let icon = Webcam;
|
||||
$: text = icon === Webcam ? "Click to Access Webcam" : "Click to Access Microphone";
|
||||
|
||||
const dispatch = createEventDispatcher<{
|
||||
click: undefined;
|
||||
}>();
|
||||
</script>
|
||||
|
||||
<button style:height="100%" on:click={() => dispatch("click")}>
|
||||
<div class="wrap">
|
||||
<span class="icon-wrap">
|
||||
<svelte:component this={icon} />
|
||||
</span>
|
||||
{text}
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<style>
|
||||
button {
|
||||
cursor: pointer;
|
||||
width: var(--size-full);
|
||||
}
|
||||
|
||||
.wrap {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: var(--size-60);
|
||||
color: var(--block-label-text-color);
|
||||
height: 100%;
|
||||
padding-top: var(--size-3);
|
||||
}
|
||||
|
||||
.icon-wrap {
|
||||
width: 30px;
|
||||
margin-bottom: var(--spacing-lg);
|
||||
}
|
||||
|
||||
@media (--screen-md) {
|
||||
.wrap {
|
||||
font-size: var(--text-lg);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
1
frontend/shared/index.ts
Normal file
1
frontend/shared/index.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { default as Video } from "./Video.svelte";
|
||||
53
frontend/shared/stream_utils.ts
Normal file
53
frontend/shared/stream_utils.ts
Normal file
@@ -0,0 +1,53 @@
|
||||
export function get_devices(): Promise<MediaDeviceInfo[]> {
|
||||
return navigator.mediaDevices.enumerateDevices();
|
||||
}
|
||||
|
||||
export function handle_error(error: string): void {
|
||||
throw new Error(error);
|
||||
}
|
||||
|
||||
export function set_local_stream(
|
||||
local_stream: MediaStream | null,
|
||||
video_source: HTMLVideoElement,
|
||||
): void {
|
||||
video_source.srcObject = local_stream;
|
||||
video_source.muted = true;
|
||||
video_source.play();
|
||||
}
|
||||
|
||||
export async function get_video_stream(
|
||||
include_audio: boolean,
|
||||
video_source: HTMLVideoElement,
|
||||
device_id?: string,
|
||||
track_constraints?: MediaTrackConstraints,
|
||||
): Promise<MediaStream> {
|
||||
const fallback_constraints = track_constraints || {
|
||||
width: { ideal: 500 },
|
||||
height: { ideal: 500 },
|
||||
};
|
||||
|
||||
const constraints = {
|
||||
video: device_id
|
||||
? { deviceId: { exact: device_id }, ...fallback_constraints }
|
||||
: fallback_constraints,
|
||||
audio: include_audio,
|
||||
};
|
||||
|
||||
return navigator.mediaDevices
|
||||
.getUserMedia(constraints)
|
||||
.then((local_stream: MediaStream) => {
|
||||
set_local_stream(local_stream, video_source);
|
||||
return local_stream;
|
||||
});
|
||||
}
|
||||
|
||||
export function set_available_devices(
|
||||
devices: MediaDeviceInfo[],
|
||||
kind: "videoinput" | "audioinput" = "videoinput",
|
||||
): MediaDeviceInfo[] {
|
||||
const cameras = devices.filter(
|
||||
(device: MediaDeviceInfo) => device.kind === kind,
|
||||
);
|
||||
|
||||
return cameras;
|
||||
}
|
||||
146
frontend/shared/utils.ts
Normal file
146
frontend/shared/utils.ts
Normal file
@@ -0,0 +1,146 @@
|
||||
import { toBlobURL } from "@ffmpeg/util";
|
||||
import { FFmpeg } from "@ffmpeg/ffmpeg";
|
||||
import { lookup } from "mrmime";
|
||||
|
||||
export const prettyBytes = (bytes: number): string => {
|
||||
let units = ["B", "KB", "MB", "GB", "PB"];
|
||||
let i = 0;
|
||||
while (bytes > 1024) {
|
||||
bytes /= 1024;
|
||||
i++;
|
||||
}
|
||||
let unit = units[i];
|
||||
return bytes.toFixed(1) + " " + unit;
|
||||
};
|
||||
|
||||
export const playable = (): boolean => {
|
||||
// TODO: Fix this
|
||||
// let video_element = document.createElement("video");
|
||||
// let mime_type = mime.lookup(filename);
|
||||
// return video_element.canPlayType(mime_type) != "";
|
||||
return true; // FIX BEFORE COMMIT - mime import causing issues
|
||||
};
|
||||
|
||||
export function loaded(
|
||||
node: HTMLVideoElement,
|
||||
{ autoplay }: { autoplay: boolean },
|
||||
): any {
|
||||
async function handle_playback(): Promise<void> {
|
||||
if (!autoplay) return;
|
||||
await node.play();
|
||||
}
|
||||
|
||||
node.addEventListener("loadeddata", handle_playback);
|
||||
|
||||
return {
|
||||
destroy(): void {
|
||||
node.removeEventListener("loadeddata", handle_playback);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export default async function loadFfmpeg(): Promise<FFmpeg> {
|
||||
const ffmpeg = new FFmpeg();
|
||||
const baseURL = "https://unpkg.com/@ffmpeg/core@0.12.4/dist/esm";
|
||||
|
||||
await ffmpeg.load({
|
||||
coreURL: await toBlobURL(`${baseURL}/ffmpeg-core.js`, "text/javascript"),
|
||||
wasmURL: await toBlobURL(`${baseURL}/ffmpeg-core.wasm`, "application/wasm"),
|
||||
});
|
||||
|
||||
return ffmpeg;
|
||||
}
|
||||
|
||||
export function blob_to_data_url(blob: Blob): Promise<string> {
|
||||
return new Promise((fulfill, reject) => {
|
||||
let reader = new FileReader();
|
||||
reader.onerror = reject;
|
||||
reader.onload = () => fulfill(reader.result as string);
|
||||
reader.readAsDataURL(blob);
|
||||
});
|
||||
}
|
||||
|
||||
export async function trimVideo(
|
||||
ffmpeg: FFmpeg,
|
||||
startTime: number,
|
||||
endTime: number,
|
||||
videoElement: HTMLVideoElement,
|
||||
): Promise<any> {
|
||||
const videoUrl = videoElement.src;
|
||||
const mimeType = lookup(videoElement.src) || "video/mp4";
|
||||
const blobUrl = await toBlobURL(videoUrl, mimeType);
|
||||
const response = await fetch(blobUrl);
|
||||
const vidBlob = await response.blob();
|
||||
const type = getVideoExtensionFromMimeType(mimeType) || "mp4";
|
||||
const inputName = `input.${type}`;
|
||||
const outputName = `output.${type}`;
|
||||
|
||||
try {
|
||||
if (startTime === 0 && endTime === 0) {
|
||||
return vidBlob;
|
||||
}
|
||||
|
||||
await ffmpeg.writeFile(
|
||||
inputName,
|
||||
new Uint8Array(await vidBlob.arrayBuffer()),
|
||||
);
|
||||
|
||||
let command = [
|
||||
"-i",
|
||||
inputName,
|
||||
...(startTime !== 0 ? ["-ss", startTime.toString()] : []),
|
||||
...(endTime !== 0 ? ["-to", endTime.toString()] : []),
|
||||
"-c:a",
|
||||
"copy",
|
||||
outputName,
|
||||
];
|
||||
|
||||
await ffmpeg.exec(command);
|
||||
const outputData = await ffmpeg.readFile(outputName);
|
||||
const outputBlob = new Blob([outputData], {
|
||||
type: `video/${type}`,
|
||||
});
|
||||
|
||||
return outputBlob;
|
||||
} catch (error) {
|
||||
console.error("Error initializing FFmpeg:", error);
|
||||
return vidBlob;
|
||||
}
|
||||
}
|
||||
|
||||
const getVideoExtensionFromMimeType = (mimeType: string): string | null => {
|
||||
const videoMimeToExtensionMap: { [key: string]: string } = {
|
||||
"video/mp4": "mp4",
|
||||
"video/webm": "webm",
|
||||
"video/ogg": "ogv",
|
||||
"video/quicktime": "mov",
|
||||
"video/x-msvideo": "avi",
|
||||
"video/x-matroska": "mkv",
|
||||
"video/mpeg": "mpeg",
|
||||
"video/3gpp": "3gp",
|
||||
"video/3gpp2": "3g2",
|
||||
"video/h261": "h261",
|
||||
"video/h263": "h263",
|
||||
"video/h264": "h264",
|
||||
"video/jpeg": "jpgv",
|
||||
"video/jpm": "jpm",
|
||||
"video/mj2": "mj2",
|
||||
"video/mpv": "mpv",
|
||||
"video/vnd.ms-playready.media.pyv": "pyv",
|
||||
"video/vnd.uvvu.mp4": "uvu",
|
||||
"video/vnd.vivo": "viv",
|
||||
"video/x-f4v": "f4v",
|
||||
"video/x-fli": "fli",
|
||||
"video/x-flv": "flv",
|
||||
"video/x-m4v": "m4v",
|
||||
"video/x-ms-asf": "asf",
|
||||
"video/x-ms-wm": "wm",
|
||||
"video/x-ms-wmv": "wmv",
|
||||
"video/x-ms-wmx": "wmx",
|
||||
"video/x-ms-wvx": "wvx",
|
||||
"video/x-sgi-movie": "movie",
|
||||
"video/x-smv": "smv",
|
||||
};
|
||||
|
||||
return videoMimeToExtensionMap[mimeType] || null;
|
||||
};
|
||||
184
frontend/shared/webrtc_utils.ts
Normal file
184
frontend/shared/webrtc_utils.ts
Normal file
@@ -0,0 +1,184 @@
|
||||
export function createPeerConnection(pc, node) {
|
||||
// register some listeners to help debugging
|
||||
pc.addEventListener(
|
||||
"icegatheringstatechange",
|
||||
() => {
|
||||
console.debug(pc.iceGatheringState);
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
pc.addEventListener(
|
||||
"iceconnectionstatechange",
|
||||
() => {
|
||||
console.debug(pc.iceConnectionState);
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
pc.addEventListener(
|
||||
"signalingstatechange",
|
||||
() => {
|
||||
console.debug(pc.signalingState);
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
// connect audio / video from server to local
|
||||
pc.addEventListener("track", (evt) => {
|
||||
console.debug("track event listener");
|
||||
if (node && node.srcObject !== evt.streams[0]) {
|
||||
console.debug("streams", evt.streams);
|
||||
node.srcObject = evt.streams[0];
|
||||
console.debug("node.srcOject", node.srcObject);
|
||||
if (evt.track.kind === "audio") {
|
||||
node.volume = 1.0; // Ensure volume is up
|
||||
node.muted = false;
|
||||
node.autoplay = true;
|
||||
// Attempt to play (needed for some browsers)
|
||||
node.play().catch((e) => console.debug("Autoplay failed:", e));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return pc;
|
||||
}
|
||||
|
||||
export async function start(
|
||||
stream,
|
||||
pc: RTCPeerConnection,
|
||||
node,
|
||||
server_fn,
|
||||
webrtc_id,
|
||||
modality: "video" | "audio" = "video",
|
||||
on_change_cb: (msg: "change" | "tick") => void = () => {},
|
||||
rtp_params = {},
|
||||
) {
|
||||
pc = createPeerConnection(pc, node);
|
||||
const data_channel = pc.createDataChannel("text");
|
||||
|
||||
data_channel.onopen = () => {
|
||||
console.debug("Data channel is open");
|
||||
data_channel.send("handshake");
|
||||
};
|
||||
|
||||
data_channel.onmessage = (event) => {
|
||||
console.debug("Received message:", event.data);
|
||||
let event_json;
|
||||
try {
|
||||
event_json = JSON.parse(event.data);
|
||||
} catch (e) {
|
||||
console.debug("Error parsing JSON");
|
||||
}
|
||||
console.log("event_json", event_json);
|
||||
if (
|
||||
event.data === "change" ||
|
||||
event.data === "tick" ||
|
||||
event.data === "stopword" ||
|
||||
event_json?.type === "warning" ||
|
||||
event_json?.type === "error"
|
||||
) {
|
||||
console.debug(`${event.data} event received`);
|
||||
on_change_cb(event_json ?? event.data);
|
||||
}
|
||||
};
|
||||
|
||||
if (stream) {
|
||||
stream.getTracks().forEach(async (track) => {
|
||||
console.debug("Track stream callback", track);
|
||||
const sender = pc.addTrack(track, stream);
|
||||
const params = sender.getParameters();
|
||||
const updated_params = { ...params, ...rtp_params };
|
||||
await sender.setParameters(updated_params);
|
||||
console.debug("sender params", sender.getParameters());
|
||||
});
|
||||
} else {
|
||||
console.debug("Creating transceiver!");
|
||||
pc.addTransceiver(modality, { direction: "recvonly" });
|
||||
}
|
||||
|
||||
await negotiate(pc, server_fn, webrtc_id);
|
||||
return pc;
|
||||
}
|
||||
|
||||
function make_offer(server_fn: any, body): Promise<object> {
|
||||
return new Promise((resolve, reject) => {
|
||||
server_fn(body).then((data) => {
|
||||
console.debug("data", data);
|
||||
if (data?.status === "failed") {
|
||||
console.debug("rejecting");
|
||||
reject("error");
|
||||
}
|
||||
resolve(data);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
async function negotiate(
|
||||
pc: RTCPeerConnection,
|
||||
server_fn: any,
|
||||
webrtc_id: string,
|
||||
): Promise<void> {
|
||||
return pc
|
||||
.createOffer()
|
||||
.then((offer) => {
|
||||
return pc.setLocalDescription(offer);
|
||||
})
|
||||
.then(() => {
|
||||
// wait for ICE gathering to complete
|
||||
return new Promise<void>((resolve) => {
|
||||
console.debug("ice gathering state", pc.iceGatheringState);
|
||||
if (pc.iceGatheringState === "complete") {
|
||||
resolve();
|
||||
} else {
|
||||
const checkState = () => {
|
||||
if (pc.iceGatheringState === "complete") {
|
||||
console.debug("ice complete");
|
||||
pc.removeEventListener("icegatheringstatechange", checkState);
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
pc.addEventListener("icegatheringstatechange", checkState);
|
||||
}
|
||||
});
|
||||
})
|
||||
.then(() => {
|
||||
var offer = pc.localDescription;
|
||||
return make_offer(server_fn, {
|
||||
sdp: offer.sdp,
|
||||
type: offer.type,
|
||||
webrtc_id: webrtc_id,
|
||||
});
|
||||
})
|
||||
.then((response) => {
|
||||
return response;
|
||||
})
|
||||
.then((answer) => {
|
||||
return pc.setRemoteDescription(answer);
|
||||
});
|
||||
}
|
||||
|
||||
export function stop(pc: RTCPeerConnection) {
|
||||
console.debug("Stopping peer connection");
|
||||
// close transceivers
|
||||
if (pc.getTransceivers) {
|
||||
pc.getTransceivers().forEach((transceiver) => {
|
||||
if (transceiver.stop) {
|
||||
transceiver.stop();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// close local audio / video
|
||||
if (pc.getSenders()) {
|
||||
pc.getSenders().forEach((sender) => {
|
||||
console.log("sender", sender);
|
||||
if (sender.track && sender.track.stop) sender.track.stop();
|
||||
});
|
||||
}
|
||||
|
||||
// close peer connection
|
||||
setTimeout(() => {
|
||||
pc.close();
|
||||
}, 500);
|
||||
}
|
||||
40
mkdocs.yml
Normal file
40
mkdocs.yml
Normal file
@@ -0,0 +1,40 @@
|
||||
site_name: Gradio WebRTC
|
||||
site_url: https://sitename.example
|
||||
repo_name: gradio-webrtc
|
||||
repo_url: https://github.com/freddyaboulton/gradio-webrtc
|
||||
theme:
|
||||
name: material
|
||||
palette:
|
||||
scheme: slate
|
||||
primary: black
|
||||
accent: yellow
|
||||
features:
|
||||
- content.code.copy
|
||||
- content.code.annotate
|
||||
logo: bolt.svg
|
||||
favicon: bolt.svg
|
||||
nav:
|
||||
- Home: index.md
|
||||
- User Guide: user-guide.md
|
||||
- Cookbook: cookbook.md
|
||||
- Deployment: deployment.md
|
||||
- Advanced Configuration: advanced-configuration.md
|
||||
- Utils: utils.md
|
||||
- Frequently Asked Questions: faq.md
|
||||
markdown_extensions:
|
||||
- pymdownx.highlight:
|
||||
anchor_linenums: true
|
||||
line_spans: __span
|
||||
pygments_lang_class: true
|
||||
- pymdownx.inlinehilite
|
||||
- pymdownx.snippets
|
||||
- pymdownx.superfences
|
||||
- pymdownx.tabbed:
|
||||
alternate_style: true
|
||||
- attr_list
|
||||
- md_in_html
|
||||
- pymdownx.emoji:
|
||||
emoji_index: !!python/name:material.extensions.emoji.twemoji
|
||||
emoji_generator: !!python/name:material.extensions.emoji.to_svg
|
||||
- admonition
|
||||
- pymdownx.details
|
||||
56
pyproject.toml
Normal file
56
pyproject.toml
Normal file
@@ -0,0 +1,56 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"hatchling",
|
||||
"hatch-requirements-txt",
|
||||
"hatch-fancy-pypi-readme>=22.5.0",
|
||||
]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "gradio_webrtc"
|
||||
version = "0.0.29"
|
||||
description = "Stream images in realtime with webrtc"
|
||||
readme = "README.md"
|
||||
license = "apache-2.0"
|
||||
requires-python = ">=3.10"
|
||||
authors = [{ name = "Freddy Boulton", email = "YOUREMAIL@domain.com" }]
|
||||
keywords = ["gradio-custom-component", "gradio-template-Video", "streaming", "webrtc", "realtime"]
|
||||
# Add dependencies here
|
||||
dependencies = ["gradio>=4.0,<6.0", "aiortc"]
|
||||
classifiers = [
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Operating System :: OS Independent',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3 :: Only',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
'Programming Language :: Python :: 3.10',
|
||||
'Programming Language :: Python :: 3.11',
|
||||
'Programming Language :: Python :: 3.12',
|
||||
'Topic :: Scientific/Engineering',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Scientific/Engineering :: Visualization',
|
||||
]
|
||||
|
||||
# The repository and space URLs are optional, but recommended.
|
||||
# Adding a repository URL will create a badge in the auto-generated README that links to the repository.
|
||||
# Adding a space URL will create a badge in the auto-generated README that links to the space.
|
||||
# This will make it easy for people to find your deployed demo or source code when they
|
||||
# encounter your project in the wild.
|
||||
|
||||
# [project.urls]
|
||||
# repository = "your github repository"
|
||||
# space = "your space url"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["build", "twine"]
|
||||
vad = ["librosa", "onnxruntime"]
|
||||
stopword = ["silero", "librosa", "onnxruntime"]
|
||||
|
||||
[tool.hatch.build]
|
||||
artifacts = ["/backend/gradio_webrtc/templates", "*.pyi"]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["/backend/gradio_webrtc"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
Reference in New Issue
Block a user