mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Improve Interruption Handling (#134)
* Clear websocket queue on interrupt * add code
This commit is contained in:
@@ -184,10 +184,11 @@ class ReplyOnPause(StreamHandler):
|
|||||||
self.process_audio(frame, self.state)
|
self.process_audio(frame, self.state)
|
||||||
if self.state.pause_detected:
|
if self.state.pause_detected:
|
||||||
self.event.set()
|
self.event.set()
|
||||||
if self.can_interrupt:
|
if self.can_interrupt and self.state.responding:
|
||||||
self.clear_queue()
|
|
||||||
self._close_generator()
|
self._close_generator()
|
||||||
self.generator = None
|
self.generator = None
|
||||||
|
if self.can_interrupt:
|
||||||
|
self.clear_queue()
|
||||||
|
|
||||||
def _close_generator(self):
|
def _close_generator(self):
|
||||||
"""Properly close the generator to ensure resources are released."""
|
"""Properly close the generator to ensure resources are released."""
|
||||||
|
|||||||
@@ -509,7 +509,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
handler.phone_mode = True
|
handler.phone_mode = True
|
||||||
|
|
||||||
async def set_handler(s: str, a: WebSocketHandler):
|
async def set_handler(s: str, a: WebSocketHandler):
|
||||||
if len(self.connections) >= self.concurrency_limit:
|
if len(self.connections) >= self.concurrency_limit: # type: ignore
|
||||||
await cast(WebSocket, a.websocket).send_json(
|
await cast(WebSocket, a.websocket).send_json(
|
||||||
{
|
{
|
||||||
"status": "failed",
|
"status": "failed",
|
||||||
@@ -532,7 +532,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
handler.phone_mode = False
|
handler.phone_mode = False
|
||||||
|
|
||||||
async def set_handler(s: str, a: WebSocketHandler):
|
async def set_handler(s: str, a: WebSocketHandler):
|
||||||
if len(self.connections) >= self.concurrency_limit:
|
if len(self.connections) >= self.concurrency_limit: # type: ignore
|
||||||
await cast(WebSocket, a.websocket).send_json(
|
await cast(WebSocket, a.websocket).send_json(
|
||||||
{
|
{
|
||||||
"status": "failed",
|
"status": "failed",
|
||||||
|
|||||||
@@ -431,10 +431,14 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
self.set_additional_outputs = set_additional_outputs
|
self.set_additional_outputs = set_additional_outputs
|
||||||
|
|
||||||
def clear_queue(self):
|
def clear_queue(self):
|
||||||
if self.queue:
|
logger.debug("clearing queue")
|
||||||
while not self.queue.empty():
|
logger.debug("queue size: %d", self.queue.qsize())
|
||||||
self.queue.get_nowait()
|
i = 0
|
||||||
self._start = None
|
while not self.queue.empty():
|
||||||
|
self.queue.get_nowait()
|
||||||
|
i += 1
|
||||||
|
logger.debug("popped %d items from queue", i)
|
||||||
|
self._start = None
|
||||||
|
|
||||||
def set_channel(self, channel: DataChannel):
|
def set_channel(self, channel: DataChannel):
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class WebSocketHandler:
|
|||||||
],
|
],
|
||||||
):
|
):
|
||||||
self.stream_handler = stream_handler
|
self.stream_handler = stream_handler
|
||||||
self.stream_handler._clear_queue = lambda: None
|
self.stream_handler._clear_queue = self._clear_queue
|
||||||
self.websocket: Optional[WebSocket] = None
|
self.websocket: Optional[WebSocket] = None
|
||||||
self._emit_task: Optional[asyncio.Task] = None
|
self._emit_task: Optional[asyncio.Task] = None
|
||||||
self.stream_id: Optional[str] = None
|
self.stream_id: Optional[str] = None
|
||||||
@@ -64,6 +64,20 @@ class WebSocketHandler:
|
|||||||
self.set_handler = set_handler
|
self.set_handler = set_handler
|
||||||
self.quit = asyncio.Event()
|
self.quit = asyncio.Event()
|
||||||
self.clean_up = clean_up
|
self.clean_up = clean_up
|
||||||
|
self.queue = asyncio.Queue()
|
||||||
|
|
||||||
|
def _clear_queue(self):
|
||||||
|
old_queue = self.queue
|
||||||
|
self.queue = asyncio.Queue()
|
||||||
|
logger.debug("clearing queue")
|
||||||
|
i = 0
|
||||||
|
while not old_queue.empty():
|
||||||
|
try:
|
||||||
|
old_queue.get_nowait()
|
||||||
|
i += 1
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
logger.debug("popped %d items from queue", i)
|
||||||
|
|
||||||
def set_args(self, args: list[Any]):
|
def set_args(self, args: list[Any]):
|
||||||
self.stream_handler.set_args(args)
|
self.stream_handler.set_args(args)
|
||||||
@@ -77,6 +91,7 @@ class WebSocketHandler:
|
|||||||
self.stream_handler._loop = loop
|
self.stream_handler._loop = loop
|
||||||
self.stream_handler.set_channel(self.data_channel)
|
self.stream_handler.set_channel(self.data_channel)
|
||||||
self._emit_task = asyncio.create_task(self._emit_loop())
|
self._emit_task = asyncio.create_task(self._emit_loop())
|
||||||
|
self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue())
|
||||||
if isinstance(self.stream_handler, AsyncStreamHandler):
|
if isinstance(self.stream_handler, AsyncStreamHandler):
|
||||||
start_up = self.stream_handler.start_up()
|
start_up = self.stream_handler.start_up()
|
||||||
else:
|
else:
|
||||||
@@ -137,17 +152,32 @@ class WebSocketHandler:
|
|||||||
finally:
|
finally:
|
||||||
if self._emit_task:
|
if self._emit_task:
|
||||||
self._emit_task.cancel()
|
self._emit_task.cancel()
|
||||||
|
if self._emit_to_queue_task:
|
||||||
|
self._emit_to_queue_task.cancel()
|
||||||
if self.start_up_task:
|
if self.start_up_task:
|
||||||
self.start_up_task.cancel()
|
self.start_up_task.cancel()
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
async def _emit_loop(self):
|
async def _emit_to_queue(self):
|
||||||
try:
|
try:
|
||||||
while not self.quit.is_set():
|
while not self.quit.is_set():
|
||||||
if isinstance(self.stream_handler, AsyncStreamHandler):
|
if isinstance(self.stream_handler, AsyncStreamHandler):
|
||||||
output = await self.stream_handler.emit()
|
output = await self.stream_handler.emit()
|
||||||
else:
|
else:
|
||||||
output = await run_sync(self.stream_handler.emit)
|
output = await run_sync(self.stream_handler.emit)
|
||||||
|
self.queue.put_nowait(output)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug("Emit loop cancelled")
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
logger.debug("Error in emit loop: %s", e)
|
||||||
|
|
||||||
|
async def _emit_loop(self):
|
||||||
|
try:
|
||||||
|
while not self.quit.is_set():
|
||||||
|
output = await self.queue.get()
|
||||||
|
|
||||||
if output is not None:
|
if output is not None:
|
||||||
frame, output = split_output(output)
|
frame, output = split_output(output)
|
||||||
|
|||||||
Reference in New Issue
Block a user