diff --git a/backend/gradio_webrtc/__init__.py b/backend/gradio_webrtc/__init__.py index edb1140..08a4395 100644 --- a/backend/gradio_webrtc/__init__.py +++ b/backend/gradio_webrtc/__init__.py @@ -4,12 +4,14 @@ from .credentials import ( get_twilio_turn_credentials, ) from .reply_on_pause import AlgoOptions, ReplyOnPause, SileroVadOptions -from .utils import AdditionalOutputs +from .utils import AdditionalOutputs, audio_to_bytes, audio_to_file from .webrtc import StreamHandler, WebRTC __all__ = [ "AlgoOptions", "AdditionalOutputs", + "audio_to_bytes", + "audio_to_file", "get_hf_turn_credentials", "get_twilio_turn_credentials", "get_turn_credentials", diff --git a/backend/gradio_webrtc/reply_on_pause.py b/backend/gradio_webrtc/reply_on_pause.py index 3146230..bb25171 100644 --- a/backend/gradio_webrtc/reply_on_pause.py +++ b/backend/gradio_webrtc/reply_on_pause.py @@ -70,6 +70,10 @@ ReplyFnGenerator = Union[ ] +async def iterate(generator: Generator) -> Any: + return next(generator) + + class ReplyOnPause(StreamHandler): def __init__( self, @@ -86,6 +90,7 @@ class ReplyOnPause(StreamHandler): self.output_frame_size = output_frame_size self.model = get_vad_model() self.fn = fn + self.is_async = inspect.isasyncgenfunction(fn) self.event = Event() self.state = AppState() self.generator = None @@ -172,6 +177,9 @@ class ReplyOnPause(StreamHandler): self.channel.send("tick") logger.debug("Sent tick") + async def async_iterate(self, generator) -> Any: + return await anext(generator) + def emit(self): if not self.event.is_set(): return None @@ -190,6 +198,11 @@ class ReplyOnPause(StreamHandler): logger.debug("Latest args: %s", self.latest_args) self.state.responding = True try: - return next(self.generator) - except StopIteration: + if self.is_async: + return asyncio.run_coroutine_threadsafe( + self.async_iterate(self.generator), self.loop + ).result() + else: + return next(self.generator) + except (StopIteration, StopAsyncIteration): self.reset() diff --git a/backend/gradio_webrtc/utils.py b/backend/gradio_webrtc/utils.py index 3d1f446..ba0bdab 100644 --- a/backend/gradio_webrtc/utils.py +++ b/backend/gradio_webrtc/utils.py @@ -1,10 +1,13 @@ import asyncio import fractions +import io import logging +import tempfile from typing import Any, Callable, Protocol, cast import av import numpy as np +from pydub import AudioSegment logger = logging.getLogger(__name__) @@ -120,3 +123,67 @@ async def player_worker_decode( logger.debug("traceback %s", exec) logger.error("Error processing frame: %s", str(e)) continue + + +def audio_to_bytes(audio: tuple[int, np.ndarray]) -> bytes: + """ + Convert an audio tuple containing sample rate and numpy array data into bytes. + + Parameters + ---------- + audio : tuple[int, np.ndarray] + A tuple containing: + - sample_rate (int): The audio sample rate in Hz + - data (np.ndarray): The audio data as a numpy array + + Returns + ------- + bytes + The audio data encoded as bytes, suitable for transmission or storage + + Example + ------- + >>> sample_rate = 44100 + >>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples + >>> audio_tuple = (sample_rate, audio_data) + >>> audio_bytes = audio_to_bytes(audio_tuple) + """ + audio_buffer = io.BytesIO() + segment = AudioSegment( + audio[1].tobytes(), + frame_rate=audio[0], + sample_width=audio[1].dtype.itemsize, + channels=1, + ) + segment.export(audio_buffer, format="mp3") + return audio_buffer.getvalue() + + +def audio_to_file(audio: tuple[int, np.ndarray]) -> str: + """ + Save an audio tuple containing sample rate and numpy array data to a file. + + Parameters + ---------- + audio : tuple[int, np.ndarray] + A tuple containing: + - sample_rate (int): The audio sample rate in Hz + - data (np.ndarray): The audio data as a numpy array + + Returns + ------- + str + The path to the saved audio file + + Example + ------- + >>> sample_rate = 44100 + >>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples + >>> audio_tuple = (sample_rate, audio_data) + >>> file_path = audio_to_file(audio_tuple) + >>> print(f"Audio saved to: {file_path}") + """ + bytes_ = audio_to_bytes(audio) + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: + f.write(bytes_) + return f.name diff --git a/docs/cookbook.md b/docs/cookbook.md index 3b45479..4a5d8b1 100644 --- a/docs/cookbook.md +++ b/docs/cookbook.md @@ -24,6 +24,18 @@ [:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-claude/blob/main/app.py) +- :speaking_head:{ .lg .middle } __Kyutai Moshi__ + + --- + + Kyutai's moshi is a novel speech-to-speech model for modeling human conversations. + + + + [:octicons-arrow-right-24: Demo](https://huggingface.co/spaces/freddyaboulton/talk-to-moshi) + + [:octicons-code-16: Code](https://huggingface.co/spaces/freddyaboulton/talk-to-moshi/blob/main/app.py) + - :robot:{ .lg .middle } __Llama Code Editor__ --- diff --git a/docs/index.md b/docs/index.md index aeccdb0..9654531 100644 --- a/docs/index.md +++ b/docs/index.md @@ -22,7 +22,4 @@ pip install gradio_webrtc[vad] ``` ## Examples -1. [Object Detection from Webcam with YOLOv10](https://huggingface.co/spaces/freddyaboulton/webrtc-yolov10n) 📷 -2. [Streaming Object Detection from Video with RT-DETR](https://huggingface.co/spaces/freddyaboulton/rt-detr-object-detection-webrtc) 🎥 -3. [Text-to-Speech](https://huggingface.co/spaces/freddyaboulton/parler-tts-streaming-webrtc) 🗣️ -4. [Conversational AI](https://huggingface.co/spaces/freddyaboulton/omni-mini-webrtc) 🤖🗣️ \ No newline at end of file +See the [cookbook](/cookbook) \ No newline at end of file diff --git a/docs/utils.md b/docs/utils.md new file mode 100644 index 0000000..f267a2d --- /dev/null +++ b/docs/utils.md @@ -0,0 +1,54 @@ +# Utils + +## `audio_to_bytes` + +Convert an audio tuple containing sample rate and numpy array data into bytes. +Useful for sending data to external APIs from `ReplyOnPause` handler. + +Parameters +``` +audio : tuple[int, np.ndarray] + A tuple containing: + - sample_rate (int): The audio sample rate in Hz + - data (np.ndarray): The audio data as a numpy array +``` + +Returns +``` +bytes + The audio data encoded as bytes, suitable for transmission or storage +``` + +Example +```python +>>> sample_rate = 44100 +>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples +>>> audio_tuple = (sample_rate, audio_data) +>>> audio_bytes = audio_to_bytes(audio_tuple) +``` + +## `audio_to_file` + +Save an audio tuple containing sample rate and numpy array data to a file. + +Parameters +``` +audio : tuple[int, np.ndarray] + A tuple containing: + - sample_rate (int): The audio sample rate in Hz + - data (np.ndarray): The audio data as a numpy array +``` +Returns +``` +str + The path to the saved audio file +``` +Example +``` +```python +>>> sample_rate = 44100 +>>> audio_data = np.array([0.1, -0.2, 0.3]) # Example audio samples +>>> audio_tuple = (sample_rate, audio_data) +>>> file_path = audio_to_file(audio_tuple) +>>> print(f"Audio saved to: {file_path}") +``` \ No newline at end of file diff --git a/frontend/shared/AudioWave.svelte b/frontend/shared/AudioWave.svelte index bd76e9e..cae982a 100644 --- a/frontend/shared/AudioWave.svelte +++ b/frontend/shared/AudioWave.svelte @@ -41,8 +41,7 @@ function updateBars() { analyser.getByteFrequencyData(dataArray); - - const bars = document.querySelectorAll('.box'); + const bars = document.querySelectorAll('.waveContainer .box'); for (let i = 0; i < bars.length; i++) { const barHeight = (dataArray[i] / 255) * 2; // Amplify the effect bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`; diff --git a/mkdocs.yml b/mkdocs.yml index ad02f7b..a9ae3b2 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -19,6 +19,7 @@ nav: - Cookbook: cookbook.md - Deployment: deployment.md - Advanced Configuration: advanced-configuration.md + - Utils: utils.md - Frequently Asked Questions: faq.md markdown_extensions: - pymdownx.highlight: diff --git a/pyproject.toml b/pyproject.toml index cc31015..7de8e4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" [project] name = "gradio_webrtc" -version = "0.0.15" +version = "0.0.16" description = "Stream images in realtime with webrtc" readme = "README.md" license = "apache-2.0" @@ -50,3 +50,6 @@ artifacts = ["/backend/gradio_webrtc/templates", "*.pyi"] [tool.hatch.build.targets.wheel] packages = ["/backend/gradio_webrtc"] + +[tool.ruff] +target-version = "py310" \ No newline at end of file