mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
sync code of fastrtc, add text support through datachannel, fix safari connect problem support chat without camera or mic
142 lines
4.2 KiB
Python
142 lines
4.2 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import gradio as gr
|
|
import numpy as np
|
|
import openai
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI
|
|
from fastapi.responses import HTMLResponse, StreamingResponse
|
|
from fastrtc import (
|
|
AdditionalOutputs,
|
|
AsyncStreamHandler,
|
|
Stream,
|
|
get_twilio_turn_credentials,
|
|
wait_for_item,
|
|
)
|
|
from gradio.utils import get_space
|
|
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
|
|
|
|
load_dotenv()
|
|
|
|
cur_dir = Path(__file__).parent
|
|
|
|
SAMPLE_RATE = 24000
|
|
|
|
|
|
class OpenAIHandler(AsyncStreamHandler):
|
|
def __init__(
|
|
self,
|
|
) -> None:
|
|
super().__init__(
|
|
expected_layout="mono",
|
|
output_sample_rate=SAMPLE_RATE,
|
|
output_frame_size=480,
|
|
input_sample_rate=SAMPLE_RATE,
|
|
)
|
|
self.connection = None
|
|
self.output_queue = asyncio.Queue()
|
|
|
|
def copy(self):
|
|
return OpenAIHandler()
|
|
|
|
async def start_up(
|
|
self,
|
|
):
|
|
"""Connect to realtime API. Run forever in separate thread to keep connection open."""
|
|
self.client = openai.AsyncOpenAI()
|
|
async with self.client.beta.realtime.connect(
|
|
model="gpt-4o-mini-realtime-preview-2024-12-17"
|
|
) as conn:
|
|
await conn.session.update(
|
|
session={"turn_detection": {"type": "server_vad"}}
|
|
)
|
|
self.connection = conn
|
|
async for event in self.connection:
|
|
if event.type == "response.audio_transcript.done":
|
|
await self.output_queue.put(AdditionalOutputs(event))
|
|
if event.type == "response.audio.delta":
|
|
await self.output_queue.put(
|
|
(
|
|
self.output_sample_rate,
|
|
np.frombuffer(
|
|
base64.b64decode(event.delta), dtype=np.int16
|
|
).reshape(1, -1),
|
|
),
|
|
)
|
|
|
|
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
|
if not self.connection:
|
|
return
|
|
_, array = frame
|
|
array = array.squeeze()
|
|
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
|
|
await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore
|
|
|
|
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
|
|
return await wait_for_item(self.output_queue)
|
|
|
|
async def shutdown(self) -> None:
|
|
if self.connection:
|
|
await self.connection.close()
|
|
self.connection = None
|
|
|
|
|
|
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
|
|
chatbot.append({"role": "assistant", "content": response.transcript})
|
|
return chatbot
|
|
|
|
|
|
chatbot = gr.Chatbot(type="messages")
|
|
latest_message = gr.Textbox(type="text", visible=False)
|
|
stream = Stream(
|
|
OpenAIHandler(),
|
|
mode="send-receive",
|
|
modality="audio",
|
|
additional_inputs=[chatbot],
|
|
additional_outputs=[chatbot],
|
|
additional_outputs_handler=update_chatbot,
|
|
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
|
concurrency_limit=5 if get_space() else None,
|
|
time_limit=90 if get_space() else None,
|
|
)
|
|
|
|
app = FastAPI()
|
|
|
|
stream.mount(app)
|
|
|
|
|
|
@app.get("/")
|
|
async def _():
|
|
rtc_config = get_twilio_turn_credentials() if get_space() else None
|
|
html_content = (cur_dir / "index.html").read_text()
|
|
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
|
|
return HTMLResponse(content=html_content)
|
|
|
|
|
|
@app.get("/outputs")
|
|
def _(webrtc_id: str):
|
|
async def output_stream():
|
|
import json
|
|
|
|
async for output in stream.output_stream(webrtc_id):
|
|
s = json.dumps({"role": "assistant", "content": output.args[0].transcript})
|
|
yield f"event: output\ndata: {s}\n\n"
|
|
|
|
return StreamingResponse(output_stream(), media_type="text/event-stream")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import os
|
|
|
|
if (mode := os.getenv("MODE")) == "UI":
|
|
stream.ui.launch(server_port=7860)
|
|
elif mode == "PHONE":
|
|
stream.fastphone(host="0.0.0.0", port=7860)
|
|
else:
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860)
|