Close Stream from Backend (#222)

* Close from backend

* Add code
This commit is contained in:
Freddy Boulton
2025-03-28 20:47:34 -04:00
committed by GitHub
parent 71743acb64
commit 8ed27fba78
10 changed files with 2336 additions and 2237 deletions

View File

@@ -32,6 +32,11 @@ class AdditionalOutputs:
self.args = args
class CloseStream:
def __init__(self, msg: str = "Stream closed") -> None:
self.msg = msg
class DataChannel(Protocol):
def send(self, message: str) -> None: ...
@@ -39,6 +44,7 @@ class DataChannel(Protocol):
def create_message(
type: Literal[
"send_input",
"end_stream",
"fetch_output",
"stopword",
"error",
@@ -98,9 +104,13 @@ class WebRTCError(Exception):
_send_log(message, "error")
def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
def split_output(
data: tuple | Any,
) -> tuple[Any, AdditionalOutputs | CloseStream | None]:
if isinstance(data, AdditionalOutputs):
return None, data
if isinstance(data, CloseStream):
return None, data
if isinstance(data, tuple):
# handle the bare audio case
if 2 <= len(data) <= 3 and isinstance(data[1], np.ndarray):
@@ -109,11 +119,11 @@ def split_output(data: tuple | Any) -> tuple[Any, AdditionalOutputs | None]:
raise ValueError(
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
)
if not isinstance(data[-1], AdditionalOutputs):
if not isinstance(data[-1], (AdditionalOutputs, CloseStream)):
raise ValueError(
"The last element of the tuple must be an instance of AdditionalOutputs."
)
return data[0], cast(AdditionalOutputs, data[1])
return data[0], cast(AdditionalOutputs | CloseStream, data[1])
return data, None
@@ -152,6 +162,8 @@ async def player_worker_decode(
cast(DataChannel, channel()).send(create_message("fetch_output", []))
if frame is None:
if isinstance(outputs, CloseStream):
await queue.put(outputs)
if quit_on_none:
await queue.put(None)
break
@@ -203,7 +215,8 @@ async def player_worker_decode(
processed_frame.time_base = audio_time_base
audio_samples += processed_frame.samples
await queue.put(processed_frame)
if isinstance(outputs, CloseStream):
await queue.put(outputs)
except (TimeoutError, asyncio.TimeoutError):
logger.warning(
"Timeout in frame processing cycle after %s seconds - resetting", 60