final touches

This commit is contained in:
freddyaboulton
2024-10-23 15:52:48 -07:00
parent 1688502e99
commit e87c4d49e8
3 changed files with 75 additions and 62 deletions

View File

@@ -1,7 +1,6 @@
import asyncio import asyncio
import fractions import fractions
import logging import logging
import threading
from typing import Callable from typing import Callable
import av import av
@@ -12,11 +11,10 @@ logger = logging.getLogger(__name__)
AUDIO_PTIME = 0.020 AUDIO_PTIME = 0.020
def player_worker_decode( async def player_worker_decode(
loop,
next_frame: Callable, next_frame: Callable,
queue: asyncio.Queue, queue: asyncio.Queue,
thread_quit: threading.Event, thread_quit: asyncio.Event,
quit_on_none: bool = False, quit_on_none: bool = False,
sample_rate: int = 48000, sample_rate: int = 48000,
frame_size: int = int(48000 * AUDIO_PTIME), frame_size: int = int(48000 * AUDIO_PTIME),
@@ -31,32 +29,54 @@ def player_worker_decode(
) )
while not thread_quit.is_set(): while not thread_quit.is_set():
frame = next_frame() try:
if frame is None: async with asyncio.timeout(5):
if quit_on_none: # Get next frame
asyncio.run_coroutine_threadsafe(queue.put(None), loop) frame = await next_frame()
if frame is None:
if quit_on_none:
await queue.put(None)
break
continue
if len(frame) == 2:
sample_rate, audio_array = frame
layout = "mono"
elif len(frame) == 3:
sample_rate, audio_array, layout = frame
logger.debug(
"received array with shape %s sample rate %s layout %s",
audio_array.shape,
sample_rate,
layout,
)
format = "s16" if audio_array.dtype == "int16" else "fltp"
# Convert to audio frame and resample
# This runs in the same timeout context
frame = av.AudioFrame.from_ndarray(
audio_array, format=format, layout=layout
)
frame.sample_rate = sample_rate
for processed_frame in audio_resampler.resample(frame):
processed_frame.pts = audio_samples
processed_frame.time_base = audio_time_base
audio_samples += processed_frame.samples
await queue.put(processed_frame)
logger.debug("Queue size utils.py: %s", queue.qsize())
except TimeoutError:
logger.warning(
"Timeout in frame processing cycle after %s seconds - resetting", 5
)
continue continue
except Exception as e:
import traceback
if len(frame) == 2: exec = traceback.format_exc()
sample_rate, audio_array = frame logger.debug("traceback %s", exec)
layout = "mono" logger.error("Error processing frame: %s", str(e))
elif len(frame) == 3: continue
sample_rate, audio_array, layout = frame
logger.debug(
"received array with shape %s sample rate %s layout %s",
audio_array.shape,
sample_rate,
layout,
)
format = "s16" if audio_array.dtype == "int16" else "fltp"
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout=layout) # type: ignore
frame.sample_rate = sample_rate
for frame in audio_resampler.resample(frame):
# fix timestamps
frame.pts = audio_samples
frame.time_base = audio_time_base
audio_samples += frame.samples
asyncio.run_coroutine_threadsafe(queue.put(frame), loop)
logger.debug("Queue size utils.py: %s", queue.qsize())

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import functools
import logging import logging
import threading import threading
import time import time
@@ -148,8 +149,7 @@ class AudioCallback(AudioStreamTrack):
self.current_timestamp = 0 self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set" self.latest_args: str | list[Any] = "not_set"
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.thread_quit = threading.Event() self.thread_quit = asyncio.Event()
self.__thread = None
self._start: float | None = None self._start: float | None = None
self.has_started = False self.has_started = False
self.last_timestamp = 0 self.last_timestamp = 0
@@ -165,26 +165,26 @@ class AudioCallback(AudioStreamTrack):
self.event_handler.receive, (frame.sample_rate, numpy_array) self.event_handler.receive, (frame.sample_rate, numpy_array)
) )
except MediaStreamError: except MediaStreamError:
logger.debug("MediaStreamError in process_input_frames") logger.debug("MediaStreamError in process_input_frames")
break break
def start(self): def start(self):
if not self.has_started: if not self.has_started:
loop = asyncio.get_running_loop()
callable = functools.partial(
loop.run_in_executor, None, self.event_handler.emit
)
asyncio.create_task(self.process_input_frames()) asyncio.create_task(self.process_input_frames())
self.__thread = threading.Thread( asyncio.create_task(
target=player_worker_decode, player_worker_decode(
daemon=False, callable,
args=(
asyncio.get_event_loop(),
self.event_handler.emit,
self.queue, self.queue,
self.thread_quit, self.thread_quit,
False, False,
self.event_handler.output_sample_rate, self.event_handler.output_sample_rate,
self.event_handler.output_frame_size, self.event_handler.output_frame_size,
), )
) )
self.__thread.start()
self.has_started = True self.has_started = True
async def recv(self): async def recv(self):
@@ -220,9 +220,6 @@ class AudioCallback(AudioStreamTrack):
def stop(self): def stop(self):
logger.debug("audio callback stop") logger.debug("audio callback stop")
self.thread_quit.set() self.thread_quit.set()
if self.__thread is not None:
self.__thread.join()
self.__thread = None
super().stop() super().stop()
@@ -284,8 +281,8 @@ class ServerToClientAudio(AudioStreamTrack):
self.latest_args: str | list[Any] = "not_set" self.latest_args: str | list[Any] = "not_set"
self.args_set = threading.Event() self.args_set = threading.Event()
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.thread_quit = threading.Event() self.thread_quit = asyncio.Event()
self.__thread = None self.has_started = False
self._start: float | None = None self._start: float | None = None
super().__init__() super().__init__()
@@ -301,20 +298,18 @@ class ServerToClientAudio(AudioStreamTrack):
self.thread_quit.set() self.thread_quit.set()
def start(self): def start(self):
if self.__thread is None: if not self.has_started:
self.__thread = threading.Thread( loop = asyncio.get_running_loop()
name="generator-runner", callable = functools.partial(loop.run_in_executor, None, self.next)
target=player_worker_decode, asyncio.create_task(
daemon=True, player_worker_decode(
args=( callable,
asyncio.get_event_loop(),
self.next,
self.queue, self.queue,
self.thread_quit, self.thread_quit,
True, True,
), )
) )
self.__thread.start() self.has_started = True
async def recv(self): async def recv(self):
try: try:
@@ -344,10 +339,8 @@ class ServerToClientAudio(AudioStreamTrack):
logger.debug("traceback %s", exec) logger.debug("traceback %s", exec)
def stop(self): def stop(self):
logger.debug("audio-to-client stop callback")
self.thread_quit.set() self.thread_quit.set()
if self.__thread is not None:
self.__thread.join()
self.__thread = None
super().stop() super().stop()

View File

@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "gradio_webrtc" name = "gradio_webrtc"
version = "0.0.6a2" version = "0.0.6a3"
description = "Stream images in realtime with webrtc" description = "Stream images in realtime with webrtc"
readme = "README.md" readme = "README.md"
license = "apache-2.0" license = "apache-2.0"