mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
@@ -25,6 +25,7 @@ from .tracks import (
|
|||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
|
CloseStream,
|
||||||
Warning,
|
Warning,
|
||||||
WebRTCError,
|
WebRTCError,
|
||||||
aggregate_bytes_to_16bit,
|
aggregate_bytes_to_16bit,
|
||||||
@@ -75,4 +76,5 @@ __all__ = [
|
|||||||
"get_silero_model",
|
"get_silero_model",
|
||||||
"SileroVadOptions",
|
"SileroVadOptions",
|
||||||
"VideoStreamHandler",
|
"VideoStreamHandler",
|
||||||
|
"CloseStream",
|
||||||
]
|
]
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -37,6 +37,7 @@ from numpy import typing as npt
|
|||||||
|
|
||||||
from fastrtc.utils import (
|
from fastrtc.utils import (
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
|
CloseStream,
|
||||||
DataChannel,
|
DataChannel,
|
||||||
WebRTCError,
|
WebRTCError,
|
||||||
create_message,
|
create_message,
|
||||||
@@ -54,9 +55,14 @@ VideoNDArray: TypeAlias = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
VideoEmitType = (
|
VideoEmitType = (
|
||||||
VideoNDArray | tuple[VideoNDArray, AdditionalOutputs] | AdditionalOutputs
|
VideoNDArray
|
||||||
|
| tuple[VideoNDArray, AdditionalOutputs]
|
||||||
|
| tuple[VideoNDArray, CloseStream]
|
||||||
|
| AdditionalOutputs
|
||||||
|
| CloseStream
|
||||||
)
|
)
|
||||||
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType]
|
VideoEventGenerator = Generator[VideoEmitType, None, None]
|
||||||
|
VideoEventHandler = Callable[[npt.ArrayLike], VideoEmitType | VideoEventGenerator]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -172,6 +178,12 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
|
|
||||||
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
args = self.add_frame_to_payload(cast(list, self.latest_args), frame_array)
|
||||||
array, outputs = split_output(self.event_handler(*args))
|
array, outputs = split_output(self.event_handler(*args))
|
||||||
|
if isinstance(outputs, CloseStream):
|
||||||
|
cast(DataChannel, self.channel).send(
|
||||||
|
create_message("end_stream", outputs.msg)
|
||||||
|
)
|
||||||
|
self.stop()
|
||||||
|
return None
|
||||||
if (
|
if (
|
||||||
isinstance(outputs, AdditionalOutputs)
|
isinstance(outputs, AdditionalOutputs)
|
||||||
and self.set_additional_outputs
|
and self.set_additional_outputs
|
||||||
@@ -444,6 +456,12 @@ class VideoStreamHandler_(VideoCallback):
|
|||||||
):
|
):
|
||||||
self.set_additional_outputs(outputs)
|
self.set_additional_outputs(outputs)
|
||||||
self.channel.send(create_message("fetch_output", []))
|
self.channel.send(create_message("fetch_output", []))
|
||||||
|
if isinstance(outputs, CloseStream):
|
||||||
|
cast(DataChannel, self.channel).send(
|
||||||
|
create_message("end_stream", outputs.msg)
|
||||||
|
)
|
||||||
|
self.stop()
|
||||||
|
return
|
||||||
if array is None and self.mode == "send":
|
if array is None and self.mode == "send":
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -586,6 +604,12 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
await self.start()
|
await self.start()
|
||||||
|
|
||||||
frame = await self.queue.get()
|
frame = await self.queue.get()
|
||||||
|
if isinstance(frame, CloseStream):
|
||||||
|
cast(DataChannel, self.channel).send(
|
||||||
|
create_message("end_stream", frame.msg)
|
||||||
|
)
|
||||||
|
self.stop()
|
||||||
|
return
|
||||||
logger.debug("frame %s", frame)
|
logger.debug("frame %s", frame)
|
||||||
|
|
||||||
data_time = frame.time
|
data_time = frame.time
|
||||||
@@ -675,6 +699,12 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
next_array, outputs = split_output(next(self.generator))
|
next_array, outputs = split_output(next(self.generator))
|
||||||
|
if isinstance(outputs, CloseStream):
|
||||||
|
cast(DataChannel, self.channel).send(
|
||||||
|
create_message("end_stream", outputs.msg)
|
||||||
|
)
|
||||||
|
self.stop()
|
||||||
|
return
|
||||||
if (
|
if (
|
||||||
isinstance(outputs, AdditionalOutputs)
|
isinstance(outputs, AdditionalOutputs)
|
||||||
and self.set_additional_outputs
|
and self.set_additional_outputs
|
||||||
@@ -770,6 +800,12 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
|
|
||||||
await self.start()
|
await self.start()
|
||||||
data = await self.queue.get()
|
data = await self.queue.get()
|
||||||
|
if isinstance(data, CloseStream):
|
||||||
|
cast(DataChannel, self.channel).send(
|
||||||
|
create_message("end_stream", data.msg)
|
||||||
|
)
|
||||||
|
self.stop()
|
||||||
|
return
|
||||||
if data is None:
|
if data is None:
|
||||||
self.stop()
|
self.stop()
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ class AdditionalOutputs:
|
|||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
|
|
||||||
|
class CloseStream:
|
||||||
|
def __init__(self, msg: str = "Stream closed") -> None:
|
||||||
|
self.msg = msg
|
||||||
|
|
||||||
|
|
||||||
class DataChannel(Protocol):
|
class DataChannel(Protocol):
|
||||||
def send(self, message: str) -> None: ...
|
def send(self, message: str) -> None: ...
|
||||||
|
|
||||||
@@ -39,6 +44,7 @@ class DataChannel(Protocol):
|
|||||||
def create_message(
|
def create_message(
|
||||||
type: Literal[
|
type: Literal[
|
||||||
"send_input",
|
"send_input",
|
||||||
|
"end_stream",
|
||||||
"fetch_output",
|
"fetch_output",
|
||||||
"stopword",
|
"stopword",
|
||||||
"error",
|
"error",
|
||||||
@@ -98,9 +104,13 @@ class WebRTCError(Exception):
|
|||||||
_send_log(message, "error")
|
_send_log(message, "error")
|
||||||
|
|
||||||
|
|
||||||
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
def split_output(
|
||||||
|
data: tuple | Any,
|
||||||
|
) -> tuple[Any, AdditionalOutputs | CloseStream | None]:
|
||||||
if isinstance(data, AdditionalOutputs):
|
if isinstance(data, AdditionalOutputs):
|
||||||
return None, data
|
return None, data
|
||||||
|
if isinstance(data, CloseStream):
|
||||||
|
return None, data
|
||||||
if isinstance(data, tuple):
|
if isinstance(data, tuple):
|
||||||
# handle the bare audio case
|
# handle the bare audio case
|
||||||
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
|
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
|
||||||
@@ -109,11 +119,11 @@ def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
||||||
)
|
)
|
||||||
if not isinstance(data[-1], AdditionalOutputs):
|
if not isinstance(data[-1], (AdditionalOutputs, CloseStream)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The last element of the tuple must be an instance of AdditionalOutputs."
|
"The last element of the tuple must be an instance of AdditionalOutputs."
|
||||||
)
|
)
|
||||||
return data[0], cast(AdditionalOutputs, data[1])
|
return data[0], cast(AdditionalOutputs | CloseStream, data[1])
|
||||||
return data, None
|
return data, None
|
||||||
|
|
||||||
|
|
||||||
@@ -152,6 +162,8 @@ async def player_worker_decode(
|
|||||||
cast(DataChannel, channel()).send(create_message("fetch_output", []))
|
cast(DataChannel, channel()).send(create_message("fetch_output", []))
|
||||||
|
|
||||||
if frame is None:
|
if frame is None:
|
||||||
|
if isinstance(outputs, CloseStream):
|
||||||
|
await queue.put(outputs)
|
||||||
if quit_on_none:
|
if quit_on_none:
|
||||||
await queue.put(None)
|
await queue.put(None)
|
||||||
break
|
break
|
||||||
@@ -203,7 +215,8 @@ async def player_worker_decode(
|
|||||||
processed_frame.time_base = audio_time_base
|
processed_frame.time_base = audio_time_base
|
||||||
audio_samples += processed_frame.samples
|
audio_samples += processed_frame.samples
|
||||||
await queue.put(processed_frame)
|
await queue.put(processed_frame)
|
||||||
|
if isinstance(outputs, CloseStream):
|
||||||
|
await queue.put(outputs)
|
||||||
except (TimeoutError, asyncio.TimeoutError):
|
except (TimeoutError, asyncio.TimeoutError):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Timeout in frame processing cycle after %s seconds - resetting", 60
|
"Timeout in frame processing cycle after %s seconds - resetting", 60
|
||||||
|
|||||||
@@ -47,6 +47,8 @@
|
|||||||
msg?.type === "error"
|
msg?.type === "error"
|
||||||
) {
|
) {
|
||||||
gradio.dispatch(msg?.type === "error" ? "error" : "warning", msg.message);
|
gradio.dispatch(msg?.type === "error" ? "error" : "warning", msg.message);
|
||||||
|
} else if (msg?.type === "end_stream") {
|
||||||
|
gradio.dispatch("warning", msg.data);
|
||||||
} else if (msg?.type === "fetch_output") {
|
} else if (msg?.type === "fetch_output") {
|
||||||
gradio.dispatch("state_change");
|
gradio.dispatch("state_change");
|
||||||
} else if (msg?.type === "send_input") {
|
} else if (msg?.type === "send_input") {
|
||||||
|
|||||||
@@ -51,12 +51,16 @@
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let _on_change_cb = (msg: "change" | "tick" | "stopword") => {
|
let _on_change_cb = (msg: "change" | "tick" | "stopword" | any) => {
|
||||||
if (msg === "stopword") {
|
if (msg === "stopword") {
|
||||||
stopword_recognized = true;
|
stopword_recognized = true;
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
stopword_recognized = false;
|
stopword_recognized = false;
|
||||||
}, 3000);
|
}, 3000);
|
||||||
|
} else if (msg.type === "end_stream") {
|
||||||
|
stream_state = "closed";
|
||||||
|
stop(pc);
|
||||||
|
on_change_cb(msg);
|
||||||
} else {
|
} else {
|
||||||
console.debug("calling on_change_cb with msg", msg);
|
console.debug("calling on_change_cb with msg", msg);
|
||||||
on_change_cb(msg);
|
on_change_cb(msg);
|
||||||
|
|||||||
@@ -29,6 +29,17 @@
|
|||||||
let pc: RTCPeerConnection;
|
let pc: RTCPeerConnection;
|
||||||
let _webrtc_id = Math.random().toString(36).substring(2);
|
let _webrtc_id = Math.random().toString(36).substring(2);
|
||||||
|
|
||||||
|
let _on_change_cb = (msg: "change" | "tick" | "stopword" | any) => {
|
||||||
|
if (msg.type === "end_stream") {
|
||||||
|
on_change_cb(msg);
|
||||||
|
stream_state = "closed";
|
||||||
|
stop(pc);
|
||||||
|
} else {
|
||||||
|
console.debug("calling on_change_cb with msg", msg);
|
||||||
|
on_change_cb(msg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const dispatch = createEventDispatcher<{
|
const dispatch = createEventDispatcher<{
|
||||||
tick: undefined;
|
tick: undefined;
|
||||||
error: string;
|
error: string;
|
||||||
@@ -75,7 +86,7 @@
|
|||||||
server.offer,
|
server.offer,
|
||||||
_webrtc_id,
|
_webrtc_id,
|
||||||
"audio",
|
"audio",
|
||||||
on_change_cb,
|
_on_change_cb,
|
||||||
)
|
)
|
||||||
.then((connection) => {
|
.then((connection) => {
|
||||||
clearTimeout(timeoutId);
|
clearTimeout(timeoutId);
|
||||||
|
|||||||
@@ -25,6 +25,17 @@
|
|||||||
tick: undefined;
|
tick: undefined;
|
||||||
}>();
|
}>();
|
||||||
|
|
||||||
|
let _on_change_cb = (msg: "change" | "tick" | "stopword" | any) => {
|
||||||
|
if (msg.type === "end_stream") {
|
||||||
|
on_change_cb(msg);
|
||||||
|
stream_state = "closed";
|
||||||
|
stop(pc);
|
||||||
|
} else {
|
||||||
|
console.debug("calling on_change_cb with msg", msg);
|
||||||
|
on_change_cb(msg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let stream_state = "closed";
|
let stream_state = "closed";
|
||||||
|
|
||||||
$: if (value === "start_webrtc_stream") {
|
$: if (value === "start_webrtc_stream") {
|
||||||
@@ -62,7 +73,7 @@
|
|||||||
server.offer,
|
server.offer,
|
||||||
_webrtc_id,
|
_webrtc_id,
|
||||||
"video",
|
"video",
|
||||||
on_change_cb,
|
_on_change_cb,
|
||||||
)
|
)
|
||||||
.then((connection) => {
|
.then((connection) => {
|
||||||
clearTimeout(timeoutId);
|
clearTimeout(timeoutId);
|
||||||
|
|||||||
@@ -124,6 +124,18 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let _on_change_cb = (msg: "change" | "tick" | "stopword" | any) => {
|
||||||
|
if (msg.type === "end_stream") {
|
||||||
|
on_change_cb(msg);
|
||||||
|
stream_state = "closed";
|
||||||
|
stop(pc);
|
||||||
|
access_webcam();
|
||||||
|
} else {
|
||||||
|
console.debug("calling on_change_cb with msg", msg);
|
||||||
|
on_change_cb(msg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let recording = false;
|
let recording = false;
|
||||||
let stream: MediaStream;
|
let stream: MediaStream;
|
||||||
|
|
||||||
@@ -171,7 +183,7 @@
|
|||||||
server.offer,
|
server.offer,
|
||||||
webrtc_id,
|
webrtc_id,
|
||||||
"video",
|
"video",
|
||||||
on_change_cb,
|
_on_change_cb,
|
||||||
rtp_params,
|
rtp_params,
|
||||||
undefined,
|
undefined,
|
||||||
reject_cb,
|
reject_cb,
|
||||||
|
|||||||
@@ -80,7 +80,8 @@ export async function start(
|
|||||||
event_json?.type === "error" ||
|
event_json?.type === "error" ||
|
||||||
event_json?.type === "send_input" ||
|
event_json?.type === "send_input" ||
|
||||||
event_json?.type === "fetch_output" ||
|
event_json?.type === "fetch_output" ||
|
||||||
event_json?.type === "stopword"
|
event_json?.type === "stopword" ||
|
||||||
|
event_json?.type === "end_stream"
|
||||||
) {
|
) {
|
||||||
on_change_cb(event_json ?? event.data);
|
on_change_cb(event_json ?? event.data);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user