mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
final touches
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import fractions
|
import fractions
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import av
|
import av
|
||||||
@@ -12,11 +11,10 @@ logger = logging.getLogger(__name__)
|
|||||||
AUDIO_PTIME = 0.020
|
AUDIO_PTIME = 0.020
|
||||||
|
|
||||||
|
|
||||||
def player_worker_decode(
|
async def player_worker_decode(
|
||||||
loop,
|
|
||||||
next_frame: Callable,
|
next_frame: Callable,
|
||||||
queue: asyncio.Queue,
|
queue: asyncio.Queue,
|
||||||
thread_quit: threading.Event,
|
thread_quit: asyncio.Event,
|
||||||
quit_on_none: bool = False,
|
quit_on_none: bool = False,
|
||||||
sample_rate: int = 48000,
|
sample_rate: int = 48000,
|
||||||
frame_size: int = int(48000 * AUDIO_PTIME),
|
frame_size: int = int(48000 * AUDIO_PTIME),
|
||||||
@@ -31,32 +29,54 @@ def player_worker_decode(
|
|||||||
)
|
)
|
||||||
|
|
||||||
while not thread_quit.is_set():
|
while not thread_quit.is_set():
|
||||||
frame = next_frame()
|
try:
|
||||||
if frame is None:
|
async with asyncio.timeout(5):
|
||||||
if quit_on_none:
|
# Get next frame
|
||||||
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
frame = await next_frame()
|
||||||
|
|
||||||
|
if frame is None:
|
||||||
|
if quit_on_none:
|
||||||
|
await queue.put(None)
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(frame) == 2:
|
||||||
|
sample_rate, audio_array = frame
|
||||||
|
layout = "mono"
|
||||||
|
elif len(frame) == 3:
|
||||||
|
sample_rate, audio_array, layout = frame
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"received array with shape %s sample rate %s layout %s",
|
||||||
|
audio_array.shape,
|
||||||
|
sample_rate,
|
||||||
|
layout,
|
||||||
|
)
|
||||||
|
format = "s16" if audio_array.dtype == "int16" else "fltp"
|
||||||
|
|
||||||
|
# Convert to audio frame and resample
|
||||||
|
# This runs in the same timeout context
|
||||||
|
frame = av.AudioFrame.from_ndarray(
|
||||||
|
audio_array, format=format, layout=layout
|
||||||
|
)
|
||||||
|
frame.sample_rate = sample_rate
|
||||||
|
|
||||||
|
for processed_frame in audio_resampler.resample(frame):
|
||||||
|
processed_frame.pts = audio_samples
|
||||||
|
processed_frame.time_base = audio_time_base
|
||||||
|
audio_samples += processed_frame.samples
|
||||||
|
await queue.put(processed_frame)
|
||||||
|
logger.debug("Queue size utils.py: %s", queue.qsize())
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"Timeout in frame processing cycle after %s seconds - resetting", 5
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
if len(frame) == 2:
|
exec = traceback.format_exc()
|
||||||
sample_rate, audio_array = frame
|
logger.debug("traceback %s", exec)
|
||||||
layout = "mono"
|
logger.error("Error processing frame: %s", str(e))
|
||||||
elif len(frame) == 3:
|
continue
|
||||||
sample_rate, audio_array, layout = frame
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"received array with shape %s sample rate %s layout %s",
|
|
||||||
audio_array.shape,
|
|
||||||
sample_rate,
|
|
||||||
layout,
|
|
||||||
)
|
|
||||||
format = "s16" if audio_array.dtype == "int16" else "fltp"
|
|
||||||
|
|
||||||
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout=layout) # type: ignore
|
|
||||||
frame.sample_rate = sample_rate
|
|
||||||
for frame in audio_resampler.resample(frame):
|
|
||||||
# fix timestamps
|
|
||||||
frame.pts = audio_samples
|
|
||||||
frame.time_base = audio_time_base
|
|
||||||
audio_samples += frame.samples
|
|
||||||
asyncio.run_coroutine_threadsafe(queue.put(frame), loop)
|
|
||||||
logger.debug("Queue size utils.py: %s", queue.qsize())
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@@ -148,8 +149,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
self.current_timestamp = 0
|
self.current_timestamp = 0
|
||||||
self.latest_args: str | list[Any] = "not_set"
|
self.latest_args: str | list[Any] = "not_set"
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
self.thread_quit = threading.Event()
|
self.thread_quit = asyncio.Event()
|
||||||
self.__thread = None
|
|
||||||
self._start: float | None = None
|
self._start: float | None = None
|
||||||
self.has_started = False
|
self.has_started = False
|
||||||
self.last_timestamp = 0
|
self.last_timestamp = 0
|
||||||
@@ -165,26 +165,26 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
||||||
)
|
)
|
||||||
except MediaStreamError:
|
except MediaStreamError:
|
||||||
logger.debug("MediaStreamError in process_input_frames")
|
logger.debug("MediaStreamError in process_input_frames")
|
||||||
break
|
break
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
if not self.has_started:
|
if not self.has_started:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
callable = functools.partial(
|
||||||
|
loop.run_in_executor, None, self.event_handler.emit
|
||||||
|
)
|
||||||
asyncio.create_task(self.process_input_frames())
|
asyncio.create_task(self.process_input_frames())
|
||||||
self.__thread = threading.Thread(
|
asyncio.create_task(
|
||||||
target=player_worker_decode,
|
player_worker_decode(
|
||||||
daemon=False,
|
callable,
|
||||||
args=(
|
|
||||||
asyncio.get_event_loop(),
|
|
||||||
self.event_handler.emit,
|
|
||||||
self.queue,
|
self.queue,
|
||||||
self.thread_quit,
|
self.thread_quit,
|
||||||
False,
|
False,
|
||||||
self.event_handler.output_sample_rate,
|
self.event_handler.output_sample_rate,
|
||||||
self.event_handler.output_frame_size,
|
self.event_handler.output_frame_size,
|
||||||
),
|
)
|
||||||
)
|
)
|
||||||
self.__thread.start()
|
|
||||||
self.has_started = True
|
self.has_started = True
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
@@ -220,9 +220,6 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
def stop(self):
|
def stop(self):
|
||||||
logger.debug("audio callback stop")
|
logger.debug("audio callback stop")
|
||||||
self.thread_quit.set()
|
self.thread_quit.set()
|
||||||
if self.__thread is not None:
|
|
||||||
self.__thread.join()
|
|
||||||
self.__thread = None
|
|
||||||
super().stop()
|
super().stop()
|
||||||
|
|
||||||
|
|
||||||
@@ -284,8 +281,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
self.latest_args: str | list[Any] = "not_set"
|
self.latest_args: str | list[Any] = "not_set"
|
||||||
self.args_set = threading.Event()
|
self.args_set = threading.Event()
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
self.thread_quit = threading.Event()
|
self.thread_quit = asyncio.Event()
|
||||||
self.__thread = None
|
self.has_started = False
|
||||||
self._start: float | None = None
|
self._start: float | None = None
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -301,20 +298,18 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
self.thread_quit.set()
|
self.thread_quit.set()
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
if self.__thread is None:
|
if not self.has_started:
|
||||||
self.__thread = threading.Thread(
|
loop = asyncio.get_running_loop()
|
||||||
name="generator-runner",
|
callable = functools.partial(loop.run_in_executor, None, self.next)
|
||||||
target=player_worker_decode,
|
asyncio.create_task(
|
||||||
daemon=True,
|
player_worker_decode(
|
||||||
args=(
|
callable,
|
||||||
asyncio.get_event_loop(),
|
|
||||||
self.next,
|
|
||||||
self.queue,
|
self.queue,
|
||||||
self.thread_quit,
|
self.thread_quit,
|
||||||
True,
|
True,
|
||||||
),
|
)
|
||||||
)
|
)
|
||||||
self.__thread.start()
|
self.has_started = True
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
try:
|
try:
|
||||||
@@ -344,10 +339,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
logger.debug("traceback %s", exec)
|
logger.debug("traceback %s", exec)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
logger.debug("audio-to-client stop callback")
|
||||||
self.thread_quit.set()
|
self.thread_quit.set()
|
||||||
if self.__thread is not None:
|
|
||||||
self.__thread.join()
|
|
||||||
self.__thread = None
|
|
||||||
super().stop()
|
super().stop()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "gradio_webrtc"
|
name = "gradio_webrtc"
|
||||||
version = "0.0.6a2"
|
version = "0.0.6a3"
|
||||||
description = "Stream images in realtime with webrtc"
|
description = "Stream images in realtime with webrtc"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "apache-2.0"
|
license = "apache-2.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user