Add ability to trigger ReplyOnPause without waiting for pause (#250)

* Add code

* Send text or audio demo
This commit is contained in:
Freddy Boulton
2025-04-03 20:19:50 -04:00
committed by GitHub
parent aed34825e3
commit 8dd17d3216
5 changed files with 726 additions and 2 deletions

View File

@@ -206,6 +206,11 @@ class ReplyOnPause(StreamHandler):
self.event.clear()
self.state = AppState()
def trigger_response(self):
self.event.set()
if self.state.stream is None:
self.state.stream = np.array([], dtype=np.int16)
async def async_iterate(self, generator) -> EmitType:
return await anext(generator)

View File

@@ -91,6 +91,7 @@ class Stream(WebRTCConnectionMixin):
self.additional_input_components = additional_inputs
self.additional_outputs_handler = additional_outputs_handler
self.track_constraints = track_constraints
self.webrtc_component: WebRTC
self.rtc_configuration = rtc_configuration
self._ui = self._generate_default_ui(ui_args)
self._ui.launch = self._wrap_gradio_launch(self._ui.launch)
@@ -234,6 +235,7 @@ class Stream(WebRTCConnectionMixin):
mode="receive",
modality="video",
)
self.webrtc_component = output_video
for component in additional_output_components:
if component not in same_components:
component.render()
@@ -284,6 +286,7 @@ class Stream(WebRTCConnectionMixin):
mode="send",
modality="video",
)
self.webrtc_component = output_video
for component in additional_output_components:
if component not in same_components:
component.render()
@@ -339,7 +342,7 @@ class Stream(WebRTCConnectionMixin):
for component in additional_output_components:
if component not in same_components:
component.render()
self.webrtc_component = image
image.stream(
fn=self.event_handler,
inputs=[image] + additional_input_components,
@@ -391,6 +394,7 @@ class Stream(WebRTCConnectionMixin):
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = output_video
for component in additional_output_components:
if component not in same_components:
component.render()
@@ -442,6 +446,7 @@ class Stream(WebRTCConnectionMixin):
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = image
for component in additional_input_components:
if component not in same_components:
component.render()
@@ -496,6 +501,7 @@ class Stream(WebRTCConnectionMixin):
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = image
for component in additional_input_components:
if component not in same_components:
component.render()
@@ -553,6 +559,7 @@ class Stream(WebRTCConnectionMixin):
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = image
for component in additional_input_components:
if component not in same_components:
component.render()

View File

@@ -73,7 +73,7 @@ class WebRTCConnectionMixin:
self.connections = defaultdict(list)
self.data_channels = {}
self.additional_outputs = defaultdict(OutputQueue)
self.handlers = {}
self.handlers: dict[str, HandlerType] = {}
self.connection_timeouts = defaultdict(asyncio.Event)
# These attributes should be set by subclasses:
self.concurrency_limit: int | None

View File

@@ -0,0 +1,173 @@
import base64
import json
import os
from pathlib import Path
from typing import cast
import gradio as gr
import huggingface_hub
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
get_stt_model,
get_twilio_turn_credentials,
)
from gradio.utils import get_space
from pydantic import BaseModel
load_dotenv()
curr_dir = Path(__file__).parent
client = huggingface_hub.InferenceClient(
api_key=os.environ.get("SAMBANOVA_API_KEY"),
provider="sambanova",
)
stt_model = get_stt_model()
def response(
audio: tuple[int, np.ndarray],
gradio_chatbot: list[dict] | None = None,
conversation_state: list[dict] | None = None,
textbox: str | None = None,
):
gradio_chatbot = gradio_chatbot or []
conversation_state = conversation_state or []
print("chatbot", gradio_chatbot)
if textbox:
text = textbox
else:
text = stt_model.stt(audio)
sample_rate, array = audio
gradio_chatbot.append({"role": "user", "content": text})
yield AdditionalOutputs(gradio_chatbot, conversation_state)
conversation_state.append({"role": "user", "content": text})
request = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
messages=conversation_state, # type: ignore
temperature=0.1,
top_p=0.1,
)
response = {"role": "assistant", "content": request.choices[0].message.content}
conversation_state.append(response)
gradio_chatbot.append(response)
yield AdditionalOutputs(gradio_chatbot, conversation_state)
chatbot = gr.Chatbot(type="messages", value=[])
state = gr.State(value=[])
textbox = gr.Textbox(value="", interactive=True)
stream = Stream(
ReplyOnPause(
response, # type: ignore
input_sample_rate=16000,
),
mode="send",
modality="audio",
additional_inputs=[
chatbot,
state,
textbox,
],
additional_outputs=[chatbot, state],
additional_outputs_handler=lambda *a: (a[2], a[3]),
concurrency_limit=20 if get_space() else 5,
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
)
def trigger_response(webrtc_id: str):
cast(ReplyOnPause, stream.webrtc_component.handlers[webrtc_id]).trigger_response()
return ""
with stream.ui as demo:
button = gr.Button("Send")
button.click(
trigger_response,
inputs=[stream.webrtc_component],
outputs=[textbox],
)
stream.ui = demo
app = FastAPI()
stream.mount(app)
class Message(BaseModel):
role: str
content: str
class InputData(BaseModel):
webrtc_id: str
chatbot: list[Message]
state: list[Message]
textbox: str
@app.get("/")
async def _():
rtc_config = get_twilio_turn_credentials() if get_space() else None
html_content = (curr_dir / "index.html").read_text()
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content)
@app.post("/input_hook")
async def _(data: InputData):
body = data.model_dump()
stream.set_input(data.webrtc_id, body["chatbot"], body["state"], body["textbox"])
cast(ReplyOnPause, stream.handlers[data.webrtc_id]).trigger_response()
def audio_to_base64(file_path):
audio_format = "wav"
with open(file_path, "rb") as audio_file:
encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8")
return f"data:audio/{audio_format};base64,{encoded_audio}"
@app.get("/outputs")
async def _(webrtc_id: str):
async def output_stream():
async for output in stream.output_stream(webrtc_id):
chatbot = output.args[0]
state = output.args[1]
user_message = chatbot[-1]["content"]
data = {
"message": state[-1],
"audio": (
audio_to_base64(user_message["path"])
if isinstance(user_message, dict) and "path" in user_message
else None
),
}
yield f"event: output\ndata: {json.dumps(data)}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
if __name__ == "__main__":
import os
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
raise ValueError("Phone mode not supported")
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)

View File

@@ -0,0 +1,539 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Talk to Sambanova</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
background-color: #f8f9fa;
color: #1a1a1a;
margin: 0;
padding: 20px;
height: 100vh;
box-sizing: border-box;
}
.container {
max-width: 800px;
margin: 0 auto;
height: 80%;
}
.logo {
text-align: center;
margin-bottom: 40px;
}
.chat-container {
background: white;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
padding: 20px;
height: 90%;
box-sizing: border-box;
display: flex;
flex-direction: column;
}
.chat-messages {
flex-grow: 1;
overflow-y: auto;
margin-bottom: 20px;
padding: 10px;
}
.message {
margin-bottom: 20px;
padding: 12px;
border-radius: 8px;
font-size: 14px;
line-height: 1.5;
}
.message.user {
background-color: #e9ecef;
margin-left: 20%;
}
.message.assistant {
background-color: #f1f3f5;
margin-right: 20%;
}
.controls {
text-align: center;
margin-top: 20px;
}
button {
background-color: #0066cc;
color: white;
border: none;
padding: 12px 24px;
font-family: inherit;
font-size: 14px;
cursor: pointer;
transition: all 0.3s;
border-radius: 4px;
font-weight: 500;
}
button:hover {
background-color: #0052a3;
}
#audio-output {
display: none;
}
.icon-with-spinner {
display: flex;
align-items: center;
justify-content: center;
gap: 12px;
min-width: 180px;
}
.spinner {
width: 20px;
height: 20px;
border: 2px solid #ffffff;
border-top-color: transparent;
border-radius: 50%;
animation: spin 1s linear infinite;
flex-shrink: 0;
}
@keyframes spin {
to {
transform: rotate(360deg);
}
}
.pulse-container {
display: flex;
align-items: center;
justify-content: center;
gap: 12px;
min-width: 180px;
}
.pulse-circle {
width: 20px;
height: 20px;
border-radius: 50%;
background-color: #ffffff;
opacity: 0.2;
flex-shrink: 0;
transform: translateX(-0%) scale(var(--audio-level, 1));
transition: transform 0.1s ease;
}
/* Add styles for typing indicator */
.typing-indicator {
padding: 8px;
background-color: #f1f3f5;
border-radius: 8px;
margin-bottom: 10px;
display: none;
}
.dots {
display: inline-flex;
gap: 4px;
}
.dot {
width: 8px;
height: 8px;
background-color: #0066cc;
border-radius: 50%;
animation: pulse 1.5s infinite;
opacity: 0.5;
}
.dot:nth-child(2) {
animation-delay: 0.5s;
}
.dot:nth-child(3) {
animation-delay: 1s;
}
@keyframes pulse {
0%,
100% {
opacity: 0.5;
transform: scale(1);
}
50% {
opacity: 1;
transform: scale(1.2);
}
}
/* Add styles for toast notifications */
.toast {
position: fixed;
top: 20px;
left: 50%;
transform: translateX(-50%);
padding: 16px 24px;
border-radius: 4px;
font-size: 14px;
z-index: 1000;
display: none;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
}
.toast.error {
background-color: #f44336;
color: white;
}
.toast.warning {
background-color: #ffd700;
color: black;
}
/* Add styles for text input */
.text-input-container {
display: flex;
margin-top: 10px;
gap: 10px;
}
#text-input {
flex-grow: 1;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
font-family: inherit;
font-size: 14px;
}
.text-input-container button {
padding: 10px 15px;
}
</style>
</head>
<body>
<!-- Add toast element after body opening tag -->
<div id="error-toast" class="toast"></div>
<div class="container">
<div class="logo">
<h1>Talk to Sambanova 🗣️</h1>
<h2 style="font-size: 1.2em; color: #666; margin-top: 10px;">Speak to Llama 3.2 powered by Sambanova API
</h2>
</div>
<div class="chat-container">
<div class="chat-messages" id="chat-messages"></div>
<div class="typing-indicator" id="typing-indicator">
<div class="dots">
<div class="dot"></div>
<div class="dot"></div>
<div class="dot"></div>
</div>
</div>
<!-- Added text input form -->
<form id="text-input-form" class="text-input-container">
<input type="text" id="text-input" placeholder="Type your message..." />
<button type="submit">Send</button>
</form>
</div>
<div class="controls">
<button id="start-button">Start Conversation</button>
</div>
</div>
<audio id="audio-output"></audio>
<script>
let peerConnection;
let webrtc_id;
const startButton = document.getElementById('start-button');
const chatMessages = document.getElementById('chat-messages');
let audioLevel = 0;
let animationFrame;
let audioContext, analyser, audioSource;
let messages = [];
let eventSource;
function updateButtonState() {
const button = document.getElementById('start-button');
if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) {
button.innerHTML = `
<div class="icon-with-spinner">
<div class="spinner"></div>
<span>Connecting...</span>
</div>
`;
} else if (peerConnection && peerConnection.connectionState === 'connected') {
button.innerHTML = `
<div class="pulse-container">
<div class="pulse-circle"></div>
<span>Stop Conversation</span>
</div>
`;
} else {
button.innerHTML = 'Start Conversation';
}
}
function setupAudioVisualization(stream) {
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
audioSource = audioContext.createMediaStreamSource(stream);
audioSource.connect(analyser);
analyser.fftSize = 64;
const dataArray = new Uint8Array(analyser.frequencyBinCount);
function updateAudioLevel() {
analyser.getByteFrequencyData(dataArray);
const average = Array.from(dataArray).reduce((a, b) => a + b, 0) / dataArray.length;
audioLevel = average / 255;
const pulseCircle = document.querySelector('.pulse-circle');
if (pulseCircle) {
pulseCircle.style.setProperty('--audio-level', 1 + audioLevel);
}
animationFrame = requestAnimationFrame(updateAudioLevel);
}
updateAudioLevel();
}
function showError(message) {
const toast = document.getElementById('error-toast');
toast.textContent = message;
toast.className = 'toast error';
toast.style.display = 'block';
// Hide toast after 5 seconds
setTimeout(() => {
toast.style.display = 'none';
}, 5000);
}
function handleMessage(event) {
const eventJson = JSON.parse(event.data);
const typingIndicator = document.getElementById('typing-indicator');
const textInput = document.getElementById('text-input');
if (eventJson.type === "error") {
showError(eventJson.message);
} else if (eventJson.type === "send_input") {
fetch('/input_hook', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
webrtc_id: webrtc_id,
chatbot: messages,
state: messages,
textbox: textInput.value
})
});
} else if (eventJson.type === "log") {
if (eventJson.data === "pause_detected") {
typingIndicator.style.display = 'block';
chatMessages.scrollTop = chatMessages.scrollHeight;
} else if (eventJson.data === "response_starting") {
typingIndicator.style.display = 'none';
}
}
}
async function setupWebRTC() {
const config = __RTC_CONFIGURATION__;
peerConnection = new RTCPeerConnection(config);
const timeoutId = setTimeout(() => {
const toast = document.getElementById('error-toast');
toast.textContent = "Connection is taking longer than usual. Are you on a VPN?";
toast.className = 'toast warning';
toast.style.display = 'block';
// Hide warning after 5 seconds
setTimeout(() => {
toast.style.display = 'none';
}, 5000);
}, 5000);
try {
const stream = await navigator.mediaDevices.getUserMedia({
audio: true
});
setupAudioVisualization(stream);
stream.getTracks().forEach(track => {
peerConnection.addTrack(track, stream);
});
const dataChannel = peerConnection.createDataChannel('text');
dataChannel.onmessage = handleMessage;
const offer = await peerConnection.createOffer();
await peerConnection.setLocalDescription(offer);
peerConnection.onicecandidate = ({ candidate }) => {
if (candidate) {
console.debug("Sending ICE candidate", candidate);
fetch('/webrtc/offer', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
candidate: candidate.toJSON(),
webrtc_id: webrtc_id,
type: "ice-candidate",
})
})
}
};
peerConnection.addEventListener('connectionstatechange', () => {
console.log('connectionstatechange', peerConnection.connectionState);
if (peerConnection.connectionState === 'connected') {
clearTimeout(timeoutId);
const toast = document.getElementById('error-toast');
toast.style.display = 'none';
}
updateButtonState();
});
webrtc_id = Math.random().toString(36).substring(7);
const response = await fetch('/webrtc/offer', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
sdp: peerConnection.localDescription.sdp,
type: peerConnection.localDescription.type,
webrtc_id: webrtc_id
})
});
const serverResponse = await response.json();
if (serverResponse.status === 'failed') {
showError(serverResponse.meta.error === 'concurrency_limit_reached'
? `Too many connections. Maximum limit is ${serverResponse.meta.limit}`
: serverResponse.meta.error);
stop();
return;
}
await peerConnection.setRemoteDescription(serverResponse);
eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id);
eventSource.addEventListener("output", (event) => {
const eventJson = JSON.parse(event.data);
console.log(eventJson);
messages.push(eventJson.message);
addMessage(eventJson.message.role, eventJson.audio ?? eventJson.message.content);
});
} catch (err) {
clearTimeout(timeoutId);
console.error('Error setting up WebRTC:', err);
showError('Failed to establish connection. Please try again.');
stop();
}
}
function addMessage(role, content) {
const messageDiv = document.createElement('div');
messageDiv.classList.add('message', role);
if (role === 'user' && content.startsWith("data:audio/wav;base64,")) {
// Create audio element for user messages
const audio = document.createElement('audio');
audio.controls = true;
audio.src = content;
messageDiv.appendChild(audio);
} else {
// Text content for assistant messages
messageDiv.textContent = content;
}
chatMessages.appendChild(messageDiv);
chatMessages.scrollTop = chatMessages.scrollHeight;
}
function stop() {
if (eventSource) {
eventSource.close();
eventSource = null;
}
if (animationFrame) {
cancelAnimationFrame(animationFrame);
}
if (audioContext) {
audioContext.close();
audioContext = null;
analyser = null;
audioSource = null;
}
if (peerConnection) {
if (peerConnection.getTransceivers) {
peerConnection.getTransceivers().forEach(transceiver => {
if (transceiver.stop) {
transceiver.stop();
}
});
}
if (peerConnection.getSenders) {
peerConnection.getSenders().forEach(sender => {
if (sender.track && sender.track.stop) sender.track.stop();
});
}
peerConnection.close();
}
updateButtonState();
audioLevel = 0;
}
startButton.addEventListener('click', () => {
if (!peerConnection || peerConnection.connectionState !== 'connected') {
setupWebRTC();
} else {
stop();
}
});
// Add event listener for text input form
document.getElementById('text-input-form').addEventListener('submit', function (e) {
e.preventDefault();
const textInput = document.getElementById('text-input');
if (textInput.value.trim() !== '') {
fetch('/input_hook', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
webrtc_id: webrtc_id,
chatbot: messages,
state: messages,
textbox: textInput.value
})
});
// Clear the input after submission
textInput.value = '';
}
});
</script>
</body>
</html>