mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
@@ -31,7 +31,9 @@ from fastrtc.tracks import (
|
||||
StreamHandlerBase,
|
||||
StreamHandlerImpl,
|
||||
VideoCallback,
|
||||
VideoEventHandler,
|
||||
VideoStreamHandler,
|
||||
VideoStreamHandler_,
|
||||
)
|
||||
from fastrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
@@ -41,7 +43,7 @@ from fastrtc.utils import (
|
||||
|
||||
Track = (
|
||||
VideoCallback
|
||||
| VideoStreamHandler
|
||||
| VideoStreamHandler_
|
||||
| AudioCallback
|
||||
| ServerToClientAudio
|
||||
| ServerToClientVideo
|
||||
@@ -174,6 +176,11 @@ class WebRTCConnectionMixin:
|
||||
handler.video_receive = webrtc_error_handler(handler.video_receive) # type: ignore
|
||||
if hasattr(handler, "video_emit"):
|
||||
handler.video_emit = webrtc_error_handler(handler.video_emit) # type: ignore
|
||||
elif isinstance(self.event_handler, VideoStreamHandler):
|
||||
self.event_handler.callable = cast(
|
||||
VideoEventHandler, webrtc_error_handler(self.event_handler.callable)
|
||||
)
|
||||
handler = self.event_handler
|
||||
else:
|
||||
handler = webrtc_error_handler(cast(Callable, self.event_handler))
|
||||
|
||||
@@ -208,17 +215,25 @@ class WebRTCConnectionMixin:
|
||||
handler = self.handlers[body["webrtc_id"]]
|
||||
|
||||
if self.modality == "video" and track.kind == "video":
|
||||
args = {}
|
||||
handler_ = handler
|
||||
if isinstance(handler, VideoStreamHandler):
|
||||
handler_ = handler.callable
|
||||
args["fps"] = handler.fps
|
||||
args["skip_frames"] = handler.skip_frames
|
||||
cb = VideoCallback(
|
||||
relay.subscribe(track),
|
||||
event_handler=cast(Callable, handler),
|
||||
event_handler=cast(Callable, handler_),
|
||||
set_additional_outputs=set_outputs,
|
||||
mode=cast(Literal["send", "send-receive"], self.mode),
|
||||
**args,
|
||||
)
|
||||
elif self.modality == "audio-video" and track.kind == "video":
|
||||
cb = VideoStreamHandler(
|
||||
cb = VideoStreamHandler_(
|
||||
relay.subscribe(track),
|
||||
event_handler=handler, # type: ignore
|
||||
set_additional_outputs=set_outputs,
|
||||
fps=cast(StreamHandlerImpl, handler).fps,
|
||||
)
|
||||
elif self.modality in ["audio", "audio-video"] and track.kind == "audio":
|
||||
eh = cast(StreamHandlerImpl, handler)
|
||||
@@ -245,10 +260,17 @@ class WebRTCConnectionMixin:
|
||||
|
||||
if self.mode == "receive":
|
||||
if self.modality == "video":
|
||||
cb = ServerToClientVideo(
|
||||
cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
if isinstance(self.event_handler, VideoStreamHandler):
|
||||
cb = ServerToClientVideo(
|
||||
cast(Callable, self.event_handler.callable),
|
||||
set_additional_outputs=set_outputs,
|
||||
fps=self.event_handler.fps,
|
||||
)
|
||||
else:
|
||||
cb = ServerToClientVideo(
|
||||
cast(Callable, self.event_handler),
|
||||
set_additional_outputs=set_outputs,
|
||||
)
|
||||
elif self.modality == "audio":
|
||||
cb = ServerToClientAudio(
|
||||
cast(Callable, self.event_handler),
|
||||
|
||||
Reference in New Issue
Block a user