mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Fix audio type conversion (#259)
* Fix conversion between audio dtypes * Run Pytest in CI * Add pytest tests path in pyproject.toml * Fix usages * Use other PR's test format (more or less) * Support legacy arguments * Fix pyproject.toml and test location * Omit `test` arg in CI, given by pyproject.toml --------- Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
|||||||
- name: Run linters
|
- name: Run linters
|
||||||
run: |
|
run: |
|
||||||
pip install ruff pyright
|
pip install ruff pyright
|
||||||
pip install -e .
|
pip install -e .[dev]
|
||||||
ruff check .
|
ruff check .
|
||||||
ruff format --check --diff .
|
ruff format --check --diff .
|
||||||
pyright
|
pyright
|
||||||
@@ -35,5 +35,5 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
python -m pip install -U pip
|
python -m pip install -U pip
|
||||||
pip install .[dev]
|
pip install .[dev]
|
||||||
python -m pytest -s test
|
python -m pytest --capture=no
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from ..utils import AudioChunk
|
from ..utils import AudioChunk, audio_to_float32
|
||||||
from .protocol import PauseDetectionModel
|
from .protocol import PauseDetectionModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -274,8 +274,7 @@ class SileroVADModel:
|
|||||||
sampling_rate, audio_ = audio
|
sampling_rate, audio_ = audio
|
||||||
logger.debug("VAD audio shape input: %s", audio_.shape)
|
logger.debug("VAD audio shape input: %s", audio_.shape)
|
||||||
try:
|
try:
|
||||||
if audio_.dtype != np.float32:
|
audio_ = audio_to_float32(audio_)
|
||||||
audio_ = audio_.astype(np.float32) / 32768.0
|
|
||||||
sr = 16000
|
sr = 16000
|
||||||
if sr != sampling_rate:
|
if sr != sampling_rate:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
|
|
||||||
if duration >= self.algo_options.audio_chunk_duration:
|
if duration >= self.algo_options.audio_chunk_duration:
|
||||||
if not state.stop_word_detected:
|
if not state.stop_word_detected:
|
||||||
audio_f32 = audio_to_float32((sampling_rate, audio))
|
audio_f32 = audio_to_float32(audio)
|
||||||
audio_rs = librosa.resample(
|
audio_rs = librosa.resample(
|
||||||
audio_f32, orig_sr=sampling_rate, target_sr=16000
|
audio_f32, orig_sr=sampling_rate, target_sr=16000
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -32,8 +32,7 @@ class MoonshineSTT(STTModel):
|
|||||||
|
|
||||||
def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str:
|
def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str:
|
||||||
sr, audio_np = audio # type: ignore
|
sr, audio_np = audio # type: ignore
|
||||||
if audio_np.dtype == np.int16:
|
audio_np = audio_to_float32(audio_np)
|
||||||
audio_np = audio_to_float32(audio)
|
|
||||||
if sr != 16000:
|
if sr != 16000:
|
||||||
audio_np: NDArray[np.float32] = librosa.resample(
|
audio_np: NDArray[np.float32] = librosa.resample(
|
||||||
audio_np, orig_sr=sr, target_sr=16000
|
audio_np, orig_sr=sr, target_sr=16000
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
|
import warnings
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable, Coroutine
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -211,7 +212,7 @@ async def player_worker_decode(
|
|||||||
first_sample_rate = sample_rate
|
first_sample_rate = sample_rate
|
||||||
|
|
||||||
if format == "s16":
|
if format == "s16":
|
||||||
audio_array = audio_to_float32((sample_rate, audio_array))
|
audio_array = audio_to_float32(audio_array)
|
||||||
|
|
||||||
if first_sample_rate != sample_rate:
|
if first_sample_rate != sample_rate:
|
||||||
audio_array = librosa.resample(
|
audio_array = librosa.resample(
|
||||||
@@ -319,17 +320,15 @@ def audio_to_file(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def audio_to_float32(
|
def audio_to_float32(
|
||||||
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
audio: NDArray[np.int16 | np.float32] | tuple[int, NDArray[np.int16 | np.float32]],
|
||||||
) -> NDArray[np.float32]:
|
) -> NDArray[np.float32]:
|
||||||
"""
|
"""
|
||||||
Convert an audio tuple containing sample rate (int16) and numpy array data to float32.
|
Convert an audio tuple containing sample rate (int16) and numpy array data to float32.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
audio : tuple[int, np.ndarray]
|
audio : np.ndarray
|
||||||
A tuple containing:
|
The audio data as a numpy array
|
||||||
- sample_rate (int): The audio sample rate in Hz
|
|
||||||
- data (np.ndarray): The audio data as a numpy array
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -338,26 +337,39 @@ def audio_to_float32(
|
|||||||
|
|
||||||
Example
|
Example
|
||||||
-------
|
-------
|
||||||
>>> sample_rate = 44100
|
|
||||||
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
|
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
|
||||||
>>> audio_tuple = (sample_rate, audio_data)
|
>>> audio_float32 = audio_to_float32(audio_data)
|
||||||
>>> audio_float32 = audio_to_float32(audio_tuple)
|
|
||||||
"""
|
"""
|
||||||
return audio[1].astype(np.float32) / 32768.0
|
if isinstance(audio, tuple):
|
||||||
|
warnings.warn(
|
||||||
|
UserWarning(
|
||||||
|
"Passing a (sr, audio) tuple to audio_to_float32() is deprecated "
|
||||||
|
"and will be removed in a future release. Pass only the audio array."
|
||||||
|
),
|
||||||
|
stacklevel=2, # So that the warning points to the user's code
|
||||||
|
)
|
||||||
|
_sr, audio = audio
|
||||||
|
|
||||||
|
if audio.dtype == np.int16:
|
||||||
|
# Divide by 32768.0 so that the values are in the range [-1.0, 1.0).
|
||||||
|
# 1.0 can actually never be reached because the int16 range is [-32768, 32767].
|
||||||
|
return audio.astype(np.float32) / 32768.0
|
||||||
|
elif audio.dtype == np.float32:
|
||||||
|
return audio # type: ignore
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported audio data type: {audio.dtype}")
|
||||||
|
|
||||||
|
|
||||||
def audio_to_int16(
|
def audio_to_int16(
|
||||||
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
audio: NDArray[np.int16 | np.float32] | tuple[int, NDArray[np.int16 | np.float32]],
|
||||||
) -> NDArray[np.int16]:
|
) -> NDArray[np.int16]:
|
||||||
"""
|
"""
|
||||||
Convert an audio tuple containing sample rate and numpy array data to int16.
|
Convert an audio tuple containing sample rate and numpy array data to int16.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
audio : tuple[int, np.ndarray]
|
audio : np.ndarray
|
||||||
A tuple containing:
|
The audio data as a numpy array
|
||||||
- sample_rate (int): The audio sample rate in Hz
|
|
||||||
- data (np.ndarray): The audio data as a numpy array
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -366,18 +378,27 @@ def audio_to_int16(
|
|||||||
|
|
||||||
Example
|
Example
|
||||||
-------
|
-------
|
||||||
>>> sample_rate = 44100
|
|
||||||
>>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples
|
>>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples
|
||||||
>>> audio_tuple = (sample_rate, audio_data)
|
>>> audio_int16 = audio_to_int16(audio_data)
|
||||||
>>> audio_int16 = audio_to_int16(audio_tuple)
|
|
||||||
"""
|
"""
|
||||||
if audio[1].dtype == np.int16:
|
if isinstance(audio, tuple):
|
||||||
return audio[1] # type: ignore
|
warnings.warn(
|
||||||
elif audio[1].dtype == np.float32:
|
UserWarning(
|
||||||
# Convert float32 to int16 by scaling to the int16 range
|
"Passing a (sr, audio) tuple to audio_to_float32() is deprecated "
|
||||||
return (audio[1] * 32767.0).astype(np.int16)
|
"and will be removed in a future release. Pass only the audio array."
|
||||||
|
),
|
||||||
|
stacklevel=2, # So that the warning points to the user's code
|
||||||
|
)
|
||||||
|
_sr, audio = audio
|
||||||
|
|
||||||
|
if audio.dtype == np.int16:
|
||||||
|
return audio # type: ignore
|
||||||
|
elif audio.dtype == np.float32:
|
||||||
|
# Convert float32 to int16 by scaling to the int16 range.
|
||||||
|
# Multiply by 32767 and not 32768 so that int16 doesn't overflow.
|
||||||
|
return (audio * 32767.0).astype(np.int16)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unsupported audio data type: {audio[1].dtype}")
|
raise TypeError(f"Unsupported audio data type: {audio.dtype}")
|
||||||
|
|
||||||
|
|
||||||
def aggregate_bytes_to_16bit(chunks_iterator):
|
def aggregate_bytes_to_16bit(chunks_iterator):
|
||||||
|
|||||||
@@ -12,7 +12,14 @@ from fastapi import WebSocket
|
|||||||
from fastapi.websockets import WebSocketDisconnect, WebSocketState
|
from fastapi.websockets import WebSocketDisconnect, WebSocketState
|
||||||
|
|
||||||
from .tracks import AsyncStreamHandler, StreamHandlerImpl
|
from .tracks import AsyncStreamHandler, StreamHandlerImpl
|
||||||
from .utils import AdditionalOutputs, CloseStream, DataChannel, split_output
|
from .utils import (
|
||||||
|
AdditionalOutputs,
|
||||||
|
CloseStream,
|
||||||
|
DataChannel,
|
||||||
|
audio_to_float32,
|
||||||
|
audio_to_int16,
|
||||||
|
split_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WebSocketDataChannel(DataChannel):
|
class WebSocketDataChannel(DataChannel):
|
||||||
@@ -31,14 +38,12 @@ def convert_to_mulaw(
|
|||||||
audio_data: np.ndarray, original_rate: int, target_rate: int
|
audio_data: np.ndarray, original_rate: int, target_rate: int
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Convert audio data to 8kHz mu-law format"""
|
"""Convert audio data to 8kHz mu-law format"""
|
||||||
|
audio_data = audio_to_float32(audio_data)
|
||||||
if audio_data.dtype != np.float32:
|
|
||||||
audio_data = audio_data.astype(np.float32) / 32768.0
|
|
||||||
|
|
||||||
if original_rate != target_rate:
|
if original_rate != target_rate:
|
||||||
audio_data = librosa.resample(audio_data, orig_sr=original_rate, target_sr=8000)
|
audio_data = librosa.resample(audio_data, orig_sr=original_rate, target_sr=8000)
|
||||||
|
|
||||||
audio_data = (audio_data * 32768).astype(np.int16)
|
audio_data = audio_to_int16(audio_data)
|
||||||
|
|
||||||
return audioop.lin2ulaw(audio_data, 2) # type: ignore
|
return audioop.lin2ulaw(audio_data, 2) # type: ignore
|
||||||
|
|
||||||
@@ -122,14 +127,13 @@ class WebSocketHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.stream_handler.input_sample_rate != 8000:
|
if self.stream_handler.input_sample_rate != 8000:
|
||||||
audio_array = audio_array.astype(np.float32) / 32768.0
|
audio_array = audio_to_float32(audio_array)
|
||||||
audio_array = librosa.resample(
|
audio_array = librosa.resample(
|
||||||
audio_array,
|
audio_array,
|
||||||
orig_sr=8000,
|
orig_sr=8000,
|
||||||
target_sr=self.stream_handler.input_sample_rate,
|
target_sr=self.stream_handler.input_sample_rate,
|
||||||
)
|
)
|
||||||
audio_array = (audio_array * 32768).astype(np.int16)
|
audio_array = audio_to_int16(audio_array)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(self.stream_handler, AsyncStreamHandler):
|
if isinstance(self.stream_handler, AsyncStreamHandler):
|
||||||
await self.stream_handler.receive(
|
await self.stream_handler.receive(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import fastapi
|
import fastapi
|
||||||
from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions
|
from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions
|
||||||
from fastrtc.utils import audio_to_bytes
|
from fastrtc.utils import audio_to_bytes, audio_to_float32
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@@ -78,8 +78,8 @@ def echo(audio):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for audio_chunk in audio_stream:
|
for audio_chunk in audio_stream:
|
||||||
audio_array = (
|
audio_array = audio_to_float32(
|
||||||
np.frombuffer(audio_chunk, dtype=np.int16).astype(np.float32) / 32768.0
|
np.frombuffer(audio_chunk, dtype=np.int16)
|
||||||
)
|
)
|
||||||
yield (24000, audio_array)
|
yield (24000, audio_array)
|
||||||
|
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ artifacts = ["/backend/fastrtc/templates", "*.pyi"]
|
|||||||
packages = ["/backend/fastrtc"]
|
packages = ["/backend/fastrtc"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["test/"]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
asyncio_default_fixture_loop_scope = "function"
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
|
|
||||||
|
|||||||
61
test/test_utils.py
Normal file
61
test/test_utils.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from fastrtc.utils import audio_to_float32, audio_to_int16
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_to_float32_valid_int16():
|
||||||
|
audio = np.array([-32768, 0, 32767], dtype=np.int16)
|
||||||
|
expected = np.array([-1.0, 0.0, 32767 / 32768.0], dtype=np.float32)
|
||||||
|
result = audio_to_float32(audio)
|
||||||
|
np.testing.assert_array_almost_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_to_float32_valid_float32():
|
||||||
|
audio = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
|
||||||
|
result = audio_to_float32(audio)
|
||||||
|
np.testing.assert_array_equal(result, audio)
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_to_float32_empty_array():
|
||||||
|
audio = np.array([], dtype=np.int16)
|
||||||
|
result = audio_to_float32(audio)
|
||||||
|
np.testing.assert_array_equal(result, np.array([], dtype=np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_to_float32_invalid_dtype():
|
||||||
|
audio = np.array([1, 2, 3], dtype=np.int32)
|
||||||
|
with pytest.raises(TypeError, match="Unsupported audio data type"):
|
||||||
|
audio_to_float32(audio) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_to_int16_valid_float32():
|
||||||
|
audio = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
|
||||||
|
expected = np.array([-32767, 0, 32767], dtype=np.int16)
|
||||||
|
result = audio_to_int16(audio)
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_to_int16_valid_int16():
|
||||||
|
audio = np.array([-32768, 0, 32767], dtype=np.int16)
|
||||||
|
result = audio_to_int16(audio)
|
||||||
|
np.testing.assert_array_equal(result, audio)
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_to_int16_empty_array():
|
||||||
|
audio = np.array([], dtype=np.float32)
|
||||||
|
result = audio_to_int16(audio)
|
||||||
|
np.testing.assert_array_equal(result, np.array([], dtype=np.int16))
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_to_int16_invalid_dtype():
|
||||||
|
audio = np.array([1, 2, 3], dtype=np.int32)
|
||||||
|
with pytest.raises(TypeError, match="Unsupported audio data type"):
|
||||||
|
audio_to_int16(audio) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_arguments():
|
||||||
|
result = audio_to_float32((16000, np.zeros(10, dtype=np.int16)))
|
||||||
|
np.testing.assert_array_equal(result, np.zeros(10, dtype=np.float32))
|
||||||
|
|
||||||
|
result = audio_to_int16((16000, np.zeros(10, dtype=np.float32)))
|
||||||
|
np.testing.assert_array_equal(result, np.zeros(10, dtype=np.int16))
|
||||||
Reference in New Issue
Block a user