This commit is contained in:
freddyaboulton
2024-10-28 09:59:08 -07:00
parent c051736fbb
commit d1c43edcd4
10 changed files with 225 additions and 18 deletions

View File

@@ -1,9 +1,10 @@
import asyncio
import fractions
import logging
from typing import Callable
from typing import Any, Callable, Protocol, cast
import av
import numpy as np
logger = logging.getLogger(__name__)
@@ -11,10 +12,38 @@ logger = logging.getLogger(__name__)
AUDIO_PTIME = 0.020
class AdditionalOutputs:
def __init__(self, *args) -> None:
self.args = args
class DataChannel(Protocol):
def send(self, message: str) -> None: ...
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
if isinstance(data, tuple):
# handle the bare audio case
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
return data, None
if not len(data) == 2:
raise ValueError(
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
)
if not isinstance(data[-1], AdditionalOutputs):
raise ValueError(
"The last element of the tuple must be an instance of AdditionalOutputs."
)
return data[0], cast(AdditionalOutputs, data[1])
return data, None
async def player_worker_decode(
next_frame: Callable,
queue: asyncio.Queue,
thread_quit: asyncio.Event,
channel: Callable[[], DataChannel | None] | None,
set_additional_outputs: Callable | None,
quit_on_none: bool = False,
sample_rate: int = 48000,
frame_size: int = int(48000 * AUDIO_PTIME),
@@ -31,7 +60,17 @@ async def player_worker_decode(
while not thread_quit.is_set():
try:
# Get next frame
frame = await asyncio.wait_for(next_frame(), timeout=60)
frame, outputs = split_output(
await asyncio.wait_for(next_frame(), timeout=60)
)
if (
isinstance(outputs, AdditionalOutputs)
and set_additional_outputs
and channel
and channel()
):
set_additional_outputs(outputs)
cast(DataChannel, channel()).send("change")
if frame is None:
if quit_on_none:
@@ -65,7 +104,7 @@ async def player_worker_decode(
processed_frame.time_base = audio_time_base
audio_samples += processed_frame.samples
await queue.put(processed_frame)
logger.debug("Queue size utils.py: %s", queue.qsize())
logger.debug("Queue size utils.py: %s", queue.qsize())
except (TimeoutError, asyncio.TimeoutError):
logger.warning(