From 85cf0d726ebbe0ef140f9c1c1578c3b33956e3e0 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Thu, 10 Oct 2024 16:04:54 -0700 Subject: [PATCH] add code --- backend/gradio_webrtc/utils.py | 6 +- backend/gradio_webrtc/webrtc.py | 86 ++++++++++---------- frontend/shared/AudioWave.svelte | 121 +++++++++++++++++++++++++++++ frontend/shared/StaticAudio.svelte | 6 +- 4 files changed, 172 insertions(+), 47 deletions(-) create mode 100644 frontend/shared/AudioWave.svelte diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index 5553b23..dedd014 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -31,7 +31,6 @@ def player_worker_decode( generator = None while not thread_quit.is_set(): - print("stream.latest_args", stream.latest_args) if stream.latest_args == "not_set": continue if generator is None: @@ -41,7 +40,7 @@ def player_worker_decode( except Exception as exc: if isinstance(exc, StopIteration): print("Not iterating") - asyncio.run_coroutine_threadsafe(queue.put(frame), loop) + asyncio.run_coroutine_threadsafe(queue.put(None), loop) thread_quit.set() break @@ -51,7 +50,8 @@ def player_worker_decode( if frame_time and frame_time > elapsed_time + 1: time.sleep(0.1) sample_rate, audio_array = frame - frame = av.AudioFrame.from_ndarray(audio_array, format="s16", layout="mono") + format = "s16" if audio_array.dtype == "int16" else "fltp" + frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout="mono") frame.sample_rate = sample_rate for frame in audio_resampler.resample(frame): # fix timestamps diff --git a/backend/gradio_webrtc/webrtc.py b/backend/gradio_webrtc/webrtc.py index c3f16f0..c5a3a13 100644 --- a/backend/gradio_webrtc/webrtc.py +++ b/backend/gradio_webrtc/webrtc.py @@ -135,7 +135,9 @@ class ServerToClientVideo(VideoStreamTrack): frame.time_base = time_base return frame elif self.generator is None: - self.generator = cast(Generator[Any, None, Any], self.event_handler(*self.latest_args)) + self.generator = cast( + Generator[Any, None, Any], self.event_handler(*self.latest_args) + ) try: next_array = next(self.generator) @@ -150,6 +152,7 @@ class ServerToClientVideo(VideoStreamTrack): except Exception as e: print(e) import traceback + traceback.print_exc() @@ -189,41 +192,38 @@ class ServerToClientAudio(AudioStreamTrack): self.current_timestamp += samples frame.pts = self.current_timestamp return frame - - def start(self): - if self.__thread is None: - self.__thread = threading.Thread( - name="generator-runner", - target=player_worker_decode, - args=( - asyncio.get_event_loop(), - self.event_handler, - self, - self.queue, - False, - self.thread_quit - ), - ) - self.__thread.start() + def start(self): + if self.__thread is None: + self.__thread = threading.Thread( + name="generator-runner", + target=player_worker_decode, + args=( + asyncio.get_event_loop(), + self.event_handler, + self, + self.queue, + False, + self.thread_quit, + ), + ) + self.__thread.start() async def recv(self): try: if self.readyState != "live": raise MediaStreamError - + self.start() data = await self.queue.get() if data is None: self.stop() - raise MediaStreamError - + return + data_time = data.time # control playback rate - if ( - data_time is not None - ): + if data_time is not None: if self._start is None: self._start = time.time() - data_time else: @@ -238,35 +238,35 @@ class ServerToClientAudio(AudioStreamTrack): traceback.print_exc() def stop(self): - super().stop() self.thread_quit.set() if self.__thread is not None: self.__thread.join() self.__thread = None + super().stop() # next_frame = await super().recv() # print("next frame", next_frame) # return next_frame - #try: - # if self.latest_args == "not_set": - # frame = await self.empty_frame() + # try: + # if self.latest_args == "not_set": + # frame = await self.empty_frame() - # # await self.modify_frame(frame) - # await asyncio.sleep(100 / 22050) - # print("next_frame not set", frame) - # return frame - # if self.generator is None: - # self.generator = cast( - # Generator[Any, None, Any], self.event_handler(*self.latest_args) - # ) + # # await self.modify_frame(frame) + # await asyncio.sleep(100 / 22050) + # print("next_frame not set", frame) + # return frame + # if self.generator is None: + # self.generator = cast( + # Generator[Any, None, Any], self.event_handler(*self.latest_args) + # ) - # try: - # next_array = next(self.generator) - # print("iteration") - # except StopIteration: - # print("exception") - # self.stop() # type: ignore - # return + # try: + # next_array = next(self.generator) + # print("iteration") + # except StopIteration: + # print("exception") + # self.stop() # type: ignore + # return # next_frame = self.array_to_frame(next_array) # # await self.modify_frame(next_frame) # print("next frame", next_frame) @@ -525,6 +525,7 @@ class WebRTC(Component): cb = ServerToClientVideo(cast(Callable, self.event_handler)) pc.addTrack(cb) self.connections[body["webrtc_id"]] = cb + cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None)) if self.mode == "receive" and self.modality == "audio": print("adding") cb = ServerToClientAudio(cast(Callable, self.event_handler)) @@ -534,6 +535,7 @@ class WebRTC(Component): # pc.addTrack(player.audio) pc.addTrack(cb) self.connections[body["webrtc_id"]] = cb + cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None)) print("here") # handle offer diff --git a/frontend/shared/AudioWave.svelte b/frontend/shared/AudioWave.svelte new file mode 100644 index 0000000..eaf6a7c --- /dev/null +++ b/frontend/shared/AudioWave.svelte @@ -0,0 +1,121 @@ + + +
+
+ {#each Array(numBars) as _} +
+ {/each} +
+ +
+ + \ No newline at end of file diff --git a/frontend/shared/StaticAudio.svelte b/frontend/shared/StaticAudio.svelte index e58ce6b..31c6c76 100644 --- a/frontend/shared/StaticAudio.svelte +++ b/frontend/shared/StaticAudio.svelte @@ -9,6 +9,7 @@ import { onMount } from "svelte"; import { start, stop } from "./webrtc_utils"; + import AudioWave from "./AudioWave.svelte"; export let value: string | null = null; @@ -56,7 +57,6 @@ ] }; pc = new RTCPeerConnection(rtc_configuration); - console.log("config", pc.getConfiguration()); pc.addEventListener("connectionstatechange", async (event) => { switch(pc.connectionState) { @@ -95,12 +95,14 @@