Close Stream from Backend (#222)

* Close from backend

* Add code
This commit is contained in:
Freddy Boulton
2025-03-28 20:47:34 -04:00
committed by GitHub
parent 71743acb64
commit 8ed27fba78
10 changed files with 2336 additions and 2237 deletions

View File

@@ -25,6 +25,7 @@ from .tracks import (
) )
from .utils import ( from .utils import (
AdditionalOutputs, AdditionalOutputs,
CloseStream,
Warning, Warning,
WebRTCError, WebRTCError,
aggregate_bytes_to_16bit, aggregate_bytes_to_16bit,
@@ -75,4 +76,5 @@ __all__ = [
"get_silero_model", "get_silero_model",
"SileroVadOptions", "SileroVadOptions",
"VideoStreamHandler", "VideoStreamHandler",
"CloseStream",
] ]

File diff suppressed because one or more lines are too long

View File

@@ -37,6 +37,7 @@ from numpy import typing as npt
from fastrtc.utils import ( from fastrtc.utils import (
AdditionalOutputs, AdditionalOutputs,
CloseStream,
DataChannel, DataChannel,
WebRTCError, WebRTCError,
create_message, create_message,
@@ -54,9 +55,14 @@ VideoNDArray: TypeAlias = Union[
] ]
VideoEmitType = ( VideoEmitType = (
VideoNDArray | tuple[VideoNDArray, AdditionalOutputs] | AdditionalOutputs VideoNDArray
| tuple[VideoNDArray, AdditionalOutputs]
| tuple[VideoNDArray, CloseStream]
| AdditionalOutputs
| CloseStream
) )
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType] VideoEventGenerator = Generator[VideoEmitType, None, None]
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType | VideoEventGenerator]
@dataclass @dataclass
@@ -172,6 +178,12 @@ class VideoCallback(VideoStreamTrack):
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array) args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
array, outputs = split_output(self.event_handler(*args)) array, outputs = split_output(self.event_handler(*args))
if isinstance(outputs, CloseStream):
cast(DataChannel, self.channel).send(
create_message("end_stream", outputs.msg)
)
self.stop()
return None
if ( if (
isinstance(outputs, AdditionalOutputs) isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs and self.set_additional_outputs
@@ -444,6 +456,12 @@ class VideoStreamHandler_(VideoCallback):
): ):
self.set_additional_outputs(outputs) self.set_additional_outputs(outputs)
self.channel.send(create_message("fetch_output", [])) self.channel.send(create_message("fetch_output", []))
if isinstance(outputs, CloseStream):
cast(DataChannel, self.channel).send(
create_message("end_stream", outputs.msg)
)
self.stop()
return
if array is None and self.mode == "send": if array is None and self.mode == "send":
return return
@@ -586,6 +604,12 @@ class AudioCallback(AudioStreamTrack):
await self.start() await self.start()
frame = await self.queue.get() frame = await self.queue.get()
if isinstance(frame, CloseStream):
cast(DataChannel, self.channel).send(
create_message("end_stream", frame.msg)
)
self.stop()
return
logger.debug("frame %s", frame) logger.debug("frame %s", frame)
data_time = frame.time data_time = frame.time
@@ -675,6 +699,12 @@ class ServerToClientVideo(VideoStreamTrack):
) )
try: try:
next_array, outputs = split_output(next(self.generator)) next_array, outputs = split_output(next(self.generator))
if isinstance(outputs, CloseStream):
cast(DataChannel, self.channel).send(
create_message("end_stream", outputs.msg)
)
self.stop()
return
if ( if (
isinstance(outputs, AdditionalOutputs) isinstance(outputs, AdditionalOutputs)
and self.set_additional_outputs and self.set_additional_outputs
@@ -770,6 +800,12 @@ class ServerToClientAudio(AudioStreamTrack):
await self.start() await self.start()
data = await self.queue.get() data = await self.queue.get()
if isinstance(data, CloseStream):
cast(DataChannel, self.channel).send(
create_message("end_stream", data.msg)
)
self.stop()
return
if data is None: if data is None:
self.stop() self.stop()
return return

View File

@@ -32,6 +32,11 @@ class AdditionalOutputs:
self.args = args self.args = args
class CloseStream:
def __init__(self, msg: str = "Stream closed") -> None:
self.msg = msg
class DataChannel(Protocol): class DataChannel(Protocol):
def send(self, message: str) -> None: ... def send(self, message: str) -> None: ...
@@ -39,6 +44,7 @@ class DataChannel(Protocol):
def create_message( def create_message(
type: Literal[ type: Literal[
"send_input", "send_input",
"end_stream",
"fetch_output", "fetch_output",
"stopword", "stopword",
"error", "error",
@@ -98,9 +104,13 @@ class WebRTCError(Exception):
_send_log(message, "error") _send_log(message, "error")
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]: def split_output(
data: tuple | Any,
) -> tuple[Any, AdditionalOutputs | CloseStream | None]:
if isinstance(data, AdditionalOutputs): if isinstance(data, AdditionalOutputs):
return None, data return None, data
if isinstance(data, CloseStream):
return None, data
if isinstance(data, tuple): if isinstance(data, tuple):
# handle the bare audio case # handle the bare audio case
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray): if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
@@ -109,11 +119,11 @@ def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
raise ValueError( raise ValueError(
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs." "The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
) )
if not isinstance(data[-1], AdditionalOutputs): if not isinstance(data[-1], (AdditionalOutputs, CloseStream)):
raise ValueError( raise ValueError(
"The last element of the tuple must be an instance of AdditionalOutputs." "The last element of the tuple must be an instance of AdditionalOutputs."
) )
return data[0], cast(AdditionalOutputs, data[1]) return data[0], cast(AdditionalOutputs | CloseStream, data[1])
return data, None return data, None
@@ -152,6 +162,8 @@ async def player_worker_decode(
cast(DataChannel, channel()).send(create_message("fetch_output", [])) cast(DataChannel, channel()).send(create_message("fetch_output", []))
if frame is None: if frame is None:
if isinstance(outputs, CloseStream):
await queue.put(outputs)
if quit_on_none: if quit_on_none:
await queue.put(None) await queue.put(None)
break break
@@ -203,7 +215,8 @@ async def player_worker_decode(
processed_frame.time_base = audio_time_base processed_frame.time_base = audio_time_base
audio_samples += processed_frame.samples audio_samples += processed_frame.samples
await queue.put(processed_frame) await queue.put(processed_frame)
if isinstance(outputs, CloseStream):
await queue.put(outputs)
except (TimeoutError, asyncio.TimeoutError): except (TimeoutError, asyncio.TimeoutError):
logger.warning( logger.warning(
"Timeout in frame processing cycle after %s seconds - resetting", 60 "Timeout in frame processing cycle after %s seconds - resetting", 60

View File

@@ -47,6 +47,8 @@
msg?.type === "error" msg?.type === "error"
) { ) {
gradio.dispatch(msg?.type === "error" ? "error" : "warning", msg.message); gradio.dispatch(msg?.type === "error" ? "error" : "warning", msg.message);
} else if (msg?.type === "end_stream") {
gradio.dispatch("warning", msg.data);
} else if (msg?.type === "fetch_output") { } else if (msg?.type === "fetch_output") {
gradio.dispatch("state_change"); gradio.dispatch("state_change");
} else if (msg?.type === "send_input") { } else if (msg?.type === "send_input") {

View File

@@ -51,12 +51,16 @@
} }
}); });
let _on_change_cb = (msg: "change" | "tick" | "stopword") => { let _on_change_cb = (msg: "change" | "tick" | "stopword" | any) => {
if (msg === "stopword") { if (msg === "stopword") {
stopword_recognized = true; stopword_recognized = true;
setTimeout(() => { setTimeout(() => {
stopword_recognized = false; stopword_recognized = false;
}, 3000); }, 3000);
} else if (msg.type === "end_stream") {
stream_state = "closed";
stop(pc);
on_change_cb(msg);
} else { } else {
console.debug("calling on_change_cb with msg", msg); console.debug("calling on_change_cb with msg", msg);
on_change_cb(msg); on_change_cb(msg);

View File

@@ -29,6 +29,17 @@
let pc: RTCPeerConnection; let pc: RTCPeerConnection;
let _webrtc_id = Math.random().toString(36).substring(2); let _webrtc_id = Math.random().toString(36).substring(2);
let _on_change_cb = (msg: "change" | "tick" | "stopword" | any) => {
if (msg.type === "end_stream") {
on_change_cb(msg);
stream_state = "closed";
stop(pc);
} else {
console.debug("calling on_change_cb with msg", msg);
on_change_cb(msg);
}
};
const dispatch = createEventDispatcher<{ const dispatch = createEventDispatcher<{
tick: undefined; tick: undefined;
error: string; error: string;
@@ -75,7 +86,7 @@
server.offer, server.offer,
_webrtc_id, _webrtc_id,
"audio", "audio",
on_change_cb, _on_change_cb,
) )
.then((connection) => { .then((connection) => {
clearTimeout(timeoutId); clearTimeout(timeoutId);

View File

@@ -25,6 +25,17 @@
tick: undefined; tick: undefined;
}>(); }>();
let _on_change_cb = (msg: "change" | "tick" | "stopword" | any) => {
if (msg.type === "end_stream") {
on_change_cb(msg);
stream_state = "closed";
stop(pc);
} else {
console.debug("calling on_change_cb with msg", msg);
on_change_cb(msg);
}
};
let stream_state = "closed"; let stream_state = "closed";
$: if (value === "start_webrtc_stream") { $: if (value === "start_webrtc_stream") {
@@ -62,7 +73,7 @@
server.offer, server.offer,
_webrtc_id, _webrtc_id,
"video", "video",
on_change_cb, _on_change_cb,
) )
.then((connection) => { .then((connection) => {
clearTimeout(timeoutId); clearTimeout(timeoutId);

View File

@@ -124,6 +124,18 @@
} }
} }
let _on_change_cb = (msg: "change" | "tick" | "stopword" | any) => {
if (msg.type === "end_stream") {
on_change_cb(msg);
stream_state = "closed";
stop(pc);
access_webcam();
} else {
console.debug("calling on_change_cb with msg", msg);
on_change_cb(msg);
}
};
let recording = false; let recording = false;
let stream: MediaStream; let stream: MediaStream;
@@ -171,7 +183,7 @@
server.offer, server.offer,
webrtc_id, webrtc_id,
"video", "video",
on_change_cb, _on_change_cb,
rtp_params, rtp_params,
undefined, undefined,
reject_cb, reject_cb,

View File

@@ -80,7 +80,8 @@ export async function start(
event_json?.type === "error" || event_json?.type === "error" ||
event_json?.type === "send_input" || event_json?.type === "send_input" ||
event_json?.type === "fetch_output" || event_json?.type === "fetch_output" ||
event_json?.type === "stopword" event_json?.type === "stopword" ||
event_json?.type === "end_stream"
) { ) {
on_change_cb(event_json ?? event.data); on_change_cb(event_json ?? event.data);
} }