Close Stream from Backend (#222)

* Close from backend

* Add code
This commit is contained in:
Freddy Boulton
2025-03-28 20:47:34 -04:00
committed by GitHub
parent 71743acb64
commit 8ed27fba78
10 changed files with 2336 additions and 2237 deletions

View File

@@ -37,6 +37,7 @@ from numpy import typing as npt
from fastrtc.utils import (
AdditionalOutputs,
CloseStream,
DataChannel,
WebRTCError,
create_message,
@@ -54,9 +55,14 @@ VideoNDArray: TypeAlias = Union[
]
VideoEmitType = (
VideoNDArray | tuple[VideoNDArray, AdditionalOutputs] | AdditionalOutputs
VideoNDArray
| tuple[VideoNDArray, AdditionalOutputs]
| tuple[VideoNDArray, CloseStream]
| AdditionalOutputs
| CloseStream
)
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
VideoEventGenerator = Generator[VideoEmitType, None, None]
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType | VideoEventGenerator]
@dataclass
@@ -172,6 +178,12 @@ class VideoCallback(VideoStreamTrack):
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array, outputs = split_output(self.event_handler(*args))
if isinstance(outputs, CloseStream):
cast(DataChannel, self.channel).send(
create_message("end_stream", outputs.msg)
)
self.stop()
return None
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
@@ -444,6 +456,12 @@ class VideoStreamHandler_(VideoCallback):
):
self.set_additional_outputs(outputs)
self.channel.send(create_message("fetch_output", []))
if isinstance(outputs, CloseStream):
cast(DataChannel, self.channel).send(
create_message("end_stream", outputs.msg)
)
self.stop()
return
if array is None and self.mode == "send":
return
@@ -586,6 +604,12 @@ class AudioCallback(AudioStreamTrack):
await self.start()
frame = await self.queue.get()
if isinstance(frame, CloseStream):
cast(DataChannel, self.channel).send(
create_message("end_stream", frame.msg)
)
self.stop()
return
logger.debug("frame %s", frame)
data_time = frame.time
@@ -675,6 +699,12 @@ class ServerToClientVideo(VideoStreamTrack):
)
try:
next_array, outputs = split_output(next(self.generator))
if isinstance(outputs, CloseStream):
cast(DataChannel, self.channel).send(
create_message("end_stream", outputs.msg)
)
self.stop()
return
if (
isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs
@@ -770,6 +800,12 @@ class ServerToClientAudio(AudioStreamTrack):
await self.start()
data = await self.queue.get()
if isinstance(data, CloseStream):
cast(DataChannel, self.channel).send(
create_message("end_stream", data.msg)
)
self.stop()
return
if data is None:
self.stop()
return