From d710c06210a9a25142c386f7b6c4cb99699b353c Mon Sep 17 00:00:00 2001 From: Freddy Boulton <41651716+freddyaboulton@users.noreply.github.com> Date: Tue, 15 Apr 2025 09:42:22 -0400 Subject: [PATCH] Fix openai demo (#279) --- demo/talk_to_openai/app.py | 23 ++++++++++++++++++++--- demo/talk_to_openai/index.html | 18 ++++++++++++------ 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/demo/talk_to_openai/app.py b/demo/talk_to_openai/app.py index d3c84b4..bc0a37d 100644 --- a/demo/talk_to_openai/app.py +++ b/demo/talk_to_openai/app.py @@ -50,15 +50,32 @@ class OpenAIHandler(AsyncStreamHandler): model="gpt-4o-mini-realtime-preview-2024-12-17" ) as conn: await conn.session.update( - session={"turn_detection": {"type": "server_vad"}} + session={ + "turn_detection": {"type": "server_vad"}, + "input_audio_transcription": { + "model": "whisper-1", + "language": "en", + }, + } ) self.connection = conn async for event in self.connection: # Handle interruptions if event.type == "input_audio_buffer.speech_started": self.clear_queue() + if ( + event.type + == "conversation.item.input_audio_transcription.completed" + ): + await self.output_queue.put( + AdditionalOutputs({"role": "user", "content": event.transcript}) + ) if event.type == "response.audio_transcript.done": - await self.output_queue.put(AdditionalOutputs(event)) + await self.output_queue.put( + AdditionalOutputs( + {"role": "assistant", "content": event.transcript} + ) + ) if event.type == "response.audio.delta": await self.output_queue.put( ( @@ -124,7 +141,7 @@ def _(webrtc_id: str): import json async for output in stream.output_stream(webrtc_id): - s = json.dumps({"role": "assistant", "content": output.args[0].transcript}) + s = json.dumps(output.args[0]) yield f"event: output\ndata: {s}\n\n" return StreamingResponse(output_stream(), media_type="text/event-stream") diff --git a/demo/talk_to_openai/index.html b/demo/talk_to_openai/index.html index 6e08820..bad8825 100644 --- a/demo/talk_to_openai/index.html +++ b/demo/talk_to_openai/index.html @@ -45,20 +45,26 @@ .message { margin-bottom: 20px; - padding: 12px; - border-radius: 4px; + padding: 12px 16px; + border-radius: 8px; font-size: 16px; line-height: 1.5; + max-width: 70%; + clear: both; } .message.user { - background-color: #1a1a1a; - margin-left: 20%; + background-color: #2c2c2c; + float: right; + border-bottom-right-radius: 2px; + border: 1px solid #404040; } .message.assistant { background-color: #262626; - margin-right: 20%; + float: left; + border-bottom-left-radius: 2px; + border: 1px solid #333; } .controls { @@ -435,7 +441,7 @@ const eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id); eventSource.addEventListener("output", (event) => { const eventJson = JSON.parse(event.data); - addMessage("assistant", eventJson.content); + addMessage(eventJson.role, eventJson.content); }); } catch (err) {