Enforce modern typing (#258)

* Allow UP

* Upgrade typing

* test smolagents

* Change to contextlib

---------

Co-authored-by: Marcus Valtonen Örnhag <marcus.valtonen.ornhag@ericsson.com>
This commit is contained in:
Marcus Valtonen Örnhag
2025-04-08 22:46:12 +02:00
committed by GitHub
parent a07e9439b6
commit f70b27bd41
15 changed files with 43 additions and 47 deletions

View File

@@ -2,7 +2,6 @@ import logging
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import List
import click import click
import numpy as np import numpy as np
@@ -102,7 +101,7 @@ class SileroVADModel:
return h, c return h, c
@staticmethod @staticmethod
def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray: def collect_chunks(audio: np.ndarray, chunks: list[AudioChunk]) -> np.ndarray:
"""Collects and concatenates audio chunks.""" """Collects and concatenates audio chunks."""
if not chunks: if not chunks:
return np.array([], dtype=np.float32) return np.array([], dtype=np.float32)
@@ -116,7 +115,7 @@ class SileroVADModel:
audio: np.ndarray, audio: np.ndarray,
vad_options: SileroVadOptions, vad_options: SileroVadOptions,
**kwargs, **kwargs,
) -> List[AudioChunk]: ) -> list[AudioChunk]:
"""This method is used for splitting long audios into speech chunks using silero VAD. """This method is used for splitting long audios into speech chunks using silero VAD.
Args: Args:

View File

@@ -1,9 +1,10 @@
import asyncio import asyncio
import inspect import inspect
from collections.abc import AsyncGenerator, Callable, Generator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from logging import getLogger from logging import getLogger
from threading import Event from threading import Event
from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast from typing import Any, Literal, cast
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray

View File

@@ -1,7 +1,8 @@
import asyncio import asyncio
import logging import logging
import re import re
from typing import Callable, Literal from collections.abc import Callable
from typing import Literal
import numpy as np import numpy as np

View File

@@ -1,11 +1,10 @@
import logging import logging
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
AsyncContextManager,
Callable,
Literal, Literal,
Optional,
TypedDict, TypedDict,
cast, cast,
) )
@@ -29,8 +28,8 @@ curr_dir = Path(__file__).parent
class Body(BaseModel): class Body(BaseModel):
sdp: Optional[str] = None sdp: str | None = None
candidate: Optional[dict[str, Any]] = None candidate: dict[str, Any] | None = None
type: str type: str
webrtc_id: str webrtc_id: str
@@ -253,7 +252,7 @@ class Stream(WebRTCConnectionMixin):
return wrapper return wrapper
def _inject_startup_message( def _inject_startup_message(
self, lifespan: Callable[[FastAPI], AsyncContextManager] | None = None self, lifespan: Callable[[FastAPI], AbstractAsyncContextManager] | None = None
): ):
""" """
Create a FastAPI lifespan context manager to print startup messages and check environment. Create a FastAPI lifespan context manager to print startup messages and check environment.

View File

@@ -1,8 +1,9 @@
import asyncio import asyncio
import re import re
from collections.abc import AsyncGenerator, Generator
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import AsyncGenerator, Generator, Literal, Protocol, TypeVar from typing import Literal, Protocol, TypeVar
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download

View File

@@ -12,15 +12,12 @@ import time
import traceback import traceback
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable, Generator
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Any, Any,
Generator,
Literal, Literal,
Tuple,
TypeAlias, TypeAlias,
Union,
cast, cast,
) )
@@ -51,11 +48,11 @@ from fastrtc.utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VideoNDArray: TypeAlias = Union[ VideoNDArray: TypeAlias = (
np.ndarray[Any, np.dtype[np.uint8]], np.ndarray[Any, np.dtype[np.uint8]]
np.ndarray[Any, np.dtype[np.uint16]], | np.ndarray[Any, np.dtype[np.uint16]]
np.ndarray[Any, np.dtype[np.float32]], | np.ndarray[Any, np.dtype[np.float32]]
] )
VideoEmitType = ( VideoEmitType = (
VideoNDArray VideoNDArray
@@ -219,7 +216,7 @@ class VideoCallback(VideoStreamTrack):
else: else:
raise WebRTCError(str(e)) from e raise WebRTCError(str(e)) from e
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]: async def next_timestamp(self) -> tuple[int, fractions.Fraction]:
"""Override to control frame rate""" """Override to control frame rate"""
if self.readyState != "live": if self.readyState != "live":
raise MediaStreamError raise MediaStreamError
@@ -906,7 +903,7 @@ class ServerToClientVideo(VideoStreamTrack):
self.latest_args = list(args) self.latest_args = list(args)
self.args_set.set() self.args_set.set()
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]: async def next_timestamp(self) -> tuple[int, fractions.Fraction]:
"""Override to control frame rate""" """Override to control frame rate"""
if self.readyState != "live": if self.readyState != "live":
raise MediaStreamError raise MediaStreamError

View File

@@ -7,9 +7,10 @@ import json
import logging import logging
import tempfile import tempfile
import traceback import traceback
from collections.abc import Callable
from contextvars import ContextVar from contextvars import ContextVar
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Literal, Protocol, TypedDict, cast from typing import Any, Literal, Protocol, TypedDict, cast
import av import av
import librosa import librosa
@@ -136,7 +137,7 @@ def split_output(
raise ValueError( raise ValueError(
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs." "The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
) )
if not isinstance(data[-1], (AdditionalOutputs, CloseStream)): if not isinstance(data[-1], AdditionalOutputs | CloseStream):
raise ValueError( raise ValueError(
"The last element of the tuple must be an instance of AdditionalOutputs." "The last element of the tuple must be an instance of AdditionalOutputs."
) )

View File

@@ -3,15 +3,13 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections.abc import Callable from collections.abc import Callable, Iterable, Sequence
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Concatenate, Concatenate,
Iterable,
Literal, Literal,
ParamSpec, ParamSpec,
Sequence,
TypeVar, TypeVar,
cast, cast,
) )

View File

@@ -6,10 +6,9 @@ import asyncio
import inspect import inspect
import logging import logging
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import AsyncGenerator, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ( from typing import (
AsyncGenerator,
Literal, Literal,
ParamSpec, ParamSpec,
TypeVar, TypeVar,

View File

@@ -2,7 +2,8 @@ import asyncio
import audioop import audioop
import base64 import base64
import logging import logging
from typing import Any, Awaitable, Callable, Optional, cast from collections.abc import Awaitable, Callable
from typing import Any, cast
import anyio import anyio
import librosa import librosa
@@ -57,9 +58,9 @@ class WebSocketHandler:
): ):
self.stream_handler = stream_handler self.stream_handler = stream_handler
self.stream_handler._clear_queue = self._clear_queue self.stream_handler._clear_queue = self._clear_queue
self.websocket: Optional[WebSocket] = None self.websocket: WebSocket | None = None
self._emit_task: Optional[asyncio.Task] = None self._emit_task: asyncio.Task | None = None
self.stream_id: Optional[str] = None self.stream_id: str | None = None
self.set_additional_outputs_factory = additional_outputs_factory self.set_additional_outputs_factory = additional_outputs_factory
self.set_additional_outputs: Callable[[AdditionalOutputs], None] self.set_additional_outputs: Callable[[AdditionalOutputs], None]
self.set_handler = set_handler self.set_handler = set_handler
@@ -67,8 +68,8 @@ class WebSocketHandler:
self.clean_up = clean_up self.clean_up = clean_up
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.playing_durations = [] # Track durations of frames being played self.playing_durations = [] # Track durations of frames being played
self._frame_cleanup_task: Optional[asyncio.Task] = None self._frame_cleanup_task: asyncio.Task | None = None
self._graceful_shutdown_task: Optional[asyncio.Task] = None self._graceful_shutdown_task: asyncio.Task | None = None
def _clear_queue(self): def _clear_queue(self):
old_queue = self.queue old_queue = self.queue

View File

@@ -1,8 +1,8 @@
import asyncio import asyncio
import base64 import base64
import os import os
from collections.abc import AsyncGenerator
from pathlib import Path from pathlib import Path
from typing import AsyncGenerator
import librosa import librosa
import numpy as np import numpy as np

View File

@@ -1,5 +1,6 @@
from functools import lru_cache from collections.abc import Generator
from typing import Generator, Literal from functools import cache
from typing import Literal
import gradio as gr import gradio as gr
import numpy as np import numpy as np
@@ -17,7 +18,7 @@ from numpy.typing import NDArray
load_dotenv() load_dotenv()
@lru_cache(maxsize=None) @cache
def load_moonshine( def load_moonshine(
model_name: Literal["moonshine/base", "moonshine/tiny"], model_name: Literal["moonshine/base", "moonshine/tiny"],
) -> MoonshineOnnxModel: ) -> MoonshineOnnxModel:

View File

@@ -3,7 +3,8 @@ import base64
import json import json
import os import os
import pathlib import pathlib
from typing import AsyncGenerator, Literal from collections.abc import AsyncGenerator
from typing import Literal
import gradio as gr import gradio as gr
import numpy as np import numpy as np

View File

@@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Dict, List
from dotenv import load_dotenv from dotenv import load_dotenv
from fastrtc import ( from fastrtc import (
@@ -22,7 +21,7 @@ stt_model = get_stt_model()
tts_model = get_tts_model() tts_model = get_tts_model()
# Conversation state to maintain history # Conversation state to maintain history
conversation_state: List[Dict[str, str]] = [] conversation_state: list[dict[str, str]] = []
# System prompt for agent # System prompt for agent
system_prompt = """You are a helpful assistant that can helps with finding places to system_prompt = """You are a helpful assistant that can helps with finding places to
@@ -78,9 +77,7 @@ def process_response(audio):
response_content = agent.run(input_text) response_content = agent.run(input_text)
# Convert response to audio using TTS model # Convert response to audio using TTS model
for audio_chunk in tts_model.stream_tts_sync(response_content or ""): yield from tts_model.stream_tts_sync(response_content or "")
# Yield the audio chunk
yield audio_chunk
stream = Stream( stream = Stream(

View File

@@ -105,12 +105,12 @@ select = [
"W", "W",
"Q", "Q",
"I", "I",
"UP",
] ]
# These can be turned on when the framework is more mature (Too many errors right now) # These can be turned on when the framework is more mature (Too many errors right now)
exclude = [ exclude = [
"D", "D",
"UP"
] ]
# Avoid enforcing line-length violations (`E501`) # Avoid enforcing line-length violations (`E501`)