mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
ReplyOnPause and ReplyOnStopWords can be interrupted (#119)
* Add all this code * add code * Fix demo --------- Co-authored-by: Freddy Boulton <freddyboulton@hf-freddy.local>
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from threading import Event
|
from threading import Event
|
||||||
@@ -59,6 +59,10 @@ class AppState:
|
|||||||
stopped: bool = False
|
stopped: bool = False
|
||||||
buffer: np.ndarray | None = None
|
buffer: np.ndarray | None = None
|
||||||
responded_audio: bool = False
|
responded_audio: bool = False
|
||||||
|
interrupted: asyncio.Event = field(default_factory=asyncio.Event)
|
||||||
|
|
||||||
|
def new(self):
|
||||||
|
return AppState()
|
||||||
|
|
||||||
|
|
||||||
ReplyFnGenerator = (
|
ReplyFnGenerator = (
|
||||||
@@ -91,6 +95,7 @@ class ReplyOnPause(StreamHandler):
|
|||||||
fn: ReplyFnGenerator,
|
fn: ReplyFnGenerator,
|
||||||
algo_options: AlgoOptions | None = None,
|
algo_options: AlgoOptions | None = None,
|
||||||
model_options: SileroVadOptions | None = None,
|
model_options: SileroVadOptions | None = None,
|
||||||
|
can_interrupt: bool = True,
|
||||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||||
output_sample_rate: int = 24000,
|
output_sample_rate: int = 24000,
|
||||||
output_frame_size: int = 480,
|
output_frame_size: int = 480,
|
||||||
@@ -102,6 +107,7 @@ class ReplyOnPause(StreamHandler):
|
|||||||
output_frame_size,
|
output_frame_size,
|
||||||
input_sample_rate=input_sample_rate,
|
input_sample_rate=input_sample_rate,
|
||||||
)
|
)
|
||||||
|
self.can_interrupt = can_interrupt
|
||||||
self.expected_layout: Literal["mono", "stereo"] = expected_layout
|
self.expected_layout: Literal["mono", "stereo"] = expected_layout
|
||||||
self.output_sample_rate = output_sample_rate
|
self.output_sample_rate = output_sample_rate
|
||||||
self.output_frame_size = output_frame_size
|
self.output_frame_size = output_frame_size
|
||||||
@@ -123,6 +129,7 @@ class ReplyOnPause(StreamHandler):
|
|||||||
self.fn,
|
self.fn,
|
||||||
self.algo_options,
|
self.algo_options,
|
||||||
self.model_options,
|
self.model_options,
|
||||||
|
self.can_interrupt,
|
||||||
self.expected_layout,
|
self.expected_layout,
|
||||||
self.output_sample_rate,
|
self.output_sample_rate,
|
||||||
self.output_frame_size,
|
self.output_frame_size,
|
||||||
@@ -170,11 +177,14 @@ class ReplyOnPause(StreamHandler):
|
|||||||
state.pause_detected = pause_detected
|
state.pause_detected = pause_detected
|
||||||
|
|
||||||
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||||
if self.state.responding:
|
if self.state.responding and not self.can_interrupt:
|
||||||
return
|
return
|
||||||
self.process_audio(frame, self.state)
|
self.process_audio(frame, self.state)
|
||||||
if self.state.pause_detected:
|
if self.state.pause_detected:
|
||||||
self.event.set()
|
self.event.set()
|
||||||
|
if self.can_interrupt:
|
||||||
|
self.clear_queue()
|
||||||
|
self.generator = None
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
super().reset()
|
super().reset()
|
||||||
@@ -207,6 +217,7 @@ class ReplyOnPause(StreamHandler):
|
|||||||
else:
|
else:
|
||||||
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
|
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
|
||||||
logger.debug("Latest args: %s", self.latest_args)
|
logger.debug("Latest args: %s", self.latest_args)
|
||||||
|
self.state = self.state.new()
|
||||||
self.state.responding = True
|
self.state.responding = True
|
||||||
try:
|
try:
|
||||||
if self.is_async:
|
if self.is_async:
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ class ReplyOnStopWordsState(AppState):
|
|||||||
post_stop_word_buffer: np.ndarray | None = None
|
post_stop_word_buffer: np.ndarray | None = None
|
||||||
started_talking_pre_stop_word: bool = False
|
started_talking_pre_stop_word: bool = False
|
||||||
|
|
||||||
|
def new(self):
|
||||||
|
return ReplyOnStopWordsState()
|
||||||
|
|
||||||
|
|
||||||
class ReplyOnStopWords(ReplyOnPause):
|
class ReplyOnStopWords(ReplyOnPause):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -31,6 +34,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
stop_words: list[str],
|
stop_words: list[str],
|
||||||
algo_options: AlgoOptions | None = None,
|
algo_options: AlgoOptions | None = None,
|
||||||
model_options: SileroVadOptions | None = None,
|
model_options: SileroVadOptions | None = None,
|
||||||
|
can_interrupt: bool = True,
|
||||||
expected_layout: Literal["mono", "stereo"] = "mono",
|
expected_layout: Literal["mono", "stereo"] = "mono",
|
||||||
output_sample_rate: int = 24000,
|
output_sample_rate: int = 24000,
|
||||||
output_frame_size: int = 480,
|
output_frame_size: int = 480,
|
||||||
@@ -40,6 +44,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
fn,
|
fn,
|
||||||
algo_options=algo_options,
|
algo_options=algo_options,
|
||||||
model_options=model_options,
|
model_options=model_options,
|
||||||
|
can_interrupt=can_interrupt,
|
||||||
expected_layout=expected_layout,
|
expected_layout=expected_layout,
|
||||||
output_sample_rate=output_sample_rate,
|
output_sample_rate=output_sample_rate,
|
||||||
output_frame_size=output_frame_size,
|
output_frame_size=output_frame_size,
|
||||||
@@ -144,6 +149,7 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
self.stop_words,
|
self.stop_words,
|
||||||
self.algo_options,
|
self.algo_options,
|
||||||
self.model_options,
|
self.model_options,
|
||||||
|
self.can_interrupt,
|
||||||
self.expected_layout,
|
self.expected_layout,
|
||||||
self.output_sample_rate,
|
self.output_sample_rate,
|
||||||
self.output_frame_size,
|
self.output_frame_size,
|
||||||
|
|||||||
@@ -360,7 +360,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
image = WebRTC(
|
image = WebRTC(
|
||||||
label="Stream",
|
label="Stream",
|
||||||
rtc_configuration=self.rtc_configuration,
|
rtc_configuration=self.rtc_configuration,
|
||||||
mode="send-receive",
|
mode="send",
|
||||||
modality="audio",
|
modality="audio",
|
||||||
icon=ui_args.get("icon"),
|
icon=ui_args.get("icon"),
|
||||||
icon_button_color=ui_args.get("icon_button_color"),
|
icon_button_color=ui_args.get("icon_button_color"),
|
||||||
@@ -505,7 +505,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
return HTMLResponse(content=str(response), media_type="application/xml")
|
return HTMLResponse(content=str(response), media_type="application/xml")
|
||||||
|
|
||||||
async def telephone_handler(self, websocket: WebSocket):
|
async def telephone_handler(self, websocket: WebSocket):
|
||||||
handler = cast(StreamHandlerImpl, self.event_handler.copy())
|
handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore
|
||||||
handler.phone_mode = True
|
handler.phone_mode = True
|
||||||
|
|
||||||
async def set_handler(s: str, a: WebSocketHandler):
|
async def set_handler(s: str, a: WebSocketHandler):
|
||||||
@@ -528,7 +528,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
await ws.handle_websocket(websocket)
|
await ws.handle_websocket(websocket)
|
||||||
|
|
||||||
async def websocket_offer(self, websocket: WebSocket):
|
async def websocket_offer(self, websocket: WebSocket):
|
||||||
handler = cast(StreamHandlerImpl, self.event_handler.copy())
|
handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore
|
||||||
handler.phone_mode = False
|
handler.phone_mode = False
|
||||||
|
|
||||||
async def set_handler(s: str, a: WebSocketHandler):
|
async def set_handler(s: str, a: WebSocketHandler):
|
||||||
|
|||||||
@@ -188,6 +188,11 @@ class StreamHandlerBase(ABC):
|
|||||||
self.args_set = asyncio.Event()
|
self.args_set = asyncio.Event()
|
||||||
self.channel_set = asyncio.Event()
|
self.channel_set = asyncio.Event()
|
||||||
self._phone_mode = False
|
self._phone_mode = False
|
||||||
|
self._clear_queue: Callable | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def clear_queue(self) -> Callable:
|
||||||
|
return cast(Callable, self._clear_queue)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loop(self) -> asyncio.AbstractEventLoop:
|
def loop(self) -> asyncio.AbstractEventLoop:
|
||||||
@@ -237,8 +242,11 @@ class StreamHandlerBase(ABC):
|
|||||||
logger.debug("Sent msg %s", msg)
|
logger.debug("Sent msg %s", msg)
|
||||||
|
|
||||||
def send_message_sync(self, msg: str):
|
def send_message_sync(self, msg: str):
|
||||||
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
|
try:
|
||||||
logger.debug("Sent msg %s", msg)
|
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
|
||||||
|
logger.debug("Sent msg %s", msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Exception sending msg %s", e)
|
||||||
|
|
||||||
def set_args(self, args: list[Any]):
|
def set_args(self, args: list[Any]):
|
||||||
logger.debug("setting args in audio callback %s", args)
|
logger.debug("setting args in audio callback %s", args)
|
||||||
@@ -411,6 +419,7 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.track = track
|
self.track = track
|
||||||
self.event_handler = cast(StreamHandlerImpl, event_handler)
|
self.event_handler = cast(StreamHandlerImpl, event_handler)
|
||||||
|
self.event_handler._clear_queue = self.clear_queue
|
||||||
self.current_timestamp = 0
|
self.current_timestamp = 0
|
||||||
self.latest_args: str | list[Any] = "not_set"
|
self.latest_args: str | list[Any] = "not_set"
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
@@ -421,6 +430,12 @@ class AudioCallback(AudioStreamTrack):
|
|||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.set_additional_outputs = set_additional_outputs
|
self.set_additional_outputs = set_additional_outputs
|
||||||
|
|
||||||
|
def clear_queue(self):
|
||||||
|
if self.queue:
|
||||||
|
while not self.queue.empty():
|
||||||
|
self.queue.get_nowait()
|
||||||
|
self._start = None
|
||||||
|
|
||||||
def set_channel(self, channel: DataChannel):
|
def set_channel(self, channel: DataChannel):
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.event_handler.set_channel(channel)
|
self.event_handler.set_channel(channel)
|
||||||
@@ -608,6 +623,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.generator: Generator[Any, None, Any] | None = None
|
self.generator: Generator[Any, None, Any] | None = None
|
||||||
self.event_handler = event_handler
|
self.event_handler = event_handler
|
||||||
|
self.event_handler._clear_queue = self.clear_queue
|
||||||
self.current_timestamp = 0
|
self.current_timestamp = 0
|
||||||
self.latest_args: str | list[Any] = "not_set"
|
self.latest_args: str | list[Any] = "not_set"
|
||||||
self.args_set = threading.Event()
|
self.args_set = threading.Event()
|
||||||
@@ -619,6 +635,11 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
self._start: float | None = None
|
self._start: float | None = None
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def clear_queue(self):
|
||||||
|
while not self.queue.empty():
|
||||||
|
self.queue.get_nowait()
|
||||||
|
self._start = None
|
||||||
|
|
||||||
def set_channel(self, channel: DataChannel):
|
def set_channel(self, channel: DataChannel):
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
|
|
||||||
|
|||||||
@@ -320,7 +320,7 @@ def audio_to_int16(
|
|||||||
>>> audio_int16 = audio_to_int16(audio_tuple)
|
>>> audio_int16 = audio_to_int16(audio_tuple)
|
||||||
"""
|
"""
|
||||||
if audio[1].dtype == np.int16:
|
if audio[1].dtype == np.int16:
|
||||||
return audio[1]
|
return audio[1] # type: ignore
|
||||||
elif audio[1].dtype == np.float32:
|
elif audio[1].dtype == np.float32:
|
||||||
# Convert float32 to int16 by scaling to the int16 range
|
# Convert float32 to int16 by scaling to the int16 range
|
||||||
return (audio[1] * 32767.0).astype(np.int16)
|
return (audio[1] * 32767.0).astype(np.int16)
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ class WebSocketHandler:
|
|||||||
],
|
],
|
||||||
):
|
):
|
||||||
self.stream_handler = stream_handler
|
self.stream_handler = stream_handler
|
||||||
|
self.stream_handler._clear_queue = lambda: None
|
||||||
self.websocket: Optional[WebSocket] = None
|
self.websocket: Optional[WebSocket] = None
|
||||||
self._emit_task: Optional[asyncio.Task] = None
|
self._emit_task: Optional[asyncio.Task] = None
|
||||||
self.stream_id: Optional[str] = None
|
self.stream_id: Optional[str] = None
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ def response(
|
|||||||
response_text = (
|
response_text = (
|
||||||
groq_client.chat.completions.create(
|
groq_client.chat.completions.create(
|
||||||
model="llama-3.1-8b-instant",
|
model="llama-3.1-8b-instant",
|
||||||
max_tokens=512,
|
max_tokens=200,
|
||||||
messages=messages, # type: ignore
|
messages=messages, # type: ignore
|
||||||
)
|
)
|
||||||
.choices[0]
|
.choices[0]
|
||||||
@@ -49,6 +49,7 @@ def response(
|
|||||||
)
|
)
|
||||||
|
|
||||||
chatbot.append({"role": "assistant", "content": response_text})
|
chatbot.append({"role": "assistant", "content": response_text})
|
||||||
|
yield AdditionalOutputs(chatbot)
|
||||||
|
|
||||||
for chunk in tts_client.text_to_speech.convert_as_stream(
|
for chunk in tts_client.text_to_speech.convert_as_stream(
|
||||||
text=response_text, # type: ignore
|
text=response_text, # type: ignore
|
||||||
@@ -58,7 +59,6 @@ def response(
|
|||||||
):
|
):
|
||||||
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
|
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
|
||||||
yield (24000, audio_array)
|
yield (24000, audio_array)
|
||||||
yield AdditionalOutputs(chatbot)
|
|
||||||
|
|
||||||
|
|
||||||
chatbot = gr.Chatbot(type="messages")
|
chatbot = gr.Chatbot(type="messages")
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Generator, Literal
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from dotenv import load_dotenv
|
||||||
from fastrtc import (
|
from fastrtc import (
|
||||||
AdditionalOutputs,
|
AdditionalOutputs,
|
||||||
ReplyOnPause,
|
ReplyOnPause,
|
||||||
@@ -13,6 +14,8 @@ from fastrtc import (
|
|||||||
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
|
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def load_moonshine(
|
def load_moonshine(
|
||||||
@@ -27,6 +30,7 @@ tokenizer = load_tokenizer()
|
|||||||
def stt(
|
def stt(
|
||||||
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
||||||
model_name: Literal["moonshine/base", "moonshine/tiny"],
|
model_name: Literal["moonshine/base", "moonshine/tiny"],
|
||||||
|
captions: str,
|
||||||
) -> Generator[AdditionalOutputs, None, None]:
|
) -> Generator[AdditionalOutputs, None, None]:
|
||||||
moonshine = load_moonshine(model_name)
|
moonshine = load_moonshine(model_name)
|
||||||
sr, audio_np = audio # type: ignore
|
sr, audio_np = audio # type: ignore
|
||||||
@@ -35,9 +39,12 @@ def stt(
|
|||||||
if audio_np.ndim == 1:
|
if audio_np.ndim == 1:
|
||||||
audio_np = audio_np.reshape(1, -1)
|
audio_np = audio_np.reshape(1, -1)
|
||||||
tokens = moonshine.generate(audio_np)
|
tokens = moonshine.generate(audio_np)
|
||||||
yield AdditionalOutputs(tokenizer.decode_batch(tokens)[0])
|
yield AdditionalOutputs(
|
||||||
|
(captions + "\n" + tokenizer.decode_batch(tokens)[0]).strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
captions = gr.Textbox(label="Captions")
|
||||||
stream = Stream(
|
stream = Stream(
|
||||||
ReplyOnPause(stt, input_sample_rate=16000),
|
ReplyOnPause(stt, input_sample_rate=16000),
|
||||||
modality="audio",
|
modality="audio",
|
||||||
@@ -55,9 +62,10 @@ stream = Stream(
|
|||||||
choices=["moonshine/base", "moonshine/tiny"],
|
choices=["moonshine/base", "moonshine/tiny"],
|
||||||
value="moonshine/base",
|
value="moonshine/base",
|
||||||
label="Model",
|
label="Model",
|
||||||
)
|
),
|
||||||
|
captions,
|
||||||
],
|
],
|
||||||
additional_outputs=[gr.Textbox(label="Captions")],
|
additional_outputs=[captions],
|
||||||
additional_outputs_handler=lambda prev, current: (prev + "\n" + current).strip(),
|
additional_outputs_handler=lambda prev, current: (prev + "\n" + current).strip(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from fastrtc import (
|
|||||||
)
|
)
|
||||||
from gradio.utils import get_space
|
from gradio.utils import get_space
|
||||||
from groq import AsyncClient
|
from groq import AsyncClient
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
cur_dir = Path(__file__).parent
|
cur_dir = Path(__file__).parent
|
||||||
|
|
||||||
@@ -24,23 +25,23 @@ load_dotenv()
|
|||||||
groq_client = AsyncClient()
|
groq_client = AsyncClient()
|
||||||
|
|
||||||
|
|
||||||
async def transcribe(audio: tuple[int, np.ndarray]):
|
async def transcribe(audio: tuple[int, np.ndarray], transcript: str):
|
||||||
transcript = await groq_client.audio.transcriptions.create(
|
response = await groq_client.audio.transcriptions.create(
|
||||||
file=("audio-file.mp3", audio_to_bytes(audio)),
|
file=("audio-file.mp3", audio_to_bytes(audio)),
|
||||||
model="whisper-large-v3-turbo",
|
model="whisper-large-v3-turbo",
|
||||||
response_format="verbose_json",
|
response_format="verbose_json",
|
||||||
)
|
)
|
||||||
yield AdditionalOutputs(transcript.text)
|
yield AdditionalOutputs(transcript + "\n" + response.text)
|
||||||
|
|
||||||
|
|
||||||
|
transcript = gr.Textbox(label="Transcript")
|
||||||
stream = Stream(
|
stream = Stream(
|
||||||
ReplyOnPause(transcribe),
|
ReplyOnPause(transcribe),
|
||||||
modality="audio",
|
modality="audio",
|
||||||
mode="send",
|
mode="send",
|
||||||
additional_outputs=[
|
additional_inputs=[transcript],
|
||||||
gr.Textbox(label="Transcript"),
|
additional_outputs=[transcript],
|
||||||
],
|
additional_outputs_handler=lambda a, b: b,
|
||||||
additional_outputs_handler=lambda a, b: a + " " + b,
|
|
||||||
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
||||||
concurrency_limit=5 if get_space() else None,
|
concurrency_limit=5 if get_space() else None,
|
||||||
time_limit=90 if get_space() else None,
|
time_limit=90 if get_space() else None,
|
||||||
@@ -51,11 +52,21 @@ app = FastAPI()
|
|||||||
stream.mount(app)
|
stream.mount(app)
|
||||||
|
|
||||||
|
|
||||||
|
class SendInput(BaseModel):
|
||||||
|
webrtc_id: str
|
||||||
|
transcript: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/send_input")
|
||||||
|
def send_input(body: SendInput):
|
||||||
|
stream.set_input(body.webrtc_id, body.transcript)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/transcript")
|
@app.get("/transcript")
|
||||||
def _(webrtc_id: str):
|
def _(webrtc_id: str):
|
||||||
async def output_stream():
|
async def output_stream():
|
||||||
async for output in stream.output_stream(webrtc_id):
|
async for output in stream.output_stream(webrtc_id):
|
||||||
transcript = output.args[0]
|
transcript = output.args[0].split("\n")[-1]
|
||||||
yield f"event: output\ndata: {transcript}\n\n"
|
yield f"event: output\ndata: {transcript}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(output_stream(), media_type="text/event-stream")
|
return StreamingResponse(output_stream(), media_type="text/event-stream")
|
||||||
@@ -73,7 +84,7 @@ if __name__ == "__main__":
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
if (mode := os.getenv("MODE")) == "UI":
|
if (mode := os.getenv("MODE")) == "UI":
|
||||||
stream.ui.launch(server_port=7860, server_name="0.0.0.0")
|
stream.ui.launch(server_port=7860)
|
||||||
elif mode == "PHONE":
|
elif mode == "PHONE":
|
||||||
stream.fastphone(host="0.0.0.0", port=7860)
|
stream.fastphone(host="0.0.0.0", port=7860)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -193,7 +193,8 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="container">
|
<div class="container">
|
||||||
<div class="transcript-container" id="transcript"></div>
|
<div class="transcript-container" id="transcript">
|
||||||
|
</div>
|
||||||
<div class="controls">
|
<div class="controls">
|
||||||
<button id="start-button">Start Recording</button>
|
<button id="start-button">Start Recording</button>
|
||||||
</div>
|
</div>
|
||||||
@@ -220,13 +221,23 @@
|
|||||||
}, 5000);
|
}, 5000);
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleMessage(event) {
|
async function handleMessage(event) {
|
||||||
// Handle any WebRTC data channel messages if needed
|
// Handle any WebRTC data channel messages if needed
|
||||||
const eventJson = JSON.parse(event.data);
|
const eventJson = JSON.parse(event.data);
|
||||||
if (eventJson.type === "error") {
|
if (eventJson.type === "error") {
|
||||||
showError(eventJson.message);
|
showError(eventJson.message);
|
||||||
|
} else if (eventJson.type === "send_input") {
|
||||||
|
const response = await fetch('/send_input', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
webrtc_id: webrtc_id,
|
||||||
|
transcript: ""
|
||||||
|
})
|
||||||
|
});
|
||||||
}
|
}
|
||||||
console.log('Received message:', event.data);
|
console.log('Received message:', event.data);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function updateButtonState() {
|
function updateButtonState() {
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
|
|
||||||
Typically, you want to run a python function whenever a user has stopped speaking. This can be done by wrapping a python generator with the `ReplyOnPause` class and passing it to the `handler` argument of the `Stream` object. The `ReplyOnPause` class will handle the voice detection and turn taking logic automatically!
|
Typically, you want to run a python function whenever a user has stopped speaking. This can be done by wrapping a python generator with the `ReplyOnPause` class and passing it to the `handler` argument of the `Stream` object. The `ReplyOnPause` class will handle the voice detection and turn taking logic automatically!
|
||||||
|
|
||||||
|
By default, the `ReplyOnPause` handler will allow you to interrupt the response at any time by speaking again. If you do not want to allow interruption, you can set the `can_interrupt` parameter to `False`.
|
||||||
|
|
||||||
=== "Code"
|
=== "Code"
|
||||||
```python
|
```python
|
||||||
from fastrtc import ReplyOnPause, Stream
|
from fastrtc import ReplyOnPause, Stream
|
||||||
@@ -33,13 +35,14 @@ Typically, you want to run a python function whenever a user has stopped speakin
|
|||||||
You can also use an async generator with `ReplyOnPause`.
|
You can also use an async generator with `ReplyOnPause`.
|
||||||
|
|
||||||
!!! tip "Parameters"
|
!!! tip "Parameters"
|
||||||
You can customize the voice detection parameters by passing in `algo_options` and `model_options` to the `ReplyOnPause` class.
|
You can customize the voice detection parameters by passing in `algo_options` and `model_options` to the `ReplyOnPause` class. Also, you can set the `can_interrupt` parameter to `False` to prevent the user from interrupting the response. By default, `can_interrupt` is `True`.
|
||||||
```python
|
```python
|
||||||
from fastrtc import AlgoOptions, SileroVadOptions
|
from fastrtc import AlgoOptions, SileroVadOptions
|
||||||
|
|
||||||
stream = Stream(
|
stream = Stream(
|
||||||
handler=ReplyOnPause(
|
handler=ReplyOnPause(
|
||||||
response,
|
response,
|
||||||
|
can_interrupt=True,
|
||||||
algo_options=AlgoOptions(
|
algo_options=AlgoOptions(
|
||||||
audio_chunk_duration=0.6,
|
audio_chunk_duration=0.6,
|
||||||
started_talking_threshold=0.2,
|
started_talking_threshold=0.2,
|
||||||
|
|||||||
@@ -7,10 +7,11 @@
|
|||||||
import { StreamingBar } from "@gradio/statustracker";
|
import { StreamingBar } from "@gradio/statustracker";
|
||||||
import {
|
import {
|
||||||
Circle,
|
Circle,
|
||||||
Square,
|
|
||||||
Spinner,
|
Spinner,
|
||||||
Music,
|
Music,
|
||||||
DropdownArrow,
|
DropdownArrow,
|
||||||
|
VolumeMuted,
|
||||||
|
VolumeHigh,
|
||||||
Microphone,
|
Microphone,
|
||||||
} from "@gradio/icons";
|
} from "@gradio/icons";
|
||||||
|
|
||||||
@@ -77,6 +78,7 @@
|
|||||||
let available_audio_devices: MediaDeviceInfo[];
|
let available_audio_devices: MediaDeviceInfo[];
|
||||||
let selected_device: MediaDeviceInfo | null = null;
|
let selected_device: MediaDeviceInfo | null = null;
|
||||||
let mic_accessed = false;
|
let mic_accessed = false;
|
||||||
|
let is_muted = false;
|
||||||
|
|
||||||
const audio_source_callback = () => {
|
const audio_source_callback = () => {
|
||||||
if (mode === "send") return stream;
|
if (mode === "send") return stream;
|
||||||
@@ -261,6 +263,13 @@
|
|||||||
options_open = false;
|
options_open = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
function toggleMute(): void {
|
||||||
|
if (audio_player) {
|
||||||
|
audio_player.muted = !audio_player.muted;
|
||||||
|
is_muted = audio_player.muted;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
$: if (stopword_recognized) {
|
$: if (stopword_recognized) {
|
||||||
notification_sound.play();
|
notification_sound.play();
|
||||||
}
|
}
|
||||||
@@ -314,19 +323,28 @@
|
|||||||
</div>
|
</div>
|
||||||
{:else if stream_state === "open"}
|
{:else if stream_state === "open"}
|
||||||
<div class="icon-with-text">
|
<div class="icon-with-text">
|
||||||
<div
|
{#if mode === "send-receive"}
|
||||||
class="icon"
|
<div
|
||||||
title="stop recording"
|
class="icon"
|
||||||
style={`fill: ${icon_button_color}; stroke: ${icon_button_color}; color: ${icon_button_color};`}
|
title="stop recording"
|
||||||
>
|
style={`fill: ${icon_button_color}; stroke: ${icon_button_color}; color: ${icon_button_color};`}
|
||||||
<PulsingIcon
|
>
|
||||||
audio_source_callback={() => stream}
|
<PulsingIcon
|
||||||
stream_state={"open"}
|
audio_source_callback={() => stream}
|
||||||
icon={Circle}
|
stream_state={"open"}
|
||||||
{icon_button_color}
|
icon={Circle}
|
||||||
{pulse_color}
|
{icon_button_color}
|
||||||
/>
|
{pulse_color}
|
||||||
</div>
|
/>
|
||||||
|
</div>
|
||||||
|
{:else}
|
||||||
|
<div
|
||||||
|
class="icon color-primary"
|
||||||
|
title="start recording"
|
||||||
|
>
|
||||||
|
<Circle />
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
{button_labels.stop || i18n("audio.stop")}
|
{button_labels.stop || i18n("audio.stop")}
|
||||||
</div>
|
</div>
|
||||||
{:else}
|
{:else}
|
||||||
@@ -347,6 +365,24 @@
|
|||||||
<DropdownArrow />
|
<DropdownArrow />
|
||||||
</button>
|
</button>
|
||||||
{/if}
|
{/if}
|
||||||
|
{#if stream_state === "open" && mode === "send-receive"}
|
||||||
|
<button
|
||||||
|
class="mute-button"
|
||||||
|
on:click={toggleMute}
|
||||||
|
aria-label={is_muted ? "unmute audio" : "mute audio"}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
class="icon"
|
||||||
|
style={`fill: ${icon_button_color}; stroke: ${icon_button_color}; color: ${icon_button_color};`}
|
||||||
|
>
|
||||||
|
{#if is_muted}
|
||||||
|
<VolumeMuted />
|
||||||
|
{:else}
|
||||||
|
<VolumeHigh />
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
|
{/if}
|
||||||
{#if options_open && selected_device}
|
{#if options_open && selected_device}
|
||||||
<select
|
<select
|
||||||
class="select-wrap"
|
class="select-wrap"
|
||||||
@@ -511,4 +547,11 @@
|
|||||||
.select-wrap > option:last-child {
|
.select-wrap > option:last-child {
|
||||||
border: none;
|
border: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.mute-button {
|
||||||
|
background-color: var(--block-background-fill);
|
||||||
|
padding-right: var(--size-2);
|
||||||
|
display: flex;
|
||||||
|
color: var(--button-secondary-text-color);
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "fastrtc"
|
name = "fastrtc"
|
||||||
version = "0.0.10"
|
version = "0.0.11"
|
||||||
description = "The realtime communication library for Python"
|
description = "The realtime communication library for Python"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "apache-2.0"
|
license = "apache-2.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user