From e87c4d49e855e6281aa5f3011051f5c7f74ba535 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 23 Oct 2024 15:52:48 -0700 Subject: [PATCH] final touches --- backend/gradio_webrtc/utils.py | 82 ++++++++++++++++++++------------- backend/gradio_webrtc/webrtc.py | 53 +++++++++------------ pyproject.toml | 2 +- 3 files changed, 75 insertions(+), 62 deletions(-) diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index f3579ca..16b75c3 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -1,7 +1,6 @@ import asyncio import fractions import logging -import threading from typing import Callable import av @@ -12,11 +11,10 @@ logger = logging.getLogger(__name__) AUDIO_PTIME = 0.020 -def player_worker_decode( - loop, +async def player_worker_decode( next_frame: Callable, queue: asyncio.Queue, - thread_quit: threading.Event, + thread_quit: asyncio.Event, quit_on_none: bool = False, sample_rate: int = 48000, frame_size: int = int(48000 * AUDIO_PTIME), @@ -31,32 +29,54 @@ def player_worker_decode( ) while not thread_quit.is_set(): - frame = next_frame() - if frame is None: - if quit_on_none: - asyncio.run_coroutine_threadsafe(queue.put(None), loop) + try: + async with asyncio.timeout(5): + # Get next frame + frame = await next_frame() + + if frame is None: + if quit_on_none: + await queue.put(None) + break + continue + + if len(frame) == 2: + sample_rate, audio_array = frame + layout = "mono" + elif len(frame) == 3: + sample_rate, audio_array, layout = frame + + logger.debug( + "received array with shape %s sample rate %s layout %s", + audio_array.shape, + sample_rate, + layout, + ) + format = "s16" if audio_array.dtype == "int16" else "fltp" + + # Convert to audio frame and resample + # This runs in the same timeout context + frame = av.AudioFrame.from_ndarray( + audio_array, format=format, layout=layout + ) + frame.sample_rate = sample_rate + + for processed_frame in audio_resampler.resample(frame): + processed_frame.pts = audio_samples + processed_frame.time_base = audio_time_base + audio_samples += processed_frame.samples + await queue.put(processed_frame) + logger.debug("Queue size utils.py: %s", queue.qsize()) + + except TimeoutError: + logger.warning( + "Timeout in frame processing cycle after %s seconds - resetting", 5 + ) continue + except Exception as e: + import traceback - if len(frame) == 2: - sample_rate, audio_array = frame - layout = "mono" - elif len(frame) == 3: - sample_rate, audio_array, layout = frame - - logger.debug( - "received array with shape %s sample rate %s layout %s", - audio_array.shape, - sample_rate, - layout, - ) - format = "s16" if audio_array.dtype == "int16" else "fltp" - - frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout=layout) # type: ignore - frame.sample_rate = sample_rate - for frame in audio_resampler.resample(frame): - # fix timestamps - frame.pts = audio_samples - frame.time_base = audio_time_base - audio_samples += frame.samples - asyncio.run_coroutine_threadsafe(queue.put(frame), loop) - logger.debug("Queue size utils.py: %s", queue.qsize()) + exec = traceback.format_exc() + logger.debug("traceback %s", exec) + logger.error("Error processing frame: %s", str(e)) + continue diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index 3d9a9cc..647b6f2 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import functools import logging import threading import time @@ -148,8 +149,7 @@ class AudioCallback(AudioStreamTrack): self.current_timestamp = 0 self.latest_args: str | list[Any] = "not_set" self.queue = asyncio.Queue() - self.thread_quit = threading.Event() - self.__thread = None + self.thread_quit = asyncio.Event() self._start: float | None = None self.has_started = False self.last_timestamp = 0 @@ -165,26 +165,26 @@ class AudioCallback(AudioStreamTrack): self.event_handler.receive, (frame.sample_rate, numpy_array) ) except MediaStreamError: - logger.debug("MediaStreamError in process_input_frames") - break + logger.debug("MediaStreamError in process_input_frames") + break def start(self): if not self.has_started: + loop = asyncio.get_running_loop() + callable = functools.partial( + loop.run_in_executor, None, self.event_handler.emit + ) asyncio.create_task(self.process_input_frames()) - self.__thread = threading.Thread( - target=player_worker_decode, - daemon=False, - args=( - asyncio.get_event_loop(), - self.event_handler.emit, + asyncio.create_task( + player_worker_decode( + callable, self.queue, self.thread_quit, False, self.event_handler.output_sample_rate, self.event_handler.output_frame_size, - ), + ) ) - self.__thread.start() self.has_started = True async def recv(self): @@ -220,9 +220,6 @@ class AudioCallback(AudioStreamTrack): def stop(self): logger.debug("audio callback stop") self.thread_quit.set() - if self.__thread is not None: - self.__thread.join() - self.__thread = None super().stop() @@ -284,8 +281,8 @@ class ServerToClientAudio(AudioStreamTrack): self.latest_args: str | list[Any] = "not_set" self.args_set = threading.Event() self.queue = asyncio.Queue() - self.thread_quit = threading.Event() - self.__thread = None + self.thread_quit = asyncio.Event() + self.has_started = False self._start: float | None = None super().__init__() @@ -301,20 +298,18 @@ class ServerToClientAudio(AudioStreamTrack): self.thread_quit.set() def start(self): - if self.__thread is None: - self.__thread = threading.Thread( - name="generator-runner", - target=player_worker_decode, - daemon=True, - args=( - asyncio.get_event_loop(), - self.next, + if not self.has_started: + loop = asyncio.get_running_loop() + callable = functools.partial(loop.run_in_executor, None, self.next) + asyncio.create_task( + player_worker_decode( + callable, self.queue, self.thread_quit, True, - ), + ) ) - self.__thread.start() + self.has_started = True async def recv(self): try: @@ -344,10 +339,8 @@ class ServerToClientAudio(AudioStreamTrack): logger.debug("traceback %s", exec) def stop(self): + logger.debug("audio-to-client stop callback") self.thread_quit.set() - if self.__thread is not None: - self.__thread.join() - self.__thread = None super().stop() diff --git a/pyproject.toml b/pyproject.toml index b764a23..3aaed1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" [project] name = "gradio_webrtc" -version = "0.0.6a2" +version = "0.0.6a3" description = "Stream images in realtime with webrtc" readme = "README.md" license = "apache-2.0"