mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
[feat] update some feature
sync code of fastrtc, add text support through datachannel, fix safari connect problem support chat without camera or mic
This commit is contained in:
181
demo/talk_to_gemini/app.py
Normal file
181
demo/talk_to_gemini/app.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
from typing import AsyncGenerator, Literal
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastrtc import (
|
||||
AsyncStreamHandler,
|
||||
Stream,
|
||||
get_twilio_turn_credentials,
|
||||
wait_for_item,
|
||||
)
|
||||
from google import genai
|
||||
from google.genai.types import (
|
||||
LiveConnectConfig,
|
||||
PrebuiltVoiceConfig,
|
||||
SpeechConfig,
|
||||
VoiceConfig,
|
||||
)
|
||||
from gradio.utils import get_space
|
||||
from pydantic import BaseModel
|
||||
|
||||
current_dir = pathlib.Path(__file__).parent
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def encode_audio(data: np.ndarray) -> str:
|
||||
"""Encode Audio data to send to the server"""
|
||||
return base64.b64encode(data.tobytes()).decode("UTF-8")
|
||||
|
||||
|
||||
class GeminiHandler(AsyncStreamHandler):
|
||||
"""Handler for the Gemini API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expected_layout: Literal["mono"] = "mono",
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 480,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
expected_layout,
|
||||
output_sample_rate,
|
||||
output_frame_size,
|
||||
input_sample_rate=16000,
|
||||
)
|
||||
self.input_queue: asyncio.Queue = asyncio.Queue()
|
||||
self.output_queue: asyncio.Queue = asyncio.Queue()
|
||||
self.quit: asyncio.Event = asyncio.Event()
|
||||
|
||||
def copy(self) -> "GeminiHandler":
|
||||
return GeminiHandler(
|
||||
expected_layout="mono",
|
||||
output_sample_rate=self.output_sample_rate,
|
||||
output_frame_size=self.output_frame_size,
|
||||
)
|
||||
|
||||
async def start_up(self):
|
||||
if not self.phone_mode:
|
||||
await self.wait_for_args()
|
||||
api_key, voice_name = self.latest_args[1:]
|
||||
else:
|
||||
api_key, voice_name = None, "Puck"
|
||||
|
||||
client = genai.Client(
|
||||
api_key=api_key or os.getenv("GEMINI_API_KEY"),
|
||||
http_options={"api_version": "v1alpha"},
|
||||
)
|
||||
|
||||
config = LiveConnectConfig(
|
||||
response_modalities=["AUDIO"], # type: ignore
|
||||
speech_config=SpeechConfig(
|
||||
voice_config=VoiceConfig(
|
||||
prebuilt_voice_config=PrebuiltVoiceConfig(
|
||||
voice_name=voice_name,
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
async with client.aio.live.connect(
|
||||
model="gemini-2.0-flash-exp", config=config
|
||||
) as session:
|
||||
async for audio in session.start_stream(
|
||||
stream=self.stream(), mime_type="audio/pcm"
|
||||
):
|
||||
if audio.data:
|
||||
array = np.frombuffer(audio.data, dtype=np.int16)
|
||||
self.output_queue.put_nowait((self.output_sample_rate, array))
|
||||
|
||||
async def stream(self) -> AsyncGenerator[bytes, None]:
|
||||
while not self.quit.is_set():
|
||||
try:
|
||||
audio = await asyncio.wait_for(self.input_queue.get(), 0.1)
|
||||
yield audio
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
pass
|
||||
|
||||
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
_, array = frame
|
||||
array = array.squeeze()
|
||||
audio_message = encode_audio(array)
|
||||
self.input_queue.put_nowait(audio_message)
|
||||
|
||||
async def emit(self) -> tuple[int, np.ndarray] | None:
|
||||
return await wait_for_item(self.output_queue)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.quit.set()
|
||||
|
||||
|
||||
stream = Stream(
|
||||
modality="audio",
|
||||
mode="send-receive",
|
||||
handler=GeminiHandler(),
|
||||
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
||||
concurrency_limit=5 if get_space() else None,
|
||||
time_limit=90 if get_space() else None,
|
||||
additional_inputs=[
|
||||
gr.Textbox(
|
||||
label="API Key",
|
||||
type="password",
|
||||
value=os.getenv("GEMINI_API_KEY") if not get_space() else "",
|
||||
),
|
||||
gr.Dropdown(
|
||||
label="Voice",
|
||||
choices=[
|
||||
"Puck",
|
||||
"Charon",
|
||||
"Kore",
|
||||
"Fenrir",
|
||||
"Aoede",
|
||||
],
|
||||
value="Puck",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class InputData(BaseModel):
|
||||
webrtc_id: str
|
||||
voice_name: str
|
||||
api_key: str
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
stream.mount(app)
|
||||
|
||||
|
||||
@app.post("/input_hook")
|
||||
async def _(body: InputData):
|
||||
stream.set_input(body.webrtc_id, body.api_key, body.voice_name)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def index():
|
||||
rtc_config = get_twilio_turn_credentials() if get_space() else None
|
||||
html_content = (current_dir / "index.html").read_text()
|
||||
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
if (mode := os.getenv("MODE")) == "UI":
|
||||
stream.ui.launch(server_port=7860)
|
||||
elif mode == "PHONE":
|
||||
stream.fastphone(host="0.0.0.0", port=7860)
|
||||
else:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7860)
|
||||
Reference in New Issue
Block a user