Introduce static type checking with pyright (#255)

This commit is contained in:
Marcus Valtonen Örnhag
2025-04-05 20:19:05 +02:00
committed by GitHub
parent d7995b8116
commit 0767030997
6 changed files with 45 additions and 26 deletions

View File

@@ -12,6 +12,8 @@ jobs:
python-version: '3.10'
- name: Run linters
run: |
pip install ruff
pip install ruff pyright
pip install -e .
ruff check .
ruff format --check --diff .
ruff format --check --diff .
pyright

View File

@@ -506,23 +506,22 @@ class Stream(WebRTCConnectionMixin):
for component in additional_input_components:
component.render()
button = gr.Button("Start Stream", variant="primary")
if additional_output_components:
with gr.Column():
output_video = WebRTC(
label="Audio Stream",
rtc_configuration=self.rtc_configuration,
track_constraints=self.track_constraints,
mode="receive",
modality="audio",
icon=ui_args.get("icon"),
icon_button_color=ui_args.get("icon_button_color"),
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = output_video
for component in additional_output_components:
if component not in same_components:
component.render()
with gr.Column():
output_video = WebRTC(
label="Audio Stream",
rtc_configuration=self.rtc_configuration,
track_constraints=self.track_constraints,
mode="receive",
modality="audio",
icon=ui_args.get("icon"),
icon_button_color=ui_args.get("icon_button_color"),
pulse_color=ui_args.get("pulse_color"),
icon_radius=ui_args.get("icon_radius"),
)
self.webrtc_component = output_video
for component in additional_output_components:
if component not in same_components:
component.render()
output_video.stream(
fn=self.event_handler,
inputs=self.additional_input_components,

View File

@@ -2,7 +2,7 @@ import asyncio
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import AsyncGenerator, Generator, Literal, Protocol
from typing import AsyncGenerator, Generator, Literal, Protocol, TypeVar
import numpy as np
from huggingface_hub import hf_hub_download
@@ -13,17 +13,20 @@ class TTSOptions:
pass
class TTSModel(Protocol):
T = TypeVar("T", bound=TTSOptions, contravariant=True)
class TTSModel(Protocol[T]):
def tts(
self, text: str, options: TTSOptions | None = None
self, text: str, options: T | None = None
) -> tuple[int, NDArray[np.float32]]: ...
async def stream_tts(
self, text: str, options: TTSOptions | None = None
def stream_tts(
self, text: str, options: T | None = None
) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]: ...
def stream_tts_sync(
self, text: str, options: TTSOptions | None = None
self, text: str, options: T | None = None
) -> Generator[tuple[int, NDArray[np.float32]], None, None]: ...

View File

@@ -976,7 +976,7 @@ class ServerToClientAudio(AudioStreamTrack):
) -> None:
self.generator: Generator[Any, None, Any] | None = None
self.event_handler = event_handler
self.event_handler._clear_queue = self.clear_queue
self.event_handler._clear_queue = self.clear_queue # pyright: ignore
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.args_set = threading.Event()

View File

@@ -196,6 +196,8 @@ async def player_worker_decode(
layout = "mono"
elif len(frame) == 3:
sample_rate, audio_array, layout = frame
else:
raise ValueError(f"frame must be of length 2 or 3, got: {len(frame)}")
logger.debug(
"received array with shape %s sample rate %s layout %s",

View File

@@ -118,3 +118,16 @@ convention = "google"
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402"]
"demo/talk_to_smolagents/app.py" = ["W291"]
[tool.pyright]
include = ["backend/fastrtc"]
exclude = [
"**/__pycache__",
"**/*.pyi",
]
reportMissingImports = false
reportMissingTypeStubs = false
pythonVersion = "3.10"
pythonPlatform = "Linux"