mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Raise errors automatically (#69)
* Add auto errors * change code --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
@@ -25,6 +25,7 @@ from .utils import (
|
|||||||
audio_to_bytes,
|
audio_to_bytes,
|
||||||
audio_to_file,
|
audio_to_file,
|
||||||
audio_to_float32,
|
audio_to_float32,
|
||||||
|
wait_for_item,
|
||||||
)
|
)
|
||||||
from .webrtc import (
|
from .webrtc import (
|
||||||
WebRTC,
|
WebRTC,
|
||||||
@@ -58,4 +59,5 @@ __all__ = [
|
|||||||
"Warning",
|
"Warning",
|
||||||
"get_tts_model",
|
"get_tts_model",
|
||||||
"KokoroTTSOptions",
|
"KokoroTTSOptions",
|
||||||
|
"wait_for_item",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -233,3 +233,4 @@ class ReplyOnPause(StreamHandler):
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.debug("Error in ReplyOnPause: %s", e)
|
logger.debug("Error in ReplyOnPause: %s", e)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
raise e
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""gr.WebRTC() component."""
|
"""WebRTC tracks."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -277,9 +277,7 @@ class StreamHandler(StreamHandlerBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def emit(
|
def emit(self) -> EmitType:
|
||||||
self,
|
|
||||||
) -> EmitType:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -296,9 +294,7 @@ class AsyncStreamHandler(StreamHandlerBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def emit(
|
async def emit(self) -> EmitType:
|
||||||
self,
|
|
||||||
) -> EmitType:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -312,15 +308,13 @@ class AsyncStreamHandler(StreamHandlerBase):
|
|||||||
StreamHandlerImpl = StreamHandler | AsyncStreamHandler
|
StreamHandlerImpl = StreamHandler | AsyncStreamHandler
|
||||||
|
|
||||||
|
|
||||||
class AudioVideoStreamHandler(StreamHandlerBase):
|
class AudioVideoStreamHandler(StreamHandler):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def video_receive(self, frame: VideoFrame) -> None:
|
def video_receive(self, frame: VideoFrame) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def video_emit(
|
def video_emit(self) -> VideoEmitType:
|
||||||
self,
|
|
||||||
) -> VideoEmitType:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -328,15 +322,13 @@ class AudioVideoStreamHandler(StreamHandlerBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AsyncAudioVideoStreamHandler(StreamHandlerBase):
|
class AsyncAudioVideoStreamHandler(AsyncStreamHandler):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def video_receive(self, frame: npt.NDArray[np.float32]) -> None:
|
async def video_receive(self, frame: npt.NDArray[np.float32]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def video_emit(
|
async def video_emit(self) -> VideoEmitType:
|
||||||
self,
|
|
||||||
) -> VideoEmitType:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import logging
|
|||||||
import tempfile
|
import tempfile
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
|
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
|
||||||
|
import functools
|
||||||
|
import traceback
|
||||||
|
import inspect
|
||||||
import av
|
import av
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
@@ -353,3 +355,47 @@ async def async_aggregate_bytes_to_16bit(chunks_iterator):
|
|||||||
if to_process:
|
if to_process:
|
||||||
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
|
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
|
||||||
yield audio_array
|
yield audio_array
|
||||||
|
|
||||||
|
|
||||||
|
def webrtc_error_handler(func):
|
||||||
|
"""Decorator to catch exceptions and raise WebRTCError with stacktrace."""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, WebRTCError):
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
raise WebRTCError(str(e)) from e
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, WebRTCError):
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
raise WebRTCError(str(e)) from e
|
||||||
|
|
||||||
|
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_item(queue: asyncio.Queue, timeout: float = 0.1) -> Any:
|
||||||
|
"""
|
||||||
|
Wait for an item from an asyncio.Queue with a timeout.
|
||||||
|
|
||||||
|
This function attempts to retrieve an item from the queue using asyncio.wait_for.
|
||||||
|
If the timeout is reached, it returns None.
|
||||||
|
|
||||||
|
This is useful to avoid blocking `emit` when the queue is empty.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(queue.get(), timeout=timeout)
|
||||||
|
except (TimeoutError, asyncio.TimeoutError):
|
||||||
|
return None
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from fastrtc.utils import (
|
|||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
DataChannel,
|
DataChannel,
|
||||||
create_message,
|
create_message,
|
||||||
|
webrtc_error_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
Track = (
|
Track = (
|
||||||
@@ -148,8 +149,16 @@ class WebRTCConnectionMixin:
|
|||||||
|
|
||||||
if isinstance(self.event_handler, StreamHandlerBase):
|
if isinstance(self.event_handler, StreamHandlerBase):
|
||||||
handler = self.event_handler.copy()
|
handler = self.event_handler.copy()
|
||||||
|
handler.emit = webrtc_error_handler(handler.emit)
|
||||||
|
handler.receive = webrtc_error_handler(handler.receive)
|
||||||
|
handler.start_up = webrtc_error_handler(handler.start_up)
|
||||||
|
handler.shutdown = webrtc_error_handler(handler.shutdown)
|
||||||
|
if hasattr(handler, "video_receive"):
|
||||||
|
handler.video_receive = webrtc_error_handler(handler.video_receive)
|
||||||
|
if hasattr(handler, "video_emit"):
|
||||||
|
handler.video_emit = webrtc_error_handler(handler.video_emit)
|
||||||
else:
|
else:
|
||||||
handler = cast(Callable, self.event_handler)
|
handler = webrtc_error_handler(cast(Callable, self.event_handler))
|
||||||
|
|
||||||
self.handlers[body["webrtc_id"]] = handler
|
self.handlers[body["webrtc_id"]] = handler
|
||||||
|
|
||||||
|
|||||||
@@ -62,44 +62,28 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
|
|||||||
api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
|
api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
|
||||||
)
|
)
|
||||||
config = {"response_modalities": ["AUDIO"]}
|
config = {"response_modalities": ["AUDIO"]}
|
||||||
try:
|
async with client.aio.live.connect(
|
||||||
async with client.aio.live.connect(
|
model="gemini-2.0-flash-exp", config=config
|
||||||
model="gemini-2.0-flash-exp", config=config
|
) as session:
|
||||||
) as session:
|
self.session = session
|
||||||
self.session = session
|
print("set session")
|
||||||
print("set session")
|
while not self.quit.is_set():
|
||||||
while not self.quit.is_set():
|
turn = self.session.receive()
|
||||||
turn = self.session.receive()
|
async for response in turn:
|
||||||
async for response in turn:
|
if data := response.data:
|
||||||
if data := response.data:
|
audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
|
||||||
audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
|
self.audio_queue.put_nowait(audio)
|
||||||
self.audio_queue.put_nowait(audio)
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
async def video_receive(self, frame: np.ndarray):
|
async def video_receive(self, frame: np.ndarray):
|
||||||
try:
|
if self.session:
|
||||||
print("out")
|
# send image every 1 second
|
||||||
if self.session:
|
print(time.time() - self.last_frame_time)
|
||||||
print("here")
|
if time.time() - self.last_frame_time > 1:
|
||||||
# send image every 1 second
|
self.last_frame_time = time.time()
|
||||||
print(time.time() - self.last_frame_time)
|
await self.session.send(input=encode_image(frame))
|
||||||
if time.time() - self.last_frame_time > 1:
|
if self.latest_args[1] is not None:
|
||||||
self.last_frame_time = time.time()
|
await self.session.send(input=encode_image(self.latest_args[1]))
|
||||||
print("sending image")
|
|
||||||
await self.session.send(input=encode_image(frame))
|
|
||||||
print("sent image")
|
|
||||||
if self.latest_args[1] is not None:
|
|
||||||
print("sending image2")
|
|
||||||
await self.session.send(input=encode_image(self.latest_args[1]))
|
|
||||||
print("sent image2")
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
self.video_queue.put_nowait(frame)
|
self.video_queue.put_nowait(frame)
|
||||||
|
|
||||||
async def video_emit(self):
|
async def video_emit(self):
|
||||||
@@ -110,13 +94,7 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
|
|||||||
array = array.squeeze()
|
array = array.squeeze()
|
||||||
audio_message = encode_audio(array)
|
audio_message = encode_audio(array)
|
||||||
if self.session:
|
if self.session:
|
||||||
try:
|
await self.session.send(input=audio_message)
|
||||||
await self.session.send(input=audio_message)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
async def emit(self):
|
async def emit(self):
|
||||||
array = await self.audio_queue.get()
|
array = await self.audio_queue.get()
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from fastrtc import (
|
|||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
ReplyOnStopWords,
|
ReplyOnStopWords,
|
||||||
Stream,
|
Stream,
|
||||||
WebRTCError,
|
|
||||||
get_stt_model,
|
get_stt_model,
|
||||||
get_twilio_turn_credentials,
|
get_twilio_turn_credentials,
|
||||||
)
|
)
|
||||||
@@ -39,30 +38,23 @@ def response(
|
|||||||
):
|
):
|
||||||
gradio_chatbot = gradio_chatbot or []
|
gradio_chatbot = gradio_chatbot or []
|
||||||
conversation_state = conversation_state or []
|
conversation_state = conversation_state or []
|
||||||
try:
|
text = model.stt(audio)
|
||||||
text = model.stt(audio)
|
print("STT in handler", text)
|
||||||
print("STT in handler", text)
|
sample_rate, array = audio
|
||||||
sample_rate, array = audio
|
gradio_chatbot.append(
|
||||||
gradio_chatbot.append(
|
{"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))}
|
||||||
{"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))}
|
)
|
||||||
)
|
yield AdditionalOutputs(gradio_chatbot, conversation_state)
|
||||||
yield AdditionalOutputs(gradio_chatbot, conversation_state)
|
|
||||||
|
|
||||||
conversation_state.append({"role": "user", "content": text})
|
conversation_state.append({"role": "user", "content": text})
|
||||||
|
|
||||||
request = client.chat.completions.create(
|
request = client.chat.completions.create(
|
||||||
model="Meta-Llama-3.2-3B-Instruct",
|
model="Meta-Llama-3.2-3B-Instruct",
|
||||||
messages=conversation_state, # type: ignore
|
messages=conversation_state, # type: ignore
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
top_p=0.1,
|
top_p=0.1,
|
||||||
)
|
)
|
||||||
response = {"role": "assistant", "content": request.choices[0].message.content}
|
response = {"role": "assistant", "content": request.choices[0].message.content}
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise WebRTCError(str(e) + "\n" + traceback.format_exc())
|
|
||||||
|
|
||||||
conversation_state.append(response)
|
conversation_state.append(response)
|
||||||
gradio_chatbot.append(response)
|
gradio_chatbot.append(response)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from fastrtc import (
|
|||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
ReplyOnPause,
|
ReplyOnPause,
|
||||||
Stream,
|
Stream,
|
||||||
WebRTCError,
|
|
||||||
get_stt_model,
|
get_stt_model,
|
||||||
get_twilio_turn_credentials,
|
get_twilio_turn_credentials,
|
||||||
)
|
)
|
||||||
@@ -30,42 +29,36 @@ def response(
|
|||||||
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
||||||
chatbot: list[dict] | None = None,
|
chatbot: list[dict] | None = None,
|
||||||
):
|
):
|
||||||
try:
|
chatbot = chatbot or []
|
||||||
chatbot = chatbot or []
|
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
|
||||||
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
|
start = time.time()
|
||||||
start = time.time()
|
text = stt_model.stt(audio)
|
||||||
text = stt_model.stt(audio)
|
print("transcription", time.time() - start)
|
||||||
print("transcription", time.time() - start)
|
print("prompt", text)
|
||||||
print("prompt", text)
|
chatbot.append({"role": "user", "content": text})
|
||||||
chatbot.append({"role": "user", "content": text})
|
yield AdditionalOutputs(chatbot)
|
||||||
yield AdditionalOutputs(chatbot)
|
messages.append({"role": "user", "content": text})
|
||||||
messages.append({"role": "user", "content": text})
|
response_text = (
|
||||||
response_text = (
|
groq_client.chat.completions.create(
|
||||||
groq_client.chat.completions.create(
|
model="llama-3.1-8b-instant",
|
||||||
model="llama-3.1-8b-instant",
|
max_tokens=512,
|
||||||
max_tokens=512,
|
messages=messages, # type: ignore
|
||||||
messages=messages, # type: ignore
|
|
||||||
)
|
|
||||||
.choices[0]
|
|
||||||
.message.content
|
|
||||||
)
|
)
|
||||||
|
.choices[0]
|
||||||
|
.message.content
|
||||||
|
)
|
||||||
|
|
||||||
chatbot.append({"role": "assistant", "content": response_text})
|
chatbot.append({"role": "assistant", "content": response_text})
|
||||||
|
|
||||||
for chunk in tts_client.text_to_speech.convert_as_stream(
|
for chunk in tts_client.text_to_speech.convert_as_stream(
|
||||||
text=response_text, # type: ignore
|
text=response_text, # type: ignore
|
||||||
voice_id="JBFqnCBsd6RMkjVDRZzb",
|
voice_id="JBFqnCBsd6RMkjVDRZzb",
|
||||||
model_id="eleven_multilingual_v2",
|
model_id="eleven_multilingual_v2",
|
||||||
output_format="pcm_24000",
|
output_format="pcm_24000",
|
||||||
):
|
):
|
||||||
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
|
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
|
||||||
yield (24000, audio_array)
|
yield (24000, audio_array)
|
||||||
yield AdditionalOutputs(chatbot)
|
yield AdditionalOutputs(chatbot)
|
||||||
except Exception:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise WebRTCError(traceback.format_exc())
|
|
||||||
|
|
||||||
|
|
||||||
chatbot = gr.Chatbot(type="messages")
|
chatbot = gr.Chatbot(type="messages")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import cv2
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastrtc import Stream, WebRTCError, get_twilio_turn_credentials
|
from fastrtc import Stream, get_twilio_turn_credentials
|
||||||
from gradio.utils import get_space
|
from gradio.utils import get_space
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -26,16 +26,10 @@ model = YOLOv10(model_file)
|
|||||||
|
|
||||||
|
|
||||||
def detection(image, conf_threshold=0.3):
|
def detection(image, conf_threshold=0.3):
|
||||||
try:
|
image = cv2.resize(image, (model.input_width, model.input_height))
|
||||||
image = cv2.resize(image, (model.input_width, model.input_height))
|
print("conf_threshold", conf_threshold)
|
||||||
print("conf_threshold", conf_threshold)
|
new_image = model.detect_objects(image, conf_threshold)
|
||||||
new_image = model.detect_objects(image, conf_threshold)
|
return cv2.resize(new_image, (500, 500))
|
||||||
return cv2.resize(new_image, (500, 500))
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise WebRTCError(str(e))
|
|
||||||
|
|
||||||
|
|
||||||
stream = Stream(
|
stream = Stream(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
subprocess.run(["pip", "install", "fastrtc==0.0.3.post7"])
|
subprocess.run(["pip", "install", "fastrtc==0.0.4.post1"])
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
@@ -15,10 +15,9 @@ from fastrtc import (
|
|||||||
AsyncStreamHandler,
|
AsyncStreamHandler,
|
||||||
Stream,
|
Stream,
|
||||||
get_twilio_turn_credentials,
|
get_twilio_turn_credentials,
|
||||||
WebRTCError,
|
|
||||||
audio_to_float32,
|
audio_to_float32,
|
||||||
|
wait_for_item,
|
||||||
)
|
)
|
||||||
from fastapi import FastAPI
|
|
||||||
from phonic.client import PhonicSTSClient, get_voices
|
from phonic.client import PhonicSTSClient, get_voices
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -42,47 +41,38 @@ class PhonicHandler(AsyncStreamHandler):
|
|||||||
async def start_up(self):
|
async def start_up(self):
|
||||||
await self.wait_for_args()
|
await self.wait_for_args()
|
||||||
voice_id = self.latest_args[1]
|
voice_id = self.latest_args[1]
|
||||||
try:
|
async with PhonicSTSClient(STS_URI, API_KEY) as client:
|
||||||
async with PhonicSTSClient(STS_URI, API_KEY) as client:
|
self.client = client
|
||||||
self.client = client
|
sts_stream = client.sts( # type: ignore
|
||||||
sts_stream = client.sts( # type: ignore
|
input_format="pcm_44100",
|
||||||
input_format="pcm_44100",
|
output_format="pcm_44100",
|
||||||
output_format="pcm_44100",
|
system_prompt="You are a helpful voice assistant. Respond conversationally.",
|
||||||
system_prompt="You are a helpful voice assistant. Respond conversationally.",
|
# welcome_message="Hello! I'm your voice assistant. How can I help you today?",
|
||||||
# welcome_message="Hello! I'm your voice assistant. How can I help you today?",
|
voice_id=voice_id,
|
||||||
voice_id=voice_id,
|
)
|
||||||
)
|
async for message in sts_stream:
|
||||||
async for message in sts_stream:
|
message_type = message.get("type")
|
||||||
message_type = message.get("type")
|
if message_type == "audio_chunk":
|
||||||
if message_type == "audio_chunk":
|
audio_b64 = message["audio"]
|
||||||
audio_b64 = message["audio"]
|
audio_bytes = base64.b64decode(audio_b64)
|
||||||
audio_bytes = base64.b64decode(audio_b64)
|
await self.output_queue.put(
|
||||||
await self.output_queue.put(
|
(SAMPLE_RATE, np.frombuffer(audio_bytes, dtype=np.int16))
|
||||||
(SAMPLE_RATE, np.frombuffer(audio_bytes, dtype=np.int16))
|
)
|
||||||
)
|
if text := message.get("text"):
|
||||||
if text := message.get("text"):
|
msg = {"role": "assistant", "content": text}
|
||||||
msg = {"role": "assistant", "content": text}
|
|
||||||
await self.output_queue.put(AdditionalOutputs(msg))
|
|
||||||
elif message_type == "input_text":
|
|
||||||
msg = {"role": "user", "content": message["text"]}
|
|
||||||
await self.output_queue.put(AdditionalOutputs(msg))
|
await self.output_queue.put(AdditionalOutputs(msg))
|
||||||
except Exception as e:
|
elif message_type == "input_text":
|
||||||
raise WebRTCError(f"Error starting up: {e}")
|
msg = {"role": "user", "content": message["text"]}
|
||||||
|
await self.output_queue.put(AdditionalOutputs(msg))
|
||||||
|
|
||||||
async def emit(self):
|
async def emit(self):
|
||||||
try:
|
return await wait_for_item(self.output_queue)
|
||||||
return await self.output_queue.get()
|
|
||||||
except Exception as e:
|
|
||||||
raise WebRTCError(f"Error emitting: {e}")
|
|
||||||
|
|
||||||
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||||
try:
|
if not self.client:
|
||||||
if not self.client:
|
return
|
||||||
return
|
audio_float32 = audio_to_float32(frame)
|
||||||
audio_float32 = audio_to_float32(frame)
|
await self.client.send_audio(audio_float32) # type: ignore
|
||||||
await self.client.send_audio(audio_float32) # type: ignore
|
|
||||||
except Exception as e:
|
|
||||||
raise WebRTCError(f"Error sending audio: {e}")
|
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
if self.client:
|
if self.client:
|
||||||
@@ -122,9 +112,6 @@ stream = Stream(
|
|||||||
with stream.ui:
|
with stream.ui:
|
||||||
state.change(lambda s: s, inputs=state, outputs=chatbot)
|
state.change(lambda s: s, inputs=state, outputs=chatbot)
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
stream.mount(app)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if (mode := os.getenv("MODE")) == "UI":
|
if (mode := os.getenv("MODE")) == "UI":
|
||||||
stream.ui.launch(server_port=7860)
|
stream.ui.launch(server_port=7860)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from fastrtc import (
|
|||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
ReplyOnPause,
|
ReplyOnPause,
|
||||||
Stream,
|
Stream,
|
||||||
WebRTCError,
|
|
||||||
get_tts_model,
|
get_tts_model,
|
||||||
get_twilio_turn_credentials,
|
get_twilio_turn_credentials,
|
||||||
)
|
)
|
||||||
@@ -38,41 +37,36 @@ def response(
|
|||||||
audio: tuple[int, np.ndarray],
|
audio: tuple[int, np.ndarray],
|
||||||
chatbot: list[dict] | None = None,
|
chatbot: list[dict] | None = None,
|
||||||
):
|
):
|
||||||
try:
|
chatbot = chatbot or []
|
||||||
chatbot = chatbot or []
|
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
|
||||||
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
|
prompt = groq_client.audio.transcriptions.create(
|
||||||
prompt = groq_client.audio.transcriptions.create(
|
file=("audio-file.mp3", audio_to_bytes(audio)),
|
||||||
file=("audio-file.mp3", audio_to_bytes(audio)),
|
model="whisper-large-v3-turbo",
|
||||||
model="whisper-large-v3-turbo",
|
response_format="verbose_json",
|
||||||
response_format="verbose_json",
|
).text
|
||||||
).text
|
chatbot.append({"role": "user", "content": prompt})
|
||||||
|
yield AdditionalOutputs(chatbot)
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
response = claude_client.messages.create(
|
||||||
|
model="claude-3-5-haiku-20241022",
|
||||||
|
max_tokens=512,
|
||||||
|
messages=messages, # type: ignore
|
||||||
|
)
|
||||||
|
response_text = " ".join(
|
||||||
|
block.text # type: ignore
|
||||||
|
for block in response.content
|
||||||
|
if getattr(block, "type", None) == "text"
|
||||||
|
)
|
||||||
|
chatbot.append({"role": "assistant", "content": response_text})
|
||||||
|
|
||||||
print("prompt", prompt)
|
start = time.time()
|
||||||
chatbot.append({"role": "user", "content": prompt})
|
|
||||||
|
print("starting tts", start)
|
||||||
|
for i, chunk in enumerate(tts_model.stream_tts_sync(response_text)):
|
||||||
|
print("chunk", i, time.time() - start)
|
||||||
|
yield chunk
|
||||||
|
print("finished tts", time.time() - start)
|
||||||
yield AdditionalOutputs(chatbot)
|
yield AdditionalOutputs(chatbot)
|
||||||
messages.append({"role": "user", "content": prompt})
|
|
||||||
response = claude_client.messages.create(
|
|
||||||
model="claude-3-5-haiku-20241022",
|
|
||||||
max_tokens=512,
|
|
||||||
messages=messages, # type: ignore
|
|
||||||
)
|
|
||||||
response_text = " ".join(
|
|
||||||
block.text # type: ignore
|
|
||||||
for block in response.content
|
|
||||||
if getattr(block, "type", None) == "text"
|
|
||||||
)
|
|
||||||
chatbot.append({"role": "assistant", "content": response_text})
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
|
|
||||||
print("starting tts", start)
|
|
||||||
for i, chunk in enumerate(tts_model.stream_tts_sync(response_text)):
|
|
||||||
print("chunk", i, time.time() - start)
|
|
||||||
yield chunk
|
|
||||||
print("finished tts", time.time() - start)
|
|
||||||
yield AdditionalOutputs(chatbot)
|
|
||||||
except Exception as e:
|
|
||||||
raise WebRTCError(str(e))
|
|
||||||
|
|
||||||
|
|
||||||
chatbot = gr.Chatbot(type="messages")
|
chatbot = gr.Chatbot(type="messages")
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from fastapi.responses import HTMLResponse
|
|||||||
from fastrtc import (
|
from fastrtc import (
|
||||||
AsyncStreamHandler,
|
AsyncStreamHandler,
|
||||||
Stream,
|
Stream,
|
||||||
WebRTCError,
|
|
||||||
get_twilio_turn_credentials,
|
get_twilio_turn_credentials,
|
||||||
|
wait_for_item,
|
||||||
)
|
)
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai.types import (
|
from google.genai.types import (
|
||||||
@@ -68,13 +68,12 @@ class GeminiHandler(AsyncStreamHandler):
|
|||||||
api_key, voice_name = self.latest_args[1:]
|
api_key, voice_name = self.latest_args[1:]
|
||||||
else:
|
else:
|
||||||
api_key, voice_name = None, "Puck"
|
api_key, voice_name = None, "Puck"
|
||||||
try:
|
|
||||||
client = genai.Client(
|
client = genai.Client(
|
||||||
api_key=api_key or os.getenv("GEMINI_API_KEY"),
|
api_key=api_key or os.getenv("GEMINI_API_KEY"),
|
||||||
http_options={"api_version": "v1alpha"},
|
http_options={"api_version": "v1alpha"},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
raise WebRTCError(str(e))
|
|
||||||
config = LiveConnectConfig(
|
config = LiveConnectConfig(
|
||||||
response_modalities=["AUDIO"], # type: ignore
|
response_modalities=["AUDIO"], # type: ignore
|
||||||
speech_config=SpeechConfig(
|
speech_config=SpeechConfig(
|
||||||
@@ -85,18 +84,15 @@ class GeminiHandler(AsyncStreamHandler):
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
try:
|
async with client.aio.live.connect(
|
||||||
async with client.aio.live.connect(
|
model="gemini-2.0-flash-exp", config=config
|
||||||
model="gemini-2.0-flash-exp", config=config
|
) as session:
|
||||||
) as session:
|
async for audio in session.start_stream(
|
||||||
async for audio in session.start_stream(
|
stream=self.stream(), mime_type="audio/pcm"
|
||||||
stream=self.stream(), mime_type="audio/pcm"
|
):
|
||||||
):
|
if audio.data:
|
||||||
if audio.data:
|
array = np.frombuffer(audio.data, dtype=np.int16)
|
||||||
array = np.frombuffer(audio.data, dtype=np.int16)
|
self.output_queue.put_nowait((self.output_sample_rate, array))
|
||||||
self.output_queue.put_nowait(array)
|
|
||||||
except Exception as e:
|
|
||||||
raise WebRTCError(str(e))
|
|
||||||
|
|
||||||
async def stream(self) -> AsyncGenerator[bytes, None]:
|
async def stream(self) -> AsyncGenerator[bytes, None]:
|
||||||
while not self.quit.is_set():
|
while not self.quit.is_set():
|
||||||
@@ -112,13 +108,11 @@ class GeminiHandler(AsyncStreamHandler):
|
|||||||
audio_message = encode_audio(array)
|
audio_message = encode_audio(array)
|
||||||
self.input_queue.put_nowait(audio_message)
|
self.input_queue.put_nowait(audio_message)
|
||||||
|
|
||||||
async def emit(self) -> tuple[int, np.ndarray]:
|
async def emit(self) -> tuple[int, np.ndarray] | None:
|
||||||
array = await self.output_queue.get()
|
return await wait_for_item(self.output_queue)
|
||||||
return (self.output_sample_rate, array)
|
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
self.quit.set()
|
self.quit.set()
|
||||||
self.args_set.clear()
|
|
||||||
|
|
||||||
|
|
||||||
stream = Stream(
|
stream = Stream(
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from fastrtc import (
|
|||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
AsyncStreamHandler,
|
AsyncStreamHandler,
|
||||||
Stream,
|
Stream,
|
||||||
WebRTCError,
|
|
||||||
get_twilio_turn_credentials,
|
get_twilio_turn_credentials,
|
||||||
|
wait_for_item,
|
||||||
)
|
)
|
||||||
from gradio.utils import get_space
|
from gradio.utils import get_space
|
||||||
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
|
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
|
||||||
@@ -47,60 +47,41 @@ class OpenAIHandler(AsyncStreamHandler):
|
|||||||
):
|
):
|
||||||
"""Connect to realtime API. Run forever in separate thread to keep connection open."""
|
"""Connect to realtime API. Run forever in separate thread to keep connection open."""
|
||||||
self.client = openai.AsyncOpenAI()
|
self.client = openai.AsyncOpenAI()
|
||||||
try:
|
async with self.client.beta.realtime.connect(
|
||||||
async with self.client.beta.realtime.connect(
|
model="gpt-4o-mini-realtime-preview-2024-12-17"
|
||||||
model="gpt-4o-mini-realtime-preview-2024-12-17"
|
) as conn:
|
||||||
) as conn:
|
await conn.session.update(
|
||||||
await conn.session.update(
|
session={"turn_detection": {"type": "server_vad"}}
|
||||||
session={"turn_detection": {"type": "server_vad"}}
|
)
|
||||||
)
|
self.connection = conn
|
||||||
self.connection = conn
|
async for event in self.connection:
|
||||||
async for event in self.connection:
|
if event.type == "response.audio_transcript.done":
|
||||||
if event.type == "response.audio_transcript.done":
|
await self.output_queue.put(AdditionalOutputs(event))
|
||||||
await self.output_queue.put(AdditionalOutputs(event))
|
if event.type == "response.audio.delta":
|
||||||
if event.type == "response.audio.delta":
|
await self.output_queue.put(
|
||||||
await self.output_queue.put(
|
(
|
||||||
(
|
self.output_sample_rate,
|
||||||
self.output_sample_rate,
|
np.frombuffer(
|
||||||
np.frombuffer(
|
base64.b64decode(event.delta), dtype=np.int16
|
||||||
base64.b64decode(event.delta), dtype=np.int16
|
).reshape(1, -1),
|
||||||
).reshape(1, -1),
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise WebRTCError(str(traceback.format_exc()))
|
|
||||||
|
|
||||||
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||||
if not self.connection:
|
if not self.connection:
|
||||||
return
|
return
|
||||||
try:
|
_, array = frame
|
||||||
_, array = frame
|
array = array.squeeze()
|
||||||
array = array.squeeze()
|
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
|
||||||
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
|
await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore
|
||||||
await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore
|
|
||||||
except Exception as e:
|
|
||||||
# print traceback
|
|
||||||
print(f"Error in receive: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise WebRTCError(str(traceback.format_exc()))
|
|
||||||
|
|
||||||
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
|
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
|
||||||
return await self.output_queue.get()
|
return await wait_for_item(self.output_queue)
|
||||||
|
|
||||||
def reset_state(self):
|
|
||||||
"""Reset connection state for new recording session"""
|
|
||||||
self.connection = None
|
|
||||||
self.args_set.clear()
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self.connection:
|
if self.connection:
|
||||||
await self.connection.close()
|
await self.connection.close()
|
||||||
self.reset_state()
|
self.connection = None
|
||||||
|
|
||||||
|
|
||||||
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
|
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
|
||||||
|
|||||||
@@ -49,20 +49,15 @@ def response(
|
|||||||
|
|
||||||
conversation_state.append({"role": "user", "content": text})
|
conversation_state.append({"role": "user", "content": text})
|
||||||
|
|
||||||
try:
|
raise WebRTCError("test")
|
||||||
request = client.chat.completions.create(
|
|
||||||
model="Meta-Llama-3.2-3B-Instruct",
|
|
||||||
messages=conversation_state, # type: ignore
|
|
||||||
temperature=0.1,
|
|
||||||
top_p=0.1,
|
|
||||||
)
|
|
||||||
response = {"role": "assistant", "content": request.choices[0].message.content}
|
|
||||||
|
|
||||||
except Exception:
|
request = client.chat.completions.create(
|
||||||
import traceback
|
model="Meta-Llama-3.2-3B-Instruct",
|
||||||
|
messages=conversation_state, # type: ignore
|
||||||
traceback.print_exc()
|
temperature=0.1,
|
||||||
raise WebRTCError(traceback.format_exc())
|
top_p=0.1,
|
||||||
|
)
|
||||||
|
response = {"role": "assistant", "content": request.choices[0].message.content}
|
||||||
|
|
||||||
conversation_state.append(response)
|
conversation_state.append(response)
|
||||||
gradio_chatbot.append(response)
|
gradio_chatbot.append(response)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from fastrtc import (
|
|||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
ReplyOnPause,
|
ReplyOnPause,
|
||||||
Stream,
|
Stream,
|
||||||
WebRTCError,
|
|
||||||
audio_to_bytes,
|
audio_to_bytes,
|
||||||
get_twilio_turn_credentials,
|
get_twilio_turn_credentials,
|
||||||
)
|
)
|
||||||
@@ -26,15 +25,12 @@ groq_client = AsyncClient()
|
|||||||
|
|
||||||
|
|
||||||
async def transcribe(audio: tuple[int, np.ndarray]):
|
async def transcribe(audio: tuple[int, np.ndarray]):
|
||||||
try:
|
transcript = await groq_client.audio.transcriptions.create(
|
||||||
transcript = await groq_client.audio.transcriptions.create(
|
file=("audio-file.mp3", audio_to_bytes(audio)),
|
||||||
file=("audio-file.mp3", audio_to_bytes(audio)),
|
model="whisper-large-v3-turbo",
|
||||||
model="whisper-large-v3-turbo",
|
response_format="verbose_json",
|
||||||
response_format="verbose_json",
|
)
|
||||||
)
|
yield AdditionalOutputs(transcript.text)
|
||||||
yield AdditionalOutputs(transcript.text)
|
|
||||||
except Exception as e:
|
|
||||||
raise WebRTCError(str(e))
|
|
||||||
|
|
||||||
|
|
||||||
stream = Stream(
|
stream = Stream(
|
||||||
|
|||||||
1
docs/CNAME
Normal file
1
docs/CNAME
Normal file
@@ -0,0 +1 @@
|
|||||||
|
fastrtc.org
|
||||||
@@ -63,56 +63,29 @@ The `Stream` has three main methods:
|
|||||||
=== "LLM Voice Chat"
|
=== "LLM Voice Chat"
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from fastrtc import (
|
import os
|
||||||
ReplyOnPause, AdditionalOutputs, Stream,
|
|
||||||
audio_to_bytes, aggregate_bytes_to_16bit
|
from fastrtc import (ReplyOnPause, Stream, get_stt_model, get_tts_model)
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
sambanova_client = OpenAI(
|
||||||
|
api_key=os.getenv("SAMBANOVA_API_KEY"), base_url="https://api.sambanova.ai/v1"
|
||||||
)
|
)
|
||||||
import gradio as gr
|
stt_model = get_stt_model()
|
||||||
from groq import Groq
|
tts_model = get_tts_model()
|
||||||
import anthropic
|
|
||||||
from elevenlabs import ElevenLabs
|
|
||||||
|
|
||||||
groq_client = Groq()
|
def echo(audio):
|
||||||
claude_client = anthropic.Anthropic()
|
prompt = stt_model.stt(audio)
|
||||||
tts_client = ElevenLabs()
|
response = sambanova_client.chat.completions.create(
|
||||||
|
model="Meta-Llama-3.2-3B-Instruct",
|
||||||
|
|
||||||
# See "Talk to Claude" in Cookbook for an example of how to keep
|
|
||||||
# track of the chat history.
|
|
||||||
def response(
|
|
||||||
audio: tuple[int, np.ndarray],
|
|
||||||
):
|
|
||||||
prompt = groq_client.audio.transcriptions.create(
|
|
||||||
file=("audio-file.mp3", audio_to_bytes(audio)),
|
|
||||||
model="whisper-large-v3-turbo",
|
|
||||||
response_format="verbose_json",
|
|
||||||
).text
|
|
||||||
response = claude_client.messages.create(
|
|
||||||
model="claude-3-5-haiku-20241022",
|
|
||||||
max_tokens=512,
|
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
max_tokens=200,
|
||||||
)
|
)
|
||||||
response_text = " ".join(
|
prompt = response.choices[0].message.content
|
||||||
block.text
|
for audio_chunk in tts_model.stream_tts_sync(prompt):
|
||||||
for block in response.content
|
yield audio_chunk
|
||||||
if getattr(block, "type", None) == "text"
|
|
||||||
)
|
|
||||||
iterator = tts_client.text_to_speech.convert_as_stream(
|
|
||||||
text=response_text,
|
|
||||||
voice_id="JBFqnCBsd6RMkjVDRZzb",
|
|
||||||
model_id="eleven_multilingual_v2",
|
|
||||||
output_format="pcm_24000"
|
|
||||||
|
|
||||||
)
|
|
||||||
for chunk in aggregate_bytes_to_16bit(iterator):
|
|
||||||
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
|
|
||||||
yield (24000, audio_array)
|
|
||||||
|
|
||||||
stream = Stream(
|
stream = Stream(ReplyOnPause(echo), modality="audio", mode="send-receive")
|
||||||
modality="audio",
|
|
||||||
mode="send-receive",
|
|
||||||
handler=ReplyOnPause(response),
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Webcam Stream"
|
=== "Webcam Stream"
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# Audio Streaming
|
|
||||||
|
|
||||||
## Reply On Pause
|
## Reply On Pause
|
||||||
|
|
||||||
@@ -133,18 +132,21 @@ The API is similar to `ReplyOnPause` with the addition of a `stop_words` paramet
|
|||||||
```
|
```
|
||||||
|
|
||||||
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
|
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
|
||||||
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
|
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`. If you need to wait for a frame, use [`wait_for_item`](../../utils#wait_for_item) from the `utils` module.
|
||||||
3. The `shutdown` method is called when the stream is closed. It should be used to clean up any resources.
|
3. The `shutdown` method is called when the stream is closed. It should be used to clean up any resources.
|
||||||
4. The `start_up` method is called when the stream is first created. It should be used to initialize any resources. See [Talk To OpenAI](https://huggingface.co/spaces/fastrtc/talk-to-openai-gradio) or [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for an example of a `StreamHandler` that uses the `start_up` method to connect to an API.
|
4. The `start_up` method is called when the stream is first created. It should be used to initialize any resources. See [Talk To OpenAI](https://huggingface.co/spaces/fastrtc/talk-to-openai-gradio) or [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for an example of a `StreamHandler` that uses the `start_up` method to connect to an API.
|
||||||
=== "Notes"
|
=== "Notes"
|
||||||
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
|
1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler.
|
||||||
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`.
|
2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`. If you need to wait for a frame, use [`wait_for_item`](../../utils#wait_for_item) from the `utils` module.
|
||||||
3. The `shutdown` method is called when the stream is closed. It should be used to clean up any resources.
|
3. The `shutdown` method is called when the stream is closed. It should be used to clean up any resources.
|
||||||
4. The `start_up` method is called when the stream is first created. It should be used to initialize any resources. See [Talk To OpenAI](https://huggingface.co/spaces/fastrtc/talk-to-openai-gradio) or [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for an example of a `StreamHandler` that uses the `start_up` method to connect to an API.
|
4. The `start_up` method is called when the stream is first created. It should be used to initialize any resources. See [Talk To OpenAI](https://huggingface.co/spaces/fastrtc/talk-to-openai-gradio) or [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for an example of a `StreamHandler` that uses the `start_up` method to connect to an API.
|
||||||
|
|
||||||
!!! tip
|
!!! tip
|
||||||
See this [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for a complete example of a more complex stream handler.
|
See this [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for a complete example of a more complex stream handler.
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
The `emit` method should not block. If you need to wait for a frame, use [`wait_for_item`](../../utils#wait_for_item) from the `utils` module.
|
||||||
|
|
||||||
## Async Stream Handlers
|
## Async Stream Handlers
|
||||||
|
|
||||||
It is also possible to create asynchronous stream handlers. This is very convenient for accessing async APIs from major LLM developers, like Google and OpenAI. The main difference is that `receive`, `emit`, and `start_up` are now defined with `async def`.
|
It is also possible to create asynchronous stream handlers. This is very convenient for accessing async APIs from major LLM developers, like Google and OpenAI. The main difference is that `receive`, `emit`, and `start_up` are now defined with `async def`.
|
||||||
|
|||||||
@@ -97,4 +97,27 @@ Example
|
|||||||
>>> print(chunk)
|
>>> print(chunk)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## `wait_for_item`
|
||||||
|
|
||||||
|
Wait for an item from an asyncio.Queue with a timeout.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
```
|
||||||
|
queue : asyncio.Queue
|
||||||
|
The queue to wait for an item from
|
||||||
|
timeout : float
|
||||||
|
The timeout in seconds
|
||||||
|
```
|
||||||
|
Returns
|
||||||
|
```
|
||||||
|
Any
|
||||||
|
The item from the queue or None if the timeout is reached
|
||||||
|
```
|
||||||
|
|
||||||
|
Example
|
||||||
|
```python
|
||||||
|
>>> queue = asyncio.Queue()
|
||||||
|
>>> queue.put_nowait(1)
|
||||||
|
>>> item = await wait_for_item(queue)
|
||||||
|
>>> print(item)
|
||||||
|
```
|
||||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "fastrtc"
|
name = "fastrtc"
|
||||||
version = "0.0.4"
|
version = "0.0.4.post1"
|
||||||
description = "The realtime communication library for Python"
|
description = "The realtime communication library for Python"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "apache-2.0"
|
license = "apache-2.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user