mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add text mode (#321)
* Pretty good spot * Working draft * Fix other mode * Add js to git * Working * Add code * fix * Fix * Add code * Fix submit race condition * demo * fix * Fix * Fix
This commit is contained in:
133
demo/text_mode/app.py
Normal file
133
demo/text_mode/app.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "fastrtc[vad, stt]==0.0.26.rc1",
|
||||
# "openai",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
|
||||
import gradio as gr
|
||||
import huggingface_hub
|
||||
from fastrtc import (
|
||||
AdditionalOutputs,
|
||||
ReplyOnPause,
|
||||
WebRTC,
|
||||
WebRTCData,
|
||||
WebRTCError,
|
||||
get_stt_model,
|
||||
)
|
||||
from openai import OpenAI
|
||||
|
||||
stt_model = get_stt_model()
|
||||
|
||||
conversations = {}
|
||||
|
||||
|
||||
def response(
|
||||
data: WebRTCData,
|
||||
conversation: list[dict],
|
||||
token: str | None = None,
|
||||
model: str = "meta-llama/Llama-3.2-3B-Instruct",
|
||||
provider: str = "sambanova",
|
||||
):
|
||||
print("conversation before", conversation)
|
||||
if not provider.startswith("http") and not token:
|
||||
raise WebRTCError("Please add your HF token.")
|
||||
|
||||
if data.audio is not None and data.audio[1].size > 0:
|
||||
user_audio_text = stt_model.stt(data.audio)
|
||||
conversation.append({"role": "user", "content": user_audio_text})
|
||||
else:
|
||||
conversation.append({"role": "user", "content": data.textbox})
|
||||
|
||||
yield AdditionalOutputs(conversation)
|
||||
|
||||
if provider.startswith("http"):
|
||||
client = OpenAI(base_url=provider, api_key="ollama")
|
||||
else:
|
||||
client = huggingface_hub.InferenceClient(
|
||||
api_key=token,
|
||||
provider=provider, # type: ignore
|
||||
)
|
||||
|
||||
request = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=conversation, # type: ignore
|
||||
temperature=1,
|
||||
top_p=0.1,
|
||||
)
|
||||
response = {"role": "assistant", "content": request.choices[0].message.content}
|
||||
|
||||
conversation.append(response)
|
||||
print("conversation after", conversation)
|
||||
yield AdditionalOutputs(conversation)
|
||||
|
||||
|
||||
css = """
|
||||
footer {
|
||||
display: none !important;
|
||||
}
|
||||
"""
|
||||
|
||||
providers = [
|
||||
"black-forest-labs",
|
||||
"cerebras",
|
||||
"cohere",
|
||||
"fal-ai",
|
||||
"fireworks-ai",
|
||||
"hf-inference",
|
||||
"hyperbolic",
|
||||
"nebius",
|
||||
"novita",
|
||||
"openai",
|
||||
"replicate",
|
||||
"sambanova",
|
||||
"together",
|
||||
]
|
||||
|
||||
|
||||
def hide_token(provider: str):
|
||||
if provider.startswith("http"):
|
||||
return gr.Textbox(visible=False)
|
||||
return gr.skip()
|
||||
|
||||
|
||||
with gr.Blocks(css=css) as demo:
|
||||
gr.HTML(
|
||||
"""
|
||||
<h1 style='text-align: center; display: flex; align-items: center; justify-content: center;'>
|
||||
<img src="https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/AV_Huggy.png" alt="Streaming Huggy" style="height: 50px; margin-right: 10px"> FastRTC Chat
|
||||
</h1>
|
||||
"""
|
||||
)
|
||||
with gr.Sidebar():
|
||||
token = gr.Textbox(
|
||||
placeholder="Place your HF token here", type="password", label="HF Token"
|
||||
)
|
||||
model = gr.Dropdown(
|
||||
choices=["meta-llama/Llama-3.2-3B-Instruct"],
|
||||
allow_custom_value=True,
|
||||
label="Model",
|
||||
)
|
||||
provider = gr.Dropdown(
|
||||
label="Provider",
|
||||
choices=providers,
|
||||
value="sambanova",
|
||||
info="Select a hf-compatible provider or type the url of your server, e.g. http://127.0.0.1:11434/v1 for ollama",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
provider.change(hide_token, inputs=[provider], outputs=[token])
|
||||
cb = gr.Chatbot(type="messages", height=600)
|
||||
webrtc = WebRTC(modality="audio", mode="send", variant="textbox")
|
||||
webrtc.stream(
|
||||
ReplyOnPause(response),
|
||||
inputs=[webrtc, cb, token, model, provider],
|
||||
outputs=[cb],
|
||||
concurrency_limit=100,
|
||||
)
|
||||
webrtc.on_additional_outputs(
|
||||
lambda old, new: new, inputs=[cb], outputs=[cb], concurrency_limit=100
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(server_port=6980)
|
||||
Reference in New Issue
Block a user