mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
@@ -25,6 +25,7 @@ from .tracks import (
|
||||
)
|
||||
from .utils import (
|
||||
AdditionalOutputs,
|
||||
CloseStream,
|
||||
Warning,
|
||||
WebRTCError,
|
||||
aggregate_bytes_to_16bit,
|
||||
@@ -75,4 +76,5 @@ __all__ = [
|
||||
"get_silero_model",
|
||||
"SileroVadOptions",
|
||||
"VideoStreamHandler",
|
||||
"CloseStream",
|
||||
]
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -37,6 +37,7 @@ from numpy import typing as npt
|
||||
|
||||
from fastrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
CloseStream,
|
||||
DataChannel,
|
||||
WebRTCError,
|
||||
create_message,
|
||||
@@ -54,9 +55,14 @@ VideoNDArray: TypeAlias = Union[
|
||||
]
|
||||
|
||||
VideoEmitType = (
|
||||
VideoNDArray | tuple[VideoNDArray, AdditionalOutputs] | AdditionalOutputs
|
||||
VideoNDArray
|
||||
| tuple[VideoNDArray, AdditionalOutputs]
|
||||
| tuple[VideoNDArray, CloseStream]
|
||||
| AdditionalOutputs
|
||||
| CloseStream
|
||||
)
|
||||
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
|
||||
VideoEventGenerator = Generator[VideoEmitType, None, None]
|
||||
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType | VideoEventGenerator]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -172,6 +178,12 @@ class VideoCallback(VideoStreamTrack):
|
||||
|
||||
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
||||
array, outputs = split_output(self.event_handler(*args))
|
||||
if isinstance(outputs, CloseStream):
|
||||
cast(DataChannel, self.channel).send(
|
||||
create_message("end_stream", outputs.msg)
|
||||
)
|
||||
self.stop()
|
||||
return None
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
and self.set_additional_outputs
|
||||
@@ -444,6 +456,12 @@ class VideoStreamHandler_(VideoCallback):
|
||||
):
|
||||
self.set_additional_outputs(outputs)
|
||||
self.channel.send(create_message("fetch_output", []))
|
||||
if isinstance(outputs, CloseStream):
|
||||
cast(DataChannel, self.channel).send(
|
||||
create_message("end_stream", outputs.msg)
|
||||
)
|
||||
self.stop()
|
||||
return
|
||||
if array is None and self.mode == "send":
|
||||
return
|
||||
|
||||
@@ -586,6 +604,12 @@ class AudioCallback(AudioStreamTrack):
|
||||
await self.start()
|
||||
|
||||
frame = await self.queue.get()
|
||||
if isinstance(frame, CloseStream):
|
||||
cast(DataChannel, self.channel).send(
|
||||
create_message("end_stream", frame.msg)
|
||||
)
|
||||
self.stop()
|
||||
return
|
||||
logger.debug("frame %s", frame)
|
||||
|
||||
data_time = frame.time
|
||||
@@ -675,6 +699,12 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
)
|
||||
try:
|
||||
next_array, outputs = split_output(next(self.generator))
|
||||
if isinstance(outputs, CloseStream):
|
||||
cast(DataChannel, self.channel).send(
|
||||
create_message("end_stream", outputs.msg)
|
||||
)
|
||||
self.stop()
|
||||
return
|
||||
if (
|
||||
isinstance(outputs, AdditionalOutputs)
|
||||
and self.set_additional_outputs
|
||||
@@ -770,6 +800,12 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
|
||||
await self.start()
|
||||
data = await self.queue.get()
|
||||
if isinstance(data, CloseStream):
|
||||
cast(DataChannel, self.channel).send(
|
||||
create_message("end_stream", data.msg)
|
||||
)
|
||||
self.stop()
|
||||
return
|
||||
if data is None:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
@@ -32,6 +32,11 @@ class AdditionalOutputs:
|
||||
self.args = args
|
||||
|
||||
|
||||
class CloseStream:
|
||||
def __init__(self, msg: str = "Stream closed") -> None:
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class DataChannel(Protocol):
|
||||
def send(self, message: str) -> None: ...
|
||||
|
||||
@@ -39,6 +44,7 @@ class DataChannel(Protocol):
|
||||
def create_message(
|
||||
type: Literal[
|
||||
"send_input",
|
||||
"end_stream",
|
||||
"fetch_output",
|
||||
"stopword",
|
||||
"error",
|
||||
@@ -98,9 +104,13 @@ class WebRTCError(Exception):
|
||||
_send_log(message, "error")
|
||||
|
||||
|
||||
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||
def split_output(
|
||||
data: tuple | Any,
|
||||
) -> tuple[Any, AdditionalOutputs | CloseStream | None]:
|
||||
if isinstance(data, AdditionalOutputs):
|
||||
return None, data
|
||||
if isinstance(data, CloseStream):
|
||||
return None, data
|
||||
if isinstance(data, tuple):
|
||||
# handle the bare audio case
|
||||
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
|
||||
@@ -109,11 +119,11 @@ def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||
raise ValueError(
|
||||
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
||||
)
|
||||
if not isinstance(data[-1], AdditionalOutputs):
|
||||
if not isinstance(data[-1], (AdditionalOutputs, CloseStream)):
|
||||
raise ValueError(
|
||||
"The last element of the tuple must be an instance of AdditionalOutputs."
|
||||
)
|
||||
return data[0], cast(AdditionalOutputs, data[1])
|
||||
return data[0], cast(AdditionalOutputs | CloseStream, data[1])
|
||||
return data, None
|
||||
|
||||
|
||||
@@ -152,6 +162,8 @@ async def player_worker_decode(
|
||||
cast(DataChannel, channel()).send(create_message("fetch_output", []))
|
||||
|
||||
if frame is None:
|
||||
if isinstance(outputs, CloseStream):
|
||||
await queue.put(outputs)
|
||||
if quit_on_none:
|
||||
await queue.put(None)
|
||||
break
|
||||
@@ -203,7 +215,8 @@ async def player_worker_decode(
|
||||
processed_frame.time_base = audio_time_base
|
||||
audio_samples += processed_frame.samples
|
||||
await queue.put(processed_frame)
|
||||
|
||||
if isinstance(outputs, CloseStream):
|
||||
await queue.put(outputs)
|
||||
except (TimeoutError, asyncio.TimeoutError):
|
||||
logger.warning(
|
||||
"Timeout in frame processing cycle after %s seconds - resetting", 60
|
||||
|
||||
Reference in New Issue
Block a user