diff --git a/backend/fastrtc/__init__.py b/backend/fastrtc/__init__.py index d966237..6068799 100644 --- a/backend/fastrtc/__init__.py +++ b/backend/fastrtc/__init__.py @@ -25,6 +25,7 @@ from .utils import ( audio_to_bytes, audio_to_file, audio_to_float32, + wait_for_item, ) from .webrtc import ( WebRTC, @@ -58,4 +59,5 @@ __all__ = [ "Warning", "get_tts_model", "KokoroTTSOptions", + "wait_for_item", ] diff --git a/backend/fastrtc/reply_on_pause.py b/backend/fastrtc/reply_on_pause.py index 60b8794..7e93b57 100644 --- a/backend/fastrtc/reply_on_pause.py +++ b/backend/fastrtc/reply_on_pause.py @@ -233,3 +233,4 @@ class ReplyOnPause(StreamHandler): traceback.print_exc() logger.debug("Error in ReplyOnPause: %s", e) self.reset() + raise e diff --git a/backend/fastrtc/tracks.py b/backend/fastrtc/tracks.py index 91104b8..5bddbd6 100644 --- a/backend/fastrtc/tracks.py +++ b/backend/fastrtc/tracks.py @@ -1,4 +1,4 @@ -"""gr.WebRTC() component.""" +"""WebRTC tracks.""" from __future__ import annotations @@ -277,9 +277,7 @@ class StreamHandler(StreamHandlerBase): pass @abstractmethod - def emit( - self, - ) -> EmitType: + def emit(self) -> EmitType: pass @abstractmethod @@ -296,9 +294,7 @@ class AsyncStreamHandler(StreamHandlerBase): pass @abstractmethod - async def emit( - self, - ) -> EmitType: + async def emit(self) -> EmitType: pass @abstractmethod @@ -312,15 +308,13 @@ class AsyncStreamHandler(StreamHandlerBase): StreamHandlerImpl = StreamHandler | AsyncStreamHandler -class AudioVideoStreamHandler(StreamHandlerBase): +class AudioVideoStreamHandler(StreamHandler): @abstractmethod def video_receive(self, frame: VideoFrame) -> None: pass @abstractmethod - def video_emit( - self, - ) -> VideoEmitType: + def video_emit(self) -> VideoEmitType: pass @abstractmethod @@ -328,15 +322,13 @@ class AudioVideoStreamHandler(StreamHandlerBase): pass -class AsyncAudioVideoStreamHandler(StreamHandlerBase): +class AsyncAudioVideoStreamHandler(AsyncStreamHandler): @abstractmethod async def video_receive(self, frame: npt.NDArray[np.float32]) -> None: pass @abstractmethod - async def video_emit( - self, - ) -> VideoEmitType: + async def video_emit(self) -> VideoEmitType: pass @abstractmethod diff --git a/backend/fastrtc/utils.py b/backend/fastrtc/utils.py index 4c736e2..1788574 100644 --- a/backend/fastrtc/utils.py +++ b/backend/fastrtc/utils.py @@ -6,7 +6,9 @@ import logging import tempfile from contextvars import ContextVar from typing import Any, Callable, Literal, Protocol, TypedDict, cast - +import functools +import traceback +import inspect import av import numpy as np from numpy.typing import NDArray @@ -353,3 +355,47 @@ async def async_aggregate_bytes_to_16bit(chunks_iterator): if to_process: audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1) yield audio_array + + +def webrtc_error_handler(func): + """Decorator to catch exceptions and raise WebRTCError with stacktrace.""" + + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except Exception as e: + traceback.print_exc() + if isinstance(e, WebRTCError): + raise e + else: + raise WebRTCError(str(e)) from e + + @functools.wraps(func) + def sync_wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + traceback.print_exc() + if isinstance(e, WebRTCError): + raise e + else: + raise WebRTCError(str(e)) from e + + return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper + + +async def wait_for_item(queue: asyncio.Queue, timeout: float = 0.1) -> Any: + """ + Wait for an item from an asyncio.Queue with a timeout. + + This function attempts to retrieve an item from the queue using asyncio.wait_for. + If the timeout is reached, it returns None. + + This is useful to avoid blocking `emit` when the queue is empty. + """ + + try: + return await asyncio.wait_for(queue.get(), timeout=timeout) + except (TimeoutError, asyncio.TimeoutError): + return None diff --git a/backend/fastrtc/webrtc_connection_mixin.py b/backend/fastrtc/webrtc_connection_mixin.py index 0f2e0e5..8407bf0 100644 --- a/backend/fastrtc/webrtc_connection_mixin.py +++ b/backend/fastrtc/webrtc_connection_mixin.py @@ -37,6 +37,7 @@ from fastrtc.utils import ( AdditionalOutputs, DataChannel, create_message, + webrtc_error_handler, ) Track = ( @@ -148,8 +149,16 @@ class WebRTCConnectionMixin: if isinstance(self.event_handler, StreamHandlerBase): handler = self.event_handler.copy() + handler.emit = webrtc_error_handler(handler.emit) + handler.receive = webrtc_error_handler(handler.receive) + handler.start_up = webrtc_error_handler(handler.start_up) + handler.shutdown = webrtc_error_handler(handler.shutdown) + if hasattr(handler, "video_receive"): + handler.video_receive = webrtc_error_handler(handler.video_receive) + if hasattr(handler, "video_emit"): + handler.video_emit = webrtc_error_handler(handler.video_emit) else: - handler = cast(Callable, self.event_handler) + handler = webrtc_error_handler(cast(Callable, self.event_handler)) self.handlers[body["webrtc_id"]] = handler diff --git a/demo/gemini_audio_video/app.py b/demo/gemini_audio_video/app.py index 7fbe44d..77b2d51 100644 --- a/demo/gemini_audio_video/app.py +++ b/demo/gemini_audio_video/app.py @@ -62,44 +62,28 @@ class GeminiHandler(AsyncAudioVideoStreamHandler): api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"} ) config = {"response_modalities": ["AUDIO"]} - try: - async with client.aio.live.connect( - model="gemini-2.0-flash-exp", config=config - ) as session: - self.session = session - print("set session") - while not self.quit.is_set(): - turn = self.session.receive() - async for response in turn: - if data := response.data: - audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1) - self.audio_queue.put_nowait(audio) - except Exception as e: - import traceback - - traceback.print_exc() + async with client.aio.live.connect( + model="gemini-2.0-flash-exp", config=config + ) as session: + self.session = session + print("set session") + while not self.quit.is_set(): + turn = self.session.receive() + async for response in turn: + if data := response.data: + audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1) + self.audio_queue.put_nowait(audio) async def video_receive(self, frame: np.ndarray): - try: - print("out") - if self.session: - print("here") - # send image every 1 second - print(time.time() - self.last_frame_time) - if time.time() - self.last_frame_time > 1: - self.last_frame_time = time.time() - print("sending image") - await self.session.send(input=encode_image(frame)) - print("sent image") - if self.latest_args[1] is not None: - print("sending image2") - await self.session.send(input=encode_image(self.latest_args[1])) - print("sent image2") - except Exception as e: - print(e) - import traceback + if self.session: + # send image every 1 second + print(time.time() - self.last_frame_time) + if time.time() - self.last_frame_time > 1: + self.last_frame_time = time.time() + await self.session.send(input=encode_image(frame)) + if self.latest_args[1] is not None: + await self.session.send(input=encode_image(self.latest_args[1])) - traceback.print_exc() self.video_queue.put_nowait(frame) async def video_emit(self): @@ -110,13 +94,7 @@ class GeminiHandler(AsyncAudioVideoStreamHandler): array = array.squeeze() audio_message = encode_audio(array) if self.session: - try: - await self.session.send(input=audio_message) - except Exception as e: - print(e) - import traceback - - traceback.print_exc() + await self.session.send(input=audio_message) async def emit(self): array = await self.audio_queue.get() diff --git a/demo/hello_computer/app.py b/demo/hello_computer/app.py index 0f5a1dc..9143c12 100644 --- a/demo/hello_computer/app.py +++ b/demo/hello_computer/app.py @@ -13,7 +13,6 @@ from fastrtc import ( AdditionalOutputs, ReplyOnStopWords, Stream, - WebRTCError, get_stt_model, get_twilio_turn_credentials, ) @@ -39,30 +38,23 @@ def response( ): gradio_chatbot = gradio_chatbot or [] conversation_state = conversation_state or [] - try: - text = model.stt(audio) - print("STT in handler", text) - sample_rate, array = audio - gradio_chatbot.append( - {"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))} - ) - yield AdditionalOutputs(gradio_chatbot, conversation_state) + text = model.stt(audio) + print("STT in handler", text) + sample_rate, array = audio + gradio_chatbot.append( + {"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))} + ) + yield AdditionalOutputs(gradio_chatbot, conversation_state) - conversation_state.append({"role": "user", "content": text}) + conversation_state.append({"role": "user", "content": text}) - request = client.chat.completions.create( - model="Meta-Llama-3.2-3B-Instruct", - messages=conversation_state, # type: ignore - temperature=0.1, - top_p=0.1, - ) - response = {"role": "assistant", "content": request.choices[0].message.content} - - except Exception as e: - import traceback - - traceback.print_exc() - raise WebRTCError(str(e) + "\n" + traceback.format_exc()) + request = client.chat.completions.create( + model="Meta-Llama-3.2-3B-Instruct", + messages=conversation_state, # type: ignore + temperature=0.1, + top_p=0.1, + ) + response = {"role": "assistant", "content": request.choices[0].message.content} conversation_state.append(response) gradio_chatbot.append(response) diff --git a/demo/llm_voice_chat/app.py b/demo/llm_voice_chat/app.py index 19e9634..0c304d6 100644 --- a/demo/llm_voice_chat/app.py +++ b/demo/llm_voice_chat/app.py @@ -10,7 +10,6 @@ from fastrtc import ( AdditionalOutputs, ReplyOnPause, Stream, - WebRTCError, get_stt_model, get_twilio_turn_credentials, ) @@ -30,42 +29,36 @@ def response( audio: tuple[int, NDArray[np.int16 | np.float32]], chatbot: list[dict] | None = None, ): - try: - chatbot = chatbot or [] - messages = [{"role": d["role"], "content": d["content"]} for d in chatbot] - start = time.time() - text = stt_model.stt(audio) - print("transcription", time.time() - start) - print("prompt", text) - chatbot.append({"role": "user", "content": text}) - yield AdditionalOutputs(chatbot) - messages.append({"role": "user", "content": text}) - response_text = ( - groq_client.chat.completions.create( - model="llama-3.1-8b-instant", - max_tokens=512, - messages=messages, # type: ignore - ) - .choices[0] - .message.content + chatbot = chatbot or [] + messages = [{"role": d["role"], "content": d["content"]} for d in chatbot] + start = time.time() + text = stt_model.stt(audio) + print("transcription", time.time() - start) + print("prompt", text) + chatbot.append({"role": "user", "content": text}) + yield AdditionalOutputs(chatbot) + messages.append({"role": "user", "content": text}) + response_text = ( + groq_client.chat.completions.create( + model="llama-3.1-8b-instant", + max_tokens=512, + messages=messages, # type: ignore ) + .choices[0] + .message.content + ) - chatbot.append({"role": "assistant", "content": response_text}) + chatbot.append({"role": "assistant", "content": response_text}) - for chunk in tts_client.text_to_speech.convert_as_stream( - text=response_text, # type: ignore - voice_id="JBFqnCBsd6RMkjVDRZzb", - model_id="eleven_multilingual_v2", - output_format="pcm_24000", - ): - audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1) - yield (24000, audio_array) - yield AdditionalOutputs(chatbot) - except Exception: - import traceback - - traceback.print_exc() - raise WebRTCError(traceback.format_exc()) + for chunk in tts_client.text_to_speech.convert_as_stream( + text=response_text, # type: ignore + voice_id="JBFqnCBsd6RMkjVDRZzb", + model_id="eleven_multilingual_v2", + output_format="pcm_24000", + ): + audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1) + yield (24000, audio_array) + yield AdditionalOutputs(chatbot) chatbot = gr.Chatbot(type="messages") diff --git a/demo/object_detection/app.py b/demo/object_detection/app.py index 419a766..06bd1ac 100644 --- a/demo/object_detection/app.py +++ b/demo/object_detection/app.py @@ -5,7 +5,7 @@ import cv2 import gradio as gr from fastapi import FastAPI from fastapi.responses import HTMLResponse -from fastrtc import Stream, WebRTCError, get_twilio_turn_credentials +from fastrtc import Stream, get_twilio_turn_credentials from gradio.utils import get_space from huggingface_hub import hf_hub_download from pydantic import BaseModel, Field @@ -26,16 +26,10 @@ model = YOLOv10(model_file) def detection(image, conf_threshold=0.3): - try: - image = cv2.resize(image, (model.input_width, model.input_height)) - print("conf_threshold", conf_threshold) - new_image = model.detect_objects(image, conf_threshold) - return cv2.resize(new_image, (500, 500)) - except Exception as e: - import traceback - - traceback.print_exc() - raise WebRTCError(str(e)) + image = cv2.resize(image, (model.input_width, model.input_height)) + print("conf_threshold", conf_threshold) + new_image = model.detect_objects(image, conf_threshold) + return cv2.resize(new_image, (500, 500)) stream = Stream( diff --git a/demo/phonic_chat/app.py b/demo/phonic_chat/app.py index 6404b1d..69f1240 100644 --- a/demo/phonic_chat/app.py +++ b/demo/phonic_chat/app.py @@ -1,6 +1,6 @@ import subprocess -subprocess.run(["pip", "install", "fastrtc==0.0.3.post7"]) +subprocess.run(["pip", "install", "fastrtc==0.0.4.post1"]) import asyncio import base64 @@ -15,10 +15,9 @@ from fastrtc import ( AsyncStreamHandler, Stream, get_twilio_turn_credentials, - WebRTCError, audio_to_float32, + wait_for_item, ) -from fastapi import FastAPI from phonic.client import PhonicSTSClient, get_voices load_dotenv() @@ -42,47 +41,38 @@ class PhonicHandler(AsyncStreamHandler): async def start_up(self): await self.wait_for_args() voice_id = self.latest_args[1] - try: - async with PhonicSTSClient(STS_URI, API_KEY) as client: - self.client = client - sts_stream = client.sts( # type: ignore - input_format="pcm_44100", - output_format="pcm_44100", - system_prompt="You are a helpful voice assistant. Respond conversationally.", - # welcome_message="Hello! I'm your voice assistant. How can I help you today?", - voice_id=voice_id, - ) - async for message in sts_stream: - message_type = message.get("type") - if message_type == "audio_chunk": - audio_b64 = message["audio"] - audio_bytes = base64.b64decode(audio_b64) - await self.output_queue.put( - (SAMPLE_RATE, np.frombuffer(audio_bytes, dtype=np.int16)) - ) - if text := message.get("text"): - msg = {"role": "assistant", "content": text} - await self.output_queue.put(AdditionalOutputs(msg)) - elif message_type == "input_text": - msg = {"role": "user", "content": message["text"]} + async with PhonicSTSClient(STS_URI, API_KEY) as client: + self.client = client + sts_stream = client.sts( # type: ignore + input_format="pcm_44100", + output_format="pcm_44100", + system_prompt="You are a helpful voice assistant. Respond conversationally.", + # welcome_message="Hello! I'm your voice assistant. How can I help you today?", + voice_id=voice_id, + ) + async for message in sts_stream: + message_type = message.get("type") + if message_type == "audio_chunk": + audio_b64 = message["audio"] + audio_bytes = base64.b64decode(audio_b64) + await self.output_queue.put( + (SAMPLE_RATE, np.frombuffer(audio_bytes, dtype=np.int16)) + ) + if text := message.get("text"): + msg = {"role": "assistant", "content": text} await self.output_queue.put(AdditionalOutputs(msg)) - except Exception as e: - raise WebRTCError(f"Error starting up: {e}") + elif message_type == "input_text": + msg = {"role": "user", "content": message["text"]} + await self.output_queue.put(AdditionalOutputs(msg)) async def emit(self): - try: - return await self.output_queue.get() - except Exception as e: - raise WebRTCError(f"Error emitting: {e}") + return await wait_for_item(self.output_queue) async def receive(self, frame: tuple[int, np.ndarray]) -> None: - try: - if not self.client: - return - audio_float32 = audio_to_float32(frame) - await self.client.send_audio(audio_float32) # type: ignore - except Exception as e: - raise WebRTCError(f"Error sending audio: {e}") + if not self.client: + return + audio_float32 = audio_to_float32(frame) + await self.client.send_audio(audio_float32) # type: ignore async def shutdown(self): if self.client: @@ -122,9 +112,6 @@ stream = Stream( with stream.ui: state.change(lambda s: s, inputs=state, outputs=chatbot) -app = FastAPI() -stream.mount(app) - if __name__ == "__main__": if (mode := os.getenv("MODE")) == "UI": stream.ui.launch(server_port=7860) diff --git a/demo/talk_to_claude/app.py b/demo/talk_to_claude/app.py index 13abde0..703efe9 100644 --- a/demo/talk_to_claude/app.py +++ b/demo/talk_to_claude/app.py @@ -14,7 +14,6 @@ from fastrtc import ( AdditionalOutputs, ReplyOnPause, Stream, - WebRTCError, get_tts_model, get_twilio_turn_credentials, ) @@ -38,41 +37,36 @@ def response( audio: tuple[int, np.ndarray], chatbot: list[dict] | None = None, ): - try: - chatbot = chatbot or [] - messages = [{"role": d["role"], "content": d["content"]} for d in chatbot] - prompt = groq_client.audio.transcriptions.create( - file=("audio-file.mp3", audio_to_bytes(audio)), - model="whisper-large-v3-turbo", - response_format="verbose_json", - ).text + chatbot = chatbot or [] + messages = [{"role": d["role"], "content": d["content"]} for d in chatbot] + prompt = groq_client.audio.transcriptions.create( + file=("audio-file.mp3", audio_to_bytes(audio)), + model="whisper-large-v3-turbo", + response_format="verbose_json", + ).text + chatbot.append({"role": "user", "content": prompt}) + yield AdditionalOutputs(chatbot) + messages.append({"role": "user", "content": prompt}) + response = claude_client.messages.create( + model="claude-3-5-haiku-20241022", + max_tokens=512, + messages=messages, # type: ignore + ) + response_text = " ".join( + block.text # type: ignore + for block in response.content + if getattr(block, "type", None) == "text" + ) + chatbot.append({"role": "assistant", "content": response_text}) - print("prompt", prompt) - chatbot.append({"role": "user", "content": prompt}) + start = time.time() + + print("starting tts", start) + for i, chunk in enumerate(tts_model.stream_tts_sync(response_text)): + print("chunk", i, time.time() - start) + yield chunk + print("finished tts", time.time() - start) yield AdditionalOutputs(chatbot) - messages.append({"role": "user", "content": prompt}) - response = claude_client.messages.create( - model="claude-3-5-haiku-20241022", - max_tokens=512, - messages=messages, # type: ignore - ) - response_text = " ".join( - block.text # type: ignore - for block in response.content - if getattr(block, "type", None) == "text" - ) - chatbot.append({"role": "assistant", "content": response_text}) - - start = time.time() - - print("starting tts", start) - for i, chunk in enumerate(tts_model.stream_tts_sync(response_text)): - print("chunk", i, time.time() - start) - yield chunk - print("finished tts", time.time() - start) - yield AdditionalOutputs(chatbot) - except Exception as e: - raise WebRTCError(str(e)) chatbot = gr.Chatbot(type="messages") diff --git a/demo/talk_to_gemini/app.py b/demo/talk_to_gemini/app.py index 23e1da7..db5dcff 100644 --- a/demo/talk_to_gemini/app.py +++ b/demo/talk_to_gemini/app.py @@ -13,8 +13,8 @@ from fastapi.responses import HTMLResponse from fastrtc import ( AsyncStreamHandler, Stream, - WebRTCError, get_twilio_turn_credentials, + wait_for_item, ) from google import genai from google.genai.types import ( @@ -68,13 +68,12 @@ class GeminiHandler(AsyncStreamHandler): api_key, voice_name = self.latest_args[1:] else: api_key, voice_name = None, "Puck" - try: - client = genai.Client( - api_key=api_key or os.getenv("GEMINI_API_KEY"), - http_options={"api_version": "v1alpha"}, - ) - except Exception as e: - raise WebRTCError(str(e)) + + client = genai.Client( + api_key=api_key or os.getenv("GEMINI_API_KEY"), + http_options={"api_version": "v1alpha"}, + ) + config = LiveConnectConfig( response_modalities=["AUDIO"], # type: ignore speech_config=SpeechConfig( @@ -85,18 +84,15 @@ class GeminiHandler(AsyncStreamHandler): ) ), ) - try: - async with client.aio.live.connect( - model="gemini-2.0-flash-exp", config=config - ) as session: - async for audio in session.start_stream( - stream=self.stream(), mime_type="audio/pcm" - ): - if audio.data: - array = np.frombuffer(audio.data, dtype=np.int16) - self.output_queue.put_nowait(array) - except Exception as e: - raise WebRTCError(str(e)) + async with client.aio.live.connect( + model="gemini-2.0-flash-exp", config=config + ) as session: + async for audio in session.start_stream( + stream=self.stream(), mime_type="audio/pcm" + ): + if audio.data: + array = np.frombuffer(audio.data, dtype=np.int16) + self.output_queue.put_nowait((self.output_sample_rate, array)) async def stream(self) -> AsyncGenerator[bytes, None]: while not self.quit.is_set(): @@ -112,13 +108,11 @@ class GeminiHandler(AsyncStreamHandler): audio_message = encode_audio(array) self.input_queue.put_nowait(audio_message) - async def emit(self) -> tuple[int, np.ndarray]: - array = await self.output_queue.get() - return (self.output_sample_rate, array) + async def emit(self) -> tuple[int, np.ndarray] | None: + return await wait_for_item(self.output_queue) def shutdown(self) -> None: self.quit.set() - self.args_set.clear() stream = Stream( diff --git a/demo/talk_to_openai/app.py b/demo/talk_to_openai/app.py index 23e336b..e60ec1f 100644 --- a/demo/talk_to_openai/app.py +++ b/demo/talk_to_openai/app.py @@ -13,8 +13,8 @@ from fastrtc import ( AdditionalOutputs, AsyncStreamHandler, Stream, - WebRTCError, get_twilio_turn_credentials, + wait_for_item, ) from gradio.utils import get_space from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent @@ -47,60 +47,41 @@ class OpenAIHandler(AsyncStreamHandler): ): """Connect to realtime API. Run forever in separate thread to keep connection open.""" self.client = openai.AsyncOpenAI() - try: - async with self.client.beta.realtime.connect( - model="gpt-4o-mini-realtime-preview-2024-12-17" - ) as conn: - await conn.session.update( - session={"turn_detection": {"type": "server_vad"}} - ) - self.connection = conn - async for event in self.connection: - if event.type == "response.audio_transcript.done": - await self.output_queue.put(AdditionalOutputs(event)) - if event.type == "response.audio.delta": - await self.output_queue.put( - ( - self.output_sample_rate, - np.frombuffer( - base64.b64decode(event.delta), dtype=np.int16 - ).reshape(1, -1), - ), - ) - except Exception: - import traceback - - traceback.print_exc() - raise WebRTCError(str(traceback.format_exc())) + async with self.client.beta.realtime.connect( + model="gpt-4o-mini-realtime-preview-2024-12-17" + ) as conn: + await conn.session.update( + session={"turn_detection": {"type": "server_vad"}} + ) + self.connection = conn + async for event in self.connection: + if event.type == "response.audio_transcript.done": + await self.output_queue.put(AdditionalOutputs(event)) + if event.type == "response.audio.delta": + await self.output_queue.put( + ( + self.output_sample_rate, + np.frombuffer( + base64.b64decode(event.delta), dtype=np.int16 + ).reshape(1, -1), + ), + ) async def receive(self, frame: tuple[int, np.ndarray]) -> None: if not self.connection: return - try: - _, array = frame - array = array.squeeze() - audio_message = base64.b64encode(array.tobytes()).decode("utf-8") - await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore - except Exception as e: - # print traceback - print(f"Error in receive: {str(e)}") - import traceback - - traceback.print_exc() - raise WebRTCError(str(traceback.format_exc())) + _, array = frame + array = array.squeeze() + audio_message = base64.b64encode(array.tobytes()).decode("utf-8") + await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None: - return await self.output_queue.get() - - def reset_state(self): - """Reset connection state for new recording session""" - self.connection = None - self.args_set.clear() + return await wait_for_item(self.output_queue) async def shutdown(self) -> None: if self.connection: await self.connection.close() - self.reset_state() + self.connection = None def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent): diff --git a/demo/talk_to_sambanova/app.py b/demo/talk_to_sambanova/app.py index 603e453..3c2bb1d 100644 --- a/demo/talk_to_sambanova/app.py +++ b/demo/talk_to_sambanova/app.py @@ -49,20 +49,15 @@ def response( conversation_state.append({"role": "user", "content": text}) - try: - request = client.chat.completions.create( - model="Meta-Llama-3.2-3B-Instruct", - messages=conversation_state, # type: ignore - temperature=0.1, - top_p=0.1, - ) - response = {"role": "assistant", "content": request.choices[0].message.content} + raise WebRTCError("test") - except Exception: - import traceback - - traceback.print_exc() - raise WebRTCError(traceback.format_exc()) + request = client.chat.completions.create( + model="Meta-Llama-3.2-3B-Instruct", + messages=conversation_state, # type: ignore + temperature=0.1, + top_p=0.1, + ) + response = {"role": "assistant", "content": request.choices[0].message.content} conversation_state.append(response) gradio_chatbot.append(response) diff --git a/demo/whisper_realtime/app.py b/demo/whisper_realtime/app.py index 445df75..74b5c83 100644 --- a/demo/whisper_realtime/app.py +++ b/demo/whisper_realtime/app.py @@ -10,7 +10,6 @@ from fastrtc import ( AdditionalOutputs, ReplyOnPause, Stream, - WebRTCError, audio_to_bytes, get_twilio_turn_credentials, ) @@ -26,15 +25,12 @@ groq_client = AsyncClient() async def transcribe(audio: tuple[int, np.ndarray]): - try: - transcript = await groq_client.audio.transcriptions.create( - file=("audio-file.mp3", audio_to_bytes(audio)), - model="whisper-large-v3-turbo", - response_format="verbose_json", - ) - yield AdditionalOutputs(transcript.text) - except Exception as e: - raise WebRTCError(str(e)) + transcript = await groq_client.audio.transcriptions.create( + file=("audio-file.mp3", audio_to_bytes(audio)), + model="whisper-large-v3-turbo", + response_format="verbose_json", + ) + yield AdditionalOutputs(transcript.text) stream = Stream( diff --git a/docs/CNAME b/docs/CNAME new file mode 100644 index 0000000..a35e970 --- /dev/null +++ b/docs/CNAME @@ -0,0 +1 @@ +fastrtc.org \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 8a490bd..554e01c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -63,56 +63,29 @@ The `Stream` has three main methods: === "LLM Voice Chat" ```py - from fastrtc import ( - ReplyOnPause, AdditionalOutputs, Stream, - audio_to_bytes, aggregate_bytes_to_16bit + import os + + from fastrtc import (ReplyOnPause, Stream, get_stt_model, get_tts_model) + from openai import OpenAI + + sambanova_client = OpenAI( + api_key=os.getenv("SAMBANOVA_API_KEY"), base_url="https://api.sambanova.ai/v1" ) - import gradio as gr - from groq import Groq - import anthropic - from elevenlabs import ElevenLabs + stt_model = get_stt_model() + tts_model = get_tts_model() - groq_client = Groq() - claude_client = anthropic.Anthropic() - tts_client = ElevenLabs() - - - # See "Talk to Claude" in Cookbook for an example of how to keep - # track of the chat history. - def response( - audio: tuple[int, np.ndarray], - ): - prompt = groq_client.audio.transcriptions.create( - file=("audio-file.mp3", audio_to_bytes(audio)), - model="whisper-large-v3-turbo", - response_format="verbose_json", - ).text - response = claude_client.messages.create( - model="claude-3-5-haiku-20241022", - max_tokens=512, + def echo(audio): + prompt = stt_model.stt(audio) + response = sambanova_client.chat.completions.create( + model="Meta-Llama-3.2-3B-Instruct", messages=[{"role": "user", "content": prompt}], + max_tokens=200, ) - response_text = " ".join( - block.text - for block in response.content - if getattr(block, "type", None) == "text" - ) - iterator = tts_client.text_to_speech.convert_as_stream( - text=response_text, - voice_id="JBFqnCBsd6RMkjVDRZzb", - model_id="eleven_multilingual_v2", - output_format="pcm_24000" - - ) - for chunk in aggregate_bytes_to_16bit(iterator): - audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1) - yield (24000, audio_array) + prompt = response.choices[0].message.content + for audio_chunk in tts_model.stream_tts_sync(prompt): + yield audio_chunk - stream = Stream( - modality="audio", - mode="send-receive", - handler=ReplyOnPause(response), - ) + stream = Stream(ReplyOnPause(echo), modality="audio", mode="send-receive") ``` === "Webcam Stream" diff --git a/docs/userguide/audio.md b/docs/userguide/audio.md index 797ce33..0a6d25f 100644 --- a/docs/userguide/audio.md +++ b/docs/userguide/audio.md @@ -1,4 +1,3 @@ -# Audio Streaming ## Reply On Pause @@ -133,18 +132,21 @@ The API is similar to `ReplyOnPause` with the addition of a `stop_words` paramet ``` 1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler. - 2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`. + 2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`. If you need to wait for a frame, use [`wait_for_item`](../../utils#wait_for_item) from the `utils` module. 3. The `shutdown` method is called when the stream is closed. It should be used to clean up any resources. 4. The `start_up` method is called when the stream is first created. It should be used to initialize any resources. See [Talk To OpenAI](https://huggingface.co/spaces/fastrtc/talk-to-openai-gradio) or [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for an example of a `StreamHandler` that uses the `start_up` method to connect to an API. === "Notes" 1. The `StreamHandler` class implements three methods: `receive`, `emit` and `copy`. The `receive` method is called when a new frame is received from the client, and the `emit` method returns the next frame to send to the client. The `copy` method is called at the beginning of the stream to ensure each user has a unique stream handler. - 2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`. + 2. The `emit` method SHOULD NOT block. If a frame is not ready to be sent, the method should return `None`. If you need to wait for a frame, use [`wait_for_item`](../../utils#wait_for_item) from the `utils` module. 3. The `shutdown` method is called when the stream is closed. It should be used to clean up any resources. 4. The `start_up` method is called when the stream is first created. It should be used to initialize any resources. See [Talk To OpenAI](https://huggingface.co/spaces/fastrtc/talk-to-openai-gradio) or [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for an example of a `StreamHandler` that uses the `start_up` method to connect to an API. !!! tip See this [Talk To Gemini](https://huggingface.co/spaces/fastrtc/talk-to-gemini-gradio) for a complete example of a more complex stream handler. +!!! warning + The `emit` method should not block. If you need to wait for a frame, use [`wait_for_item`](../../utils#wait_for_item) from the `utils` module. + ## Async Stream Handlers It is also possible to create asynchronous stream handlers. This is very convenient for accessing async APIs from major LLM developers, like Google and OpenAI. The main difference is that `receive`, `emit`, and `start_up` are now defined with `async def`. diff --git a/docs/utils.md b/docs/utils.md index e7ada85..cb64ab6 100644 --- a/docs/utils.md +++ b/docs/utils.md @@ -97,4 +97,27 @@ Example >>> print(chunk) ``` +## `wait_for_item` +Wait for an item from an asyncio.Queue with a timeout. + +Parameters +``` +queue : asyncio.Queue + The queue to wait for an item from +timeout : float + The timeout in seconds +``` +Returns +``` +Any + The item from the queue or None if the timeout is reached +``` + +Example +```python +>>> queue = asyncio.Queue() +>>> queue.put_nowait(1) +>>> item = await wait_for_item(queue) +>>> print(item) +``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 718d92b..624ad26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" [project] name = "fastrtc" -version = "0.0.4" +version = "0.0.4.post1" description = "The realtime communication library for Python" readme = "README.md" license = "apache-2.0"