This commit is contained in:
freddyaboulton
2024-10-10 16:04:54 -07:00
parent 3777bfe777
commit 85cf0d726e
4 changed files with 172 additions and 47 deletions

View File

@@ -31,7 +31,6 @@ def player_worker_decode(
generator = None
while not thread_quit.is_set():
print("stream.latest_args", stream.latest_args)
if stream.latest_args == "not_set":
continue
if generator is None:
@@ -41,7 +40,7 @@ def player_worker_decode(
except Exception as exc:
if isinstance(exc, StopIteration):
print("Not iterating")
asyncio.run_coroutine_threadsafe(queue.put(frame), loop)
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
thread_quit.set()
break
@@ -51,7 +50,8 @@ def player_worker_decode(
if frame_time and frame_time > elapsed_time + 1:
time.sleep(0.1)
sample_rate, audio_array = frame
frame = av.AudioFrame.from_ndarray(audio_array, format="s16", layout="mono")
format = "s16" if audio_array.dtype == "int16" else "fltp"
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout="mono")
frame.sample_rate = sample_rate
for frame in audio_resampler.resample(frame):
# fix timestamps

View File

@@ -135,7 +135,9 @@ class ServerToClientVideo(VideoStreamTrack):
frame.time_base = time_base
return frame
elif self.generator is None:
self.generator = cast(Generator[Any, None, Any], self.event_handler(*self.latest_args))
self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args)
)
try:
next_array = next(self.generator)
@@ -150,6 +152,7 @@ class ServerToClientVideo(VideoStreamTrack):
except Exception as e:
print(e)
import traceback
traceback.print_exc()
@@ -189,41 +192,38 @@ class ServerToClientAudio(AudioStreamTrack):
self.current_timestamp += samples
frame.pts = self.current_timestamp
return frame
def start(self):
if self.__thread is None:
self.__thread = threading.Thread(
name="generator-runner",
target=player_worker_decode,
args=(
asyncio.get_event_loop(),
self.event_handler,
self,
self.queue,
False,
self.thread_quit
),
)
self.__thread.start()
def start(self):
if self.__thread is None:
self.__thread = threading.Thread(
name="generator-runner",
target=player_worker_decode,
args=(
asyncio.get_event_loop(),
self.event_handler,
self,
self.queue,
False,
self.thread_quit,
),
)
self.__thread.start()
async def recv(self):
try:
if self.readyState != "live":
raise MediaStreamError
self.start()
data = await self.queue.get()
if data is None:
self.stop()
raise MediaStreamError
return
data_time = data.time
# control playback rate
if (
data_time is not None
):
if data_time is not None:
if self._start is None:
self._start = time.time() - data_time
else:
@@ -238,35 +238,35 @@ class ServerToClientAudio(AudioStreamTrack):
traceback.print_exc()
def stop(self):
super().stop()
self.thread_quit.set()
if self.__thread is not None:
self.__thread.join()
self.__thread = None
super().stop()
# next_frame = await super().recv()
# print("next frame", next_frame)
# return next_frame
#try:
# if self.latest_args == "not_set":
# frame = await self.empty_frame()
# try:
# if self.latest_args == "not_set":
# frame = await self.empty_frame()
# # await self.modify_frame(frame)
# await asyncio.sleep(100 / 22050)
# print("next_frame not set", frame)
# return frame
# if self.generator is None:
# self.generator = cast(
# Generator[Any, None, Any], self.event_handler(*self.latest_args)
# )
# # await self.modify_frame(frame)
# await asyncio.sleep(100 / 22050)
# print("next_frame not set", frame)
# return frame
# if self.generator is None:
# self.generator = cast(
# Generator[Any, None, Any], self.event_handler(*self.latest_args)
# )
# try:
# next_array = next(self.generator)
# print("iteration")
# except StopIteration:
# print("exception")
# self.stop() # type: ignore
# return
# try:
# next_array = next(self.generator)
# print("iteration")
# except StopIteration:
# print("exception")
# self.stop() # type: ignore
# return
# next_frame = self.array_to_frame(next_array)
# # await self.modify_frame(next_frame)
# print("next frame", next_frame)
@@ -525,6 +525,7 @@ class WebRTC(Component):
cb = ServerToClientVideo(cast(Callable, self.event_handler))
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
if self.mode == "receive" and self.modality == "audio":
print("adding")
cb = ServerToClientAudio(cast(Callable, self.event_handler))
@@ -534,6 +535,7 @@ class WebRTC(Component):
# pc.addTrack(player.audio)
pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
print("here")
# handle offer