final touches

This commit is contained in:
freddyaboulton
2024-10-23 15:52:48 -07:00
parent 1688502e99
commit e87c4d49e8
3 changed files with 75 additions and 62 deletions

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import functools
import logging
import threading
import time
@@ -148,8 +149,7 @@ class AudioCallback(AudioStreamTrack):
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.queue = asyncio.Queue()
self.thread_quit = threading.Event()
self.__thread = None
self.thread_quit = asyncio.Event()
self._start: float | None = None
self.has_started = False
self.last_timestamp = 0
@@ -165,26 +165,26 @@ class AudioCallback(AudioStreamTrack):
self.event_handler.receive, (frame.sample_rate, numpy_array)
)
except MediaStreamError:
logger.debug("MediaStreamError in process_input_frames")
break
logger.debug("MediaStreamError in process_input_frames")
break
def start(self):
if not self.has_started:
loop = asyncio.get_running_loop()
callable = functools.partial(
loop.run_in_executor, None, self.event_handler.emit
)
asyncio.create_task(self.process_input_frames())
self.__thread = threading.Thread(
target=player_worker_decode,
daemon=False,
args=(
asyncio.get_event_loop(),
self.event_handler.emit,
asyncio.create_task(
player_worker_decode(
callable,
self.queue,
self.thread_quit,
False,
self.event_handler.output_sample_rate,
self.event_handler.output_frame_size,
),
)
)
self.__thread.start()
self.has_started = True
async def recv(self):
@@ -220,9 +220,6 @@ class AudioCallback(AudioStreamTrack):
def stop(self):
logger.debug("audio callback stop")
self.thread_quit.set()
if self.__thread is not None:
self.__thread.join()
self.__thread = None
super().stop()
@@ -284,8 +281,8 @@ class ServerToClientAudio(AudioStreamTrack):
self.latest_args: str | list[Any] = "not_set"
self.args_set = threading.Event()
self.queue = asyncio.Queue()
self.thread_quit = threading.Event()
self.__thread = None
self.thread_quit = asyncio.Event()
self.has_started = False
self._start: float | None = None
super().__init__()
@@ -301,20 +298,18 @@ class ServerToClientAudio(AudioStreamTrack):
self.thread_quit.set()
def start(self):
if self.__thread is None:
self.__thread = threading.Thread(
name="generator-runner",
target=player_worker_decode,
daemon=True,
args=(
asyncio.get_event_loop(),
self.next,
if not self.has_started:
loop = asyncio.get_running_loop()
callable = functools.partial(loop.run_in_executor, None, self.next)
asyncio.create_task(
player_worker_decode(
callable,
self.queue,
self.thread_quit,
True,
),
)
)
self.__thread.start()
self.has_started = True
async def recv(self):
try:
@@ -344,10 +339,8 @@ class ServerToClientAudio(AudioStreamTrack):
logger.debug("traceback %s", exec)
def stop(self):
logger.debug("audio-to-client stop callback")
self.thread_quit.set()
if self.__thread is not None:
self.__thread.join()
self.__thread = None
super().stop()