Additional outputs tweaks + fix track constraints (#28)

* code

* add code

* add code
This commit is contained in:
Freddy Boulton
2024-12-03 15:32:43 -05:00
committed by GitHub
parent 65d0ba023f
commit c85c117576
10 changed files with 91 additions and 53 deletions

View File

@@ -167,10 +167,12 @@ class StreamHandler(ABC):
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 960,
input_sample_rate: int = 48000,
) -> None:
self.expected_layout = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
self.input_sample_rate = input_sample_rate
self.latest_args: str | list[Any] = "not_set"
self._resampler = None
self._channel: DataChannel | None = None
@@ -191,6 +193,9 @@ class StreamHandler(ABC):
logger.debug("setting args in audio callback %s", args)
self.latest_args = ["__webrtc_value__"] + list(args)
def shutdown(self):
pass
@abstractmethod
def copy(self) -> "StreamHandler":
pass
@@ -200,17 +205,23 @@ class StreamHandler(ABC):
self._resampler = av.AudioResampler( # type: ignore
format="s16",
layout=self.expected_layout,
rate=frame.sample_rate,
rate=self.input_sample_rate,
frame_size=frame.samples,
)
yield from self._resampler.resample(frame)
@abstractmethod
def receive(self, frame: tuple[int, np.ndarray] | np.ndarray) -> None:
def receive(self, frame: tuple[int, np.ndarray]) -> None:
pass
@abstractmethod
def emit(self) -> None:
def emit(
self,
) -> (
tuple[int, np.ndarray]
| AdditionalOutputs
| tuple[tuple[int, np.ndarray], AdditionalOutputs]
):
pass
@@ -313,6 +324,9 @@ class AudioCallback(AudioStreamTrack):
self.thread_quit.set()
super().stop()
def shutdown(self):
self.event_handler.shutdown()
class ServerToClientVideo(VideoStreamTrack):
"""
@@ -489,7 +503,7 @@ class WebRTC(Component):
str, VideoCallback | ServerToClientVideo | ServerToClientAudio | AudioCallback
] = {}
data_channels: dict[str, DataChannel] = {}
additional_outputs: dict[str, AdditionalOutputs] = {}
additional_outputs: dict[str, list[AdditionalOutputs]] = {}
EVENTS = ["tick", "state_change"]
@@ -517,6 +531,7 @@ class WebRTC(Component):
time_limit: float | None = None,
mode: Literal["send-receive", "receive", "send"] = "send-receive",
modality: Literal["video", "audio"] = "video",
rtp_params: dict[str, Any] | None = None,
):
"""
Parameters:
@@ -538,15 +553,12 @@ class WebRTC(Component):
render: if False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.
key: if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.
mirror_webcam: if True webcam will be mirrored. Default is True.
include_audio: whether the component should record/retain the audio track for a video. By default, audio is excluded for webcam videos and included for uploaded videos.
autoplay: whether to automatically play the video when the component is used as an output. Note: browsers will not autoplay video files if the user has not interacted with the page yet.
show_share_button: if True, will show a share icon in the corner of the component that allows user to share outputs to Hugging Face Spaces Discussions. If False, icon does not appear. If set to None (default behavior), then the icon appears if this Gradio app is launched on Spaces, but not otherwise.
show_download_button: if True, will show a download icon in the corner of the component that allows user to download the output. If False, icon does not appear. By default, it will be True for output components and False for input components.
min_length: the minimum length of video (in seconds) that the user can pass into the prediction function. If None, there is no minimum length.
max_length: the maximum length of video (in seconds) that the user can pass into the prediction function. If None, there is no maximum length.
loop: if True, the video will loop when it reaches the end and continue playing from the beginning.
streaming: when used set as an output, takes video chunks yielded from the backend and combines them into one streaming video output. Each chunk should be a video file with a .ts extension using an h.264 encoding. Mp4 files are also accepted but they will be converted to h.264 encoding.
watermark: an image file to be included as a watermark on the video. The image is not scaled and is displayed on the bottom right of the video. Valid formats for the image are: jpeg, png.
rtc_configuration: WebRTC configuration options. See https://developer.mozilla.org/en-US/docs/Web/API/RTCPeerConnection/RTCPeerConnection . If running the demo on a remote server, you will need to specify a rtc_configuration. See https://freddyaboulton.github.io/gradio-webrtc/deployment/
track_constraints: Media track constraints for WebRTC. For example, to set video height, width use {"width": {"exact": 800}, "height": {"exact": 600}, "aspectRatio": {"exact": 1.33333}}
time_limit: Maximum duration in seconds for recording.
mode: WebRTC mode - "send-receive", "receive", or "send".
modality: Type of media - "video" or "audio".
rtp_params: See https://developer.mozilla.org/en-US/docs/Web/API/RTCRtpSender/setParameters. If you are changing the video resolution, you can set this to {"degradationPreference": "maintain-framerate"} to keep the frame rate consistent.
"""
self.time_limit = time_limit
self.height = height
@@ -556,6 +568,7 @@ class WebRTC(Component):
self.rtc_configuration = rtc_configuration
self.mode = mode
self.modality = modality
self.rtp_params = rtp_params or {}
if track_constraints is None and modality == "audio":
track_constraints = {
"echoCancellation": True,
@@ -595,7 +608,9 @@ class WebRTC(Component):
self, webrtc_id: str
) -> Callable[[AdditionalOutputs], None]:
def set_outputs(outputs: AdditionalOutputs):
self.additional_outputs[webrtc_id] = outputs
if webrtc_id not in self.additional_outputs:
self.additional_outputs[webrtc_id] = []
self.additional_outputs[webrtc_id].append(outputs)
return set_outputs
@@ -638,8 +653,12 @@ class WebRTC(Component):
inputs = list(inputs)
def handler(webrtc_id: str, *args):
if webrtc_id in self.additional_outputs:
return fn(*args, *self.additional_outputs[webrtc_id].args) # type: ignore
if (
webrtc_id in self.additional_outputs
and len(self.additional_outputs[webrtc_id]) > 0
):
next_outputs = self.additional_outputs[webrtc_id].pop(0)
return fn(*args, *next_outputs.args) # type: ignore
return (
tuple([None for _ in range(len(outputs))])
if isinstance(outputs, Iterable)
@@ -655,6 +674,7 @@ class WebRTC(Component):
concurrency_id=concurrency_id,
show_progress=show_progress,
queue=queue,
trigger_mode="multiple",
)
def stream(
@@ -748,6 +768,8 @@ class WebRTC(Component):
def clean_up(self, webrtc_id: str):
connection = self.connections.pop(webrtc_id, None)
if isinstance(connection, AudioCallback):
connection.event_handler.shutdown()
self.additional_outputs.pop(webrtc_id, None)
self.data_channels.pop(webrtc_id, None)
return connection