mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
Introduce static type checking with pyright (#255)
This commit is contained in:
committed by
GitHub
parent
d7995b8116
commit
0767030997
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -12,6 +12,8 @@ jobs:
|
||||
python-version: '3.10'
|
||||
- name: Run linters
|
||||
run: |
|
||||
pip install ruff
|
||||
pip install ruff pyright
|
||||
pip install -e .
|
||||
ruff check .
|
||||
ruff format --check --diff .
|
||||
ruff format --check --diff .
|
||||
pyright
|
||||
@@ -506,23 +506,22 @@ class Stream(WebRTCConnectionMixin):
|
||||
for component in additional_input_components:
|
||||
component.render()
|
||||
button = gr.Button("Start Stream", variant="primary")
|
||||
if additional_output_components:
|
||||
with gr.Column():
|
||||
output_video = WebRTC(
|
||||
label="Audio Stream",
|
||||
rtc_configuration=self.rtc_configuration,
|
||||
track_constraints=self.track_constraints,
|
||||
mode="receive",
|
||||
modality="audio",
|
||||
icon=ui_args.get("icon"),
|
||||
icon_button_color=ui_args.get("icon_button_color"),
|
||||
pulse_color=ui_args.get("pulse_color"),
|
||||
icon_radius=ui_args.get("icon_radius"),
|
||||
)
|
||||
self.webrtc_component = output_video
|
||||
for component in additional_output_components:
|
||||
if component not in same_components:
|
||||
component.render()
|
||||
with gr.Column():
|
||||
output_video = WebRTC(
|
||||
label="Audio Stream",
|
||||
rtc_configuration=self.rtc_configuration,
|
||||
track_constraints=self.track_constraints,
|
||||
mode="receive",
|
||||
modality="audio",
|
||||
icon=ui_args.get("icon"),
|
||||
icon_button_color=ui_args.get("icon_button_color"),
|
||||
pulse_color=ui_args.get("pulse_color"),
|
||||
icon_radius=ui_args.get("icon_radius"),
|
||||
)
|
||||
self.webrtc_component = output_video
|
||||
for component in additional_output_components:
|
||||
if component not in same_components:
|
||||
component.render()
|
||||
output_video.stream(
|
||||
fn=self.event_handler,
|
||||
inputs=self.additional_input_components,
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import AsyncGenerator, Generator, Literal, Protocol
|
||||
from typing import AsyncGenerator, Generator, Literal, Protocol, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -13,17 +13,20 @@ class TTSOptions:
|
||||
pass
|
||||
|
||||
|
||||
class TTSModel(Protocol):
|
||||
T = TypeVar("T", bound=TTSOptions, contravariant=True)
|
||||
|
||||
|
||||
class TTSModel(Protocol[T]):
|
||||
def tts(
|
||||
self, text: str, options: TTSOptions | None = None
|
||||
self, text: str, options: T | None = None
|
||||
) -> tuple[int, NDArray[np.float32]]: ...
|
||||
|
||||
async def stream_tts(
|
||||
self, text: str, options: TTSOptions | None = None
|
||||
def stream_tts(
|
||||
self, text: str, options: T | None = None
|
||||
) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]: ...
|
||||
|
||||
def stream_tts_sync(
|
||||
self, text: str, options: TTSOptions | None = None
|
||||
self, text: str, options: T | None = None
|
||||
) -> Generator[tuple[int, NDArray[np.float32]], None, None]: ...
|
||||
|
||||
|
||||
|
||||
@@ -976,7 +976,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
) -> None:
|
||||
self.generator: Generator[Any, None, Any] | None = None
|
||||
self.event_handler = event_handler
|
||||
self.event_handler._clear_queue = self.clear_queue
|
||||
self.event_handler._clear_queue = self.clear_queue # pyright: ignore
|
||||
self.current_timestamp = 0
|
||||
self.latest_args: str | list[Any] = "not_set"
|
||||
self.args_set = threading.Event()
|
||||
|
||||
@@ -196,6 +196,8 @@ async def player_worker_decode(
|
||||
layout = "mono"
|
||||
elif len(frame) == 3:
|
||||
sample_rate, audio_array, layout = frame
|
||||
else:
|
||||
raise ValueError(f"frame must be of length 2 or 3, got: {len(frame)}")
|
||||
|
||||
logger.debug(
|
||||
"received array with shape %s sample rate %s layout %s",
|
||||
|
||||
@@ -118,3 +118,16 @@ convention = "google"
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"__init__.py" = ["E402"]
|
||||
"demo/talk_to_smolagents/app.py" = ["W291"]
|
||||
|
||||
[tool.pyright]
|
||||
include = ["backend/fastrtc"]
|
||||
exclude = [
|
||||
"**/__pycache__",
|
||||
"**/*.pyi",
|
||||
]
|
||||
|
||||
reportMissingImports = false
|
||||
reportMissingTypeStubs = false
|
||||
|
||||
pythonVersion = "3.10"
|
||||
pythonPlatform = "Linux"
|
||||
Reference in New Issue
Block a user