diff --git a/demo/video_chat.py b/demo/video_chat.py new file mode 100644 index 0000000..7de000b --- /dev/null +++ b/demo/video_chat.py @@ -0,0 +1,116 @@ +import asyncio +import base64 +import os +import time +from io import BytesIO + +import gradio as gr +import numpy as np +from gradio_webrtc import ( + AsyncAudioVideoStreamHandler, + WebRTC, + async_aggregate_bytes_to_16bit, + VideoEmitType, + AudioEmitType, +) +from PIL import Image + + +def encode_audio(data: np.ndarray) -> dict: + """Encode Audio data to send to the server""" + return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")} + + +def encode_image(data: np.ndarray) -> dict: + with BytesIO() as output_bytes: + pil_image = Image.fromarray(data) + pil_image.save(output_bytes, "JPEG") + bytes_data = output_bytes.getvalue() + base64_str = str(base64.b64encode(bytes_data), "utf-8") + return {"mime_type": "image/jpeg", "data": base64_str} + + +class GeminiHandler(AsyncAudioVideoStreamHandler): + def __init__( + self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480 + ) -> None: + super().__init__( + expected_layout, + output_sample_rate, + output_frame_size, + input_sample_rate=24000, + ) + self.audio_queue = asyncio.Queue() + self.video_queue = asyncio.Queue() + self.quit = asyncio.Event() + self.session = None + self.last_frame_time = 0 + + def copy(self) -> "GeminiHandler": + return GeminiHandler( + expected_layout=self.expected_layout, + output_sample_rate=self.output_sample_rate, + output_frame_size=self.output_frame_size, + ) + + async def video_receive(self, frame: np.ndarray): + # if self.session: + # # send image every 1 second + # if time.time() - self.last_frame_time > 1: + # self.last_frame_time = time.time() + # await self.session.send(encode_image(frame)) + # if self.latest_args[2] is not None: + # await self.session.send(encode_image(self.latest_args[2])) + # print(frame.shape) + newFrame = np.array(frame) + newFrame[0:, :, 0] = 255 - newFrame[0:, :, 0] + self.video_queue.put_nowait(newFrame) + + async def video_emit(self) -> VideoEmitType: + return await self.video_queue.get() + + async def receive(self, frame: tuple[int, np.ndarray]) -> None: + frame_size, array = frame + self.audio_queue.put_nowait(array) + + async def emit(self) -> AudioEmitType: + if not self.args_set.is_set(): + await self.wait_for_args() + array = await self.audio_queue.get() + return (self.output_sample_rate, array) + + def shutdown(self) -> None: + self.quit.set() + self.connection = None + self.args_set.clear() + self.quit.clear() + + + +css = """ +#video-source {max-width: 1500px !important; max-height: 600px !important;} +""" + +with gr.Blocks(css=css) as demo: + + with gr.Column(): + webrtc = WebRTC( + width=500, + height=1500, + label="Video Chat", + modality="audio-video", + mode="send-receive", + show_local_video='picture-in-picture', + elem_id="video-source", + ) + webrtc.stream( + GeminiHandler(), + inputs=[webrtc], + outputs=[webrtc], + time_limit=90, + concurrency_limit=2, + ) + + +if __name__ == "__main__": + demo.launch() diff --git a/frontend/shared/AudioWave.svelte b/frontend/shared/AudioWave.svelte index bcbb8dc..9a31e80 100644 --- a/frontend/shared/AudioWave.svelte +++ b/frontend/shared/AudioWave.svelte @@ -10,6 +10,7 @@ export let icon: string | undefined | ComponentType = undefined; export let icon_button_color: string = "var(--color-accent)"; export let pulse_color: string = "var(--color-accent)"; + export let wave_color: string = "var(--color-accent)"; let audioContext: AudioContext; let analyser: AnalyserNode; @@ -52,12 +53,23 @@ // Update bars const bars = document.querySelectorAll('.gradio-webrtc-waveContainer .gradio-webrtc-box'); for (let i = 0; i < bars.length; i++) { - const barHeight = (dataArray[i] / 255) * 2; + const barHeight = (dataArray[transformIndex(i)] / 255); bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`; + bars[i].style.background = wave_color; + bars[i].style.opacity = 0.5; } animationId = requestAnimationFrame(updateVisualization); } + + // 声波高度从两侧向中间收拢 + function transformIndex(index: number): number { + const mapping = [0, 2, 4, 6, 8, 10, 12, 14, 15, 13, 11, 9, 7, 5, 3, 1]; + if (index < 0 || index >= mapping.length) { + throw new Error("Index must be between 0 and 15"); + } + return mapping[index]; +}
diff --git a/frontend/shared/PulsingIcon.svelte b/frontend/shared/PulsingIcon.svelte index 450ac57..20494b1 100644 --- a/frontend/shared/PulsingIcon.svelte +++ b/frontend/shared/PulsingIcon.svelte @@ -7,6 +7,7 @@ export let icon: string | ComponentType = undefined; export let icon_button_color: string = "var(--color-accent)"; export let pulse_color: string = "var(--color-accent)"; + let audioContext: AudioContext; let analyser: AnalyserNode; diff --git a/frontend/shared/VideoChat.svelte b/frontend/shared/VideoChat.svelte index b22d4e4..ae9de05 100644 --- a/frontend/shared/VideoChat.svelte +++ b/frontend/shared/VideoChat.svelte @@ -3,7 +3,11 @@ import type { ComponentType } from "svelte"; import type { I18nFormatter } from "@gradio/utils"; + import { + Spinner, + } from "@gradio/icons"; import WebcamPermissions from "./WebcamPermissions.svelte"; + import AudioWave from "./AudioWave.svelte"; import { fade } from "svelte/transition"; import { get_devices, @@ -95,6 +99,7 @@ }: true, node, videoDeviceId, track_constraints).then( async (local_stream) => { stream = local_stream; + local_stream = local_stream; selected_video_device = available_video_devices.find( (device) => device.deviceId === videoDeviceId @@ -118,6 +123,7 @@ webcam_accessed = true; let available_devices = await get_devices(); stream = local_stream; + local_stream = local_stream; return available_devices }) .then((devices) => { @@ -156,12 +162,20 @@ let recording = false; let stream: MediaStream; + let local_stream: MediaStream; let webcam_accessed = false; let webcam_received = false; let pc: RTCPeerConnection; export let webrtc_id; + export let wave_color: string = "#7873F6"; + + const audio_source_callback = () => { + if(local_stream) return local_stream; + else return localVideoRef.srcObject as MediaStream + } + async function start_webrtc(): Promise { if (stream_state === 'closed') { pc = new RTCPeerConnection(rtc_configuration); @@ -404,8 +418,8 @@
- {#if volumeMuted} - + {#if volumeMuted} + {:else} {/if} @@ -423,7 +437,25 @@
- +
+ {#if stream_state === 'closed'} + 点击开始对话 + {:else if stream_state === 'waiting'} +
+
+ +
+ 等待中 +
+ {:else} +
+ {/if} +
+ {#if stream_state === 'open'} +
+ +
+ {/if}
@@ -561,6 +593,59 @@ .player-controls { height: 15%; position: relative; + display: flex; + justify-content: center; + align-items: center; + + .chat-btn { + height: 64px; + width: 100%; + display: flex; + justify-content: center; + align-items: center; + border-radius: 999px; + opacity: 1; + background: linear-gradient(180deg, #7873F6 0%, #524DE1 100%); + transition: all 0.3s; + z-index: 2; + } + .start-chat { + font-size: 16px; + font-weight: 500; + text-align: center; + color: #FFFFFF; + } + .waiting-icon-text { + width: 80px; + align-items: center; + font-size: 16px; + font-weight: 500; + color: #FFFFFF; + margin: 0 var(--spacing-sm); + display: flex; + justify-content: space-evenly; + gap: var(--size-1); + .icon { + width: 25px; + height: 25px; + fill: #FFFFFF; + stroke: #FFFFFF; + color: #FFFFFF; + } + } + + .stop-chat { + width: 64px; + .stop-chat-inner { + width: 25px; + height: 25px; + border-radius: 6.25px; + background: #FAFAFA; + } + } + } + .input-audio-wave { + position: absolute; } }