mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 01:49:23 +08:00
* Add code * add code * add code * Rename messages * rename * add code * Add demo * docs + demos + bug fixes * add code * styles * user guide * Styles * Add code * misc docs updates * print nit * whisper + pr * url for images * whsiper update * Fix bugs * remove demo files * version number * Fix pypi readme * Fix * demos * Add llama code editor * Update llama code editor and object detection cookbook * Add more cookbook demos * add code * Fix links for PR deploys * add code * Fix the install * add tts * TTS docs * Typo * Pending bubbles for reply on pause * Stream redesign (#63) * better error handling * Websocket error handling * add code --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local> * remove docs from dist * Some docs typos * more typos * upload changes + docs * docs * better phone * update docs * add code * Make demos better * fix docs + websocket start_up * remove mention of FastAPI app * fastphone tweaks * add code * ReplyOnStopWord fixes * Fix cookbook * Fix pypi readme * add code * bump versions * sambanova cookbook * Fix tags * Llm voice chat * kyutai tag * Add error message to all index.html * STT module uses Moonshine * Not required from typing extensions * fix llm voice chat * Add vpn warning * demo fixes * demos * Add more ui args and gemini audio-video * update cookbook * version 9 --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
185 lines
6.9 KiB
Python
185 lines
6.9 KiB
Python
import asyncio
|
|
import audioop
|
|
import base64
|
|
import logging
|
|
from typing import Any, Awaitable, Callable, Optional, cast
|
|
|
|
import anyio
|
|
import librosa
|
|
import numpy as np
|
|
from fastapi import WebSocket
|
|
|
|
from .tracks import AsyncStreamHandler, StreamHandlerImpl
|
|
from .utils import AdditionalOutputs, DataChannel, split_output
|
|
|
|
|
|
class WebSocketDataChannel(DataChannel):
|
|
def __init__(self, websocket: WebSocket, loop: asyncio.AbstractEventLoop):
|
|
self.websocket = websocket
|
|
self.loop = loop
|
|
|
|
def send(self, message: str) -> None:
|
|
asyncio.run_coroutine_threadsafe(self.websocket.send_text(message), self.loop)
|
|
|
|
|
|
logger = logging.getLogger(__file__)
|
|
|
|
|
|
def convert_to_mulaw(
|
|
audio_data: np.ndarray, original_rate: int, target_rate: int
|
|
) -> bytes:
|
|
"""Convert audio data to 8kHz mu-law format"""
|
|
|
|
if audio_data.dtype != np.float32:
|
|
audio_data = audio_data.astype(np.float32) / 32768.0
|
|
|
|
if original_rate != target_rate:
|
|
audio_data = librosa.resample(audio_data, orig_sr=original_rate, target_sr=8000)
|
|
|
|
audio_data = (audio_data * 32768).astype(np.int16)
|
|
|
|
return audioop.lin2ulaw(audio_data, 2) # type: ignore
|
|
|
|
|
|
run_sync = anyio.to_thread.run_sync # type: ignore
|
|
|
|
|
|
class WebSocketHandler:
|
|
def __init__(
|
|
self,
|
|
stream_handler: StreamHandlerImpl,
|
|
set_handler: Callable[[str, "WebSocketHandler"], Awaitable[None]],
|
|
clean_up: Callable[[str], None],
|
|
additional_outputs_factory: Callable[
|
|
[str], Callable[[AdditionalOutputs], None]
|
|
],
|
|
):
|
|
self.stream_handler = stream_handler
|
|
self.websocket: Optional[WebSocket] = None
|
|
self._emit_task: Optional[asyncio.Task] = None
|
|
self.stream_id: Optional[str] = None
|
|
self.set_additional_outputs_factory = additional_outputs_factory
|
|
self.set_additional_outputs: Callable[[AdditionalOutputs], None]
|
|
self.set_handler = set_handler
|
|
self.quit = asyncio.Event()
|
|
self.clean_up = clean_up
|
|
|
|
def set_args(self, args: list[Any]):
|
|
self.stream_handler.set_args(args)
|
|
|
|
async def handle_websocket(self, websocket: WebSocket):
|
|
await websocket.accept()
|
|
loop = asyncio.get_running_loop()
|
|
self.loop = loop
|
|
self.websocket = websocket
|
|
self.data_channel = WebSocketDataChannel(websocket, loop)
|
|
self.stream_handler._loop = loop
|
|
self.stream_handler.set_channel(self.data_channel)
|
|
self._emit_task = asyncio.create_task(self._emit_loop())
|
|
if isinstance(self.stream_handler, AsyncStreamHandler):
|
|
start_up = self.stream_handler.start_up()
|
|
else:
|
|
start_up = anyio.to_thread.run_sync(self.stream_handler.start_up) # type: ignore
|
|
|
|
self.start_up_task = asyncio.create_task(start_up)
|
|
try:
|
|
while not self.quit.is_set():
|
|
message = await websocket.receive_json()
|
|
|
|
if message["event"] == "media":
|
|
audio_payload = base64.b64decode(message["media"]["payload"])
|
|
|
|
audio_array = np.frombuffer(
|
|
audioop.ulaw2lin(audio_payload, 2), dtype=np.int16
|
|
)
|
|
|
|
if self.stream_handler.input_sample_rate != 8000:
|
|
audio_array = audio_array.astype(np.float32) / 32768.0
|
|
audio_array = librosa.resample(
|
|
audio_array,
|
|
orig_sr=8000,
|
|
target_sr=self.stream_handler.input_sample_rate,
|
|
)
|
|
audio_array = (audio_array * 32768).astype(np.int16)
|
|
if isinstance(self.stream_handler, AsyncStreamHandler):
|
|
await self.stream_handler.receive(
|
|
(self.stream_handler.input_sample_rate, audio_array)
|
|
)
|
|
else:
|
|
await run_sync(
|
|
self.stream_handler.receive,
|
|
(self.stream_handler.input_sample_rate, audio_array),
|
|
)
|
|
|
|
elif message["event"] == "start":
|
|
if self.stream_handler.phone_mode:
|
|
self.stream_id = cast(str, message["streamSid"])
|
|
else:
|
|
self.stream_id = cast(str, message["websocket_id"])
|
|
self.set_additional_outputs = self.set_additional_outputs_factory(
|
|
self.stream_id
|
|
)
|
|
await self.set_handler(self.stream_id, self)
|
|
elif message["event"] == "stop":
|
|
self.quit.set()
|
|
self.clean_up(cast(str, self.stream_id))
|
|
return
|
|
elif message["event"] == "ping":
|
|
await websocket.send_json({"event": "pong"})
|
|
|
|
except Exception as e:
|
|
print(e)
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
logger.debug("Error in websocket handler %s", e)
|
|
finally:
|
|
if self._emit_task:
|
|
self._emit_task.cancel()
|
|
if self.start_up_task:
|
|
self.start_up_task.cancel()
|
|
await websocket.close()
|
|
|
|
async def _emit_loop(self):
|
|
try:
|
|
while not self.quit.is_set():
|
|
if isinstance(self.stream_handler, AsyncStreamHandler):
|
|
output = await self.stream_handler.emit()
|
|
else:
|
|
output = await run_sync(self.stream_handler.emit)
|
|
|
|
if output is not None:
|
|
frame, output = split_output(output)
|
|
if output is not None:
|
|
self.set_additional_outputs(output)
|
|
if not isinstance(frame, tuple):
|
|
continue
|
|
target_rate = (
|
|
self.stream_handler.output_sample_rate
|
|
if not self.stream_handler.phone_mode
|
|
else 8000
|
|
)
|
|
mulaw_audio = convert_to_mulaw(
|
|
frame[1], frame[0], target_rate=target_rate
|
|
)
|
|
audio_payload = base64.b64encode(mulaw_audio).decode("utf-8")
|
|
|
|
if self.websocket and self.stream_id:
|
|
payload = {
|
|
"event": "media",
|
|
"media": {"payload": audio_payload},
|
|
}
|
|
if self.stream_handler.phone_mode:
|
|
payload["streamSid"] = self.stream_id
|
|
await self.websocket.send_json(payload)
|
|
|
|
await asyncio.sleep(0.02)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.debug("Emit loop cancelled")
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
logger.debug("Error in emit loop: %s", e)
|