Enforce modern typing (#258)

* Allow UP

* Upgrade typing

* test smolagents

* Change to contextlib

---------

Co-authored-by: Marcus Valtonen Örnhag <marcus.valtonen.ornhag@ericsson.com>
This commit is contained in:
Marcus Valtonen Örnhag
2025-04-08 22:46:12 +02:00
committed by GitHub
parent a07e9439b6
commit f70b27bd41
15 changed files with 43 additions and 47 deletions

View File

@@ -2,7 +2,6 @@ import logging
import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import List
import click
import numpy as np
@@ -102,7 +101,7 @@ class SileroVADModel:
return h, c
@staticmethod
def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray:
def collect_chunks(audio: np.ndarray, chunks: list[AudioChunk]) -> np.ndarray:
"""Collects and concatenates audio chunks."""
if not chunks:
return np.array([], dtype=np.float32)
@@ -116,7 +115,7 @@ class SileroVADModel:
audio: np.ndarray,
vad_options: SileroVadOptions,
**kwargs,
) -> List[AudioChunk]:
) -> list[AudioChunk]:
"""This method is used for splitting long audios into speech chunks using silero VAD.
Args:

View File

@@ -1,9 +1,10 @@
import asyncio
import inspect
from collections.abc import AsyncGenerator, Callable, Generator
from dataclasses import dataclass, field
from logging import getLogger
from threading import Event
from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast
from typing import Any, Literal, cast
import numpy as np
from numpy.typing import NDArray

View File

@@ -1,7 +1,8 @@
import asyncio
import logging
import re
from typing import Callable, Literal
from collections.abc import Callable
from typing import Literal
import numpy as np

View File

@@ -1,11 +1,10 @@
import logging
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from pathlib import Path
from typing import (
Any,
AsyncContextManager,
Callable,
Literal,
Optional,
TypedDict,
cast,
)
@@ -29,8 +28,8 @@ curr_dir = Path(__file__).parent
class Body(BaseModel):
sdp: Optional[str] = None
candidate: Optional[dict[str, Any]] = None
sdp: str | None = None
candidate: dict[str, Any] | None = None
type: str
webrtc_id: str
@@ -253,7 +252,7 @@ class Stream(WebRTCConnectionMixin):
return wrapper
def _inject_startup_message(
self, lifespan: Callable[[FastAPI], AsyncContextManager] | None = None
self, lifespan: Callable[[FastAPI], AbstractAsyncContextManager] | None = None
):
"""
Create a FastAPI lifespan context manager to print startup messages and check environment.

View File

@@ -1,8 +1,9 @@
import asyncio
import re
from collections.abc import AsyncGenerator, Generator
from dataclasses import dataclass
from functools import lru_cache
from typing import AsyncGenerator, Generator, Literal, Protocol, TypeVar
from typing import Literal, Protocol, TypeVar
import numpy as np
from huggingface_hub import hf_hub_download

View File

@@ -12,15 +12,12 @@ import time
import traceback
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Generator
from dataclasses import dataclass
from typing import (
Any,
Generator,
Literal,
Tuple,
TypeAlias,
Union,
cast,
)
@@ -51,11 +48,11 @@ from fastrtc.utils import (
logger = logging.getLogger(__name__)
VideoNDArray: TypeAlias = Union[
np.ndarray[Any, np.dtype[np.uint8]],
np.ndarray[Any, np.dtype[np.uint16]],
np.ndarray[Any, np.dtype[np.float32]],
]
VideoNDArray: TypeAlias = (
np.ndarray[Any, np.dtype[np.uint8]]
| np.ndarray[Any, np.dtype[np.uint16]]
| np.ndarray[Any, np.dtype[np.float32]]
)
VideoEmitType = (
VideoNDArray
@@ -219,7 +216,7 @@ class VideoCallback(VideoStreamTrack):
else:
raise WebRTCError(str(e)) from e
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
async def next_timestamp(self) -> tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError
@@ -906,7 +903,7 @@ class ServerToClientVideo(VideoStreamTrack):
self.latest_args = list(args)
self.args_set.set()
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
async def next_timestamp(self) -> tuple[int, fractions.Fraction]:
"""Override to control frame rate"""
if self.readyState != "live":
raise MediaStreamError

View File

@@ -7,9 +7,10 @@ import json
import logging
import tempfile
import traceback
from collections.abc import Callable
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
from typing import Any, Literal, Protocol, TypedDict, cast
import av
import librosa
@@ -136,7 +137,7 @@ def split_output(
raise ValueError(
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
)
if not isinstance(data[-1], (AdditionalOutputs, CloseStream)):
if not isinstance(data[-1], AdditionalOutputs | CloseStream):
raise ValueError(
"The last element of the tuple must be an instance of AdditionalOutputs."
)

View File

@@ -3,15 +3,13 @@
from __future__ import annotations
import logging
from collections.abc import Callable
from collections.abc import Callable, Iterable, Sequence
from typing import (
TYPE_CHECKING,
Any,
Concatenate,
Iterable,
Literal,
ParamSpec,
Sequence,
TypeVar,
cast,
)

View File

@@ -6,10 +6,9 @@ import asyncio
import inspect
import logging
from collections import defaultdict
from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable
from dataclasses import dataclass, field
from typing import (
AsyncGenerator,
Literal,
ParamSpec,
TypeVar,

View File

@@ -2,7 +2,8 @@ import asyncio
import audioop
import base64
import logging
from typing import Any, Awaitable, Callable, Optional, cast
from collections.abc import Awaitable, Callable
from typing import Any, cast
import anyio
import librosa
@@ -57,9 +58,9 @@ class WebSocketHandler:
):
self.stream_handler = stream_handler
self.stream_handler._clear_queue = self._clear_queue
self.websocket: Optional[WebSocket] = None
self._emit_task: Optional[asyncio.Task] = None
self.stream_id: Optional[str] = None
self.websocket: WebSocket | None = None
self._emit_task: asyncio.Task | None = None
self.stream_id: str | None = None
self.set_additional_outputs_factory = additional_outputs_factory
self.set_additional_outputs: Callable[[AdditionalOutputs], None]
self.set_handler = set_handler
@@ -67,8 +68,8 @@ class WebSocketHandler:
self.clean_up = clean_up
self.queue = asyncio.Queue()
self.playing_durations = [] # Track durations of frames being played
self._frame_cleanup_task: Optional[asyncio.Task] = None
self._graceful_shutdown_task: Optional[asyncio.Task] = None
self._frame_cleanup_task: asyncio.Task | None = None
self._graceful_shutdown_task: asyncio.Task | None = None
def _clear_queue(self):
old_queue = self.queue

View File

@@ -1,8 +1,8 @@
import asyncio
import base64
import os
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import AsyncGenerator
import librosa
import numpy as np

View File

@@ -1,5 +1,6 @@
from functools import lru_cache
from typing import Generator, Literal
from collections.abc import Generator
from functools import cache
from typing import Literal
import gradio as gr
import numpy as np
@@ -17,7 +18,7 @@ from numpy.typing import NDArray
load_dotenv()
@lru_cache(maxsize=None)
@cache
def load_moonshine(
model_name: Literal["moonshine/base", "moonshine/tiny"],
) -> MoonshineOnnxModel:

View File

@@ -3,7 +3,8 @@ import base64
import json
import os
import pathlib
from typing import AsyncGenerator, Literal
from collections.abc import AsyncGenerator
from typing import Literal
import gradio as gr
import numpy as np

View File

@@ -1,5 +1,4 @@
from pathlib import Path
from typing import Dict, List
from dotenv import load_dotenv
from fastrtc import (
@@ -22,7 +21,7 @@ stt_model = get_stt_model()
tts_model = get_tts_model()
# Conversation state to maintain history
conversation_state: List[Dict[str, str]] = []
conversation_state: list[dict[str, str]] = []
# System prompt for agent
system_prompt = """You are a helpful assistant that can helps with finding places to
@@ -78,9 +77,7 @@ def process_response(audio):
response_content = agent.run(input_text)
# Convert response to audio using TTS model
for audio_chunk in tts_model.stream_tts_sync(response_content or ""):
# Yield the audio chunk
yield audio_chunk
yield from tts_model.stream_tts_sync(response_content or "")
stream = Stream(

View File

@@ -105,12 +105,12 @@ select = [
"W",
"Q",
"I",
"UP",
]
# These can be turned on when the framework is more mature (Too many errors right now)
exclude = [
"D",
"UP"
]
# Avoid enforcing line-length violations (`E501`)