Fix issue when the audio stream mixes sample rates and numpy array data types (#188)

* Fix code

* Fix

* keep same
This commit is contained in:
Freddy Boulton
2025-03-18 18:53:47 -04:00
committed by GitHub
parent 5a196868dd
commit 44aac8d964
3 changed files with 19 additions and 9 deletions

View File

@@ -11,6 +11,7 @@ from contextvars import ContextVar
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
import av
import librosa
import numpy as np
from numpy.typing import NDArray
from pydub import AudioSegment
@@ -134,7 +135,7 @@ async def player_worker_decode(
rate=sample_rate,
frame_size=frame_size,
)
first_sample_rate = None
while not thread_quit.is_set():
try:
# Get next frame
@@ -174,19 +175,29 @@ async def player_worker_decode(
layout, # type: ignore
)
format = "s16" if audio_array.dtype == "int16" else "fltp" # type: ignore
if first_sample_rate is None:
first_sample_rate = sample_rate
if format == "s16":
audio_array = audio_to_float32((sample_rate, audio_array))
if first_sample_rate != sample_rate:
audio_array = librosa.resample(
audio_array, target_sr=first_sample_rate, orig_sr=sample_rate
)
if audio_array.ndim == 1:
audio_array = audio_array.reshape(1, -1)
# Convert to audio frame and resample
# Convert to audio frame and
# This runs in the same timeout context
frame = av.AudioFrame.from_ndarray( # type: ignore
audio_array, # type: ignore
format=format,
format="fltp",
layout=layout, # type: ignore
)
frame.sample_rate = sample_rate
frame.sample_rate = first_sample_rate
for processed_frame in audio_resampler.resample(frame):
processed_frame.pts = audio_samples
processed_frame.time_base = audio_time_base

View File

@@ -2,11 +2,10 @@ import asyncio
import base64
import json
from pathlib import Path
import sounddevice as sd
import aiohttp # pip install aiohttp
import gradio as gr
import numpy as np
import aiohttp # pip install aiohttp
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse

View File

@@ -86,7 +86,7 @@ from fastrtc import get_tts_model, Stream, ReplyOnPause
tts_client = get_tts_model()
def detection(audio: tuple[int, np.ndarray]):
def echo(audio: tuple[int, np.ndarray]):
# Implement any iterator that yields audio
# See "LLM Voice Chat" for a more complete example
yield audio
@@ -98,7 +98,7 @@ def startup():
stream = Stream(
handler=ReplyOnPause(detection, startup_fn=startup),
handler=ReplyOnPause(echo, startup_fn=startup),
modality="audio",
mode="send-receive",
ui_args={"title": "Echo Audio"},