Fix audio type conversion (#259)

* Fix conversion between audio dtypes

* Run Pytest in CI

* Add pytest tests path in pyproject.toml

* Fix usages

* Use other PR's test format (more or less)

* Support legacy arguments

* Fix pyproject.toml and test location

* Omit `test` arg in CI, given by pyproject.toml

---------

Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
Václav Volhejn
2025-04-09 16:00:23 +02:00
committed by GitHub
parent fdf6bea1c6
commit 58bccddd93
9 changed files with 128 additions and 43 deletions

View File

@@ -13,7 +13,7 @@ jobs:
- name: Run linters - name: Run linters
run: | run: |
pip install ruff pyright pip install ruff pyright
pip install -e . pip install -e .[dev]
ruff check . ruff check .
ruff format --check --diff . ruff format --check --diff .
pyright pyright
@@ -35,5 +35,5 @@ jobs:
run: | run: |
python -m pip install -U pip python -m pip install -U pip
pip install .[dev] pip install .[dev]
python -m pytest -s test python -m pytest --capture=no
shell: bash shell: bash

View File

@@ -8,7 +8,7 @@ import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from numpy.typing import NDArray from numpy.typing import NDArray
from ..utils import AudioChunk from ..utils import AudioChunk, audio_to_float32
from .protocol import PauseDetectionModel from .protocol import PauseDetectionModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -274,8 +274,7 @@ class SileroVADModel:
sampling_rate, audio_ = audio sampling_rate, audio_ = audio
logger.debug("VAD audio shape input: %s", audio_.shape) logger.debug("VAD audio shape input: %s", audio_.shape)
try: try:
if audio_.dtype != np.float32: audio_ = audio_to_float32(audio_)
audio_ = audio_.astype(np.float32) / 32768.0
sr = 16000 sr = 16000
if sr != sampling_rate: if sr != sampling_rate:
try: try:

View File

@@ -161,7 +161,7 @@ class ReplyOnStopWords(ReplyOnPause):
if duration >= self.algo_options.audio_chunk_duration: if duration >= self.algo_options.audio_chunk_duration:
if not state.stop_word_detected: if not state.stop_word_detected:
audio_f32 = audio_to_float32((sampling_rate, audio)) audio_f32 = audio_to_float32(audio)
audio_rs = librosa.resample( audio_rs = librosa.resample(
audio_f32, orig_sr=sampling_rate, target_sr=16000 audio_f32, orig_sr=sampling_rate, target_sr=16000
) )

View File

@@ -32,8 +32,7 @@ class MoonshineSTT(STTModel):
def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str: def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str:
sr, audio_np = audio # type: ignore sr, audio_np = audio # type: ignore
if audio_np.dtype == np.int16: audio_np = audio_to_float32(audio_np)
audio_np = audio_to_float32(audio)
if sr != 16000: if sr != 16000:
audio_np: NDArray[np.float32] = librosa.resample( audio_np: NDArray[np.float32] = librosa.resample(
audio_np, orig_sr=sr, target_sr=16000 audio_np, orig_sr=sr, target_sr=16000

View File

@@ -7,6 +7,7 @@ import json
import logging import logging
import tempfile import tempfile
import traceback import traceback
import warnings
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from contextvars import ContextVar from contextvars import ContextVar
from dataclasses import dataclass from dataclasses import dataclass
@@ -211,7 +212,7 @@ async def player_worker_decode(
first_sample_rate = sample_rate first_sample_rate = sample_rate
if format == "s16": if format == "s16":
audio_array = audio_to_float32((sample_rate, audio_array)) audio_array = audio_to_float32(audio_array)
if first_sample_rate != sample_rate: if first_sample_rate != sample_rate:
audio_array = librosa.resample( audio_array = librosa.resample(
@@ -319,17 +320,15 @@ def audio_to_file(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str:
def audio_to_float32( def audio_to_float32(
audio: tuple[int, NDArray[np.int16 | np.float32]], audio: NDArray[np.int16 | np.float32] | tuple[int, NDArray[np.int16 | np.float32]],
) -> NDArray[np.float32]: ) -> NDArray[np.float32]:
""" """
Convert an audio tuple containing sample rate (int16) and numpy array data to float32. Convert an audio tuple containing sample rate (int16) and numpy array data to float32.
Parameters Parameters
---------- ----------
audio : tuple[int, np.ndarray] audio : np.ndarray
A tuple containing: The audio data as a numpy array
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
Returns Returns
------- -------
@@ -338,26 +337,39 @@ def audio_to_float32(
Example Example
------- -------
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples >>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data) >>> audio_float32 = audio_to_float32(audio_data)
>>> audio_float32 = audio_to_float32(audio_tuple)
""" """
return audio[1].astype(np.float32) / 32768.0 if isinstance(audio, tuple):
warnings.warn(
UserWarning(
"Passing a (sr, audio) tuple to audio_to_float32() is deprecated "
"and will be removed in a future release. Pass only the audio array."
),
stacklevel=2, # So that the warning points to the user's code
)
_sr, audio = audio
if audio.dtype == np.int16:
# Divide by 32768.0 so that the values are in the range [-1.0, 1.0).
# 1.0 can actually never be reached because the int16 range is [-32768, 32767].
return audio.astype(np.float32) / 32768.0
elif audio.dtype == np.float32:
return audio # type: ignore
else:
raise TypeError(f"Unsupported audio data type: {audio.dtype}")
def audio_to_int16( def audio_to_int16(
audio: tuple[int, NDArray[np.int16 | np.float32]], audio: NDArray[np.int16 | np.float32] | tuple[int, NDArray[np.int16 | np.float32]],
) -> NDArray[np.int16]: ) -> NDArray[np.int16]:
""" """
Convert an audio tuple containing sample rate and numpy array data to int16. Convert an audio tuple containing sample rate and numpy array data to int16.
Parameters Parameters
---------- ----------
audio : tuple[int, np.ndarray] audio : np.ndarray
A tuple containing: The audio data as a numpy array
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
Returns Returns
------- -------
@@ -366,18 +378,27 @@ def audio_to_int16(
Example Example
------- -------
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples >>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data) >>> audio_int16 = audio_to_int16(audio_data)
>>> audio_int16 = audio_to_int16(audio_tuple)
""" """
if audio[1].dtype == np.int16: if isinstance(audio, tuple):
return audio[1] # type: ignore warnings.warn(
elif audio[1].dtype == np.float32: UserWarning(
# Convert float32 to int16 by scaling to the int16 range "Passing a (sr, audio) tuple to audio_to_float32() is deprecated "
return (audio[1] * 32767.0).astype(np.int16) "and will be removed in a future release. Pass only the audio array."
),
stacklevel=2, # So that the warning points to the user's code
)
_sr, audio = audio
if audio.dtype == np.int16:
return audio # type: ignore
elif audio.dtype == np.float32:
# Convert float32 to int16 by scaling to the int16 range.
# Multiply by 32767 and not 32768 so that int16 doesn't overflow.
return (audio * 32767.0).astype(np.int16)
else: else:
raise TypeError(f"Unsupported audio data type: {audio[1].dtype}") raise TypeError(f"Unsupported audio data type: {audio.dtype}")
def aggregate_bytes_to_16bit(chunks_iterator): def aggregate_bytes_to_16bit(chunks_iterator):

View File

@@ -12,7 +12,14 @@ from fastapi import WebSocket
from fastapi.websockets import WebSocketDisconnect, WebSocketState from fastapi.websockets import WebSocketDisconnect, WebSocketState
from .tracks import AsyncStreamHandler, StreamHandlerImpl from .tracks import AsyncStreamHandler, StreamHandlerImpl
from .utils import AdditionalOutputs, CloseStream, DataChannel, split_output from .utils import (
AdditionalOutputs,
CloseStream,
DataChannel,
audio_to_float32,
audio_to_int16,
split_output,
)
class WebSocketDataChannel(DataChannel): class WebSocketDataChannel(DataChannel):
@@ -31,14 +38,12 @@ def convert_to_mulaw(
audio_data: np.ndarray, original_rate: int, target_rate: int audio_data: np.ndarray, original_rate: int, target_rate: int
) -> bytes: ) -> bytes:
"""Convert audio data to 8kHz mu-law format""" """Convert audio data to 8kHz mu-law format"""
audio_data = audio_to_float32(audio_data)
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32) / 32768.0
if original_rate != target_rate: if original_rate != target_rate:
audio_data = librosa.resample(audio_data, orig_sr=original_rate, target_sr=8000) audio_data = librosa.resample(audio_data, orig_sr=original_rate, target_sr=8000)
audio_data = (audio_data * 32768).astype(np.int16) audio_data = audio_to_int16(audio_data)
return audioop.lin2ulaw(audio_data, 2) # type: ignore return audioop.lin2ulaw(audio_data, 2) # type: ignore
@@ -122,14 +127,13 @@ class WebSocketHandler:
) )
if self.stream_handler.input_sample_rate != 8000: if self.stream_handler.input_sample_rate != 8000:
audio_array = audio_array.astype(np.float32) / 32768.0 audio_array = audio_to_float32(audio_array)
audio_array = librosa.resample( audio_array = librosa.resample(
audio_array, audio_array,
orig_sr=8000, orig_sr=8000,
target_sr=self.stream_handler.input_sample_rate, target_sr=self.stream_handler.input_sample_rate,
) )
audio_array = (audio_array * 32768).astype(np.int16) audio_array = audio_to_int16(audio_array)
try: try:
if isinstance(self.stream_handler, AsyncStreamHandler): if isinstance(self.stream_handler, AsyncStreamHandler):
await self.stream_handler.receive( await self.stream_handler.receive(

View File

@@ -1,6 +1,6 @@
import fastapi import fastapi
from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions
from fastrtc.utils import audio_to_bytes from fastrtc.utils import audio_to_bytes, audio_to_float32
from openai import OpenAI from openai import OpenAI
import logging import logging
import time import time
@@ -78,8 +78,8 @@ def echo(audio):
) )
for audio_chunk in audio_stream: for audio_chunk in audio_stream:
audio_array = ( audio_array = audio_to_float32(
np.frombuffer(audio_chunk, dtype=np.int16).astype(np.float32) / 32768.0 np.frombuffer(audio_chunk, dtype=np.int16)
) )
yield (24000, audio_array) yield (24000, audio_array)

View File

@@ -83,6 +83,7 @@ artifacts = ["/backend/fastrtc/templates", "*.pyi"]
packages = ["/backend/fastrtc"] packages = ["/backend/fastrtc"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = ["test/"]
asyncio_mode = "auto" asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function" asyncio_default_fixture_loop_scope = "function"

61
test/test_utils.py Normal file
View File

@@ -0,0 +1,61 @@
import numpy as np
import pytest
from fastrtc.utils import audio_to_float32, audio_to_int16
def test_audio_to_float32_valid_int16():
audio = np.array([-32768, 0, 32767], dtype=np.int16)
expected = np.array([-1.0, 0.0, 32767 / 32768.0], dtype=np.float32)
result = audio_to_float32(audio)
np.testing.assert_array_almost_equal(result, expected)
def test_audio_to_float32_valid_float32():
audio = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
result = audio_to_float32(audio)
np.testing.assert_array_equal(result, audio)
def test_audio_to_float32_empty_array():
audio = np.array([], dtype=np.int16)
result = audio_to_float32(audio)
np.testing.assert_array_equal(result, np.array([], dtype=np.float32))
def test_audio_to_float32_invalid_dtype():
audio = np.array([1, 2, 3], dtype=np.int32)
with pytest.raises(TypeError, match="Unsupported audio data type"):
audio_to_float32(audio) # type: ignore
def test_audio_to_int16_valid_float32():
audio = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
expected = np.array([-32767, 0, 32767], dtype=np.int16)
result = audio_to_int16(audio)
np.testing.assert_array_equal(result, expected)
def test_audio_to_int16_valid_int16():
audio = np.array([-32768, 0, 32767], dtype=np.int16)
result = audio_to_int16(audio)
np.testing.assert_array_equal(result, audio)
def test_audio_to_int16_empty_array():
audio = np.array([], dtype=np.float32)
result = audio_to_int16(audio)
np.testing.assert_array_equal(result, np.array([], dtype=np.int16))
def test_audio_to_int16_invalid_dtype():
audio = np.array([1, 2, 3], dtype=np.int32)
with pytest.raises(TypeError, match="Unsupported audio data type"):
audio_to_int16(audio) # type: ignore
def test_legacy_arguments():
result = audio_to_float32((16000, np.zeros(10, dtype=np.int16)))
np.testing.assert_array_equal(result, np.zeros(10, dtype=np.float32))
result = audio_to_int16((16000, np.zeros(10, dtype=np.float32)))
np.testing.assert_array_equal(result, np.zeros(10, dtype=np.int16))