From 074e9c9345b6917fc76d5ab639fa7c9dc41acc49 Mon Sep 17 00:00:00 2001 From: Freddy Boulton <41651716+freddyaboulton@users.noreply.github.com> Date: Tue, 22 Apr 2025 14:40:19 -0400 Subject: [PATCH] Fix websocket interruption (#291) * Code * Fix * add code * interruptions * Add code * code * Add code * Add code * code --- backend/fastrtc/stream.py | 9 ++ backend/fastrtc/websocket.py | 25 ++- demo/qwen_phone_chat/README.md | 14 ++ demo/qwen_phone_chat/app.py | 217 ++++++++++++++++++++++++++ demo/qwen_phone_chat/requirements.txt | 2 + demo/talk_to_openai/app.py | 5 +- demo/webrtc_vs_websocket/app.py | 2 +- justfile | 32 ++++ pyproject.toml | 2 +- 9 files changed, 287 insertions(+), 21 deletions(-) create mode 100644 demo/qwen_phone_chat/README.md create mode 100644 demo/qwen_phone_chat/app.py create mode 100644 demo/qwen_phone_chat/requirements.txt diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py index 87b84bb..fc320ca 100644 --- a/backend/fastrtc/stream.py +++ b/backend/fastrtc/stream.py @@ -928,6 +928,7 @@ class Stream(WebRTCConnectionMixin): json={"url": host}, headers={"Authorization": token or get_token() or ""}, ) + r.raise_for_status() except Exception: URL = "https://fastrtc-fastphone.hf.space" r = httpx.post( @@ -936,6 +937,14 @@ class Stream(WebRTCConnectionMixin): headers={"Authorization": token or get_token() or ""}, ) r.raise_for_status() + if r.status_code == 202: + print( + click.style("INFO", fg="orange") + + ":\t You have " + + "run out of your quota" + ) + return + data = r.json() code = f"{data['code']}" phone_number = data["phone"] diff --git a/backend/fastrtc/websocket.py b/backend/fastrtc/websocket.py index 4100208..f2b182d 100644 --- a/backend/fastrtc/websocket.py +++ b/backend/fastrtc/websocket.py @@ -81,17 +81,11 @@ class WebSocketHandler: self._graceful_shutdown_task: asyncio.Task | None = None def _clear_queue(self): - old_queue = self.queue - self.queue = asyncio.Queue() - logger.debug("clearing queue") i = 0 - while not old_queue.empty(): - try: - old_queue.get_nowait() - i += 1 - except asyncio.QueueEmpty: - break - logger.debug("popped %d items from queue", i) + while not self.queue.empty(): + self.queue.get_nowait() + i += 1 + logger.debug("websocket: popped %d items from queue", i) def set_args(self, args: list[Any]): self.stream_handler.set_args(args) @@ -260,8 +254,8 @@ class WebSocketHandler: async def _emit_loop(self): try: while not self.quit.is_set(): + wait_duration = 0.02 output = await self.queue.get() - if output is not None: frame, output = split_output(output) if isinstance(output, AdditionalOutputs): @@ -279,6 +273,7 @@ class WebSocketHandler: if self.stream_handler.phone_mode else self.stream_handler.output_sample_rate ) + duration = np.atleast_2d(frame[1]).shape[1] / frame[0] mulaw_audio = convert_to_mulaw( frame[1], frame[0], @@ -287,9 +282,6 @@ class WebSocketHandler: audio_payload = base64.b64encode(mulaw_audio).decode("utf-8") if self.websocket and self.stream_id: - sample_rate, audio_array = frame[:2] - duration = len(audio_array) / sample_rate - self.playing_durations.append(duration) payload = { @@ -298,9 +290,10 @@ class WebSocketHandler: } if self.stream_handler.phone_mode: payload["streamSid"] = self.stream_id + # yield audio slightly faster than real-time + wait_duration = 0.75 * duration await self.websocket.send_json(payload) - - await asyncio.sleep(0.02) + await asyncio.sleep(wait_duration) except asyncio.CancelledError: logger.debug("Emit loop cancelled") diff --git a/demo/qwen_phone_chat/README.md b/demo/qwen_phone_chat/README.md new file mode 100644 index 0000000..da5e482 --- /dev/null +++ b/demo/qwen_phone_chat/README.md @@ -0,0 +1,14 @@ +--- +title: Qwen Phone Chat +emoji: 📞 +colorFrom: pink +colorTo: green +sdk: gradio +sdk_version: 5.25.2 +app_file: app.py +pinned: false +license: mit +short_description: Talk with Qwen 2.5 Omni over the Phone +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/demo/qwen_phone_chat/app.py b/demo/qwen_phone_chat/app.py new file mode 100644 index 0000000..959cb73 --- /dev/null +++ b/demo/qwen_phone_chat/app.py @@ -0,0 +1,217 @@ +import asyncio +import base64 +import json +import os +import secrets +from pathlib import Path + +import gradio as gr +import numpy as np +from dotenv import load_dotenv +from fastapi import FastAPI, Request +from fastapi.responses import HTMLResponse +from fastrtc import ( + AdditionalOutputs, + AsyncStreamHandler, + Stream, + get_cloudflare_turn_credentials_async, + wait_for_item, +) +from websockets.asyncio.client import connect + +load_dotenv() + +cur_dir = Path(__file__).parent + +API_KEY = os.getenv("MODELSCOPE_API_KEY", "") +API_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model=qwen-omni-turbo-realtime-2025-03-26" +VOICES = ["Chelsie", "Serena", "Ethan", "Cherry"] +headers = {"Authorization": "Bearer " + API_KEY} + + +class QwenOmniHandler(AsyncStreamHandler): + def __init__( + self, + ) -> None: + super().__init__( + expected_layout="mono", + output_sample_rate=24_000, + input_sample_rate=16_000, + ) + self.connection = None + self.output_queue = asyncio.Queue() + + def copy(self): + return QwenOmniHandler() + + @staticmethod + def msg_id() -> str: + return f"event_{secrets.token_hex(10)}" + + async def start_up( + self, + ): + """Connect to realtime API. Run forever in separate thread to keep connection open.""" + voice_id = "Serena" + print("voice_id", voice_id) + async with connect( + API_URL, + additional_headers=headers, + ) as conn: + self.client = conn + await conn.send( + json.dumps( + { + "event_id": self.msg_id(), + "type": "session.update", + "session": { + "modalities": [ + "text", + "audio", + ], + "voice": voice_id, + "input_audio_format": "pcm16", + }, + } + ) + ) + self.connection = conn + try: + async for data in self.connection: + event = json.loads(data) + print("event", event["type"]) + if "type" not in event: + continue + # Handle interruptions + if event["type"] == "input_audio_buffer.speech_started": + self.clear_queue() + if event["type"] == "response.audio.delta": + print("putting output") + await self.output_queue.put( + ( + self.output_sample_rate, + np.frombuffer( + base64.b64decode(event["delta"]), dtype=np.int16 + ).reshape(1, -1), + ), + ) + except Exception as e: + print("error", e) + + async def receive(self, frame: tuple[int, np.ndarray]) -> None: + if not self.connection: + return + _, array = frame + array = array.squeeze() + audio_message = base64.b64encode(array.tobytes()).decode("utf-8") + try: + await self.connection.send( + json.dumps( + { + "event_id": self.msg_id(), + "type": "input_audio_buffer.append", + "audio": audio_message, + } + ) + ) + except Exception as e: + print("error", e) + + async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None: + return await wait_for_item(self.output_queue) + + async def shutdown(self) -> None: + if self.connection: + await self.connection.close() + self.connection = None + + +voice = gr.Dropdown(choices=VOICES, value=VOICES[0], type="value", label="Voice") +stream = Stream( + QwenOmniHandler(), + mode="send-receive", + modality="audio", + additional_inputs=[voice], + additional_outputs=None, + rtc_configuration=get_cloudflare_turn_credentials_async, + concurrency_limit=20, + time_limit=180, +) + +app = FastAPI() + + +@app.post("/telephone/incoming") +async def handle_incoming_call(request: Request): + """ + Handle incoming telephone calls (e.g., via Twilio). + + Generates TwiML instructions to connect the incoming call to the + WebSocket handler (`/telephone/handler`) for audio streaming. + + Args: + request: The FastAPI Request object for the incoming call webhook. + + Returns: + An HTMLResponse containing the TwiML instructions as XML. + """ + from twilio.twiml.voice_response import Connect, VoiceResponse + + if len(stream.connections) > (stream.concurrency_limit or 20): + response = VoiceResponse() + response.say("Qwen is busy please try again later!") + return HTMLResponse(content=str(response), media_type="application/xml") + + response = VoiceResponse() + response.say("Connecting to Qwen") + connect = Connect() + print("request.url.hostname", request.url.hostname) + connect.stream(url=f"wss://{request.url.hostname}/telephone/handler") + response.append(connect) + response.say("The call has been disconnected.") + return HTMLResponse(content=str(response), media_type="application/xml") + + +stream.mount(app) + + +@app.get("/") +async def _(): + html_content = """ + + + + Qwen Phone Chat + + + +

Qwen Phone Chat

+

Call +1 (877) 853-7936

+ + + """ + return HTMLResponse(content=html_content) + + +if __name__ == "__main__": + # stream.fastphone(host="0.0.0.0", port=7860) + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=7860) diff --git a/demo/qwen_phone_chat/requirements.txt b/demo/qwen_phone_chat/requirements.txt new file mode 100644 index 0000000..c249d2b --- /dev/null +++ b/demo/qwen_phone_chat/requirements.txt @@ -0,0 +1,2 @@ +fastrtc +websockets>=14.0 \ No newline at end of file diff --git a/demo/talk_to_openai/app.py b/demo/talk_to_openai/app.py index bc0a37d..3d05992 100644 --- a/demo/talk_to_openai/app.py +++ b/demo/talk_to_openai/app.py @@ -17,7 +17,6 @@ from fastrtc import ( wait_for_item, ) from gradio.utils import get_space -from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent load_dotenv() @@ -103,8 +102,8 @@ class OpenAIHandler(AsyncStreamHandler): self.connection = None -def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent): - chatbot.append({"role": "assistant", "content": response.transcript}) +def update_chatbot(chatbot: list[dict], response: dict): + chatbot.append(response) return chatbot diff --git a/demo/webrtc_vs_websocket/app.py b/demo/webrtc_vs_websocket/app.py index bf98a35..a7c7479 100644 --- a/demo/webrtc_vs_websocket/app.py +++ b/demo/webrtc_vs_websocket/app.py @@ -76,7 +76,7 @@ def response( ) for chunk in aggregate_bytes_to_16bit(iterator): audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1) - yield (24000, audio_array, "mono") + yield (24000, audio_array) chatbot = gr.Chatbot(type="messages") diff --git a/justfile b/justfile index ee0a33d..d51a6d3 100644 --- a/justfile +++ b/justfile @@ -42,6 +42,38 @@ publish: print(f"Uploading {latest_wheel}") os.system(f"twine upload {latest_wheel}") +# Upload the latest wheel to HF space with a random ID +publish-dev: + #!/usr/bin/env python + import glob + import os + import uuid + import subprocess + + # Find all wheel files in dist directory + wheels = glob.glob('dist/*.whl') + if not wheels: + print("No wheel files found in dist directory") + exit(1) + + # Sort by creation time to get the latest + latest_wheel = max(wheels, key=os.path.getctime) + wheel_name = os.path.basename(latest_wheel) + + # Generate random ID + random_id = str(uuid.uuid4())[:8] + + # Define the HF path + hf_space = "freddyaboulton/bucket" + hf_path = f"wheels/fastrtc/{random_id}/" + + # Upload to Hugging Face space + cmd = f"huggingface-cli upload {hf_space} {latest_wheel} {hf_path}{wheel_name} --repo-type dataset" + subprocess.run(cmd, shell=True, check=True) + + # Print the URL + print(f"Wheel uploaded successfully!") + print(f"URL: https://huggingface.co/datasets/{hf_space}/resolve/main/{hf_path}{wheel_name}") # Build the package build: diff --git a/pyproject.toml b/pyproject.toml index 9d6aaac..bf036bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" [project] name = "fastrtc" -version = "0.0.22.rc2" +version = "0.0.22.rc5" description = "The realtime communication library for Python" readme = "README.md" license = "MIT"