diff --git a/demo/video_chat.py b/demo/video_chat.py new file mode 100644 index 0000000..7de000b --- /dev/null +++ b/demo/video_chat.py @@ -0,0 +1,116 @@ +import asyncio +import base64 +import os +import time +from io import BytesIO + +import gradio as gr +import numpy as np +from gradio_webrtc import ( + AsyncAudioVideoStreamHandler, + WebRTC, + async_aggregate_bytes_to_16bit, + VideoEmitType, + AudioEmitType, +) +from PIL import Image + + +def encode_audio(data: np.ndarray) -> dict: + """Encode Audio data to send to the server""" + return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")} + + +def encode_image(data: np.ndarray) -> dict: + with BytesIO() as output_bytes: + pil_image = Image.fromarray(data) + pil_image.save(output_bytes, "JPEG") + bytes_data = output_bytes.getvalue() + base64_str = str(base64.b64encode(bytes_data), "utf-8") + return {"mime_type": "image/jpeg", "data": base64_str} + + +class GeminiHandler(AsyncAudioVideoStreamHandler): + def __init__( + self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480 + ) -> None: + super().__init__( + expected_layout, + output_sample_rate, + output_frame_size, + input_sample_rate=24000, + ) + self.audio_queue = asyncio.Queue() + self.video_queue = asyncio.Queue() + self.quit = asyncio.Event() + self.session = None + self.last_frame_time = 0 + + def copy(self) -> "GeminiHandler": + return GeminiHandler( + expected_layout=self.expected_layout, + output_sample_rate=self.output_sample_rate, + output_frame_size=self.output_frame_size, + ) + + async def video_receive(self, frame: np.ndarray): + # if self.session: + # # send image every 1 second + # if time.time() - self.last_frame_time > 1: + # self.last_frame_time = time.time() + # await self.session.send(encode_image(frame)) + # if self.latest_args[2] is not None: + # await self.session.send(encode_image(self.latest_args[2])) + # print(frame.shape) + newFrame = np.array(frame) + newFrame[0:, :, 0] = 255 - newFrame[0:, :, 0] + self.video_queue.put_nowait(newFrame) + + async def video_emit(self) -> VideoEmitType: + return await self.video_queue.get() + + async def receive(self, frame: tuple[int, np.ndarray]) -> None: + frame_size, array = frame + self.audio_queue.put_nowait(array) + + async def emit(self) -> AudioEmitType: + if not self.args_set.is_set(): + await self.wait_for_args() + array = await self.audio_queue.get() + return (self.output_sample_rate, array) + + def shutdown(self) -> None: + self.quit.set() + self.connection = None + self.args_set.clear() + self.quit.clear() + + + +css = """ +#video-source {max-width: 1500px !important; max-height: 600px !important;} +""" + +with gr.Blocks(css=css) as demo: + + with gr.Column(): + webrtc = WebRTC( + width=500, + height=1500, + label="Video Chat", + modality="audio-video", + mode="send-receive", + show_local_video='picture-in-picture', + elem_id="video-source", + ) + webrtc.stream( + GeminiHandler(), + inputs=[webrtc], + outputs=[webrtc], + time_limit=90, + concurrency_limit=2, + ) + + +if __name__ == "__main__": + demo.launch() diff --git a/frontend/shared/AudioWave.svelte b/frontend/shared/AudioWave.svelte index bcbb8dc..9a31e80 100644 --- a/frontend/shared/AudioWave.svelte +++ b/frontend/shared/AudioWave.svelte @@ -10,6 +10,7 @@ export let icon: string | undefined | ComponentType = undefined; export let icon_button_color: string = "var(--color-accent)"; export let pulse_color: string = "var(--color-accent)"; + export let wave_color: string = "var(--color-accent)"; let audioContext: AudioContext; let analyser: AnalyserNode; @@ -52,12 +53,23 @@ // Update bars const bars = document.querySelectorAll('.gradio-webrtc-waveContainer .gradio-webrtc-box'); for (let i = 0; i < bars.length; i++) { - const barHeight = (dataArray[i] / 255) * 2; + const barHeight = (dataArray[transformIndex(i)] / 255); bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`; + bars[i].style.background = wave_color; + bars[i].style.opacity = 0.5; } animationId = requestAnimationFrame(updateVisualization); } + + // 声波高度从两侧向中间收拢 + function transformIndex(index: number): number { + const mapping = [0, 2, 4, 6, 8, 10, 12, 14, 15, 13, 11, 9, 7, 5, 3, 1]; + if (index < 0 || index >= mapping.length) { + throw new Error("Index must be between 0 and 15"); + } + return mapping[index]; +}