This commit is contained in:
freddyaboulton
2024-10-10 16:04:54 -07:00
parent 3777bfe777
commit 85cf0d726e
4 changed files with 172 additions and 47 deletions

View File

@@ -31,7 +31,6 @@ def player_worker_decode(
generator = None generator = None
while not thread_quit.is_set(): while not thread_quit.is_set():
print("stream.latest_args", stream.latest_args)
if stream.latest_args == "not_set": if stream.latest_args == "not_set":
continue continue
if generator is None: if generator is None:
@@ -41,7 +40,7 @@ def player_worker_decode(
except Exception as exc: except Exception as exc:
if isinstance(exc, StopIteration): if isinstance(exc, StopIteration):
print("Not iterating") print("Not iterating")
asyncio.run_coroutine_threadsafe(queue.put(frame), loop) asyncio.run_coroutine_threadsafe(queue.put(None), loop)
thread_quit.set() thread_quit.set()
break break
@@ -51,7 +50,8 @@ def player_worker_decode(
if frame_time and frame_time > elapsed_time + 1: if frame_time and frame_time > elapsed_time + 1:
time.sleep(0.1) time.sleep(0.1)
sample_rate, audio_array = frame sample_rate, audio_array = frame
frame = av.AudioFrame.from_ndarray(audio_array, format="s16", layout="mono") format = "s16" if audio_array.dtype == "int16" else "fltp"
frame = av.AudioFrame.from_ndarray(audio_array, format=format, layout="mono")
frame.sample_rate = sample_rate frame.sample_rate = sample_rate
for frame in audio_resampler.resample(frame): for frame in audio_resampler.resample(frame):
# fix timestamps # fix timestamps

View File

@@ -135,7 +135,9 @@ class ServerToClientVideo(VideoStreamTrack):
frame.time_base = time_base frame.time_base = time_base
return frame return frame
elif self.generator is None: elif self.generator is None:
self.generator = cast(Generator[Any, None, Any], self.event_handler(*self.latest_args)) self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args)
)
try: try:
next_array = next(self.generator) next_array = next(self.generator)
@@ -150,6 +152,7 @@ class ServerToClientVideo(VideoStreamTrack):
except Exception as e: except Exception as e:
print(e) print(e)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@@ -193,20 +196,19 @@ class ServerToClientAudio(AudioStreamTrack):
def start(self): def start(self):
if self.__thread is None: if self.__thread is None:
self.__thread = threading.Thread( self.__thread = threading.Thread(
name="generator-runner", name="generator-runner",
target=player_worker_decode, target=player_worker_decode,
args=( args=(
asyncio.get_event_loop(), asyncio.get_event_loop(),
self.event_handler, self.event_handler,
self, self,
self.queue, self.queue,
False, False,
self.thread_quit self.thread_quit,
), ),
) )
self.__thread.start() self.__thread.start()
async def recv(self): async def recv(self):
try: try:
if self.readyState != "live": if self.readyState != "live":
@@ -216,14 +218,12 @@ class ServerToClientAudio(AudioStreamTrack):
data = await self.queue.get() data = await self.queue.get()
if data is None: if data is None:
self.stop() self.stop()
raise MediaStreamError return
data_time = data.time data_time = data.time
# control playback rate # control playback rate
if ( if data_time is not None:
data_time is not None
):
if self._start is None: if self._start is None:
self._start = time.time() - data_time self._start = time.time() - data_time
else: else:
@@ -238,35 +238,35 @@ class ServerToClientAudio(AudioStreamTrack):
traceback.print_exc() traceback.print_exc()
def stop(self): def stop(self):
super().stop()
self.thread_quit.set() self.thread_quit.set()
if self.__thread is not None: if self.__thread is not None:
self.__thread.join() self.__thread.join()
self.__thread = None self.__thread = None
super().stop()
# next_frame = await super().recv() # next_frame = await super().recv()
# print("next frame", next_frame) # print("next frame", next_frame)
# return next_frame # return next_frame
#try: # try:
# if self.latest_args == "not_set": # if self.latest_args == "not_set":
# frame = await self.empty_frame() # frame = await self.empty_frame()
# # await self.modify_frame(frame) # # await self.modify_frame(frame)
# await asyncio.sleep(100 / 22050) # await asyncio.sleep(100 / 22050)
# print("next_frame not set", frame) # print("next_frame not set", frame)
# return frame # return frame
# if self.generator is None: # if self.generator is None:
# self.generator = cast( # self.generator = cast(
# Generator[Any, None, Any], self.event_handler(*self.latest_args) # Generator[Any, None, Any], self.event_handler(*self.latest_args)
# ) # )
# try: # try:
# next_array = next(self.generator) # next_array = next(self.generator)
# print("iteration") # print("iteration")
# except StopIteration: # except StopIteration:
# print("exception") # print("exception")
# self.stop() # type: ignore # self.stop() # type: ignore
# return # return
# next_frame = self.array_to_frame(next_array) # next_frame = self.array_to_frame(next_array)
# # await self.modify_frame(next_frame) # # await self.modify_frame(next_frame)
# print("next frame", next_frame) # print("next frame", next_frame)
@@ -525,6 +525,7 @@ class WebRTC(Component):
cb = ServerToClientVideo(cast(Callable, self.event_handler)) cb = ServerToClientVideo(cast(Callable, self.event_handler))
pc.addTrack(cb) pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
if self.mode == "receive" and self.modality == "audio": if self.mode == "receive" and self.modality == "audio":
print("adding") print("adding")
cb = ServerToClientAudio(cast(Callable, self.event_handler)) cb = ServerToClientAudio(cast(Callable, self.event_handler))
@@ -534,6 +535,7 @@ class WebRTC(Component):
# pc.addTrack(player.audio) # pc.addTrack(player.audio)
pc.addTrack(cb) pc.addTrack(cb)
self.connections[body["webrtc_id"]] = cb self.connections[body["webrtc_id"]] = cb
cb.on("ended", lambda: self.connections.pop(body["webrtc_id"], None))
print("here") print("here")
# handle offer # handle offer

View File

@@ -0,0 +1,121 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
export let numBars = 16;
export let stream_state: "open" | "closed" = "closed";
export let audio_source: HTMLAudioElement;
let audioContext: AudioContext;
let analyser;
let dataArray;
let animationId;
let is_muted = false;
$: containerWidth = `calc((var(--boxSize) + var(--gutter)) * ${numBars})`;
$: if(stream_state === "open") setupAudioContext()
onDestroy(() => {
if (animationId) {
cancelAnimationFrame(animationId);
}
if (audioContext) {
audioContext.close();
}
});
function setupAudioContext() {
console.log("set up")
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
console.log("audio_source", audio_source.srcObject);
const source = audioContext.createMediaStreamSource(audio_source.srcObject);
source.connect(analyser);
analyser.connect(audioContext.destination);
analyser.fftSize = 64;
dataArray = new Uint8Array(analyser.frequencyBinCount);
updateBars();
}
function updateBars() {
analyser.getByteFrequencyData(dataArray);
const bars = document.querySelectorAll('.box');
for (let i = 0; i < bars.length; i++) {
const barHeight = (dataArray[i] / 255) * 2; // Amplify the effect
bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`;
}
animationId = requestAnimationFrame(updateBars);
}
function togglePlayPause() {
if (is_muted) {
audio_source.muted = false;
} else {
audio_source.muted = true;
}
is_muted = !is_muted;
}
</script>
<div class="waveContainer">
<div class="boxContainer" style:width={containerWidth}>
{#each Array(numBars) as _}
<div class="box"></div>
{/each}
</div>
<button class="playPauseButton" on:click={togglePlayPause}>
{is_muted ? '🔈' : '🔇'}
</button>
</div>
<style>
.waveContainer {
position: relative;
display: flex;
flex-direction: column;
align-items: center;
}
.boxContainer {
display: flex;
justify-content: space-between;
height: 64px;
--boxSize: 8px;
--gutter: 4px;
}
.box {
height: 100%;
width: var(--boxSize);
background: var(--color-accent);
border-radius: 8px;
transition: transform 0.05s ease;
}
.playPauseButton {
margin-top: 10px;
padding: 10px 20px;
font-size: 24px;
cursor: pointer;
background: none;
border: none;
border-radius: 5px;
color: black;
}
:global(body) {
display: flex;
justify-content: center;
background: black;
margin: 0;
padding: 0;
align-items: center;
height: 100vh;
color: white;
font-family: Arial, sans-serif;
}

View File

@@ -9,6 +9,7 @@
import { onMount } from "svelte"; import { onMount } from "svelte";
import { start, stop } from "./webrtc_utils"; import { start, stop } from "./webrtc_utils";
import AudioWave from "./AudioWave.svelte";
export let value: string | null = null; export let value: string | null = null;
@@ -56,7 +57,6 @@
] ]
}; };
pc = new RTCPeerConnection(rtc_configuration); pc = new RTCPeerConnection(rtc_configuration);
console.log("config", pc.getConfiguration());
pc.addEventListener("connectionstatechange", pc.addEventListener("connectionstatechange",
async (event) => { async (event) => {
switch(pc.connectionState) { switch(pc.connectionState) {
@@ -95,12 +95,14 @@
<audio <audio
class="standard-player" class="standard-player"
class:hidden={value === "__webrtc_value__"} class:hidden={value === "__webrtc_value__"}
controls
on:load on:load
bind:this={audio_player} bind:this={audio_player}
on:ended={() => dispatch("stop")} on:ended={() => dispatch("stop")}
on:play={() => dispatch("play")} on:play={() => dispatch("play")}
/> />
{#if value !== "__webrtc_value__"}
<AudioWave audio_source={audio_player} {stream_state}/>
{/if}
{#if value === "__webrtc_value__"} {#if value === "__webrtc_value__"}
<Empty size="small"> <Empty size="small">
<Music /> <Music />