diff --git a/backend/fastrtc/pause_detection/silero.py b/backend/fastrtc/pause_detection/silero.py index ecbcd61..d6012b0 100644 --- a/backend/fastrtc/pause_detection/silero.py +++ b/backend/fastrtc/pause_detection/silero.py @@ -2,7 +2,6 @@ import logging import warnings from dataclasses import dataclass from functools import lru_cache -from typing import List import click import numpy as np @@ -102,7 +101,7 @@ class SileroVADModel: return h, c @staticmethod - def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray: + def collect_chunks(audio: np.ndarray, chunks: list[AudioChunk]) -> np.ndarray: """Collects and concatenates audio chunks.""" if not chunks: return np.array([], dtype=np.float32) @@ -116,7 +115,7 @@ class SileroVADModel: audio: np.ndarray, vad_options: SileroVadOptions, **kwargs, - ) -> List[AudioChunk]: + ) -> list[AudioChunk]: """This method is used for splitting long audios into speech chunks using silero VAD. Args: diff --git a/backend/fastrtc/reply_on_pause.py b/backend/fastrtc/reply_on_pause.py index 84ae358..c4c6ba7 100644 --- a/backend/fastrtc/reply_on_pause.py +++ b/backend/fastrtc/reply_on_pause.py @@ -1,9 +1,10 @@ import asyncio import inspect +from collections.abc import AsyncGenerator, Callable, Generator from dataclasses import dataclass, field from logging import getLogger from threading import Event -from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast +from typing import Any, Literal, cast import numpy as np from numpy.typing import NDArray diff --git a/backend/fastrtc/reply_on_stopwords.py b/backend/fastrtc/reply_on_stopwords.py index ba5a1b7..723e61f 100644 --- a/backend/fastrtc/reply_on_stopwords.py +++ b/backend/fastrtc/reply_on_stopwords.py @@ -1,7 +1,8 @@ import asyncio import logging import re -from typing import Callable, Literal +from collections.abc import Callable +from typing import Literal import numpy as np diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py index aa0b964..0a96a28 100644 --- a/backend/fastrtc/stream.py +++ b/backend/fastrtc/stream.py @@ -1,11 +1,10 @@ import logging +from collections.abc import Callable +from contextlib import AbstractAsyncContextManager from pathlib import Path from typing import ( Any, - AsyncContextManager, - Callable, Literal, - Optional, TypedDict, cast, ) @@ -29,8 +28,8 @@ curr_dir = Path(__file__).parent class Body(BaseModel): - sdp: Optional[str] = None - candidate: Optional[dict[str, Any]] = None + sdp: str | None = None + candidate: dict[str, Any] | None = None type: str webrtc_id: str @@ -253,7 +252,7 @@ class Stream(WebRTCConnectionMixin): return wrapper def _inject_startup_message( - self, lifespan: Callable[[FastAPI], AsyncContextManager] | None = None + self, lifespan: Callable[[FastAPI], AbstractAsyncContextManager] | None = None ): """ Create a FastAPI lifespan context manager to print startup messages and check environment. diff --git a/backend/fastrtc/text_to_speech/tts.py b/backend/fastrtc/text_to_speech/tts.py index 00a6523..5e910e1 100644 --- a/backend/fastrtc/text_to_speech/tts.py +++ b/backend/fastrtc/text_to_speech/tts.py @@ -1,8 +1,9 @@ import asyncio import re +from collections.abc import AsyncGenerator, Generator from dataclasses import dataclass from functools import lru_cache -from typing import AsyncGenerator, Generator, Literal, Protocol, TypeVar +from typing import Literal, Protocol, TypeVar import numpy as np from huggingface_hub import hf_hub_download diff --git a/backend/fastrtc/tracks.py b/backend/fastrtc/tracks.py index 5b6155a..c022633 100644 --- a/backend/fastrtc/tracks.py +++ b/backend/fastrtc/tracks.py @@ -12,15 +12,12 @@ import time import traceback import warnings from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Generator from dataclasses import dataclass from typing import ( Any, - Generator, Literal, - Tuple, TypeAlias, - Union, cast, ) @@ -51,11 +48,11 @@ from fastrtc.utils import ( logger = logging.getLogger(__name__) -VideoNDArray: TypeAlias = Union[ - np.ndarray[Any, np.dtype[np.uint8]], - np.ndarray[Any, np.dtype[np.uint16]], - np.ndarray[Any, np.dtype[np.float32]], -] +VideoNDArray: TypeAlias = ( + np.ndarray[Any, np.dtype[np.uint8]] + | np.ndarray[Any, np.dtype[np.uint16]] + | np.ndarray[Any, np.dtype[np.float32]] +) VideoEmitType = ( VideoNDArray @@ -219,7 +216,7 @@ class VideoCallback(VideoStreamTrack): else: raise WebRTCError(str(e)) from e - async def next_timestamp(self) -> Tuple[int, fractions.Fraction]: + async def next_timestamp(self) -> tuple[int, fractions.Fraction]: """Override to control frame rate""" if self.readyState != "live": raise MediaStreamError @@ -906,7 +903,7 @@ class ServerToClientVideo(VideoStreamTrack): self.latest_args = list(args) self.args_set.set() - async def next_timestamp(self) -> Tuple[int, fractions.Fraction]: + async def next_timestamp(self) -> tuple[int, fractions.Fraction]: """Override to control frame rate""" if self.readyState != "live": raise MediaStreamError diff --git a/backend/fastrtc/utils.py b/backend/fastrtc/utils.py index 04587c4..148bda8 100644 --- a/backend/fastrtc/utils.py +++ b/backend/fastrtc/utils.py @@ -7,9 +7,10 @@ import json import logging import tempfile import traceback +from collections.abc import Callable from contextvars import ContextVar from dataclasses import dataclass -from typing import Any, Callable, Literal, Protocol, TypedDict, cast +from typing import Any, Literal, Protocol, TypedDict, cast import av import librosa @@ -136,7 +137,7 @@ def split_output( raise ValueError( "The tuple must have exactly two elements: the data and an instance of AdditionalOutputs." ) - if not isinstance(data[-1], (AdditionalOutputs, CloseStream)): + if not isinstance(data[-1], AdditionalOutputs | CloseStream): raise ValueError( "The last element of the tuple must be an instance of AdditionalOutputs." ) diff --git a/backend/fastrtc/webrtc.py b/backend/fastrtc/webrtc.py index 75c4220..cb6edeb 100644 --- a/backend/fastrtc/webrtc.py +++ b/backend/fastrtc/webrtc.py @@ -3,15 +3,13 @@ from __future__ import annotations import logging -from collections.abc import Callable +from collections.abc import Callable, Iterable, Sequence from typing import ( TYPE_CHECKING, Any, Concatenate, - Iterable, Literal, ParamSpec, - Sequence, TypeVar, cast, ) diff --git a/backend/fastrtc/webrtc_connection_mixin.py b/backend/fastrtc/webrtc_connection_mixin.py index 3ce24b6..774b507 100644 --- a/backend/fastrtc/webrtc_connection_mixin.py +++ b/backend/fastrtc/webrtc_connection_mixin.py @@ -6,10 +6,9 @@ import asyncio import inspect import logging from collections import defaultdict -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable from dataclasses import dataclass, field from typing import ( - AsyncGenerator, Literal, ParamSpec, TypeVar, diff --git a/backend/fastrtc/websocket.py b/backend/fastrtc/websocket.py index 2bf60a6..4ee1e87 100644 --- a/backend/fastrtc/websocket.py +++ b/backend/fastrtc/websocket.py @@ -2,7 +2,8 @@ import asyncio import audioop import base64 import logging -from typing import Any, Awaitable, Callable, Optional, cast +from collections.abc import Awaitable, Callable +from typing import Any, cast import anyio import librosa @@ -57,9 +58,9 @@ class WebSocketHandler: ): self.stream_handler = stream_handler self.stream_handler._clear_queue = self._clear_queue - self.websocket: Optional[WebSocket] = None - self._emit_task: Optional[asyncio.Task] = None - self.stream_id: Optional[str] = None + self.websocket: WebSocket | None = None + self._emit_task: asyncio.Task | None = None + self.stream_id: str | None = None self.set_additional_outputs_factory = additional_outputs_factory self.set_additional_outputs: Callable[[AdditionalOutputs], None] self.set_handler = set_handler @@ -67,8 +68,8 @@ class WebSocketHandler: self.clean_up = clean_up self.queue = asyncio.Queue() self.playing_durations = [] # Track durations of frames being played - self._frame_cleanup_task: Optional[asyncio.Task] = None - self._graceful_shutdown_task: Optional[asyncio.Task] = None + self._frame_cleanup_task: asyncio.Task | None = None + self._graceful_shutdown_task: asyncio.Task | None = None def _clear_queue(self): old_queue = self.queue diff --git a/demo/gemini_conversation/app.py b/demo/gemini_conversation/app.py index 9fb0959..1476eac 100644 --- a/demo/gemini_conversation/app.py +++ b/demo/gemini_conversation/app.py @@ -1,8 +1,8 @@ import asyncio import base64 import os +from collections.abc import AsyncGenerator from pathlib import Path -from typing import AsyncGenerator import librosa import numpy as np diff --git a/demo/moonshine_live/app.py b/demo/moonshine_live/app.py index f6db735..a58fb72 100644 --- a/demo/moonshine_live/app.py +++ b/demo/moonshine_live/app.py @@ -1,5 +1,6 @@ -from functools import lru_cache -from typing import Generator, Literal +from collections.abc import Generator +from functools import cache +from typing import Literal import gradio as gr import numpy as np @@ -17,7 +18,7 @@ from numpy.typing import NDArray load_dotenv() -@lru_cache(maxsize=None) +@cache def load_moonshine( model_name: Literal["moonshine/base", "moonshine/tiny"], ) -> MoonshineOnnxModel: diff --git a/demo/talk_to_gemini/app.py b/demo/talk_to_gemini/app.py index 7e8929c..52155f1 100644 --- a/demo/talk_to_gemini/app.py +++ b/demo/talk_to_gemini/app.py @@ -3,7 +3,8 @@ import base64 import json import os import pathlib -from typing import AsyncGenerator, Literal +from collections.abc import AsyncGenerator +from typing import Literal import gradio as gr import numpy as np diff --git a/demo/talk_to_smolagents/app.py b/demo/talk_to_smolagents/app.py index 99231d7..66332db 100644 --- a/demo/talk_to_smolagents/app.py +++ b/demo/talk_to_smolagents/app.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Dict, List from dotenv import load_dotenv from fastrtc import ( @@ -22,7 +21,7 @@ stt_model = get_stt_model() tts_model = get_tts_model() # Conversation state to maintain history -conversation_state: List[Dict[str, str]] = [] +conversation_state: list[dict[str, str]] = [] # System prompt for agent system_prompt = """You are a helpful assistant that can helps with finding places to @@ -78,9 +77,7 @@ def process_response(audio): response_content = agent.run(input_text) # Convert response to audio using TTS model - for audio_chunk in tts_model.stream_tts_sync(response_content or ""): - # Yield the audio chunk - yield audio_chunk + yield from tts_model.stream_tts_sync(response_content or "") stream = Stream( diff --git a/pyproject.toml b/pyproject.toml index 8f938a5..9720fe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,12 +105,12 @@ select = [ "W", "Q", "I", + "UP", ] # These can be turned on when the framework is more mature (Too many errors right now) exclude = [ "D", - "UP" ] # Avoid enforcing line-length violations (`E501`)