mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 01:49:23 +08:00
Fix (#242)
This commit is contained in:
@@ -10,7 +10,7 @@ from numpy.typing import NDArray
|
||||
|
||||
from .pause_detection import ModelOptions, PauseDetectionModel, get_silero_model
|
||||
from .tracks import EmitType, StreamHandler
|
||||
from .utils import create_message, split_output
|
||||
from .utils import AdditionalOutputs, create_message, split_output
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@@ -243,7 +243,7 @@ class ReplyOnPause(StreamHandler):
|
||||
self.send_message_sync(create_message("log", "response_starting"))
|
||||
self.state.responded_audio = True
|
||||
if self.phone_mode:
|
||||
if additional_outputs:
|
||||
if isinstance(additional_outputs, AdditionalOutputs):
|
||||
self.latest_args = [None] + list(additional_outputs.args)
|
||||
return output
|
||||
except (StopIteration, StopAsyncIteration):
|
||||
|
||||
@@ -11,7 +11,7 @@ from fastapi import WebSocket
|
||||
from fastapi.websockets import WebSocketDisconnect, WebSocketState
|
||||
|
||||
from .tracks import AsyncStreamHandler, StreamHandlerImpl
|
||||
from .utils import AdditionalOutputs, DataChannel, split_output
|
||||
from .utils import AdditionalOutputs, CloseStream, DataChannel, split_output
|
||||
|
||||
|
||||
class WebSocketDataChannel(DataChannel):
|
||||
@@ -66,6 +66,9 @@ class WebSocketHandler:
|
||||
self.quit = asyncio.Event()
|
||||
self.clean_up = clean_up
|
||||
self.queue = asyncio.Queue()
|
||||
self.playing_durations = [] # Track durations of frames being played
|
||||
self._frame_cleanup_task: Optional[asyncio.Task] = None
|
||||
self._graceful_shutdown_task: Optional[asyncio.Task] = None
|
||||
|
||||
def _clear_queue(self):
|
||||
old_queue = self.queue
|
||||
@@ -93,6 +96,7 @@ class WebSocketHandler:
|
||||
self.stream_handler.set_channel(self.data_channel)
|
||||
self._emit_task = asyncio.create_task(self._emit_loop())
|
||||
self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue())
|
||||
self._frame_cleanup_task = asyncio.create_task(self._cleanup_frames_loop())
|
||||
if isinstance(self.stream_handler, AsyncStreamHandler):
|
||||
start_up = self.stream_handler.start_up()
|
||||
else:
|
||||
@@ -165,6 +169,10 @@ class WebSocketHandler:
|
||||
self._emit_task.cancel()
|
||||
if self._emit_to_queue_task:
|
||||
self._emit_to_queue_task.cancel()
|
||||
if self._frame_cleanup_task:
|
||||
self._frame_cleanup_task.cancel()
|
||||
if self._graceful_shutdown_task:
|
||||
self._graceful_shutdown_task.cancel()
|
||||
if self.start_up_task:
|
||||
self.start_up_task.cancel()
|
||||
|
||||
@@ -189,6 +197,45 @@ class WebSocketHandler:
|
||||
traceback.print_exc()
|
||||
logger.debug("Error in emit loop: %s", e)
|
||||
|
||||
async def _cleanup_frames_loop(self):
|
||||
"""Background task that removes frames from playing_durations after they've finished playing."""
|
||||
try:
|
||||
while not self.quit.is_set():
|
||||
if self.playing_durations:
|
||||
duration = self.playing_durations[0]
|
||||
await asyncio.sleep(duration)
|
||||
if self.playing_durations:
|
||||
self.playing_durations.pop(0)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Frame cleanup loop cancelled")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in frame cleanup loop: {e}")
|
||||
|
||||
async def _wait_for_audio_completion(self):
|
||||
"""Wait for all queued audio to finish playing before shutting down."""
|
||||
try:
|
||||
if not self.playing_durations:
|
||||
self.quit.set()
|
||||
return
|
||||
|
||||
# Calculate total remaining playback time
|
||||
total_wait = sum(self.playing_durations)
|
||||
|
||||
if total_wait > 0:
|
||||
logger.debug(
|
||||
f"Waiting {total_wait:.2f}s for audio to complete before closing"
|
||||
)
|
||||
await asyncio.sleep(total_wait)
|
||||
|
||||
self.quit.set()
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Graceful shutdown cancelled")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in graceful shutdown: {e}")
|
||||
self.quit.set()
|
||||
|
||||
async def _emit_loop(self):
|
||||
try:
|
||||
while not self.quit.is_set():
|
||||
@@ -196,10 +243,17 @@ class WebSocketHandler:
|
||||
|
||||
if output is not None:
|
||||
frame, output = split_output(output)
|
||||
if output is not None:
|
||||
if isinstance(output, AdditionalOutputs):
|
||||
self.set_additional_outputs(output)
|
||||
elif isinstance(output, CloseStream):
|
||||
self._graceful_shutdown_task = asyncio.create_task(
|
||||
self._wait_for_audio_completion()
|
||||
)
|
||||
continue
|
||||
|
||||
if not isinstance(frame, tuple):
|
||||
continue
|
||||
|
||||
target_rate = (
|
||||
self.stream_handler.output_sample_rate
|
||||
if not self.stream_handler.phone_mode
|
||||
@@ -211,6 +265,11 @@ class WebSocketHandler:
|
||||
audio_payload = base64.b64encode(mulaw_audio).decode("utf-8")
|
||||
|
||||
if self.websocket and self.stream_id:
|
||||
sample_rate, audio_array = frame
|
||||
duration = len(audio_array) / sample_rate
|
||||
|
||||
self.playing_durations.append(duration)
|
||||
|
||||
payload = {
|
||||
"event": "media",
|
||||
"media": {"payload": audio_payload},
|
||||
|
||||
Reference in New Issue
Block a user