mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
make code
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .webrtc import WebRTC, StreamHandler
|
||||
from .webrtc import StreamHandler, WebRTC
|
||||
|
||||
__all__ = ["StreamHandler", "WebRTC"]
|
||||
|
||||
@@ -2,7 +2,6 @@ import asyncio
|
||||
import fractions
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import av
|
||||
@@ -15,35 +14,44 @@ AUDIO_PTIME = 0.020
|
||||
|
||||
def player_worker_decode(
|
||||
loop,
|
||||
next: Callable,
|
||||
next_frame: Callable,
|
||||
queue: asyncio.Queue,
|
||||
throttle_playback: bool,
|
||||
thread_quit: threading.Event,
|
||||
quit_on_none: bool = False,
|
||||
sample_rate: int = 48000,
|
||||
frame_size: int = int(48000 * AUDIO_PTIME),
|
||||
):
|
||||
audio_sample_rate = 48000
|
||||
audio_samples = 0
|
||||
audio_time_base = fractions.Fraction(1, audio_sample_rate)
|
||||
audio_resampler = av.AudioResampler(
|
||||
audio_time_base = fractions.Fraction(1, sample_rate)
|
||||
audio_resampler = av.AudioResampler( # type: ignore
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
rate=audio_sample_rate,
|
||||
frame_size=int(audio_sample_rate * AUDIO_PTIME),
|
||||
rate=sample_rate,
|
||||
frame_size=frame_size,
|
||||
)
|
||||
|
||||
frame_time = None
|
||||
start_time = time.time()
|
||||
|
||||
while not thread_quit.is_set():
|
||||
frame = next()
|
||||
logger.debug("emitted %s", frame)
|
||||
# read up to 1 second ahead
|
||||
if throttle_playback:
|
||||
elapsed_time = time.time() - start_time
|
||||
if frame_time and frame_time > elapsed_time + 1:
|
||||
time.sleep(0.1)
|
||||
sample_rate, audio_array = frame
|
||||
frame = next_frame()
|
||||
if frame is None:
|
||||
if quit_on_none:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
||||
continue
|
||||
|
||||
if len(frame) == 2:
|
||||
sample_rate, audio_array = frame
|
||||
layout = "mono"
|
||||
elif len(frame) == 3:
|
||||
sample_rate, audio_array, layout = frame
|
||||
|
||||
logger.debug(
|
||||
"received array with shape %s sample rate %s layout %s",
|
||||
audio_array.shape,
|
||||
sample_rate,
|
||||
layout,
|
||||
)
|
||||
format = "s16" if audio_array.dtype == "int16" else "fltp"
|
||||
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout="stereo")
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout=layout) # type: ignore
|
||||
frame.sample_rate = sample_rate
|
||||
for frame in audio_resampler.resample(frame):
|
||||
# fix timestamps
|
||||
@@ -51,5 +59,4 @@ def player_worker_decode(
|
||||
frame.time_base = audio_time_base
|
||||
audio_samples += frame.samples
|
||||
|
||||
frame_time = frame.time
|
||||
asyncio.run_coroutine_threadsafe(queue.put(frame), loop)
|
||||
|
||||
@@ -3,24 +3,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Generator, Literal, Sequence, cast
|
||||
|
||||
import anyio.to_thread
|
||||
import av
|
||||
import numpy as np
|
||||
from aiortc import (
|
||||
AudioStreamTrack,
|
||||
MediaStreamTrack,
|
||||
RTCPeerConnection,
|
||||
RTCSessionDescription,
|
||||
VideoStreamTrack,
|
||||
MediaStreamTrack,
|
||||
)
|
||||
from aiortc.contrib.media import MediaRelay, AudioFrame, VideoFrame # type: ignore
|
||||
from aiortc.contrib.media import AudioFrame, MediaRelay, VideoFrame # type: ignore
|
||||
from aiortc.mediastreams import MediaStreamError
|
||||
from gradio import wasm_utils
|
||||
from gradio.components.base import Component, server
|
||||
@@ -104,6 +105,27 @@ class VideoCallback(VideoStreamTrack):
|
||||
|
||||
|
||||
class StreamHandler(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||
output_sample_rate: int = 24000,
|
||||
output_frame_size: int = 960,
|
||||
) -> None:
|
||||
self.expected_layout = expected_layout
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_frame_size = output_frame_size
|
||||
self._resampler = None
|
||||
|
||||
def resample(self, frame: AudioFrame) -> Generator[AudioFrame, None, None]:
|
||||
if self._resampler is None:
|
||||
self._resampler = av.AudioResampler( # type: ignore
|
||||
format="s16",
|
||||
layout=self.expected_layout,
|
||||
rate=frame.sample_rate,
|
||||
frame_size=frame.samples,
|
||||
)
|
||||
yield from self._resampler.resample(frame)
|
||||
|
||||
@abstractmethod
|
||||
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
|
||||
pass
|
||||
@@ -124,24 +146,27 @@ class AudioCallback(AudioStreamTrack):
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
self.current_timestamp = 0
|
||||
self.latest_args = "not_set"
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.queue = asyncio.Queue()
|
||||
self.thread_quit = threading.Event()
|
||||
self.__thread = None
|
||||
self._start: float | None = None
|
||||
self.has_started = False
|
||||
self.last_timestamp = 0
|
||||
super().__init__()
|
||||
|
||||
async def process_input_frames(self) -> None:
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
frame = cast(AudioFrame, await self.track.recv())
|
||||
numpy_array = frame.to_ndarray()
|
||||
logger.debug("numpy array shape %s", numpy_array.shape)
|
||||
await anyio.to_thread.run_sync(
|
||||
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
||||
)
|
||||
except MediaStreamError:
|
||||
for frame in self.event_handler.resample(frame):
|
||||
numpy_array = frame.to_ndarray()
|
||||
logger.debug("numpy array shape %s", numpy_array.shape)
|
||||
await anyio.to_thread.run_sync(
|
||||
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
||||
)
|
||||
except MediaStreamError as e:
|
||||
print("MediaStreamError", e)
|
||||
break
|
||||
|
||||
def start(self):
|
||||
@@ -154,8 +179,10 @@ class AudioCallback(AudioStreamTrack):
|
||||
asyncio.get_event_loop(),
|
||||
self.event_handler.emit,
|
||||
self.queue,
|
||||
True,
|
||||
self.thread_quit,
|
||||
False,
|
||||
self.event_handler.output_sample_rate,
|
||||
self.event_handler.output_frame_size,
|
||||
),
|
||||
)
|
||||
self.__thread.start()
|
||||
@@ -167,23 +194,25 @@ class AudioCallback(AudioStreamTrack):
|
||||
raise MediaStreamError
|
||||
|
||||
self.start()
|
||||
data = await self.queue.get()
|
||||
logger.debug("data %s", data)
|
||||
if data is None:
|
||||
self.stop()
|
||||
return
|
||||
frame = await self.queue.get()
|
||||
logger.debug("frame %s", frame)
|
||||
|
||||
data_time = data.time
|
||||
data_time = frame.time
|
||||
|
||||
if time.time() - self.last_timestamp > 10 * (
|
||||
self.event_handler.output_frame_size
|
||||
/ self.event_handler.output_sample_rate
|
||||
):
|
||||
self._start = None
|
||||
|
||||
# control playback rate
|
||||
if data_time is not None:
|
||||
if self._start is None:
|
||||
self._start = time.time() - data_time
|
||||
else:
|
||||
wait = self._start + data_time - time.time()
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
return data
|
||||
if self._start is None:
|
||||
self._start = time.time() - data_time
|
||||
else:
|
||||
wait = self._start + data_time - time.time()
|
||||
await asyncio.sleep(wait)
|
||||
self.last_timestamp = time.time()
|
||||
return frame
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
exec = traceback.format_exc()
|
||||
@@ -210,6 +239,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
) -> None:
|
||||
super().__init__() # don't forget this!
|
||||
self.event_handler = event_handler
|
||||
self.args_set = asyncio.Event()
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
|
||||
@@ -219,12 +249,8 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
async def recv(self):
|
||||
try:
|
||||
pts, time_base = await self.next_timestamp()
|
||||
if self.latest_args == "not_set":
|
||||
frame = self.array_to_frame(np.zeros((480, 640, 3), dtype=np.uint8))
|
||||
frame.pts = pts
|
||||
frame.time_base = time_base
|
||||
return frame
|
||||
elif self.generator is None:
|
||||
await self.args_set.wait()
|
||||
if self.generator is None:
|
||||
self.generator = cast(
|
||||
Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
||||
)
|
||||
@@ -255,7 +281,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
self.event_handler = event_handler
|
||||
self.current_timestamp = 0
|
||||
self.latest_args = "not_set"
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.args_set = threading.Event()
|
||||
self.queue = asyncio.Queue()
|
||||
self.thread_quit = threading.Event()
|
||||
self.__thread = None
|
||||
@@ -263,23 +290,15 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
super().__init__()
|
||||
|
||||
def next(self) -> tuple[int, np.ndarray] | None:
|
||||
import anyio
|
||||
|
||||
if self.latest_args == "not_set":
|
||||
return
|
||||
self.args_set.wait()
|
||||
if self.generator is None:
|
||||
self.generator = self.event_handler(*self.latest_args)
|
||||
if self.generator is not None:
|
||||
try:
|
||||
frame = next(self.generator)
|
||||
return frame
|
||||
except Exception as exc:
|
||||
if isinstance(exc, StopIteration):
|
||||
logger.debug("Stopping audio stream")
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.queue.put(None), asyncio.get_event_loop()
|
||||
)
|
||||
self.thread_quit.set()
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
if self.__thread is None:
|
||||
@@ -290,8 +309,8 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
asyncio.get_event_loop(),
|
||||
self.next,
|
||||
self.queue,
|
||||
False,
|
||||
self.thread_quit,
|
||||
True,
|
||||
),
|
||||
)
|
||||
self.__thread.start()
|
||||
@@ -370,6 +389,7 @@ class WebRTC(Component):
|
||||
key: int | str | None = None,
|
||||
mirror_webcam: bool = True,
|
||||
rtc_configuration: dict[str, Any] | None = None,
|
||||
track_constraints: dict[str, Any] | None = None,
|
||||
time_limit: float | None = None,
|
||||
mode: Literal["send-receive", "receive"] = "send-receive",
|
||||
modality: Literal["video", "audio"] = "video",
|
||||
@@ -412,7 +432,24 @@ class WebRTC(Component):
|
||||
self.rtc_configuration = rtc_configuration
|
||||
self.mode = mode
|
||||
self.modality = modality
|
||||
self.event_handler: Callable | None = None
|
||||
if track_constraints is None and modality == "audio":
|
||||
track_constraints = {
|
||||
"echoCancellation": True,
|
||||
"noiseSuppression": {"exact": True},
|
||||
"autoGainControl": {"exact": True},
|
||||
"sampleRate": {"ideal": 24000},
|
||||
"sampleSize": {"ideal": 16},
|
||||
"channelCount": {"exact": 1},
|
||||
}
|
||||
if track_constraints is None and modality == "video":
|
||||
track_constraints = {
|
||||
"facingMode": "user",
|
||||
"width": {"ideal": 500},
|
||||
"height": {"ideal": 500},
|
||||
"frameRate": {"ideal": 30},
|
||||
}
|
||||
self.track_constraints = track_constraints
|
||||
self.event_handler: Callable | StreamHandler | None = None
|
||||
super().__init__(
|
||||
label=label,
|
||||
every=every,
|
||||
@@ -456,6 +493,7 @@ class WebRTC(Component):
|
||||
)
|
||||
elif self.mode == "receive":
|
||||
self.connections[webrtc_id].latest_args = list(args)
|
||||
self.connections[webrtc_id].args_set.set() # type: ignore
|
||||
|
||||
def stream(
|
||||
self,
|
||||
@@ -534,9 +572,9 @@ class WebRTC(Component):
|
||||
"In the receive mode stream event, the trigger parameter must be provided"
|
||||
)
|
||||
trigger(lambda: "start_webrtc_stream", inputs=None, outputs=self)
|
||||
self.tick(
|
||||
self.tick( # type: ignore
|
||||
self.set_output,
|
||||
inputs=[self] + inputs,
|
||||
inputs=[self] + list(inputs),
|
||||
outputs=None,
|
||||
concurrency_id=concurrency_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user