mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Fix (#242)
This commit is contained in:
@@ -10,7 +10,7 @@ from numpy.typing import NDArray
|
|||||||
|
|
||||||
from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model
|
from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model
|
||||||
from .tracks import EmitType, StreamHandler
|
from .tracks import EmitType, StreamHandler
|
||||||
from .utils import create_message, split_output
|
from .utils import AdditionalOutputs, create_message, split_output
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
@@ -243,7 +243,7 @@ class ReplyOnPause(StreamHandler):
|
|||||||
self.send_message_sync(create_message("log", "response_starting"))
|
self.send_message_sync(create_message("log", "response_starting"))
|
||||||
self.state.responded_audio = True
|
self.state.responded_audio = True
|
||||||
if self.phone_mode:
|
if self.phone_mode:
|
||||||
if additional_outputs:
|
if isinstance(additional_outputs, AdditionalOutputs):
|
||||||
self.latest_args = [None] + list(additional_outputs.args)
|
self.latest_args = [None] + list(additional_outputs.args)
|
||||||
return output
|
return output
|
||||||
except (StopIteration, StopAsyncIteration):
|
except (StopIteration, StopAsyncIteration):
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from fastapi import WebSocket
|
|||||||
from fastapi.websockets import WebSocketDisconnect, WebSocketState
|
from fastapi.websockets import WebSocketDisconnect, WebSocketState
|
||||||
|
|
||||||
from .tracks import AsyncStreamHandler, StreamHandlerImpl
|
from .tracks import AsyncStreamHandler, StreamHandlerImpl
|
||||||
from .utils import AdditionalOutputs, DataChannel, split_output
|
from .utils import AdditionalOutputs, CloseStream, DataChannel, split_output
|
||||||
|
|
||||||
|
|
||||||
class WebSocketDataChannel(DataChannel):
|
class WebSocketDataChannel(DataChannel):
|
||||||
@@ -66,6 +66,9 @@ class WebSocketHandler:
|
|||||||
self.quit = asyncio.Event()
|
self.quit = asyncio.Event()
|
||||||
self.clean_up = clean_up
|
self.clean_up = clean_up
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
|
self.playing_durations = [] # Track durations of frames being played
|
||||||
|
self._frame_cleanup_task: Optional[asyncio.Task] = None
|
||||||
|
self._graceful_shutdown_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
def _clear_queue(self):
|
def _clear_queue(self):
|
||||||
old_queue = self.queue
|
old_queue = self.queue
|
||||||
@@ -93,6 +96,7 @@ class WebSocketHandler:
|
|||||||
self.stream_handler.set_channel(self.data_channel)
|
self.stream_handler.set_channel(self.data_channel)
|
||||||
self._emit_task = asyncio.create_task(self._emit_loop())
|
self._emit_task = asyncio.create_task(self._emit_loop())
|
||||||
self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue())
|
self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue())
|
||||||
|
self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop())
|
||||||
if isinstance(self.stream_handler, AsyncStreamHandler):
|
if isinstance(self.stream_handler, AsyncStreamHandler):
|
||||||
start_up = self.stream_handler.start_up()
|
start_up = self.stream_handler.start_up()
|
||||||
else:
|
else:
|
||||||
@@ -165,6 +169,10 @@ class WebSocketHandler:
|
|||||||
self._emit_task.cancel()
|
self._emit_task.cancel()
|
||||||
if self._emit_to_queue_task:
|
if self._emit_to_queue_task:
|
||||||
self._emit_to_queue_task.cancel()
|
self._emit_to_queue_task.cancel()
|
||||||
|
if self._frame_cleanup_task:
|
||||||
|
self._frame_cleanup_task.cancel()
|
||||||
|
if self._graceful_shutdown_task:
|
||||||
|
self._graceful_shutdown_task.cancel()
|
||||||
if self.start_up_task:
|
if self.start_up_task:
|
||||||
self.start_up_task.cancel()
|
self.start_up_task.cancel()
|
||||||
|
|
||||||
@@ -189,6 +197,45 @@ class WebSocketHandler:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.debug("Error in emit loop: %s", e)
|
logger.debug("Error in emit loop: %s", e)
|
||||||
|
|
||||||
|
async def _cleanup_frames_loop(self):
|
||||||
|
"""Background task that removes frames from playing_durations after they've finished playing."""
|
||||||
|
try:
|
||||||
|
while not self.quit.is_set():
|
||||||
|
if self.playing_durations:
|
||||||
|
duration = self.playing_durations[0]
|
||||||
|
await asyncio.sleep(duration)
|
||||||
|
if self.playing_durations:
|
||||||
|
self.playing_durations.pop(0)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug("Frame cleanup loop cancelled")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error in frame cleanup loop: {e}")
|
||||||
|
|
||||||
|
async def _wait_for_audio_completion(self):
|
||||||
|
"""Wait for all queued audio to finish playing before shutting down."""
|
||||||
|
try:
|
||||||
|
if not self.playing_durations:
|
||||||
|
self.quit.set()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate total remaining playback time
|
||||||
|
total_wait = sum(self.playing_durations)
|
||||||
|
|
||||||
|
if total_wait > 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Waiting {total_wait:.2f}s for audio to complete before closing"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(total_wait)
|
||||||
|
|
||||||
|
self.quit.set()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug("Graceful shutdown cancelled")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error in graceful shutdown: {e}")
|
||||||
|
self.quit.set()
|
||||||
|
|
||||||
async def _emit_loop(self):
|
async def _emit_loop(self):
|
||||||
try:
|
try:
|
||||||
while not self.quit.is_set():
|
while not self.quit.is_set():
|
||||||
@@ -196,10 +243,17 @@ class WebSocketHandler:
|
|||||||
|
|
||||||
if output is not None:
|
if output is not None:
|
||||||
frame, output = split_output(output)
|
frame, output = split_output(output)
|
||||||
if output is not None:
|
if isinstance(output, AdditionalOutputs):
|
||||||
self.set_additional_outputs(output)
|
self.set_additional_outputs(output)
|
||||||
|
elif isinstance(output, CloseStream):
|
||||||
|
self._graceful_shutdown_task = asyncio.create_task(
|
||||||
|
self._wait_for_audio_completion()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
if not isinstance(frame, tuple):
|
if not isinstance(frame, tuple):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
target_rate = (
|
target_rate = (
|
||||||
self.stream_handler.output_sample_rate
|
self.stream_handler.output_sample_rate
|
||||||
if not self.stream_handler.phone_mode
|
if not self.stream_handler.phone_mode
|
||||||
@@ -211,6 +265,11 @@ class WebSocketHandler:
|
|||||||
audio_payload = base64.b64encode(mulaw_audio).decode("utf-8")
|
audio_payload = base64.b64encode(mulaw_audio).decode("utf-8")
|
||||||
|
|
||||||
if self.websocket and self.stream_id:
|
if self.websocket and self.stream_id:
|
||||||
|
sample_rate, audio_array = frame
|
||||||
|
duration = len(audio_array) / sample_rate
|
||||||
|
|
||||||
|
self.playing_durations.append(duration)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"event": "media",
|
"event": "media",
|
||||||
"media": {"payload": audio_payload},
|
"media": {"payload": audio_payload},
|
||||||
|
|||||||
Reference in New Issue
Block a user