Files
gradio-webrtc/demo/gemini_audio_video/app.py
Freddy Boulton 853d6a06b5 Rebrand to FastRTC (#60)
* Add code

* add code

* add code

* Rename messages

* rename

* add code

* Add demo

* docs + demos + bug fixes

* add code

* styles

* user guide

* Styles

* Add code

* misc docs updates

* print nit

* whisper + pr

* url for images

* whsiper update

* Fix bugs

* remove demo files

* version number

* Fix pypi readme

* Fix

* demos

* Add llama code editor

* Update llama code editor and object detection cookbook

* Add more cookbook demos

* add code

* Fix links for PR deploys

* add code

* Fix the install

* add tts

* TTS docs

* Typo

* Pending bubbles for reply on pause

* Stream redesign (#63)

* better error handling

* Websocket error handling

* add code

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>

* remove docs from dist

* Some docs typos

* more typos

* upload changes + docs

* docs

* better phone

* update docs

* add code

* Make demos better

* fix docs + websocket start_up

* remove mention of FastAPI app

* fastphone tweaks

* add code

* ReplyOnStopWord fixes

* Fix cookbook

* Fix pypi readme

* add code

* bump versions

* sambanova cookbook

* Fix tags

* Llm voice chat

* kyutai tag

* Add error message to all index.html

* STT module uses Moonshine

* Not required from typing extensions

* fix llm voice chat

* Add vpn warning

* demo fixes

* demos

* Add more ui args and gemini audio-video

* update cookbook

* version 9

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
2025-02-24 01:13:42 -05:00

208 lines
6.7 KiB
Python

import asyncio
import base64
import os
import time
from io import BytesIO
import gradio as gr
from gradio.utils import get_space
import numpy as np
from google import genai
from dotenv import load_dotenv
from fastrtc import (
AsyncAudioVideoStreamHandler,
Stream,
get_twilio_turn_credentials,
WebRTC,
)
from PIL import Image
load_dotenv()
def encode_audio(data: np.ndarray) -> dict:
"""Encode Audio data to send to the server"""
return {
"mime_type": "audio/pcm",
"data": base64.b64encode(data.tobytes()).decode("UTF-8"),
}
def encode_image(data: np.ndarray) -> dict:
with BytesIO() as output_bytes:
pil_image = Image.fromarray(data)
pil_image.save(output_bytes, "JPEG")
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return {"mime_type": "image/jpeg", "data": base64_str}
class GeminiHandler(AsyncAudioVideoStreamHandler):
def __init__(
self,
) -> None:
super().__init__(
"mono",
output_sample_rate=24000,
output_frame_size=480,
input_sample_rate=16000,
)
self.audio_queue = asyncio.Queue()
self.video_queue = asyncio.Queue()
self.quit = asyncio.Event()
self.session = None
self.last_frame_time = 0
self.quit = asyncio.Event()
def copy(self) -> "GeminiHandler":
return GeminiHandler()
async def start_up(self):
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
)
config = {"response_modalities": ["AUDIO"]}
try:
async with client.aio.live.connect(
model="gemini-2.0-flash-exp", config=config
) as session:
self.session = session
print("set session")
while not self.quit.is_set():
turn = self.session.receive()
async for response in turn:
if data := response.data:
audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
self.audio_queue.put_nowait(audio)
except Exception as e:
import traceback
traceback.print_exc()
async def video_receive(self, frame: np.ndarray):
try:
print("out")
if self.session:
print("here")
# send image every 1 second
print(time.time() - self.last_frame_time)
if time.time() - self.last_frame_time > 1:
self.last_frame_time = time.time()
print("sending image")
await self.session.send(input=encode_image(frame))
print("sent image")
if self.latest_args[1] is not None:
print("sending image2")
await self.session.send(input=encode_image(self.latest_args[1]))
print("sent image2")
except Exception as e:
print(e)
import traceback
traceback.print_exc()
self.video_queue.put_nowait(frame)
async def video_emit(self):
return await self.video_queue.get()
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
_, array = frame
array = array.squeeze()
audio_message = encode_audio(array)
if self.session:
try:
await self.session.send(input=audio_message)
except Exception as e:
print(e)
import traceback
traceback.print_exc()
async def emit(self):
array = await self.audio_queue.get()
return (self.output_sample_rate, array)
async def shutdown(self) -> None:
if self.session:
self.quit.set()
await self.session._websocket.close()
self.quit.clear()
stream = Stream(
handler=GeminiHandler(),
modality="audio-video",
mode="send-receive",
rtc_configuration=get_twilio_turn_credentials()
if get_space() == "spaces"
else None,
time_limit=90 if get_space() else None,
additional_inputs=[
gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
],
ui_args={
"icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
"pulse_color": "rgb(35, 157, 225)",
"icon_button_color": "rgb(35, 157, 225)",
"title": "Gemini Audio Video Chat",
},
)
css = """
#video-source {max-width: 600px !important; max-height: 600 !important;}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style='display: flex; align-items: center; justify-content: center; gap: 20px'>
<div style="background-color: var(--block-background-fill); border-radius: 8px">
<img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;">
</div>
<div>
<h1>Gen AI SDK Voice Chat</h1>
<p>Speak with Gemini using real-time audio + video streaming</p>
<p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
<p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
</div>
</div>
"""
)
with gr.Row() as row:
with gr.Column():
webrtc = WebRTC(
label="Video Chat",
modality="audio-video",
mode="send-receive",
elem_id="video-source",
rtc_configuration=get_twilio_turn_credentials()
if get_space() == "spaces"
else None,
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
pulse_color="rgb(35, 157, 225)",
icon_button_color="rgb(35, 157, 225)",
)
with gr.Column():
image_input = gr.Image(
label="Image", type="numpy", sources=["upload", "clipboard"]
)
webrtc.stream(
GeminiHandler(),
inputs=[webrtc, image_input],
outputs=[webrtc],
time_limit=60 if get_space() else None,
concurrency_limit=2 if get_space() else None,
)
stream.ui = demo
if __name__ == "__main__":
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
raise ValueError("Phone mode not supported for this demo")
else:
stream.ui.launch(server_port=7860)