mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
stt models (#147)
This commit is contained in:
@@ -13,7 +13,7 @@ from .reply_on_pause import (
|
|||||||
ReplyFnGenerator,
|
ReplyFnGenerator,
|
||||||
ReplyOnPause,
|
ReplyOnPause,
|
||||||
)
|
)
|
||||||
from .speech_to_text import get_stt_model
|
from .speech_to_text import get_stt_model, stt_for_chunks
|
||||||
from .utils import audio_to_float32, create_message
|
from .utils import audio_to_float32, create_message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -105,10 +105,9 @@ class ReplyOnStopWords(ReplyOnPause):
|
|||||||
dur_vad, chunks = self.model.vad(
|
dur_vad, chunks = self.model.vad(
|
||||||
(16000, state.post_stop_word_buffer),
|
(16000, state.post_stop_word_buffer),
|
||||||
self.model_options,
|
self.model_options,
|
||||||
return_chunks=True,
|
|
||||||
)
|
)
|
||||||
text = self.stt_model.stt_for_chunks(
|
text = stt_for_chunks(
|
||||||
(16000, state.post_stop_word_buffer), chunks
|
self.stt_model, (16000, state.post_stop_word_buffer), chunks
|
||||||
)
|
)
|
||||||
logger.debug(f"STT: {text}")
|
logger.debug(f"STT: {text}")
|
||||||
state.stop_word_detected = self.stop_word_detected(text)
|
state.stop_word_detected = self.stop_word_detected(text)
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from .stt_ import MoonshineSTT, get_stt_model
|
from .stt_ import MoonshineSTT, get_stt_model, stt_for_chunks
|
||||||
|
|
||||||
__all__ = ["get_stt_model", "MoonshineSTT", "get_stt_model"]
|
__all__ = ["get_stt_model", "MoonshineSTT", "get_stt_model", "stt_for_chunks"]
|
||||||
|
|||||||
@@ -15,12 +15,6 @@ curr_dir = Path(__file__).parent
|
|||||||
class STTModel(Protocol):
|
class STTModel(Protocol):
|
||||||
def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str: ...
|
def stt(self, audio: tuple[int, NDArray[np.int16 | np.float32]]) -> str: ...
|
||||||
|
|
||||||
def stt_for_chunks(
|
|
||||||
self,
|
|
||||||
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
|
||||||
chunks: list[AudioChunk],
|
|
||||||
) -> str: ...
|
|
||||||
|
|
||||||
|
|
||||||
class MoonshineSTT(STTModel):
|
class MoonshineSTT(STTModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -49,19 +43,6 @@ class MoonshineSTT(STTModel):
|
|||||||
tokens = self.model.generate(audio_np)
|
tokens = self.model.generate(audio_np)
|
||||||
return self.tokenizer.decode_batch(tokens)[0]
|
return self.tokenizer.decode_batch(tokens)[0]
|
||||||
|
|
||||||
def stt_for_chunks(
|
|
||||||
self,
|
|
||||||
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
|
||||||
chunks: list[AudioChunk],
|
|
||||||
) -> str:
|
|
||||||
sr, audio_np = audio
|
|
||||||
return " ".join(
|
|
||||||
[
|
|
||||||
self.stt((sr, audio_np[chunk["start"] : chunk["end"]]))
|
|
||||||
for chunk in chunks
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_stt_model(
|
def get_stt_model(
|
||||||
@@ -79,3 +60,17 @@ def get_stt_model(
|
|||||||
m.stt((16000, audio))
|
m.stt((16000, audio))
|
||||||
print(click.style("INFO", fg="green") + ":\t STT model warmed up.")
|
print(click.style("INFO", fg="green") + ":\t STT model warmed up.")
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def stt_for_chunks(
|
||||||
|
stt_model: STTModel,
|
||||||
|
audio: tuple[int, NDArray[np.int16 | np.float32]],
|
||||||
|
chunks: list[AudioChunk],
|
||||||
|
) -> str:
|
||||||
|
sr, audio_np = audio
|
||||||
|
return " ".join(
|
||||||
|
[
|
||||||
|
stt_model.stt((sr, audio_np[chunk["start"] : chunk["end"]]))
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import fastapi
|
import fastapi
|
||||||
from fastapi.responses import FileResponse
|
|
||||||
from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions
|
from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions
|
||||||
from fastrtc.utils import audio_to_bytes
|
from fastrtc.utils import audio_to_bytes
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@@ -9,7 +8,6 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from elevenlabs import VoiceSettings, stream
|
from elevenlabs import VoiceSettings, stream
|
||||||
from elevenlabs.client import ElevenLabs
|
from elevenlabs.client import ElevenLabs
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import io
|
|
||||||
|
|
||||||
from .env import LLM_API_KEY, ELEVENLABS_API_KEY
|
from .env import LLM_API_KEY, ELEVENLABS_API_KEY
|
||||||
|
|
||||||
@@ -22,16 +20,14 @@ Begin a conversation with a self-deprecating joke like 'I'm not sure if I'm read
|
|||||||
|
|
||||||
messages = [{"role": "system", "content": sys_prompt}]
|
messages = [{"role": "system", "content": sys_prompt}]
|
||||||
|
|
||||||
openai_client = OpenAI(
|
openai_client = OpenAI(api_key=LLM_API_KEY)
|
||||||
api_key=LLM_API_KEY
|
|
||||||
)
|
|
||||||
|
|
||||||
elevenlabs_client = ElevenLabs(api_key=ELEVENLABS_API_KEY)
|
elevenlabs_client = ElevenLabs(api_key=ELEVENLABS_API_KEY)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
def echo(audio):
|
|
||||||
|
|
||||||
|
def echo(audio):
|
||||||
stt_time = time.time()
|
stt_time = time.time()
|
||||||
|
|
||||||
logging.info("Performing STT")
|
logging.info("Performing STT")
|
||||||
@@ -60,10 +56,7 @@ def echo(audio):
|
|||||||
full_response = ""
|
full_response = ""
|
||||||
|
|
||||||
response = openai_client.chat.completions.create(
|
response = openai_client.chat.completions.create(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo", messages=messages, max_tokens=200, stream=True
|
||||||
messages=messages,
|
|
||||||
max_tokens=200,
|
|
||||||
stream=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
@@ -77,18 +70,17 @@ def echo(audio):
|
|||||||
text=text_stream(),
|
text=text_stream(),
|
||||||
voice="Rachel", # Cassidy is also really good
|
voice="Rachel", # Cassidy is also really good
|
||||||
voice_settings=VoiceSettings(
|
voice_settings=VoiceSettings(
|
||||||
similarity_boost=0.9,
|
similarity_boost=0.9, stability=0.6, style=0.4, speed=1
|
||||||
stability=0.6,
|
|
||||||
style=0.4,
|
|
||||||
speed=1
|
|
||||||
),
|
),
|
||||||
model="eleven_multilingual_v2",
|
model="eleven_multilingual_v2",
|
||||||
output_format="pcm_24000",
|
output_format="pcm_24000",
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for audio_chunk in audio_stream:
|
for audio_chunk in audio_stream:
|
||||||
audio_array = np.frombuffer(audio_chunk, dtype=np.int16).astype(np.float32) / 32768.0
|
audio_array = (
|
||||||
|
np.frombuffer(audio_chunk, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
)
|
||||||
yield (24000, audio_array)
|
yield (24000, audio_array)
|
||||||
|
|
||||||
messages.append({"role": "assistant", "content": full_response + " "})
|
messages.append({"role": "assistant", "content": full_response + " "})
|
||||||
@@ -96,22 +88,25 @@ def echo(audio):
|
|||||||
logging.info(f"LLM took {time.time() - llm_time} seconds")
|
logging.info(f"LLM took {time.time() - llm_time} seconds")
|
||||||
|
|
||||||
|
|
||||||
stream = Stream(ReplyOnPause(echo,
|
stream = Stream(
|
||||||
algo_options=AlgoOptions(
|
ReplyOnPause(
|
||||||
audio_chunk_duration=0.5,
|
echo,
|
||||||
started_talking_threshold=0.1,
|
algo_options=AlgoOptions(
|
||||||
speech_threshold=0.03
|
audio_chunk_duration=0.5,
|
||||||
),
|
started_talking_threshold=0.1,
|
||||||
model_options=SileroVadOptions(
|
speech_threshold=0.03,
|
||||||
threshold=0.75,
|
),
|
||||||
min_speech_duration_ms=250,
|
model_options=SileroVadOptions(
|
||||||
min_silence_duration_ms=1500,
|
threshold=0.75,
|
||||||
speech_pad_ms=400,
|
min_speech_duration_ms=250,
|
||||||
max_speech_duration_s=15
|
min_silence_duration_ms=1500,
|
||||||
)),
|
speech_pad_ms=400,
|
||||||
modality="audio",
|
max_speech_duration_s=15,
|
||||||
mode="send-receive"
|
),
|
||||||
)
|
),
|
||||||
|
modality="audio",
|
||||||
|
mode="send-receive",
|
||||||
|
)
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
|
|
||||||
@@ -125,6 +120,7 @@ app.add_middleware(
|
|||||||
|
|
||||||
stream.mount(app)
|
stream.mount(app)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/reset")
|
@app.get("/reset")
|
||||||
async def reset():
|
async def reset():
|
||||||
global messages
|
global messages
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ nav:
|
|||||||
- Cookbook: cookbook.md
|
- Cookbook: cookbook.md
|
||||||
- Deployment: deployment.md
|
- Deployment: deployment.md
|
||||||
- Advanced Configuration: advanced-configuration.md
|
- Advanced Configuration: advanced-configuration.md
|
||||||
|
- Speech-to-Text Gallery: speech_to_text_gallery.md
|
||||||
- VAD Gallery: vad_gallery.md
|
- VAD Gallery: vad_gallery.md
|
||||||
- Utils: utils.md
|
- Utils: utils.md
|
||||||
- Frequently Asked Questions: faq.md
|
- Frequently Asked Questions: faq.md
|
||||||
|
|||||||
@@ -83,4 +83,4 @@ packages = ["/backend/fastrtc"]
|
|||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
extend-exclude = ["demo/phonic_chat"]
|
extend-exclude = ["demo/phonic_chat", "demo/nextjs_voice_chat"]
|
||||||
|
|||||||
Reference in New Issue
Block a user