Files
gradio-webrtc/demo/talk_to_gemini/app.py
Freddy Boulton 853d6a06b5 Rebrand to FastRTC (#60)
* Add code

* add code

* add code

* Rename messages

* rename

* add code

* Add demo

* docs + demos + bug fixes

* add code

* styles

* user guide

* Styles

* Add code

* misc docs updates

* print nit

* whisper + pr

* url for images

* whsiper update

* Fix bugs

* remove demo files

* version number

* Fix pypi readme

* Fix

* demos

* Add llama code editor

* Update llama code editor and object detection cookbook

* Add more cookbook demos

* add code

* Fix links for PR deploys

* add code

* Fix the install

* add tts

* TTS docs

* Typo

* Pending bubbles for reply on pause

* Stream redesign (#63)

* better error handling

* Websocket error handling

* add code

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>

* remove docs from dist

* Some docs typos

* more typos

* upload changes + docs

* docs

* better phone

* update docs

* add code

* Make demos better

* fix docs + websocket start_up

* remove mention of FastAPI app

* fastphone tweaks

* add code

* ReplyOnStopWord fixes

* Fix cookbook

* Fix pypi readme

* add code

* bump versions

* sambanova cookbook

* Fix tags

* Llm voice chat

* kyutai tag

* Add error message to all index.html

* STT module uses Moonshine

* Not required from typing extensions

* fix llm voice chat

* Add vpn warning

* demo fixes

* demos

* Add more ui args and gemini audio-video

* update cookbook

* version 9

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
2025-02-24 01:13:42 -05:00

188 lines
5.2 KiB
Python

import asyncio
import base64
import json
import os
import pathlib
from typing import AsyncGenerator, Literal
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastrtc import (
AsyncStreamHandler,
Stream,
WebRTCError,
get_twilio_turn_credentials,
)
from google import genai
from google.genai.types import (
LiveConnectConfig,
PrebuiltVoiceConfig,
SpeechConfig,
VoiceConfig,
)
from gradio.utils import get_space
from pydantic import BaseModel
current_dir = pathlib.Path(__file__).parent
load_dotenv()
def encode_audio(data: np.ndarray) -> str:
"""Encode Audio data to send to the server"""
return base64.b64encode(data.tobytes()).decode("UTF-8")
class GeminiHandler(AsyncStreamHandler):
"""Handler for the Gemini API"""
def __init__(
self,
expected_layout: Literal["mono"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=16000,
)
self.input_queue: asyncio.Queue = asyncio.Queue()
self.output_queue: asyncio.Queue = asyncio.Queue()
self.quit: asyncio.Event = asyncio.Event()
def copy(self) -> "GeminiHandler":
return GeminiHandler(
expected_layout="mono",
output_sample_rate=self.output_sample_rate,
output_frame_size=self.output_frame_size,
)
async def start_up(self):
if not self.phone_mode:
await self.wait_for_args()
api_key, voice_name = self.latest_args[1:]
else:
api_key, voice_name = None, "Puck"
try:
client = genai.Client(
api_key=api_key or os.getenv("GEMINI_API_KEY"),
http_options={"api_version": "v1alpha"},
)
except Exception as e:
raise WebRTCError(str(e))
config = LiveConnectConfig(
response_modalities=["AUDIO"], # type: ignore
speech_config=SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(
voice_name=voice_name,
)
)
),
)
try:
async with client.aio.live.connect(
model="gemini-2.0-flash-exp", config=config
) as session:
async for audio in session.start_stream(
stream=self.stream(), mime_type="audio/pcm"
):
if audio.data:
array = np.frombuffer(audio.data, dtype=np.int16)
self.output_queue.put_nowait(array)
except Exception as e:
raise WebRTCError(str(e))
async def stream(self) -> AsyncGenerator[bytes, None]:
while not self.quit.is_set():
try:
audio = await asyncio.wait_for(self.input_queue.get(), 0.1)
yield audio
except (asyncio.TimeoutError, TimeoutError):
pass
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
_, array = frame
array = array.squeeze()
audio_message = encode_audio(array)
self.input_queue.put_nowait(audio_message)
async def emit(self) -> tuple[int, np.ndarray]:
array = await self.output_queue.get()
return (self.output_sample_rate, array)
def shutdown(self) -> None:
self.quit.set()
self.args_set.clear()
stream = Stream(
modality="audio",
mode="send-receive",
handler=GeminiHandler(),
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=90 if get_space() else None,
additional_inputs=[
gr.Textbox(
label="API Key",
type="password",
value=os.getenv("GEMINI_API_KEY") if not get_space() else "",
),
gr.Dropdown(
label="Voice",
choices=[
"Puck",
"Charon",
"Kore",
"Fenrir",
"Aoede",
],
value="Puck",
),
],
)
class InputData(BaseModel):
webrtc_id: str
voice_name: str
api_key: str
app = FastAPI()
stream.mount(app)
@app.post("/input_hook")
async def _(body: InputData):
stream.set_input(body.webrtc_id, body.api_key, body.voice_name)
return {"status": "ok"}
@app.get("/")
async def index():
rtc_config = get_twilio_turn_credentials() if get_space() else None
html_content = (current_dir / "index.html").read_text()
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content)
if __name__ == "__main__":
import os
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)