Fix audio type conversion (#259)

* Fix conversion between audio dtypes

* Run Pytest in CI

* Add pytest tests path in pyproject.toml

* Fix usages

* Use other PR's test format (more or less)

* Support legacy arguments

* Fix pyproject.toml and test location

* Omit `test` arg in CI, given by pyproject.toml

---------

Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
Václav Volhejn
2025-04-09 16:00:23 +02:00
committed by GitHub
parent fdf6bea1c6
commit 58bccddd93
9 changed files with 128 additions and 43 deletions

View File

@@ -8,7 +8,7 @@ import numpy as np
from huggingface_hub import hf_hub_download
from numpy.typing import NDArray
from ..utils import AudioChunk
from ..utils import AudioChunk, audio_to_float32
from .protocol import PauseDetectionModel
logger = logging.getLogger(__name__)
@@ -274,8 +274,7 @@ class SileroVADModel:
sampling_rate, audio_ = audio
logger.debug("VAD audio shape input: %s", audio_.shape)
try:
if audio_.dtype != np.float32:
audio_ = audio_.astype(np.float32) / 32768.0
audio_ = audio_to_float32(audio_)
sr = 16000
if sr != sampling_rate:
try:

View File

@@ -161,7 +161,7 @@ class ReplyOnStopWords(ReplyOnPause):
if duration >= self.algo_options.audio_chunk_duration:
if not state.stop_word_detected:
audio_f32 = audio_to_float32((sampling_rate, audio))
audio_f32 = audio_to_float32(audio)
audio_rs = librosa.resample(
audio_f32, orig_sr=sampling_rate, target_sr=16000
)

View File

@@ -32,8 +32,7 @@ class MoonshineSTT(STTModel):
def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str:
sr, audio_np = audio # type: ignore
if audio_np.dtype == np.int16:
audio_np = audio_to_float32(audio)
audio_np = audio_to_float32(audio_np)
if sr != 16000:
audio_np: NDArray[np.float32] = librosa.resample(
audio_np, orig_sr=sr, target_sr=16000

View File

@@ -7,6 +7,7 @@ import json
import logging
import tempfile
import traceback
import warnings
from collections.abc import Callable, Coroutine
from contextvars import ContextVar
from dataclasses import dataclass
@@ -211,7 +212,7 @@ async def player_worker_decode(
first_sample_rate = sample_rate
if format == "s16":
audio_array = audio_to_float32((sample_rate, audio_array))
audio_array = audio_to_float32(audio_array)
if first_sample_rate != sample_rate:
audio_array = librosa.resample(
@@ -319,17 +320,15 @@ def audio_to_file(audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str:
def audio_to_float32(
audio: tuple[int, NDArray[np.int16 | np.float32]],
audio: NDArray[np.int16 | np.float32] | tuple[int, NDArray[np.int16 | np.float32]],
) -> NDArray[np.float32]:
"""
Convert an audio tuple containing sample rate (int16) and numpy array data to float32.
Parameters
----------
audio : tuple[int, np.ndarray]
A tuple containing:
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
audio : np.ndarray
The audio data as a numpy array
Returns
-------
@@ -338,26 +337,39 @@ def audio_to_float32(
Example
-------
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data)
>>> audio_float32 = audio_to_float32(audio_tuple)
>>> audio_float32 = audio_to_float32(audio_data)
"""
return audio[1].astype(np.float32) / 32768.0
if isinstance(audio, tuple):
warnings.warn(
UserWarning(
"Passing a (sr, audio) tuple to audio_to_float32() is deprecated "
"and will be removed in a future release. Pass only the audio array."
),
stacklevel=2, # So that the warning points to the user's code
)
_sr, audio = audio
if audio.dtype == np.int16:
# Divide by 32768.0 so that the values are in the range [-1.0, 1.0).
# 1.0 can actually never be reached because the int16 range is [-32768, 32767].
return audio.astype(np.float32) / 32768.0
elif audio.dtype == np.float32:
return audio # type: ignore
else:
raise TypeError(f"Unsupported audio data type: {audio.dtype}")
def audio_to_int16(
audio: tuple[int, NDArray[np.int16 | np.float32]],
audio: NDArray[np.int16 | np.float32] | tuple[int, NDArray[np.int16 | np.float32]],
) -> NDArray[np.int16]:
"""
Convert an audio tuple containing sample rate and numpy array data to int16.
Parameters
----------
audio : tuple[int, np.ndarray]
A tuple containing:
- sample_rate (int): The audio sample rate in Hz
- data (np.ndarray): The audio data as a numpy array
audio : np.ndarray
The audio data as a numpy array
Returns
-------
@@ -366,18 +378,27 @@ def audio_to_int16(
Example
-------
>>> sample_rate = 44100
>>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples
>>> audio_tuple = (sample_rate, audio_data)
>>> audio_int16 = audio_to_int16(audio_tuple)
>>> audio_int16 = audio_to_int16(audio_data)
"""
if audio[1].dtype == np.int16:
return audio[1] # type: ignore
elif audio[1].dtype == np.float32:
# Convert float32 to int16 by scaling to the int16 range
return (audio[1] * 32767.0).astype(np.int16)
if isinstance(audio, tuple):
warnings.warn(
UserWarning(
"Passing a (sr, audio) tuple to audio_to_float32() is deprecated "
"and will be removed in a future release. Pass only the audio array."
),
stacklevel=2, # So that the warning points to the user's code
)
_sr, audio = audio
if audio.dtype == np.int16:
return audio # type: ignore
elif audio.dtype == np.float32:
# Convert float32 to int16 by scaling to the int16 range.
# Multiply by 32767 and not 32768 so that int16 doesn't overflow.
return (audio * 32767.0).astype(np.int16)
else:
raise TypeError(f"Unsupported audio data type: {audio[1].dtype}")
raise TypeError(f"Unsupported audio data type: {audio.dtype}")
def aggregate_bytes_to_16bit(chunks_iterator):

View File

@@ -12,7 +12,14 @@ from fastapi import WebSocket
from fastapi.websockets import WebSocketDisconnect, WebSocketState
from .tracks import AsyncStreamHandler, StreamHandlerImpl
from .utils import AdditionalOutputs, CloseStream, DataChannel, split_output
from .utils import (
AdditionalOutputs,
CloseStream,
DataChannel,
audio_to_float32,
audio_to_int16,
split_output,
)
class WebSocketDataChannel(DataChannel):
@@ -31,14 +38,12 @@ def convert_to_mulaw(
audio_data: np.ndarray, original_rate: int, target_rate: int
) -> bytes:
"""Convert audio data to 8kHz mu-law format"""
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32) / 32768.0
audio_data = audio_to_float32(audio_data)
if original_rate != target_rate:
audio_data = librosa.resample(audio_data, orig_sr=original_rate, target_sr=8000)
audio_data = (audio_data * 32768).astype(np.int16)
audio_data = audio_to_int16(audio_data)
return audioop.lin2ulaw(audio_data, 2) # type: ignore
@@ -122,14 +127,13 @@ class WebSocketHandler:
)
if self.stream_handler.input_sample_rate != 8000:
audio_array = audio_array.astype(np.float32) / 32768.0
audio_array = audio_to_float32(audio_array)
audio_array = librosa.resample(
audio_array,
orig_sr=8000,
target_sr=self.stream_handler.input_sample_rate,
)
audio_array = (audio_array * 32768).astype(np.int16)
audio_array = audio_to_int16(audio_array)
try:
if isinstance(self.stream_handler, AsyncStreamHandler):
await self.stream_handler.receive(