mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 09:29:23 +08:00
Video Bugfix + generator (#96)
* Code * Fix demo * move to init --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
@@ -25,6 +25,7 @@ from .utils import (
|
||||
audio_to_bytes,
|
||||
audio_to_file,
|
||||
audio_to_float32,
|
||||
audio_to_int16,
|
||||
wait_for_item,
|
||||
)
|
||||
from .webrtc import (
|
||||
@@ -43,6 +44,7 @@ __all__ = [
|
||||
"audio_to_bytes",
|
||||
"audio_to_file",
|
||||
"audio_to_float32",
|
||||
"audio_to_int16",
|
||||
"get_hf_turn_credentials",
|
||||
"get_twilio_turn_credentials",
|
||||
"get_turn_credentials",
|
||||
|
||||
@@ -63,7 +63,7 @@ class AppState:
|
||||
|
||||
ReplyFnGenerator = (
|
||||
Callable[
|
||||
[tuple[int, NDArray[np.int16]], list[dict[Any, Any]]],
|
||||
[tuple[int, NDArray[np.int16]], Any],
|
||||
Generator[EmitType, None, None],
|
||||
]
|
||||
| Callable[
|
||||
@@ -75,7 +75,7 @@ ReplyFnGenerator = (
|
||||
AsyncGenerator[EmitType, None],
|
||||
]
|
||||
| Callable[
|
||||
[tuple[int, NDArray[np.int16]], list[dict[Any, Any]]],
|
||||
[tuple[int, NDArray[np.int16]], Any],
|
||||
AsyncGenerator[EmitType, None],
|
||||
]
|
||||
)
|
||||
|
||||
@@ -62,6 +62,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
additional_outputs: list[Component] | None = None,
|
||||
ui_args: UIArgs | None = None,
|
||||
):
|
||||
WebRTCConnectionMixin.__init__(self)
|
||||
self.mode = mode
|
||||
self.modality = modality
|
||||
self.rtp_params = rtp_params
|
||||
|
||||
@@ -294,6 +294,40 @@ def audio_to_float32(
|
||||
return audio[1].astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def audio_to_int16(
|
||||
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
||||
) -> NDArray[np.int16]:
|
||||
"""
|
||||
Convert an audio tuple containing sample rate and numpy array data to int16.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : tuple[int, np.ndarray]
|
||||
A tuple containing:
|
||||
- sample_rate (int): The audio sample rate in Hz
|
||||
- data (np.ndarray): The audio data as a numpy array
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The audio data as a numpy array with dtype int16
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> sample_rate = 44100
|
||||
>>> audio_data = np.array([0.1, -0.2, 0.3], dtype=np.float32) # Example audio samples
|
||||
>>> audio_tuple = (sample_rate, audio_data)
|
||||
>>> audio_int16 = audio_to_int16(audio_tuple)
|
||||
"""
|
||||
if audio[1].dtype == np.int16:
|
||||
return audio[1]
|
||||
elif audio[1].dtype == np.float32:
|
||||
# Convert float32 to int16 by scaling to the int16 range
|
||||
return (audio[1] * 32767.0).astype(np.int16)
|
||||
else:
|
||||
raise TypeError(f"Unsupported audio data type: {audio[1].dtype}")
|
||||
|
||||
|
||||
def aggregate_bytes_to_16bit(chunks_iterator):
|
||||
"""
|
||||
Aggregate bytes to 16-bit audio samples.
|
||||
|
||||
@@ -122,6 +122,7 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
button_labels: Text to display on the audio or video start, stop, waiting buttons. Dict with keys "start", "stop", "waiting" mapping to the text to display on the buttons.
|
||||
icon_radius: Border radius of the icon button expressed as a percentage of the button size. Default is 50%
|
||||
"""
|
||||
WebRTCConnectionMixin.__init__(self)
|
||||
self.time_limit = time_limit
|
||||
self.height = height
|
||||
self.width = width
|
||||
@@ -230,15 +231,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
inputs = [inputs]
|
||||
inputs = list(inputs)
|
||||
|
||||
def handler(webrtc_id: str, *args):
|
||||
if self.additional_outputs[webrtc_id].queue.qsize() > 0:
|
||||
next_outputs = self.additional_outputs[webrtc_id].queue.get_nowait()
|
||||
return fn(*args, *next_outputs.args) # type: ignore
|
||||
return (
|
||||
tuple([None for _ in range(len(outputs))])
|
||||
if isinstance(outputs, Iterable)
|
||||
else None
|
||||
)
|
||||
async def handler(webrtc_id: str, *args):
|
||||
async for next_outputs in self.output_stream(webrtc_id):
|
||||
yield fn(*args, *next_outputs.args) # type: ignore
|
||||
|
||||
return self.state_change( # type: ignore
|
||||
fn=handler,
|
||||
@@ -247,9 +242,9 @@ class WebRTC(Component, WebRTCConnectionMixin):
|
||||
js=js,
|
||||
concurrency_limit=concurrency_limit,
|
||||
concurrency_id=concurrency_id,
|
||||
show_progress=show_progress,
|
||||
show_progress="minimal",
|
||||
queue=queue,
|
||||
trigger_mode="multiple",
|
||||
trigger_mode="once",
|
||||
)
|
||||
|
||||
def stream(
|
||||
|
||||
@@ -35,7 +35,6 @@ from fastrtc.tracks import (
|
||||
)
|
||||
from fastrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
DataChannel,
|
||||
create_message,
|
||||
webrtc_error_handler,
|
||||
)
|
||||
@@ -64,18 +63,20 @@ class OutputQueue:
|
||||
|
||||
|
||||
class WebRTCConnectionMixin:
|
||||
pcs: set[RTCPeerConnection] = set([])
|
||||
relay = MediaRelay()
|
||||
connections: dict[str, list[Track]] = defaultdict(list)
|
||||
data_channels: dict[str, DataChannel] = {}
|
||||
additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue)
|
||||
handlers: dict[str, HandlerType | Callable] = {}
|
||||
connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event)
|
||||
concurrency_limit: int | float
|
||||
event_handler: HandlerType
|
||||
time_limit: float | int | None
|
||||
modality: Literal["video", "audio", "audio-video"]
|
||||
mode: Literal["send", "receive", "send-receive"]
|
||||
def __init__(self):
|
||||
self.pcs = set([])
|
||||
self.relay = MediaRelay()
|
||||
self.connections = defaultdict(list)
|
||||
self.data_channels = {}
|
||||
self.additional_outputs = defaultdict(OutputQueue)
|
||||
self.handlers = {}
|
||||
self.connection_timeouts = defaultdict(asyncio.Event)
|
||||
# These attributes should be set by subclasses:
|
||||
self.concurrency_limit: int | float | None
|
||||
self.event_handler: HandlerType | None
|
||||
self.time_limit: float | None
|
||||
self.modality: Literal["video", "audio", "audio-video"]
|
||||
self.mode: Literal["send", "receive", "send-receive"]
|
||||
|
||||
@staticmethod
|
||||
async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import subprocess
|
||||
|
||||
subprocess.run(["pip", "install", "fastrtc==0.0.4.post1"])
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
@@ -80,12 +76,11 @@ class PhonicHandler(AsyncStreamHandler):
|
||||
return super().shutdown()
|
||||
|
||||
|
||||
def add_to_chatbot(state, chatbot, message):
|
||||
state.append(message)
|
||||
return state, gr.skip()
|
||||
def add_to_chatbot(chatbot, message):
|
||||
chatbot.append(message)
|
||||
return chatbot
|
||||
|
||||
|
||||
state = gr.State(value=[])
|
||||
chatbot = gr.Chatbot(type="messages", value=[])
|
||||
stream = Stream(
|
||||
handler=PhonicHandler(),
|
||||
@@ -99,7 +94,7 @@ stream = Stream(
|
||||
info="Select a voice from the dropdown",
|
||||
)
|
||||
],
|
||||
additional_outputs=[state, chatbot],
|
||||
additional_outputs=[chatbot],
|
||||
additional_outputs_handler=add_to_chatbot,
|
||||
ui_args={
|
||||
"title": "Phonic Chat (Powered by FastRTC ⚡️)",
|
||||
@@ -109,8 +104,8 @@ stream = Stream(
|
||||
time_limit=90 if get_space() else None,
|
||||
)
|
||||
|
||||
with stream.ui:
|
||||
state.change(lambda s: s, inputs=state, outputs=chatbot)
|
||||
# with stream.ui:
|
||||
# state.change(lambda s: s, inputs=state, outputs=chatbot)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if (mode := os.getenv("MODE")) == "UI":
|
||||
|
||||
@@ -38,6 +38,7 @@ def response(
|
||||
):
|
||||
gradio_chatbot = gradio_chatbot or []
|
||||
conversation_state = conversation_state or []
|
||||
print("chatbot", gradio_chatbot)
|
||||
|
||||
text = stt_model.stt(audio)
|
||||
sample_rate, array = audio
|
||||
@@ -47,7 +48,6 @@ def response(
|
||||
yield AdditionalOutputs(gradio_chatbot, conversation_state)
|
||||
|
||||
conversation_state.append({"role": "user", "content": text})
|
||||
|
||||
request = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
messages=conversation_state, # type: ignore
|
||||
|
||||
@@ -10,7 +10,7 @@ from fastrtc import Stream
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
def detection(image):
|
||||
def detection(image, slider):
|
||||
return np.flip(image, axis=0)
|
||||
|
||||
stream = Stream(
|
||||
|
||||
@@ -148,6 +148,7 @@
|
||||
{icon}
|
||||
{icon_button_color}
|
||||
{pulse_color}
|
||||
{icon_radius}
|
||||
{button_labels}
|
||||
on:clear={() => gradio.dispatch("clear")}
|
||||
on:play={() => gradio.dispatch("play")}
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
export let icon: string | undefined | ComponentType = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
export let icon_radius: number = 50;
|
||||
|
||||
const dispatch = createEventDispatcher<{
|
||||
change: FileData | null;
|
||||
@@ -62,8 +63,8 @@
|
||||
{icon}
|
||||
{icon_button_color}
|
||||
{pulse_color}
|
||||
{button_labels}
|
||||
{icon_radius}
|
||||
{button_labels}
|
||||
on:error
|
||||
on:start_recording
|
||||
on:stop_recording
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
export let rtp_params: RTCRtpParameters = {} as RTCRtpParameters;
|
||||
export let icon: string | undefined | ComponentType = undefined;
|
||||
export let icon_button_color: string = "var(--color-accent)";
|
||||
export let icon_radius: number = 50;
|
||||
export let pulse_color: string = "var(--color-accent)";
|
||||
export let button_labels: { start: string; stop: string; waiting: string };
|
||||
|
||||
@@ -242,6 +243,7 @@
|
||||
icon={icon || Mic}
|
||||
{icon_button_color}
|
||||
{pulse_color}
|
||||
{icon_radius}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "fastrtc"
|
||||
version = "0.0.8post1"
|
||||
version = "0.0.9"
|
||||
description = "The realtime communication library for Python"
|
||||
readme = "README.md"
|
||||
license = "apache-2.0"
|
||||
|
||||
Reference in New Issue
Block a user