From 0767030997199a005014e4e6ead3beb556a67868 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcus=20Valtonen=20=C3=96rnhag?= Date: Sat, 5 Apr 2025 20:19:05 +0200 Subject: [PATCH] Introduce static type checking with pyright (#255) --- .github/workflows/tests.yml | 6 +++-- backend/fastrtc/stream.py | 33 +++++++++++++-------------- backend/fastrtc/text_to_speech/tts.py | 15 +++++++----- backend/fastrtc/tracks.py | 2 +- backend/fastrtc/utils.py | 2 ++ pyproject.toml | 13 +++++++++++ 6 files changed, 45 insertions(+), 26 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ce5759a..0b91e09 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,6 +12,8 @@ jobs: python-version: '3.10' - name: Run linters run: | - pip install ruff + pip install ruff pyright + pip install -e . ruff check . - ruff format --check --diff . \ No newline at end of file + ruff format --check --diff . + pyright \ No newline at end of file diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py index b6d383f..aa0b964 100644 --- a/backend/fastrtc/stream.py +++ b/backend/fastrtc/stream.py @@ -506,23 +506,22 @@ class Stream(WebRTCConnectionMixin): for component in additional_input_components: component.render() button = gr.Button("Start Stream", variant="primary") - if additional_output_components: - with gr.Column(): - output_video = WebRTC( - label="Audio Stream", - rtc_configuration=self.rtc_configuration, - track_constraints=self.track_constraints, - mode="receive", - modality="audio", - icon=ui_args.get("icon"), - icon_button_color=ui_args.get("icon_button_color"), - pulse_color=ui_args.get("pulse_color"), - icon_radius=ui_args.get("icon_radius"), - ) - self.webrtc_component = output_video - for component in additional_output_components: - if component not in same_components: - component.render() + with gr.Column(): + output_video = WebRTC( + label="Audio Stream", + rtc_configuration=self.rtc_configuration, + track_constraints=self.track_constraints, + mode="receive", + modality="audio", + icon=ui_args.get("icon"), + icon_button_color=ui_args.get("icon_button_color"), + pulse_color=ui_args.get("pulse_color"), + icon_radius=ui_args.get("icon_radius"), + ) + self.webrtc_component = output_video + for component in additional_output_components: + if component not in same_components: + component.render() output_video.stream( fn=self.event_handler, inputs=self.additional_input_components, diff --git a/backend/fastrtc/text_to_speech/tts.py b/backend/fastrtc/text_to_speech/tts.py index ee477fd..00a6523 100644 --- a/backend/fastrtc/text_to_speech/tts.py +++ b/backend/fastrtc/text_to_speech/tts.py @@ -2,7 +2,7 @@ import asyncio import re from dataclasses import dataclass from functools import lru_cache -from typing import AsyncGenerator, Generator, Literal, Protocol +from typing import AsyncGenerator, Generator, Literal, Protocol, TypeVar import numpy as np from huggingface_hub import hf_hub_download @@ -13,17 +13,20 @@ class TTSOptions: pass -class TTSModel(Protocol): +T = TypeVar("T", bound=TTSOptions, contravariant=True) + + +class TTSModel(Protocol[T]): def tts( - self, text: str, options: TTSOptions | None = None + self, text: str, options: T | None = None ) -> tuple[int, NDArray[np.float32]]: ... - async def stream_tts( - self, text: str, options: TTSOptions | None = None + def stream_tts( + self, text: str, options: T | None = None ) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]: ... def stream_tts_sync( - self, text: str, options: TTSOptions | None = None + self, text: str, options: T | None = None ) -> Generator[tuple[int, NDArray[np.float32]], None, None]: ... diff --git a/backend/fastrtc/tracks.py b/backend/fastrtc/tracks.py index bbfa42d..5b6155a 100644 --- a/backend/fastrtc/tracks.py +++ b/backend/fastrtc/tracks.py @@ -976,7 +976,7 @@ class ServerToClientAudio(AudioStreamTrack): ) -> None: self.generator: Generator[Any, None, Any] | None = None self.event_handler = event_handler - self.event_handler._clear_queue = self.clear_queue + self.event_handler._clear_queue = self.clear_queue # pyright: ignore self.current_timestamp = 0 self.latest_args: str | list[Any] = "not_set" self.args_set = threading.Event() diff --git a/backend/fastrtc/utils.py b/backend/fastrtc/utils.py index 5f2802d..04587c4 100644 --- a/backend/fastrtc/utils.py +++ b/backend/fastrtc/utils.py @@ -196,6 +196,8 @@ async def player_worker_decode( layout = "mono" elif len(frame) == 3: sample_rate, audio_array, layout = frame + else: + raise ValueError(f"frame must be of length 2 or 3, got: {len(frame)}") logger.debug( "received array with shape %s sample rate %s layout %s", diff --git a/pyproject.toml b/pyproject.toml index 73ab183..da598c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,3 +118,16 @@ convention = "google" [tool.ruff.lint.per-file-ignores] "__init__.py" = ["E402"] "demo/talk_to_smolagents/app.py" = ["W291"] + +[tool.pyright] +include = ["backend/fastrtc"] +exclude = [ + "**/__pycache__", + "**/*.pyi", +] + +reportMissingImports = false +reportMissingTypeStubs = false + +pythonVersion = "3.10" +pythonPlatform = "Linux" \ No newline at end of file