Files
gradio-webrtc/demo/talk_to_openai/app.py
Freddy Boulton 853d6a06b5 Rebrand to FastRTC (#60)
* Add code

* add code

* add code

* Rename messages

* rename

* add code

* Add demo

* docs + demos + bug fixes

* add code

* styles

* user guide

* Styles

* Add code

* misc docs updates

* print nit

* whisper + pr

* url for images

* whsiper update

* Fix bugs

* remove demo files

* version number

* Fix pypi readme

* Fix

* demos

* Add llama code editor

* Update llama code editor and object detection cookbook

* Add more cookbook demos

* add code

* Fix links for PR deploys

* add code

* Fix the install

* add tts

* TTS docs

* Typo

* Pending bubbles for reply on pause

* Stream redesign (#63)

* better error handling

* Websocket error handling

* add code

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>

* remove docs from dist

* Some docs typos

* more typos

* upload changes + docs

* docs

* better phone

* update docs

* add code

* Make demos better

* fix docs + websocket start_up

* remove mention of FastAPI app

* fastphone tweaks

* add code

* ReplyOnStopWord fixes

* Fix cookbook

* Fix pypi readme

* add code

* bump versions

* sambanova cookbook

* Fix tags

* Llm voice chat

* kyutai tag

* Add error message to all index.html

* STT module uses Moonshine

* Not required from typing extensions

* fix llm voice chat

* Add vpn warning

* demo fixes

* demos

* Add more ui args and gemini audio-video

* update cookbook

* version 9

---------

Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
2025-02-24 01:13:42 -05:00

161 lines
4.8 KiB
Python

import asyncio
import base64
import json
from pathlib import Path
import gradio as gr
import numpy as np
import openai
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse
from fastrtc import (
AdditionalOutputs,
AsyncStreamHandler,
Stream,
WebRTCError,
get_twilio_turn_credentials,
)
from gradio.utils import get_space
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
load_dotenv()
cur_dir = Path(__file__).parent
SAMPLE_RATE = 24000
class OpenAIHandler(AsyncStreamHandler):
def __init__(
self,
) -> None:
super().__init__(
expected_layout="mono",
output_sample_rate=SAMPLE_RATE,
output_frame_size=480,
input_sample_rate=SAMPLE_RATE,
)
self.connection = None
self.output_queue = asyncio.Queue()
def copy(self):
return OpenAIHandler()
async def start_up(
self,
):
"""Connect to realtime API. Run forever in separate thread to keep connection open."""
self.client = openai.AsyncOpenAI()
try:
async with self.client.beta.realtime.connect(
model="gpt-4o-mini-realtime-preview-2024-12-17"
) as conn:
await conn.session.update(
session={"turn_detection": {"type": "server_vad"}}
)
self.connection = conn
async for event in self.connection:
if event.type == "response.audio_transcript.done":
await self.output_queue.put(AdditionalOutputs(event))
if event.type == "response.audio.delta":
await self.output_queue.put(
(
self.output_sample_rate,
np.frombuffer(
base64.b64decode(event.delta), dtype=np.int16
).reshape(1, -1),
),
)
except Exception:
import traceback
traceback.print_exc()
raise WebRTCError(str(traceback.format_exc()))
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
if not self.connection:
return
try:
_, array = frame
array = array.squeeze()
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore
except Exception as e:
# print traceback
print(f"Error in receive: {str(e)}")
import traceback
traceback.print_exc()
raise WebRTCError(str(traceback.format_exc()))
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
return await self.output_queue.get()
def reset_state(self):
"""Reset connection state for new recording session"""
self.connection = None
self.args_set.clear()
async def shutdown(self) -> None:
if self.connection:
await self.connection.close()
self.reset_state()
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
chatbot.append({"role": "assistant", "content": response.transcript})
return chatbot
chatbot = gr.Chatbot(type="messages")
latest_message = gr.Textbox(type="text", visible=False)
stream = Stream(
OpenAIHandler(),
mode="send-receive",
modality="audio",
additional_inputs=[chatbot],
additional_outputs=[chatbot],
additional_outputs_handler=update_chatbot,
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=90 if get_space() else None,
)
app = FastAPI()
stream.mount(app)
@app.get("/")
async def _():
rtc_config = get_twilio_turn_credentials() if get_space() else None
html_content = (cur_dir / "index.html").read_text()
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content)
@app.get("/outputs")
def _(webrtc_id: str):
async def output_stream():
import json
async for output in stream.output_stream(webrtc_id):
s = json.dumps({"role": "assistant", "content": output.args[0].transcript})
yield f"event: output\ndata: {s}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
if __name__ == "__main__":
import os
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)