mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
@@ -32,6 +32,11 @@ class AdditionalOutputs:
|
||||
self.args = args
|
||||
|
||||
|
||||
class CloseStream:
|
||||
def __init__(self, msg: str = "Stream closed") -> None:
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class DataChannel(Protocol):
|
||||
def send(self, message: str) -> None: ...
|
||||
|
||||
@@ -39,6 +44,7 @@ class DataChannel(Protocol):
|
||||
def create_message(
|
||||
type: Literal[
|
||||
"send_input",
|
||||
"end_stream",
|
||||
"fetch_output",
|
||||
"stopword",
|
||||
"error",
|
||||
@@ -98,9 +104,13 @@ class WebRTCError(Exception):
|
||||
_send_log(message, "error")
|
||||
|
||||
|
||||
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||
def split_output(
|
||||
data: tuple | Any,
|
||||
) -> tuple[Any, AdditionalOutputs | CloseStream | None]:
|
||||
if isinstance(data, AdditionalOutputs):
|
||||
return None, data
|
||||
if isinstance(data, CloseStream):
|
||||
return None, data
|
||||
if isinstance(data, tuple):
|
||||
# handle the bare audio case
|
||||
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
|
||||
@@ -109,11 +119,11 @@ def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
|
||||
raise ValueError(
|
||||
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
||||
)
|
||||
if not isinstance(data[-1], AdditionalOutputs):
|
||||
if not isinstance(data[-1], (AdditionalOutputs, CloseStream)):
|
||||
raise ValueError(
|
||||
"The last element of the tuple must be an instance of AdditionalOutputs."
|
||||
)
|
||||
return data[0], cast(AdditionalOutputs, data[1])
|
||||
return data[0], cast(AdditionalOutputs | CloseStream, data[1])
|
||||
return data, None
|
||||
|
||||
|
||||
@@ -152,6 +162,8 @@ async def player_worker_decode(
|
||||
cast(DataChannel, channel()).send(create_message("fetch_output", []))
|
||||
|
||||
if frame is None:
|
||||
if isinstance(outputs, CloseStream):
|
||||
await queue.put(outputs)
|
||||
if quit_on_none:
|
||||
await queue.put(None)
|
||||
break
|
||||
@@ -203,7 +215,8 @@ async def player_worker_decode(
|
||||
processed_frame.time_base = audio_time_base
|
||||
audio_samples += processed_frame.samples
|
||||
await queue.put(processed_frame)
|
||||
|
||||
if isinstance(outputs, CloseStream):
|
||||
await queue.put(outputs)
|
||||
except (TimeoutError, asyncio.TimeoutError):
|
||||
logger.warning(
|
||||
"Timeout in frame processing cycle after %s seconds - resetting", 60
|
||||
|
||||
Reference in New Issue
Block a user