make code

This commit is contained in:
freddyaboulton
2024-10-22 16:24:21 -07:00
parent cff6073df0
commit e7f3e63c79
20 changed files with 427 additions and 156 deletions

View File

@@ -1,3 +1,3 @@
from .webrtc import WebRTC, StreamHandler
from .webrtc import StreamHandler, WebRTC
__all__ = ["StreamHandler", "WebRTC"]

View File

@@ -2,7 +2,6 @@ import asyncio
import fractions
import logging
import threading
import time
from typing import Callable
import av
@@ -15,35 +14,44 @@ AUDIO_PTIME = 0.020
def player_worker_decode(
loop,
next: Callable,
next_frame: Callable,
queue: asyncio.Queue,
throttle_playback: bool,
thread_quit: threading.Event,
quit_on_none: bool = False,
sample_rate: int = 48000,
frame_size: int = int(48000 * AUDIO_PTIME),
):
audio_sample_rate = 48000
audio_samples = 0
audio_time_base = fractions.Fraction(1, audio_sample_rate)
audio_resampler = av.AudioResampler(
audio_time_base = fractions.Fraction(1, sample_rate)
audio_resampler = av.AudioResampler( # type: ignore
format="s16",
layout="stereo",
rate=audio_sample_rate,
frame_size=int(audio_sample_rate * AUDIO_PTIME),
rate=sample_rate,
frame_size=frame_size,
)
frame_time = None
start_time = time.time()
while not thread_quit.is_set():
frame = next()
logger.debug("emitted %s", frame)
# read up to 1 second ahead
if throttle_playback:
elapsed_time = time.time() - start_time
if frame_time and frame_time > elapsed_time + 1:
time.sleep(0.1)
sample_rate, audio_array = frame
frame = next_frame()
if frame is None:
if quit_on_none:
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
continue
if len(frame) == 2:
sample_rate, audio_array = frame
layout = "mono"
elif len(frame) == 3:
sample_rate, audio_array, layout = frame
logger.debug(
"received array with shape %s sample rate %s layout %s",
audio_array.shape,
sample_rate,
layout,
)
format = "s16" if audio_array.dtype == "int16" else "fltp"
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout="stereo")
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout=layout) # type: ignore
frame.sample_rate = sample_rate
for frame in audio_resampler.resample(frame):
# fix timestamps
@@ -51,5 +59,4 @@ def player_worker_decode(
frame.time_base = audio_time_base
audio_samples += frame.samples
frame_time = frame.time
asyncio.run_coroutine_threadsafe(queue.put(frame), loop)

View File

@@ -3,24 +3,25 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
import logging
import threading
import time
import traceback
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
import anyio.to_thread
import av
import numpy as np
from aiortc import (
AudioStreamTrack,
MediaStreamTrack,
RTCPeerConnection,
RTCSessionDescription,
VideoStreamTrack,
MediaStreamTrack,
)
from aiortc.contrib.media import MediaRelay, AudioFrame, VideoFrame # type: ignore
from aiortc.contrib.media import AudioFrame, MediaRelay, VideoFrame # type: ignore
from aiortc.mediastreams import MediaStreamError
from gradio import wasm_utils
from gradio.components.base import Component, server
@@ -104,6 +105,27 @@ class VideoCallback(VideoStreamTrack):
class StreamHandler(ABC):
def __init__(
self,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 960,
) -> None:
self.expected_layout = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self._resampler = None
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
if self._resampler is None:
self._resampler = av.AudioResampler( # type: ignore
format="s16",
layout=self.expected_layout,
rate=frame.sample_rate,
frame_size=frame.samples,
)
yield from self._resampler.resample(frame)
@abstractmethod
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
pass
@@ -124,24 +146,27 @@ class AudioCallback(AudioStreamTrack):
self.track = track
self.event_handler = event_handler
self.current_timestamp = 0
self.latest_args = "not_set"
self.latest_args: str | list[Any] = "not_set"
self.queue = asyncio.Queue()
self.thread_quit = threading.Event()
self.__thread = None
self._start: float | None = None
self.has_started = False
self.last_timestamp = 0
super().__init__()
async def process_input_frames(self) -> None:
while not self.thread_quit.is_set():
try:
frame = cast(AudioFrame, await self.track.recv())
numpy_array = frame.to_ndarray()
logger.debug("numpy array shape %s", numpy_array.shape)
await anyio.to_thread.run_sync(
self.event_handler.receive, (frame.sample_rate, numpy_array)
)
except MediaStreamError:
for frame in self.event_handler.resample(frame):
numpy_array = frame.to_ndarray()
logger.debug("numpy array shape %s", numpy_array.shape)
await anyio.to_thread.run_sync(
self.event_handler.receive, (frame.sample_rate, numpy_array)
)
except MediaStreamError as e:
print("MediaStreamError", e)
break
def start(self):
@@ -154,8 +179,10 @@ class AudioCallback(AudioStreamTrack):
asyncio.get_event_loop(),
self.event_handler.emit,
self.queue,
True,
self.thread_quit,
False,
self.event_handler.output_sample_rate,
self.event_handler.output_frame_size,
),
)
self.__thread.start()
@@ -167,23 +194,25 @@ class AudioCallback(AudioStreamTrack):
raise MediaStreamError
self.start()
data = await self.queue.get()
logger.debug("data %s", data)
if data is None:
self.stop()
return
frame = await self.queue.get()
logger.debug("frame %s", frame)
data_time = data.time
data_time = frame.time
if time.time() - self.last_timestamp > 10 * (
self.event_handler.output_frame_size
/ self.event_handler.output_sample_rate
):
self._start = None
# control playback rate
if data_time is not None:
if self._start is None:
self._start = time.time() - data_time
else:
wait = self._start + data_time - time.time()
await asyncio.sleep(wait)
return data
if self._start is None:
self._start = time.time() - data_time
else:
wait = self._start + data_time - time.time()
await asyncio.sleep(wait)
self.last_timestamp = time.time()
return frame
except Exception as e:
logger.debug(e)
exec = traceback.format_exc()
@@ -210,6 +239,7 @@ class ServerToClientVideo(VideoStreamTrack):
) -> None:
super().__init__() # don't forget this!
self.event_handler = event_handler
self.args_set = asyncio.Event()
self.latest_args: str | list[Any] = "not_set"
self.generator: Generator[Any, None, Any] | None = None
@@ -219,12 +249,8 @@ class ServerToClientVideo(VideoStreamTrack):
async def recv(self):
try:
pts, time_base = await self.next_timestamp()
if self.latest_args == "not_set":
frame = self.array_to_frame(np.zeros((480, 640, 3), dtype=np.uint8))
frame.pts = pts
frame.time_base = time_base
return frame
elif self.generator is None:
await self.args_set.wait()
if self.generator is None:
self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args)
)
@@ -255,7 +281,8 @@ class ServerToClientAudio(AudioStreamTrack):
self.generator: Generator[Any, None, Any] | None = None
self.event_handler = event_handler
self.current_timestamp = 0
self.latest_args = "not_set"
self.latest_args: str | list[Any] = "not_set"
self.args_set = threading.Event()
self.queue = asyncio.Queue()
self.thread_quit = threading.Event()
self.__thread = None
@@ -263,23 +290,15 @@ class ServerToClientAudio(AudioStreamTrack):
super().__init__()
def next(self) -> tuple[int, np.ndarray] | None:
import anyio
if self.latest_args == "not_set":
return
self.args_set.wait()
if self.generator is None:
self.generator = self.event_handler(*self.latest_args)
if self.generator is not None:
try:
frame = next(self.generator)
return frame
except Exception as exc:
if isinstance(exc, StopIteration):
logger.debug("Stopping audio stream")
asyncio.run_coroutine_threadsafe(
self.queue.put(None), asyncio.get_event_loop()
)
self.thread_quit.set()
except StopIteration:
pass
def start(self):
if self.__thread is None:
@@ -290,8 +309,8 @@ class ServerToClientAudio(AudioStreamTrack):
asyncio.get_event_loop(),
self.next,
self.queue,
False,
self.thread_quit,
True,
),
)
self.__thread.start()
@@ -370,6 +389,7 @@ class WebRTC(Component):
key: int | str | None = None,
mirror_webcam: bool = True,
rtc_configuration: dict[str, Any] | None = None,
track_constraints: dict[str, Any] | None = None,
time_limit: float | None = None,
mode: Literal["send-receive", "receive"] = "send-receive",
modality: Literal["video", "audio"] = "video",
@@ -412,7 +432,24 @@ class WebRTC(Component):
self.rtc_configuration = rtc_configuration
self.mode = mode
self.modality = modality
self.event_handler: Callable | None = None
if track_constraints is None and modality == "audio":
track_constraints = {
"echoCancellation": True,
"noiseSuppression": {"exact": True},
"autoGainControl": {"exact": True},
"sampleRate": {"ideal": 24000},
"sampleSize": {"ideal": 16},
"channelCount": {"exact": 1},
}
if track_constraints is None and modality == "video":
track_constraints = {
"facingMode": "user",
"width": {"ideal": 500},
"height": {"ideal": 500},
"frameRate": {"ideal": 30},
}
self.track_constraints = track_constraints
self.event_handler: Callable | StreamHandler | None = None
super().__init__(
label=label,
every=every,
@@ -456,6 +493,7 @@ class WebRTC(Component):
)
elif self.mode == "receive":
self.connections[webrtc_id].latest_args = list(args)
self.connections[webrtc_id].args_set.set() # type: ignore
def stream(
self,
@@ -534,9 +572,9 @@ class WebRTC(Component):
"In the receive mode stream event, the trigger parameter must be provided"
)
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
self.tick(
self.tick( # type: ignore
self.set_output,
inputs=[self] + inputs,
inputs=[self] + list(inputs),
outputs=None,
concurrency_id=concurrency_id,
)