mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
Raise errors automatically (#69)
* Add auto errors * change code --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
@@ -25,6 +25,7 @@ from .utils import (
|
||||
audio_to_bytes,
|
||||
audio_to_file,
|
||||
audio_to_float32,
|
||||
wait_for_item,
|
||||
)
|
||||
from .webrtc import (
|
||||
WebRTC,
|
||||
@@ -58,4 +59,5 @@ __all__ = [
|
||||
"Warning",
|
||||
"get_tts_model",
|
||||
"KokoroTTSOptions",
|
||||
"wait_for_item",
|
||||
]
|
||||
|
||||
@@ -233,3 +233,4 @@ class ReplyOnPause(StreamHandler):
|
||||
traceback.print_exc()
|
||||
logger.debug("Error in ReplyOnPause: %s", e)
|
||||
self.reset()
|
||||
raise e
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""gr.WebRTC() component."""
|
||||
"""WebRTC tracks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -277,9 +277,7 @@ class StreamHandler(StreamHandlerBase):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def emit(
|
||||
self,
|
||||
) -> EmitType:
|
||||
def emit(self) -> EmitType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -296,9 +294,7 @@ class AsyncStreamHandler(StreamHandlerBase):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def emit(
|
||||
self,
|
||||
) -> EmitType:
|
||||
async def emit(self) -> EmitType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -312,15 +308,13 @@ class AsyncStreamHandler(StreamHandlerBase):
|
||||
StreamHandlerImpl = StreamHandler | AsyncStreamHandler
|
||||
|
||||
|
||||
class AudioVideoStreamHandler(StreamHandlerBase):
|
||||
class AudioVideoStreamHandler(StreamHandler):
|
||||
@abstractmethod
|
||||
def video_receive(self, frame: VideoFrame) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def video_emit(
|
||||
self,
|
||||
) -> VideoEmitType:
|
||||
def video_emit(self) -> VideoEmitType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -328,15 +322,13 @@ class AudioVideoStreamHandler(StreamHandlerBase):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncAudioVideoStreamHandler(StreamHandlerBase):
|
||||
class AsyncAudioVideoStreamHandler(AsyncStreamHandler):
|
||||
@abstractmethod
|
||||
async def video_receive(self, frame: npt.NDArray[np.float32]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def video_emit(
|
||||
self,
|
||||
) -> VideoEmitType:
|
||||
async def video_emit(self) -> VideoEmitType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -6,7 +6,9 @@ import logging
|
||||
import tempfile
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
|
||||
|
||||
import functools
|
||||
import traceback
|
||||
import inspect
|
||||
import av
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
@@ -353,3 +355,47 @@ async def async_aggregate_bytes_to_16bit(chunks_iterator):
|
||||
if to_process:
|
||||
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
|
||||
yield audio_array
|
||||
|
||||
|
||||
def webrtc_error_handler(func):
|
||||
"""Decorator to catch exceptions and raise WebRTCError with stacktrace."""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, WebRTCError):
|
||||
raise e
|
||||
else:
|
||||
raise WebRTCError(str(e)) from e
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, WebRTCError):
|
||||
raise e
|
||||
else:
|
||||
raise WebRTCError(str(e)) from e
|
||||
|
||||
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
async def wait_for_item(queue: asyncio.Queue, timeout: float = 0.1) -> Any:
|
||||
"""
|
||||
Wait for an item from an asyncio.Queue with a timeout.
|
||||
|
||||
This function attempts to retrieve an item from the queue using asyncio.wait_for.
|
||||
If the timeout is reached, it returns None.
|
||||
|
||||
This is useful to avoid blocking `emit` when the queue is empty.
|
||||
"""
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(queue.get(), timeout=timeout)
|
||||
except (TimeoutError, asyncio.TimeoutError):
|
||||
return None
|
||||
|
||||
@@ -37,6 +37,7 @@ from fastrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
DataChannel,
|
||||
create_message,
|
||||
webrtc_error_handler,
|
||||
)
|
||||
|
||||
Track = (
|
||||
@@ -148,8 +149,16 @@ class WebRTCConnectionMixin:
|
||||
|
||||
if isinstance(self.event_handler, StreamHandlerBase):
|
||||
handler = self.event_handler.copy()
|
||||
handler.emit = webrtc_error_handler(handler.emit)
|
||||
handler.receive = webrtc_error_handler(handler.receive)
|
||||
handler.start_up = webrtc_error_handler(handler.start_up)
|
||||
handler.shutdown = webrtc_error_handler(handler.shutdown)
|
||||
if hasattr(handler, "video_receive"):
|
||||
handler.video_receive = webrtc_error_handler(handler.video_receive)
|
||||
if hasattr(handler, "video_emit"):
|
||||
handler.video_emit = webrtc_error_handler(handler.video_emit)
|
||||
else:
|
||||
handler = cast(Callable, self.event_handler)
|
||||
handler = webrtc_error_handler(cast(Callable, self.event_handler))
|
||||
|
||||
self.handlers[body["webrtc_id"]] = handler
|
||||
|
||||
|
||||
@@ -62,44 +62,28 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
|
||||
api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
|
||||
)
|
||||
config = {"response_modalities": ["AUDIO"]}
|
||||
try:
|
||||
async with client.aio.live.connect(
|
||||
model="gemini-2.0-flash-exp", config=config
|
||||
) as session:
|
||||
self.session = session
|
||||
print("set session")
|
||||
while not self.quit.is_set():
|
||||
turn = self.session.receive()
|
||||
async for response in turn:
|
||||
if data := response.data:
|
||||
audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
|
||||
self.audio_queue.put_nowait(audio)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
async with client.aio.live.connect(
|
||||
model="gemini-2.0-flash-exp", config=config
|
||||
) as session:
|
||||
self.session = session
|
||||
print("set session")
|
||||
while not self.quit.is_set():
|
||||
turn = self.session.receive()
|
||||
async for response in turn:
|
||||
if data := response.data:
|
||||
audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
|
||||
self.audio_queue.put_nowait(audio)
|
||||
|
||||
async def video_receive(self, frame: np.ndarray):
|
||||
try:
|
||||
print("out")
|
||||
if self.session:
|
||||
print("here")
|
||||
# send image every 1 second
|
||||
print(time.time() - self.last_frame_time)
|
||||
if time.time() - self.last_frame_time > 1:
|
||||
self.last_frame_time = time.time()
|
||||
print("sending image")
|
||||
await self.session.send(input=encode_image(frame))
|
||||
print("sent image")
|
||||
if self.latest_args[1] is not None:
|
||||
print("sending image2")
|
||||
await self.session.send(input=encode_image(self.latest_args[1]))
|
||||
print("sent image2")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
if self.session:
|
||||
# send image every 1 second
|
||||
print(time.time() - self.last_frame_time)
|
||||
if time.time() - self.last_frame_time > 1:
|
||||
self.last_frame_time = time.time()
|
||||
await self.session.send(input=encode_image(frame))
|
||||
if self.latest_args[1] is not None:
|
||||
await self.session.send(input=encode_image(self.latest_args[1]))
|
||||
|
||||
traceback.print_exc()
|
||||
self.video_queue.put_nowait(frame)
|
||||
|
||||
async def video_emit(self):
|
||||
@@ -110,13 +94,7 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
|
||||
array = array.squeeze()
|
||||
audio_message = encode_audio(array)
|
||||
if self.session:
|
||||
try:
|
||||
await self.session.send(input=audio_message)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
await self.session.send(input=audio_message)
|
||||
|
||||
async def emit(self):
|
||||
array = await self.audio_queue.get()
|
||||
|
||||
@@ -13,7 +13,6 @@ from fastrtc import (
|
||||
AdditionalOutputs,
|
||||
ReplyOnStopWords,
|
||||
Stream,
|
||||
WebRTCError,
|
||||
get_stt_model,
|
||||
get_twilio_turn_credentials,
|
||||
)
|
||||
@@ -39,30 +38,23 @@ def response(
|
||||
):
|
||||
gradio_chatbot = gradio_chatbot or []
|
||||
conversation_state = conversation_state or []
|
||||
try:
|
||||
text = model.stt(audio)
|
||||
print("STT in handler", text)
|
||||
sample_rate, array = audio
|
||||
gradio_chatbot.append(
|
||||
{"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))}
|
||||
)
|
||||
yield AdditionalOutputs(gradio_chatbot, conversation_state)
|
||||
text = model.stt(audio)
|
||||
print("STT in handler", text)
|
||||
sample_rate, array = audio
|
||||
gradio_chatbot.append(
|
||||
{"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))}
|
||||
)
|
||||
yield AdditionalOutputs(gradio_chatbot, conversation_state)
|
||||
|
||||
conversation_state.append({"role": "user", "content": text})
|
||||
conversation_state.append({"role": "user", "content": text})
|
||||
|
||||
request = client.chat.completions.create(
|
||||
model="Meta-Llama-3.2-3B-Instruct",
|
||||
messages=conversation_state, # type: ignore
|
||||
temperature=0.1,
|
||||
top_p=0.1,
|
||||
)
|
||||
response = {"role": "assistant", "content": request.choices[0].message.content}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise WebRTCError(str(e) + "\n" + traceback.format_exc())
|
||||
request = client.chat.completions.create(
|
||||
model="Meta-Llama-3.2-3B-Instruct",
|
||||
messages=conversation_state, # type: ignore
|
||||
temperature=0.1,
|
||||
top_p=0.1,
|
||||
)
|
||||
response = {"role": "assistant", "content": request.choices[0].message.content}
|
||||
|
||||
conversation_state.append(response)
|
||||
gradio_chatbot.append(response)
|
||||
|
||||
@@ -10,7 +10,6 @@ from fastrtc import (
|
||||
AdditionalOutputs,
|
||||
ReplyOnPause,
|
||||
Stream,
|
||||
WebRTCError,
|
||||
get_stt_model,
|
||||
get_twilio_turn_credentials,
|
||||
)
|
||||
@@ -30,42 +29,36 @@ def response(
|
||||
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
||||
chatbot: list[dict] | None = None,
|
||||
):
|
||||
try:
|
||||
chatbot = chatbot or []
|
||||
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
|
||||
start = time.time()
|
||||
text = stt_model.stt(audio)
|
||||
print("transcription", time.time() - start)
|
||||
print("prompt", text)
|
||||
chatbot.append({"role": "user", "content": text})
|
||||
yield AdditionalOutputs(chatbot)
|
||||
messages.append({"role": "user", "content": text})
|
||||
response_text = (
|
||||
groq_client.chat.completions.create(
|
||||
model="llama-3.1-8b-instant",
|
||||
max_tokens=512,
|
||||
messages=messages, # type: ignore
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
chatbot = chatbot or []
|
||||
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
|
||||
start = time.time()
|
||||
text = stt_model.stt(audio)
|
||||
print("transcription", time.time() - start)
|
||||
print("prompt", text)
|
||||
chatbot.append({"role": "user", "content": text})
|
||||
yield AdditionalOutputs(chatbot)
|
||||
messages.append({"role": "user", "content": text})
|
||||
response_text = (
|
||||
groq_client.chat.completions.create(
|
||||
model="llama-3.1-8b-instant",
|
||||
max_tokens=512,
|
||||
messages=messages, # type: ignore
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
chatbot.append({"role": "assistant", "content": response_text})
|
||||
chatbot.append({"role": "assistant", "content": response_text})
|
||||
|
||||
for chunk in tts_client.text_to_speech.convert_as_stream(
|
||||
text=response_text, # type: ignore
|
||||
voice_id="JBFqnCBsd6RMkjVDRZzb",
|
||||
model_id="eleven_multilingual_v2",
|
||||
output_format="pcm_24000",
|
||||
):
|
||||
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
|
||||
yield (24000, audio_array)
|
||||
yield AdditionalOutputs(chatbot)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise WebRTCError(traceback.format_exc())
|
||||
for chunk in tts_client.text_to_speech.convert_as_stream(
|
||||
text=response_text, # type: ignore
|
||||
voice_id="JBFqnCBsd6RMkjVDRZzb",
|
||||
model_id="eleven_multilingual_v2",
|
||||
output_format="pcm_24000",
|
||||
):
|
||||
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
|
||||
yield (24000, audio_array)
|
||||
yield AdditionalOutputs(chatbot)
|
||||
|
||||
|
||||
chatbot = gr.Chatbot(type="messages")
|
||||
|
||||
@@ -5,7 +5,7 @@ import cv2
|
||||
import gradio as gr
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastrtc import Stream, WebRTCError, get_twilio_turn_credentials
|
||||
from fastrtc import Stream, get_twilio_turn_credentials
|
||||
from gradio.utils import get_space
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -26,16 +26,10 @@ model = YOLOv10(model_file)
|
||||
|
||||
|
||||
def detection(image, conf_threshold=0.3):
|
||||
try:
|
||||
image = cv2.resize(image, (model.input_width, model.input_height))
|
||||
print("conf_threshold", conf_threshold)
|
||||
new_image = model.detect_objects(image, conf_threshold)
|
||||
return cv2.resize(new_image, (500, 500))
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise WebRTCError(str(e))
|
||||
image = cv2.resize(image, (model.input_width, model.input_height))
|
||||
print("conf_threshold", conf_threshold)
|
||||
new_image = model.detect_objects(image, conf_threshold)
|
||||
return cv2.resize(new_image, (500, 500))
|
||||
|
||||
|
||||
stream = Stream(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import subprocess
|
||||
|
||||
subprocess.run(["pip", "install", "fastrtc==0.0.3.post7"])
|
||||
subprocess.run(["pip", "install", "fastrtc==0.0.4.post1"])
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
@@ -15,10 +15,9 @@ from fastrtc import (
|
||||
AsyncStreamHandler,
|
||||
Stream,
|
||||
get_twilio_turn_credentials,
|
||||
WebRTCError,
|
||||
audio_to_float32,
|
||||
wait_for_item,
|
||||
)
|
||||
from fastapi import FastAPI
|
||||
from phonic.client import PhonicSTSClient, get_voices
|
||||
|
||||
load_dotenv()
|
||||
@@ -42,47 +41,38 @@ class PhonicHandler(AsyncStreamHandler):
|
||||
async def start_up(self):
|
||||
await self.wait_for_args()
|
||||
voice_id = self.latest_args[1]
|
||||
try:
|
||||
async with PhonicSTSClient(STS_URI, API_KEY) as client:
|
||||
self.client = client
|
||||
sts_stream = client.sts( # type: ignore
|
||||
input_format="pcm_44100",
|
||||
output_format="pcm_44100",
|
||||
system_prompt="You are a helpful voice assistant. Respond conversationally.",
|
||||
# welcome_message="Hello! I'm your voice assistant. How can I help you today?",
|
||||
voice_id=voice_id,
|
||||
)
|
||||
async for message in sts_stream:
|
||||
message_type = message.get("type")
|
||||
if message_type == "audio_chunk":
|
||||
audio_b64 = message["audio"]
|
||||
audio_bytes = base64.b64decode(audio_b64)
|
||||
await self.output_queue.put(
|
||||
(SAMPLE_RATE, np.frombuffer(audio_bytes, dtype=np.int16))
|
||||
)
|
||||
if text := message.get("text"):
|
||||
msg = {"role": "assistant", "content": text}
|
||||
await self.output_queue.put(AdditionalOutputs(msg))
|
||||
elif message_type == "input_text":
|
||||
msg = {"role": "user", "content": message["text"]}
|
||||
async with PhonicSTSClient(STS_URI, API_KEY) as client:
|
||||
self.client = client
|
||||
sts_stream = client.sts( # type: ignore
|
||||
input_format="pcm_44100",
|
||||
output_format="pcm_44100",
|
||||
system_prompt="You are a helpful voice assistant. Respond conversationally.",
|
||||
# welcome_message="Hello! I'm your voice assistant. How can I help you today?",
|
||||
voice_id=voice_id,
|
||||
)
|
||||
async for message in sts_stream:
|
||||
message_type = message.get("type")
|
||||
if message_type == "audio_chunk":
|
||||
audio_b64 = message["audio"]
|
||||
audio_bytes = base64.b64decode(audio_b64)
|
||||
await self.output_queue.put(
|
||||
(SAMPLE_RATE, np.frombuffer(audio_bytes, dtype=np.int16))
|
||||
)
|
||||
if text := message.get("text"):
|
||||
msg = {"role": "assistant", "content": text}
|
||||
await self.output_queue.put(AdditionalOutputs(msg))
|
||||
except Exception as e:
|
||||
raise WebRTCError(f"Error starting up: {e}")
|
||||
elif message_type == "input_text":
|
||||
msg = {"role": "user", "content": message["text"]}
|
||||
await self.output_queue.put(AdditionalOutputs(msg))
|
||||
|
||||
async def emit(self):
|
||||
try:
|
||||
return await self.output_queue.get()
|
||||
except Exception as e:
|
||||
raise WebRTCError(f"Error emitting: {e}")
|
||||
return await wait_for_item(self.output_queue)
|
||||
|
||||
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
try:
|
||||
if not self.client:
|
||||
return
|
||||
audio_float32 = audio_to_float32(frame)
|
||||
await self.client.send_audio(audio_float32) # type: ignore
|
||||
except Exception as e:
|
||||
raise WebRTCError(f"Error sending audio: {e}")
|
||||
if not self.client:
|
||||
return
|
||||
audio_float32 = audio_to_float32(frame)
|
||||
await self.client.send_audio(audio_float32) # type: ignore
|
||||
|
||||
async def shutdown(self):
|
||||
if self.client:
|
||||
@@ -122,9 +112,6 @@ stream = Stream(
|
||||
with stream.ui:
|
||||
state.change(lambda s: s, inputs=state, outputs=chatbot)
|
||||
|
||||
app = FastAPI()
|
||||
stream.mount(app)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if (mode := os.getenv("MODE")) == "UI":
|
||||
stream.ui.launch(server_port=7860)
|
||||
|
||||
@@ -14,7 +14,6 @@ from fastrtc import (
|
||||
AdditionalOutputs,
|
||||
ReplyOnPause,
|
||||
Stream,
|
||||
WebRTCError,
|
||||
get_tts_model,
|
||||
get_twilio_turn_credentials,
|
||||
)
|
||||
@@ -38,41 +37,36 @@ def response(
|
||||
audio: tuple[int, np.ndarray],
|
||||
chatbot: list[dict] | None = None,
|
||||
):
|
||||
try:
|
||||
chatbot = chatbot or []
|
||||
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
|
||||
prompt = groq_client.audio.transcriptions.create(
|
||||
file=("audio-file.mp3", audio_to_bytes(audio)),
|
||||
model="whisper-large-v3-turbo",
|
||||
response_format="verbose_json",
|
||||
).text
|
||||
chatbot = chatbot or []
|
||||
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
|
||||
prompt = groq_client.audio.transcriptions.create(
|
||||
file=("audio-file.mp3", audio_to_bytes(audio)),
|
||||
model="whisper-large-v3-turbo",
|
||||
response_format="verbose_json",
|
||||
).text
|
||||
chatbot.append({"role": "user", "content": prompt})
|
||||
yield AdditionalOutputs(chatbot)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
response = claude_client.messages.create(
|
||||
model="claude-3-5-haiku-20241022",
|
||||
max_tokens=512,
|
||||
messages=messages, # type: ignore
|
||||
)
|
||||
response_text = " ".join(
|
||||
block.text # type: ignore
|
||||
for block in response.content
|
||||
if getattr(block, "type", None) == "text"
|
||||
)
|
||||
chatbot.append({"role": "assistant", "content": response_text})
|
||||
|
||||
print("prompt", prompt)
|
||||
chatbot.append({"role": "user", "content": prompt})
|
||||
start = time.time()
|
||||
|
||||
print("starting tts", start)
|
||||
for i, chunk in enumerate(tts_model.stream_tts_sync(response_text)):
|
||||
print("chunk", i, time.time() - start)
|
||||
yield chunk
|
||||
print("finished tts", time.time() - start)
|
||||
yield AdditionalOutputs(chatbot)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
response = claude_client.messages.create(
|
||||
model="claude-3-5-haiku-20241022",
|
||||
max_tokens=512,
|
||||
messages=messages, # type: ignore
|
||||
)
|
||||
response_text = " ".join(
|
||||
block.text # type: ignore
|
||||
for block in response.content
|
||||
if getattr(block, "type", None) == "text"
|
||||
)
|
||||
chatbot.append({"role": "assistant", "content": response_text})
|
||||
|
||||
start = time.time()
|
||||
|
||||
print("starting tts", start)
|
||||
for i, chunk in enumerate(tts_model.stream_tts_sync(response_text)):
|
||||
print("chunk", i, time.time() - start)
|
||||
yield chunk
|
||||
print("finished tts", time.time() - start)
|
||||
yield AdditionalOutputs(chatbot)
|
||||
except Exception as e:
|
||||
raise WebRTCError(str(e))
|
||||
|
||||
|
||||
chatbot = gr.Chatbot(type="messages")
|
||||
|
||||
@@ -13,8 +13,8 @@ from fastapi.responses import HTMLResponse
|
||||
from fastrtc import (
|
||||
AsyncStreamHandler,
|
||||
Stream,
|
||||
WebRTCError,
|
||||
get_twilio_turn_credentials,
|
||||
wait_for_item,
|
||||
)
|
||||
from google import genai
|
||||
from google.genai.types import (
|
||||
@@ -68,13 +68,12 @@ class GeminiHandler(AsyncStreamHandler):
|
||||
api_key, voice_name = self.latest_args[1:]
|
||||
else:
|
||||
api_key, voice_name = None, "Puck"
|
||||
try:
|
||||
client = genai.Client(
|
||||
api_key=api_key or os.getenv("GEMINI_API_KEY"),
|
||||
http_options={"api_version": "v1alpha"},
|
||||
)
|
||||
except Exception as e:
|
||||
raise WebRTCError(str(e))
|
||||
|
||||
client = genai.Client(
|
||||
api_key=api_key or os.getenv("GEMINI_API_KEY"),
|
||||
http_options={"api_version": "v1alpha"},
|
||||
)
|
||||
|
||||
config = LiveConnectConfig(
|
||||
response_modalities=["AUDIO"], # type: ignore
|
||||
speech_config=SpeechConfig(
|
||||
@@ -85,18 +84,15 @@ class GeminiHandler(AsyncStreamHandler):
|
||||
)
|
||||
),
|
||||
)
|
||||
try:
|
||||
async with client.aio.live.connect(
|
||||
model="gemini-2.0-flash-exp", config=config
|
||||
) as session:
|
||||
async for audio in session.start_stream(
|
||||
stream=self.stream(), mime_type="audio/pcm"
|
||||
):
|
||||
if audio.data:
|
||||
array = np.frombuffer(audio.data, dtype=np.int16)
|
||||
self.output_queue.put_nowait(array)
|
||||
except Exception as e:
|
||||
raise WebRTCError(str(e))
|
||||
async with client.aio.live.connect(
|
||||
model="gemini-2.0-flash-exp", config=config
|
||||
) as session:
|
||||
async for audio in session.start_stream(
|
||||
stream=self.stream(), mime_type="audio/pcm"
|
||||
):
|
||||
if audio.data:
|
||||
array = np.frombuffer(audio.data, dtype=np.int16)
|
||||
self.output_queue.put_nowait((self.output_sample_rate, array))
|
||||
|
||||
async def stream(self) -> AsyncGenerator[bytes, None]:
|
||||
while not self.quit.is_set():
|
||||
@@ -112,13 +108,11 @@ class GeminiHandler(AsyncStreamHandler):
|
||||
audio_message = encode_audio(array)
|
||||
self.input_queue.put_nowait(audio_message)
|
||||
|
||||
async def emit(self) -> tuple[int, np.ndarray]:
|
||||
array = await self.output_queue.get()
|
||||
return (self.output_sample_rate, array)
|
||||
async def emit(self) -> tuple[int, np.ndarray] | None:
|
||||
return await wait_for_item(self.output_queue)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.quit.set()
|
||||
self.args_set.clear()
|
||||
|
||||
|
||||
stream = Stream(
|
||||
|
||||
@@ -13,8 +13,8 @@ from fastrtc import (
|
||||
AdditionalOutputs,
|
||||
AsyncStreamHandler,
|
||||
Stream,
|
||||
WebRTCError,
|
||||
get_twilio_turn_credentials,
|
||||
wait_for_item,
|
||||
)
|
||||
from gradio.utils import get_space
|
||||
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
|
||||
@@ -47,60 +47,41 @@ class OpenAIHandler(AsyncStreamHandler):
|
||||
):
|
||||
"""Connect to realtime API. Run forever in separate thread to keep connection open."""
|
||||
self.client = openai.AsyncOpenAI()
|
||||
try:
|
||||
async with self.client.beta.realtime.connect(
|
||||
model="gpt-4o-mini-realtime-preview-2024-12-17"
|
||||
) as conn:
|
||||
await conn.session.update(
|
||||
session={"turn_detection": {"type": "server_vad"}}
|
||||
)
|
||||
self.connection = conn
|
||||
async for event in self.connection:
|
||||
if event.type == "response.audio_transcript.done":
|
||||
await self.output_queue.put(AdditionalOutputs(event))
|
||||
if event.type == "response.audio.delta":
|
||||
await self.output_queue.put(
|
||||
(
|
||||
self.output_sample_rate,
|
||||
np.frombuffer(
|
||||
base64.b64decode(event.delta), dtype=np.int16
|
||||
).reshape(1, -1),
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise WebRTCError(str(traceback.format_exc()))
|
||||
async with self.client.beta.realtime.connect(
|
||||
model="gpt-4o-mini-realtime-preview-2024-12-17"
|
||||
) as conn:
|
||||
await conn.session.update(
|
||||
session={"turn_detection": {"type": "server_vad"}}
|
||||
)
|
||||
self.connection = conn
|
||||
async for event in self.connection:
|
||||
if event.type == "response.audio_transcript.done":
|
||||
await self.output_queue.put(AdditionalOutputs(event))
|
||||
if event.type == "response.audio.delta":
|
||||
await self.output_queue.put(
|
||||
(
|
||||
self.output_sample_rate,
|
||||
np.frombuffer(
|
||||
base64.b64decode(event.delta), dtype=np.int16
|
||||
).reshape(1, -1),
|
||||
),
|
||||
)
|
||||
|
||||
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
if not self.connection:
|
||||
return
|
||||
try:
|
||||
_, array = frame
|
||||
array = array.squeeze()
|
||||
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
|
||||
await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore
|
||||
except Exception as e:
|
||||
# print traceback
|
||||
print(f"Error in receive: {str(e)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise WebRTCError(str(traceback.format_exc()))
|
||||
_, array = frame
|
||||
array = array.squeeze()
|
||||
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
|
||||
await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore
|
||||
|
||||
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
|
||||
return await self.output_queue.get()
|
||||
|
||||
def reset_state(self):
|
||||
"""Reset connection state for new recording session"""
|
||||
self.connection = None
|
||||
self.args_set.clear()
|
||||
return await wait_for_item(self.output_queue)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
self.reset_state()
|
||||
self.connection = None
|
||||
|
||||
|
||||
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
|
||||
|
||||
@@ -49,20 +49,15 @@ def response(
|
||||
|
||||
conversation_state.append({"role": "user", "content": text})
|
||||
|
||||
try:
|
||||
request = client.chat.completions.create(
|
||||
model="Meta-Llama-3.2-3B-Instruct",
|
||||
messages=conversation_state, # type: ignore
|
||||
temperature=0.1,
|
||||
top_p=0.1,
|
||||
)
|
||||
response = {"role": "assistant", "content": request.choices[0].message.content}
|
||||
raise WebRTCError("test")
|
||||
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise WebRTCError(traceback.format_exc())
|
||||
request = client.chat.completions.create(
|
||||
model="Meta-Llama-3.2-3B-Instruct",
|
||||
messages=conversation_state, # type: ignore
|
||||
temperature=0.1,
|
||||
top_p=0.1,
|
||||
)
|
||||
response = {"role": "assistant", "content": request.choices[0].message.content}
|
||||
|
||||
conversation_state.append(response)
|
||||
gradio_chatbot.append(response)
|
||||
|
||||
@@ -10,7 +10,6 @@ from fastrtc import (
|
||||
AdditionalOutputs,
|
||||
ReplyOnPause,
|
||||
Stream,
|
||||
WebRTCError,
|
||||
audio_to_bytes,
|
||||
get_twilio_turn_credentials,
|
||||
)
|
||||
@@ -26,15 +25,12 @@ groq_client = AsyncClient()
|
||||
|
||||
|
||||
async def transcribe(audio: tuple[int, np.ndarray]):
|
||||
try:
|
||||
transcript = await groq_client.audio.transcriptions.create(
|
||||
file=("audio-file.mp3", audio_to_bytes(audio)),
|
||||
model="whisper-large-v3-turbo",
|
||||
response_format="verbose_json",
|
||||
)
|
||||
yield AdditionalOutputs(transcript.text)
|
||||
except Exception as e:
|
||||
raise WebRTCError(str(e))
|
||||
transcript = await groq_client.audio.transcriptions.create(
|
||||
file=("audio-file.mp3", audio_to_bytes(audio)),
|
||||
model="whisper-large-v3-turbo",
|
||||
response_format="verbose_json",
|
||||
)
|
||||
yield AdditionalOutputs(transcript.text)
|
||||
|
||||
|
||||
stream = Stream(
|
||||
|
||||
1
docs/CNAME
Normal file
1
docs/CNAME
Normal file
@@ -0,0 +1 @@
|
||||
fastrtc.org
|
||||
@@ -63,56 +63,29 @@ The `Stream` has three main methods:
|
||||
=== "LLM Voice Chat"
|
||||
|
||||
```py
|
||||
from fastrtc import (
|
||||
ReplyOnPause, AdditionalOutputs, Stream,
|
||||
audio_to_bytes, aggregate_bytes_to_16bit
|
||||
import os
|
||||
|
||||
from fastrtc import (ReplyOnPause, Stream, get_stt_model, get_tts_model)
|
||||
from openai import OpenAI
|
||||
|
||||
sambanova_client = OpenAI(
|
||||
api_key=os.getenv("SAMBANOVA_API_KEY"), base_url="https://api.sambanova.ai/v1"
|
||||
)
|
||||
import gradio as gr
|
||||
from groq import Groq
|
||||
import anthropic
|
||||
from elevenlabs import ElevenLabs
|
||||
stt_model = get_stt_model()
|
||||
tts_model = get_tts_model()
|
||||
|
||||
groq_client = Groq()
|
||||
claude_client = anthropic.Anthropic()
|
||||
tts_client = ElevenLabs()
|
||||
|
||||
|
||||
# See "Talk to Claude" in Cookbook for an example of how to keep
|
||||
# track of the chat history.
|
||||
def response(
|
||||
audio: tuple[int, np.ndarray],
|
||||
):
|
||||
prompt = groq_client.audio.transcriptions.create(
|
||||
file=("audio-file.mp3", audio_to_bytes(audio)),
|
||||
model="whisper-large-v3-turbo",
|
||||
response_format="verbose_json",
|
||||
).text
|
||||
response = claude_client.messages.create(
|
||||
model="claude-3-5-haiku-20241022",
|
||||
max_tokens=512,
|
||||
def echo(audio):
|
||||
prompt = stt_model.stt(audio)
|
||||
response = sambanova_client.chat.completions.create(
|
||||
model="Meta-Llama-3.2-3B-Instruct",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=200,
|
||||
)
|
||||
response_text = " ".join(
|
||||
block.text
|
||||
for block in response.content
|
||||
if getattr(block, "type", None) == "text"
|
||||
)
|
||||
iterator = tts_client.text_to_speech.convert_as_stream(
|
||||
text=response_text,
|
||||
voice_id="JBFqnCBsd6RMkjVDRZzb",
|
||||
model_id="eleven_multilingual_v2",
|
||||
output_format="pcm_24000"
|
||||
|
||||
)
|
||||
for chunk in aggregate_bytes_to_16bit(iterator):
|
||||
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
|
||||
yield (24000, audio_array)
|
||||
prompt = response.choices[0].message.content
|
||||
for audio_chunk in tts_model.stream_tts_sync(prompt):
|
||||
yield audio_chunk
|
||||
|
||||
stream = Stream(
|
||||
modality="audio",
|
||||
mode="send-receive",
|
||||
handler=ReplyOnPause(response),
|
||||
)
|
||||
stream = Stream(ReplyOnPause(echo), modality="audio", mode="send-receive")
|
||||
```
|
||||
|
||||
=== "Webcam Stream"
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# Audio Streaming
|
||||
|
||||
## Reply On Pause
|
||||
|
||||
@@ -133,18 +132,21 @@ The API is similar to `ReplyOnPause` with the addition of a `stop_words` paramet
|
||||
```
|
||||
|
||||
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
|
||||
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
|
||||
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`. If you need to wait for a frame, use [`wait_for_item`](../../utils#wait_for_item) from the `utils` module.
|
||||
3. The `shutdown` method is called when the stream is closed. It should be used to clean up any resources.
|
||||
4. The `start_up` method is called when the stream is first created. It should be used to initialize any resources. See [Talk To OpenAI](https://huggingface.co/spaces/fastrtc/talk-to-openai-gradio) or [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for an example of a `StreamHandler` that uses the `start_up` method to connect to an API.
|
||||
=== "Notes"
|
||||
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
|
||||
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
|
||||
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`. If you need to wait for a frame, use [`wait_for_item`](../../utils#wait_for_item) from the `utils` module.
|
||||
3. The `shutdown` method is called when the stream is closed. It should be used to clean up any resources.
|
||||
4. The `start_up` method is called when the stream is first created. It should be used to initialize any resources. See [Talk To OpenAI](https://huggingface.co/spaces/fastrtc/talk-to-openai-gradio) or [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for an example of a `StreamHandler` that uses the `start_up` method to connect to an API.
|
||||
|
||||
!!! tip
|
||||
See this [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for a complete example of a more complex stream handler.
|
||||
|
||||
!!! warning
|
||||
The `emit` method should not block. If you need to wait for a frame, use [`wait_for_item`](../../utils#wait_for_item) from the `utils` module.
|
||||
|
||||
## Async Stream Handlers
|
||||
|
||||
It is also possible to create asynchronous stream handlers. This is very convenient for accessing async APIs from major LLM developers, like Google and OpenAI. The main difference is that `receive`, `emit`, and `start_up` are now defined with `async def`.
|
||||
|
||||
@@ -97,4 +97,27 @@ Example
|
||||
>>> print(chunk)
|
||||
```
|
||||
|
||||
## `wait_for_item`
|
||||
|
||||
Wait for an item from an asyncio.Queue with a timeout.
|
||||
|
||||
Parameters
|
||||
```
|
||||
queue : asyncio.Queue
|
||||
The queue to wait for an item from
|
||||
timeout : float
|
||||
The timeout in seconds
|
||||
```
|
||||
Returns
|
||||
```
|
||||
Any
|
||||
The item from the queue or None if the timeout is reached
|
||||
```
|
||||
|
||||
Example
|
||||
```python
|
||||
>>> queue = asyncio.Queue()
|
||||
>>> queue.put_nowait(1)
|
||||
>>> item = await wait_for_item(queue)
|
||||
>>> print(item)
|
||||
```
|
||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "fastrtc"
|
||||
version = "0.0.4"
|
||||
version = "0.0.4.post1"
|
||||
description = "The realtime communication library for Python"
|
||||
readme = "README.md"
|
||||
license = "apache-2.0"
|
||||
|
||||
Reference in New Issue
Block a user