diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fe3e1bb..f0a0c6d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: - name: Run linters run: | pip install ruff pyright - pip install -e . + pip install -e .[dev] ruff check . ruff format --check --diff . pyright @@ -35,5 +35,5 @@ jobs: run: | python -m pip install -U pip pip install .[dev] - python -m pytest -s test + python -m pytest --capture=no shell: bash diff --git a/backend/fastrtc/pause_detection/silero.py b/backend/fastrtc/pause_detection/silero.py index d6012b0..51ff55f 100644 --- a/backend/fastrtc/pause_detection/silero.py +++ b/backend/fastrtc/pause_detection/silero.py @@ -8,7 +8,7 @@ import numpy as np from huggingface_hub import hf_hub_download from numpy.typing import NDArray -from ..utils import AudioChunk +from ..utils import AudioChunk, audio_to_float32 from .protocol import PauseDetectionModel logger = logging.getLogger(__name__) @@ -274,8 +274,7 @@ class SileroVADModel: sampling_rate, audio_ = audio logger.debug("VAD audio shape input: %s", audio_.shape) try: - if audio_.dtype != np.float32: - audio_ = audio_.astype(np.float32) / 32768.0 + audio_ = audio_to_float32(audio_) sr = 16000 if sr != sampling_rate: try: diff --git a/backend/fastrtc/reply_on_stopwords.py b/backend/fastrtc/reply_on_stopwords.py index 723e61f..e0b7c4e 100644 --- a/backend/fastrtc/reply_on_stopwords.py +++ b/backend/fastrtc/reply_on_stopwords.py @@ -161,7 +161,7 @@ class ReplyOnStopWords(ReplyOnPause): if duration >= self.algo_options.audio_chunk_duration: if not state.stop_word_detected: - audio_f32 = audio_to_float32((sampling_rate, audio)) + audio_f32 = audio_to_float32(audio) audio_rs = librosa.resample( audio_f32, orig_sr=sampling_rate, target_sr=16000 ) diff --git a/backend/fastrtc/speech_to_text/stt_.py b/backend/fastrtc/speech_to_text/stt_.py index f8d7a6c..c4c31d8 100644 --- a/backend/fastrtc/speech_to_text/stt_.py +++ b/backend/fastrtc/speech_to_text/stt_.py @@ -32,8 +32,7 @@ class MoonshineSTT(STTModel): def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str: sr, audio_np = audio # type: ignore - if audio_np.dtype == np.int16: - audio_np = audio_to_float32(audio) + audio_np = audio_to_float32(audio_np) if sr != 16000: audio_np: NDArray[np.float32] = librosa.resample( audio_np, orig_sr=sr, target_sr=16000 diff --git a/backend/fastrtc/utils.py b/backend/fastrtc/utils.py index 22a7c6c..93b5728 100644 --- a/backend/fastrtc/utils.py +++ b/backend/fastrtc/utils.py @@ -7,6 +7,7 @@ import json import logging import tempfile import traceback +import warnings from collections.abc import Callable, Coroutine from contextvars import ContextVar from dataclasses import dataclass @@ -211,7 +212,7 @@ async def player_worker_decode( first_sample_rate = sample_rate if format == "s16": - audio_array = audio_to_float32((sample_rate, audio_array)) + audio_array = audio_to_float32(audio_array) if first_sample_rate != sample_rate: audio_array = librosa.resample( @@ -319,17 +320,15 @@ def audio_to_file(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str: def audio_to_float32( - audio: tuple[int, NDArray[np.int16 | np.float32]], + audio: NDArray[np.int16 | np.float32] | tuple[int, NDArray[np.int16 | np.float32]], ) -> NDArray[np.float32]: """ Convert an audio tuple containing sample rate (int16) and numpy array data to float32. Parameters ---------- - audio : tuple[int, np.ndarray] - A tuple containing: - - sample_rate (int): The audio sample rate in Hz - - data (np.ndarray): The audio data as a numpy array + audio : np.ndarray + The audio data as a numpy array Returns ------- @@ -338,26 +337,39 @@ def audio_to_float32( Example ------- - >>> sample_rate = 44100 >>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples - >>> audio_tuple = (sample_rate, audio_data) - >>> audio_float32 = audio_to_float32(audio_tuple) + >>> audio_float32 = audio_to_float32(audio_data) """ - return audio[1].astype(np.float32) / 32768.0 + if isinstance(audio, tuple): + warnings.warn( + UserWarning( + "Passing a (sr, audio) tuple to audio_to_float32() is deprecated " + "and will be removed in a future release. Pass only the audio array." + ), + stacklevel=2, # So that the warning points to the user's code + ) + _sr, audio = audio + + if audio.dtype == np.int16: + # Divide by 32768.0 so that the values are in the range [-1.0, 1.0). + # 1.0 can actually never be reached because the int16 range is [-32768, 32767]. + return audio.astype(np.float32) / 32768.0 + elif audio.dtype == np.float32: + return audio # type: ignore + else: + raise TypeError(f"Unsupported audio data type: {audio.dtype}") def audio_to_int16( - audio: tuple[int, NDArray[np.int16 | np.float32]], + audio: NDArray[np.int16 | np.float32] | tuple[int, NDArray[np.int16 | np.float32]], ) -> NDArray[np.int16]: """ Convert an audio tuple containing sample rate and numpy array data to int16. Parameters ---------- - audio : tuple[int, np.ndarray] - A tuple containing: - - sample_rate (int): The audio sample rate in Hz - - data (np.ndarray): The audio data as a numpy array + audio : np.ndarray + The audio data as a numpy array Returns ------- @@ -366,18 +378,27 @@ def audio_to_int16( Example ------- - >>> sample_rate = 44100 >>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples - >>> audio_tuple = (sample_rate, audio_data) - >>> audio_int16 = audio_to_int16(audio_tuple) + >>> audio_int16 = audio_to_int16(audio_data) """ - if audio[1].dtype == np.int16: - return audio[1] # type: ignore - elif audio[1].dtype == np.float32: - # Convert float32 to int16 by scaling to the int16 range - return (audio[1] * 32767.0).astype(np.int16) + if isinstance(audio, tuple): + warnings.warn( + UserWarning( + "Passing a (sr, audio) tuple to audio_to_float32() is deprecated " + "and will be removed in a future release. Pass only the audio array." + ), + stacklevel=2, # So that the warning points to the user's code + ) + _sr, audio = audio + + if audio.dtype == np.int16: + return audio # type: ignore + elif audio.dtype == np.float32: + # Convert float32 to int16 by scaling to the int16 range. + # Multiply by 32767 and not 32768 so that int16 doesn't overflow. + return (audio * 32767.0).astype(np.int16) else: - raise TypeError(f"Unsupported audio data type: {audio[1].dtype}") + raise TypeError(f"Unsupported audio data type: {audio.dtype}") def aggregate_bytes_to_16bit(chunks_iterator): diff --git a/backend/fastrtc/websocket.py b/backend/fastrtc/websocket.py index 4ee1e87..48f74e4 100644 --- a/backend/fastrtc/websocket.py +++ b/backend/fastrtc/websocket.py @@ -12,7 +12,14 @@ from fastapi import WebSocket from fastapi.websockets import WebSocketDisconnect, WebSocketState from .tracks import AsyncStreamHandler, StreamHandlerImpl -from .utils import AdditionalOutputs, CloseStream, DataChannel, split_output +from .utils import ( + AdditionalOutputs, + CloseStream, + DataChannel, + audio_to_float32, + audio_to_int16, + split_output, +) class WebSocketDataChannel(DataChannel): @@ -31,14 +38,12 @@ def convert_to_mulaw( audio_data: np.ndarray, original_rate: int, target_rate: int ) -> bytes: """Convert audio data to 8kHz mu-law format""" - - if audio_data.dtype != np.float32: - audio_data = audio_data.astype(np.float32) / 32768.0 + audio_data = audio_to_float32(audio_data) if original_rate != target_rate: audio_data = librosa.resample(audio_data, orig_sr=original_rate, target_sr=8000) - audio_data = (audio_data * 32768).astype(np.int16) + audio_data = audio_to_int16(audio_data) return audioop.lin2ulaw(audio_data, 2) # type: ignore @@ -122,14 +127,13 @@ class WebSocketHandler: ) if self.stream_handler.input_sample_rate != 8000: - audio_array = audio_array.astype(np.float32) / 32768.0 + audio_array = audio_to_float32(audio_array) audio_array = librosa.resample( audio_array, orig_sr=8000, target_sr=self.stream_handler.input_sample_rate, ) - audio_array = (audio_array * 32768).astype(np.int16) - + audio_array = audio_to_int16(audio_array) try: if isinstance(self.stream_handler, AsyncStreamHandler): await self.stream_handler.receive( diff --git a/demo/nextjs_voice_chat/backend/server.py b/demo/nextjs_voice_chat/backend/server.py index f0a1eb1..0dec077 100644 --- a/demo/nextjs_voice_chat/backend/server.py +++ b/demo/nextjs_voice_chat/backend/server.py @@ -1,6 +1,6 @@ import fastapi from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions -from fastrtc.utils import audio_to_bytes +from fastrtc.utils import audio_to_bytes, audio_to_float32 from openai import OpenAI import logging import time @@ -78,8 +78,8 @@ def echo(audio): ) for audio_chunk in audio_stream: - audio_array = ( - np.frombuffer(audio_chunk, dtype=np.int16).astype(np.float32) / 32768.0 + audio_array = audio_to_float32( + np.frombuffer(audio_chunk, dtype=np.int16) ) yield (24000, audio_array) diff --git a/pyproject.toml b/pyproject.toml index 0d0b91f..314a5b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ artifacts = ["/backend/fastrtc/templates", "*.pyi"] packages = ["/backend/fastrtc"] [tool.pytest.ini_options] +testpaths = ["test/"] asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..61076db --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,61 @@ +import numpy as np +import pytest +from fastrtc.utils import audio_to_float32, audio_to_int16 + + +def test_audio_to_float32_valid_int16(): + audio = np.array([-32768, 0, 32767], dtype=np.int16) + expected = np.array([-1.0, 0.0, 32767 / 32768.0], dtype=np.float32) + result = audio_to_float32(audio) + np.testing.assert_array_almost_equal(result, expected) + + +def test_audio_to_float32_valid_float32(): + audio = np.array([-1.0, 0.0, 1.0], dtype=np.float32) + result = audio_to_float32(audio) + np.testing.assert_array_equal(result, audio) + + +def test_audio_to_float32_empty_array(): + audio = np.array([], dtype=np.int16) + result = audio_to_float32(audio) + np.testing.assert_array_equal(result, np.array([], dtype=np.float32)) + + +def test_audio_to_float32_invalid_dtype(): + audio = np.array([1, 2, 3], dtype=np.int32) + with pytest.raises(TypeError, match="Unsupported audio data type"): + audio_to_float32(audio) # type: ignore + + +def test_audio_to_int16_valid_float32(): + audio = np.array([-1.0, 0.0, 1.0], dtype=np.float32) + expected = np.array([-32767, 0, 32767], dtype=np.int16) + result = audio_to_int16(audio) + np.testing.assert_array_equal(result, expected) + + +def test_audio_to_int16_valid_int16(): + audio = np.array([-32768, 0, 32767], dtype=np.int16) + result = audio_to_int16(audio) + np.testing.assert_array_equal(result, audio) + + +def test_audio_to_int16_empty_array(): + audio = np.array([], dtype=np.float32) + result = audio_to_int16(audio) + np.testing.assert_array_equal(result, np.array([], dtype=np.int16)) + + +def test_audio_to_int16_invalid_dtype(): + audio = np.array([1, 2, 3], dtype=np.int32) + with pytest.raises(TypeError, match="Unsupported audio data type"): + audio_to_int16(audio) # type: ignore + + +def test_legacy_arguments(): + result = audio_to_float32((16000, np.zeros(10, dtype=np.int16))) + np.testing.assert_array_equal(result, np.zeros(10, dtype=np.float32)) + + result = audio_to_int16((16000, np.zeros(10, dtype=np.float32))) + np.testing.assert_array_equal(result, np.zeros(10, dtype=np.int16))