mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 01:49:23 +08:00
sync code of fastrtc, add text support through datachannel, fix safari connect problem support chat without camera or mic
456 lines
13 KiB
Python
456 lines
13 KiB
Python
import asyncio
|
|
import fractions
|
|
import functools
|
|
import inspect
|
|
import io
|
|
import json
|
|
import logging
|
|
import tempfile
|
|
import traceback
|
|
from contextvars import ContextVar
|
|
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
|
|
|
|
import av
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
from pydub import AudioSegment
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
AUDIO_PTIME = 0.020
|
|
|
|
|
|
class Message(TypedDict):
|
|
type: str
|
|
data: Any
|
|
|
|
class AudioChunk(TypedDict):
|
|
start: int
|
|
end: int
|
|
|
|
|
|
class AdditionalOutputs:
|
|
def __init__(self, *args) -> None:
|
|
self.args = args
|
|
|
|
|
|
class DataChannel(Protocol):
|
|
def send(self, message: str) -> None: ...
|
|
|
|
|
|
def create_message(
|
|
type: Literal[
|
|
"send_input",
|
|
"fetch_output",
|
|
"stopword",
|
|
"error",
|
|
"warning",
|
|
"log",
|
|
],
|
|
data: list[Any] | str,
|
|
) -> str:
|
|
return json.dumps({"type": type, "data": data})
|
|
|
|
|
|
current_channel: ContextVar[DataChannel | None] = ContextVar(
|
|
"current_channel", default=None
|
|
)
|
|
|
|
|
|
def _send_log(message: str, type: str) -> None:
|
|
async def _send(channel: DataChannel) -> None:
|
|
channel.send(
|
|
json.dumps(
|
|
{
|
|
"type": type,
|
|
"message": message,
|
|
}
|
|
)
|
|
)
|
|
|
|
if channel := current_channel.get():
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
asyncio.run_coroutine_threadsafe(_send(channel), loop)
|
|
except RuntimeError:
|
|
asyncio.run(_send(channel))
|
|
|
|
|
|
def Warning( # noqa: N802
|
|
message: str = "Warning issued.",
|
|
):
|
|
"""
|
|
Send a warning message that is deplayed in the UI of the application.
|
|
|
|
Parameters
|
|
----------
|
|
audio : str
|
|
The warning message to send
|
|
|
|
Returns
|
|
-------
|
|
None
|
|
"""
|
|
_send_log(message, "warning")
|
|
|
|
|
|
class WebRTCError(Exception):
|
|
def __init__(self, message: str) -> None:
|
|
super().__init__(message)
|
|
_send_log(message, "error")
|
|
|
|
|
|
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
|
if isinstance(data, AdditionalOutputs):
|
|
return None, data
|
|
if isinstance(data, tuple):
|
|
# handle the bare audio case
|
|
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
|
|
return data, None
|
|
if not len(data) == 2:
|
|
raise ValueError(
|
|
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
|
)
|
|
if not isinstance(data[-1], AdditionalOutputs):
|
|
raise ValueError(
|
|
"The last element of the tuple must be an instance of AdditionalOutputs."
|
|
)
|
|
return data[0], cast(AdditionalOutputs, data[1])
|
|
return data, None
|
|
|
|
|
|
async def player_worker_decode(
|
|
next_frame: Callable,
|
|
queue: asyncio.Queue,
|
|
thread_quit: asyncio.Event,
|
|
channel: Callable[[], DataChannel | None] | None,
|
|
set_additional_outputs: Callable | None,
|
|
quit_on_none: bool = False,
|
|
sample_rate: int = 48000,
|
|
frame_size: int = int(48000 * AUDIO_PTIME),
|
|
):
|
|
audio_samples = 0
|
|
audio_time_base = fractions.Fraction(1, sample_rate)
|
|
audio_resampler = av.AudioResampler( # type: ignore
|
|
format="s16",
|
|
layout="stereo",
|
|
rate=sample_rate,
|
|
frame_size=frame_size,
|
|
)
|
|
|
|
while not thread_quit.is_set():
|
|
try:
|
|
# Get next frame
|
|
frame, outputs = split_output(
|
|
await asyncio.wait_for(next_frame(), timeout=60)
|
|
)
|
|
if (
|
|
isinstance(outputs, AdditionalOutputs)
|
|
and set_additional_outputs
|
|
and channel
|
|
and channel()
|
|
):
|
|
set_additional_outputs(outputs)
|
|
cast(DataChannel, channel()).send(create_message("fetch_output", []))
|
|
|
|
if frame is None:
|
|
if quit_on_none:
|
|
await queue.put(None)
|
|
break
|
|
continue
|
|
|
|
if not isinstance(frame, tuple) and not isinstance(frame[1], np.ndarray):
|
|
raise WebRTCError(
|
|
"The frame must be a tuple containing a sample rate and a numpy array."
|
|
)
|
|
|
|
if len(frame) == 2:
|
|
sample_rate, audio_array = frame
|
|
layout = "mono"
|
|
elif len(frame) == 3:
|
|
sample_rate, audio_array, layout = frame
|
|
|
|
logger.debug(
|
|
"received array with shape %s sample rate %s layout %s",
|
|
audio_array.shape, # type: ignore
|
|
sample_rate,
|
|
layout, # type: ignore
|
|
)
|
|
format = "s16" if audio_array.dtype == "int16" else "fltp" # type: ignore
|
|
|
|
if audio_array.ndim == 1:
|
|
audio_array = audio_array.reshape(1, -1)
|
|
|
|
# Convert to audio frame and resample
|
|
# This runs in the same timeout context
|
|
frame = av.AudioFrame.from_ndarray( # type: ignore
|
|
audio_array, # type: ignore
|
|
format=format,
|
|
layout=layout, # type: ignore
|
|
)
|
|
frame.sample_rate = sample_rate
|
|
|
|
for processed_frame in audio_resampler.resample(frame):
|
|
processed_frame.pts = audio_samples
|
|
processed_frame.time_base = audio_time_base
|
|
audio_samples += processed_frame.samples
|
|
await queue.put(processed_frame)
|
|
|
|
except (TimeoutError, asyncio.TimeoutError):
|
|
logger.warning(
|
|
"Timeout in frame processing cycle after %s seconds - resetting", 60
|
|
)
|
|
continue
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
exec = traceback.format_exc()
|
|
print("traceback %s", exec)
|
|
print("Error processing frame: %s", str(e))
|
|
if isinstance(e, WebRTCError):
|
|
raise e
|
|
else:
|
|
continue
|
|
|
|
|
|
def audio_to_bytes(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> bytes:
|
|
"""
|
|
Convert an audio tuple containing sample rate and numpy array data into bytes.
|
|
|
|
Parameters
|
|
----------
|
|
audio : tuple[int, np.ndarray]
|
|
A tuple containing:
|
|
- sample_rate (int): The audio sample rate in Hz
|
|
- data (np.ndarray): The audio data as a numpy array
|
|
|
|
Returns
|
|
-------
|
|
bytes
|
|
The audio data encoded as bytes, suitable for transmission or storage
|
|
|
|
Example
|
|
-------
|
|
>>> sample_rate = 44100
|
|
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
|
|
>>> audio_tuple = (sample_rate, audio_data)
|
|
>>> audio_bytes = audio_to_bytes(audio_tuple)
|
|
"""
|
|
audio_buffer = io.BytesIO()
|
|
segment = AudioSegment(
|
|
audio[1].tobytes(),
|
|
frame_rate=audio[0],
|
|
sample_width=audio[1].dtype.itemsize,
|
|
channels=1,
|
|
)
|
|
segment.export(audio_buffer, format="mp3")
|
|
return audio_buffer.getvalue()
|
|
|
|
|
|
def audio_to_file(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str:
|
|
"""
|
|
Save an audio tuple containing sample rate and numpy array data to a file.
|
|
|
|
Parameters
|
|
----------
|
|
audio : tuple[int, np.ndarray]
|
|
A tuple containing:
|
|
- sample_rate (int): The audio sample rate in Hz
|
|
- data (np.ndarray): The audio data as a numpy array
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The path to the saved audio file
|
|
|
|
Example
|
|
-------
|
|
>>> sample_rate = 44100
|
|
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
|
|
>>> audio_tuple = (sample_rate, audio_data)
|
|
>>> file_path = audio_to_file(audio_tuple)
|
|
>>> print(f"Audio saved to: {file_path}")
|
|
"""
|
|
bytes_ = audio_to_bytes(audio)
|
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
|
|
f.write(bytes_)
|
|
return f.name
|
|
|
|
|
|
def audio_to_float32(
|
|
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
|
) -> NDArray[np.float32]:
|
|
"""
|
|
Convert an audio tuple containing sample rate (int16) and numpy array data to float32.
|
|
|
|
Parameters
|
|
----------
|
|
audio : tuple[int, np.ndarray]
|
|
A tuple containing:
|
|
- sample_rate (int): The audio sample rate in Hz
|
|
- data (np.ndarray): The audio data as a numpy array
|
|
|
|
Returns
|
|
-------
|
|
np.ndarray
|
|
The audio data as a numpy array with dtype float32
|
|
|
|
Example
|
|
-------
|
|
>>> sample_rate = 44100
|
|
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
|
|
>>> audio_tuple = (sample_rate, audio_data)
|
|
>>> audio_float32 = audio_to_float32(audio_tuple)
|
|
"""
|
|
return audio[1].astype(np.float32) / 32768.0
|
|
|
|
|
|
def audio_to_int16(
|
|
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
|
) -> NDArray[np.int16]:
|
|
"""
|
|
Convert an audio tuple containing sample rate and numpy array data to int16.
|
|
|
|
Parameters
|
|
----------
|
|
audio : tuple[int, np.ndarray]
|
|
A tuple containing:
|
|
- sample_rate (int): The audio sample rate in Hz
|
|
- data (np.ndarray): The audio data as a numpy array
|
|
|
|
Returns
|
|
-------
|
|
np.ndarray
|
|
The audio data as a numpy array with dtype int16
|
|
|
|
Example
|
|
-------
|
|
>>> sample_rate = 44100
|
|
>>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples
|
|
>>> audio_tuple = (sample_rate, audio_data)
|
|
>>> audio_int16 = audio_to_int16(audio_tuple)
|
|
"""
|
|
if audio[1].dtype == np.int16:
|
|
return audio[1] # type: ignore
|
|
elif audio[1].dtype == np.float32:
|
|
# Convert float32 to int16 by scaling to the int16 range
|
|
return (audio[1] * 32767.0).astype(np.int16)
|
|
else:
|
|
raise TypeError(f"Unsupported audio data type: {audio[1].dtype}")
|
|
|
|
|
|
def aggregate_bytes_to_16bit(chunks_iterator):
|
|
"""
|
|
Aggregate bytes to 16-bit audio samples.
|
|
|
|
This function takes an iterator of chunks and aggregates them into 16-bit audio samples.
|
|
It handles incomplete samples and combines them with the next chunk.
|
|
|
|
Parameters
|
|
----------
|
|
chunks_iterator : Iterator[bytes]
|
|
An iterator of byte chunks to aggregate
|
|
|
|
Returns
|
|
-------
|
|
Iterator[NDArray[np.int16]]
|
|
"""
|
|
leftover = b""
|
|
for chunk in chunks_iterator:
|
|
current_bytes = leftover + chunk
|
|
|
|
n_complete_samples = len(current_bytes) // 2
|
|
bytes_to_process = n_complete_samples * 2
|
|
|
|
to_process = current_bytes[:bytes_to_process]
|
|
leftover = current_bytes[bytes_to_process:]
|
|
|
|
if to_process:
|
|
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
|
|
yield audio_array
|
|
|
|
|
|
async def async_aggregate_bytes_to_16bit(chunks_iterator):
|
|
"""
|
|
Aggregate bytes to 16-bit audio samples.
|
|
|
|
This function takes an iterator of chunks and aggregates them into 16-bit audio samples.
|
|
It handles incomplete samples and combines them with the next chunk.
|
|
|
|
Parameters
|
|
----------
|
|
chunks_iterator : Iterator[bytes]
|
|
An iterator of byte chunks to aggregate
|
|
|
|
Returns
|
|
-------
|
|
Iterator[NDArray[np.int16]]
|
|
An iterator of 16-bit audio samples
|
|
"""
|
|
leftover = b""
|
|
|
|
async for chunk in chunks_iterator:
|
|
current_bytes = leftover + chunk
|
|
|
|
n_complete_samples = len(current_bytes) // 2
|
|
bytes_to_process = n_complete_samples * 2
|
|
|
|
to_process = current_bytes[:bytes_to_process]
|
|
leftover = current_bytes[bytes_to_process:]
|
|
|
|
if to_process:
|
|
audio_array = np.frombuffer(to_process, dtype=np.int16).reshape(1, -1)
|
|
yield audio_array
|
|
|
|
|
|
def webrtc_error_handler(func):
|
|
"""Decorator to catch exceptions and raise WebRTCError with stacktrace."""
|
|
|
|
@functools.wraps(func)
|
|
async def async_wrapper(*args, **kwargs):
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
if isinstance(e, WebRTCError):
|
|
raise e
|
|
else:
|
|
raise WebRTCError(str(e)) from e
|
|
|
|
@functools.wraps(func)
|
|
def sync_wrapper(*args, **kwargs):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
if isinstance(e, WebRTCError):
|
|
raise e
|
|
else:
|
|
raise WebRTCError(str(e)) from e
|
|
|
|
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper
|
|
|
|
|
|
async def wait_for_item(queue: asyncio.Queue, timeout: float = 0.1) -> Any:
|
|
"""
|
|
Wait for an item from an asyncio.Queue with a timeout.
|
|
|
|
This function attempts to retrieve an item from the queue using asyncio.wait_for.
|
|
If the timeout is reached, it returns None.
|
|
|
|
This is useful to avoid blocking `emit` when the queue is empty.
|
|
"""
|
|
|
|
try:
|
|
return await asyncio.wait_for(queue.get(), timeout=timeout)
|
|
except (TimeoutError, asyncio.TimeoutError):
|
|
return None
|
|
|
|
def parse_json_safely(str: str):
|
|
try:
|
|
result = json.loads(str)
|
|
return result, None
|
|
except json.JSONDecodeError as e:
|
|
print(f"JSON解析错误: {e.msg}")
|
|
return None, e |