mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Introduce static type checking with pyright (#255)
This commit is contained in:
committed by
GitHub
parent
d7995b8116
commit
0767030997
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -12,6 +12,8 @@ jobs:
|
|||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: Run linters
|
- name: Run linters
|
||||||
run: |
|
run: |
|
||||||
pip install ruff
|
pip install ruff pyright
|
||||||
|
pip install -e .
|
||||||
ruff check .
|
ruff check .
|
||||||
ruff format --check --diff .
|
ruff format --check --diff .
|
||||||
|
pyright
|
||||||
@@ -506,23 +506,22 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
for component in additional_input_components:
|
for component in additional_input_components:
|
||||||
component.render()
|
component.render()
|
||||||
button = gr.Button("Start Stream", variant="primary")
|
button = gr.Button("Start Stream", variant="primary")
|
||||||
if additional_output_components:
|
with gr.Column():
|
||||||
with gr.Column():
|
output_video = WebRTC(
|
||||||
output_video = WebRTC(
|
label="Audio Stream",
|
||||||
label="Audio Stream",
|
rtc_configuration=self.rtc_configuration,
|
||||||
rtc_configuration=self.rtc_configuration,
|
track_constraints=self.track_constraints,
|
||||||
track_constraints=self.track_constraints,
|
mode="receive",
|
||||||
mode="receive",
|
modality="audio",
|
||||||
modality="audio",
|
icon=ui_args.get("icon"),
|
||||||
icon=ui_args.get("icon"),
|
icon_button_color=ui_args.get("icon_button_color"),
|
||||||
icon_button_color=ui_args.get("icon_button_color"),
|
pulse_color=ui_args.get("pulse_color"),
|
||||||
pulse_color=ui_args.get("pulse_color"),
|
icon_radius=ui_args.get("icon_radius"),
|
||||||
icon_radius=ui_args.get("icon_radius"),
|
)
|
||||||
)
|
self.webrtc_component = output_video
|
||||||
self.webrtc_component = output_video
|
for component in additional_output_components:
|
||||||
for component in additional_output_components:
|
if component not in same_components:
|
||||||
if component not in same_components:
|
component.render()
|
||||||
component.render()
|
|
||||||
output_video.stream(
|
output_video.stream(
|
||||||
fn=self.event_handler,
|
fn=self.event_handler,
|
||||||
inputs=self.additional_input_components,
|
inputs=self.additional_input_components,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import AsyncGenerator, Generator, Literal, Protocol
|
from typing import AsyncGenerator, Generator, Literal, Protocol, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@@ -13,17 +13,20 @@ class TTSOptions:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TTSModel(Protocol):
|
T = TypeVar("T", bound=TTSOptions, contravariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TTSModel(Protocol[T]):
|
||||||
def tts(
|
def tts(
|
||||||
self, text: str, options: TTSOptions | None = None
|
self, text: str, options: T | None = None
|
||||||
) -> tuple[int, NDArray[np.float32]]: ...
|
) -> tuple[int, NDArray[np.float32]]: ...
|
||||||
|
|
||||||
async def stream_tts(
|
def stream_tts(
|
||||||
self, text: str, options: TTSOptions | None = None
|
self, text: str, options: T | None = None
|
||||||
) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]: ...
|
) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]: ...
|
||||||
|
|
||||||
def stream_tts_sync(
|
def stream_tts_sync(
|
||||||
self, text: str, options: TTSOptions | None = None
|
self, text: str, options: T | None = None
|
||||||
) -> Generator[tuple[int, NDArray[np.float32]], None, None]: ...
|
) -> Generator[tuple[int, NDArray[np.float32]], None, None]: ...
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -976,7 +976,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.generator: Generator[Any, None, Any] | None = None
|
self.generator: Generator[Any, None, Any] | None = None
|
||||||
self.event_handler = event_handler
|
self.event_handler = event_handler
|
||||||
self.event_handler._clear_queue = self.clear_queue
|
self.event_handler._clear_queue = self.clear_queue # pyright: ignore
|
||||||
self.current_timestamp = 0
|
self.current_timestamp = 0
|
||||||
self.latest_args: str | list[Any] = "not_set"
|
self.latest_args: str | list[Any] = "not_set"
|
||||||
self.args_set = threading.Event()
|
self.args_set = threading.Event()
|
||||||
|
|||||||
@@ -196,6 +196,8 @@ async def player_worker_decode(
|
|||||||
layout = "mono"
|
layout = "mono"
|
||||||
elif len(frame) == 3:
|
elif len(frame) == 3:
|
||||||
sample_rate, audio_array, layout = frame
|
sample_rate, audio_array, layout = frame
|
||||||
|
else:
|
||||||
|
raise ValueError(f"frame must be of length 2 or 3, got: {len(frame)}")
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"received array with shape %s sample rate %s layout %s",
|
"received array with shape %s sample rate %s layout %s",
|
||||||
|
|||||||
@@ -118,3 +118,16 @@ convention = "google"
|
|||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"__init__.py" = ["E402"]
|
"__init__.py" = ["E402"]
|
||||||
"demo/talk_to_smolagents/app.py" = ["W291"]
|
"demo/talk_to_smolagents/app.py" = ["W291"]
|
||||||
|
|
||||||
|
[tool.pyright]
|
||||||
|
include = ["backend/fastrtc"]
|
||||||
|
exclude = [
|
||||||
|
"**/__pycache__",
|
||||||
|
"**/*.pyi",
|
||||||
|
]
|
||||||
|
|
||||||
|
reportMissingImports = false
|
||||||
|
reportMissingTypeStubs = false
|
||||||
|
|
||||||
|
pythonVersion = "3.10"
|
||||||
|
pythonPlatform = "Linux"
|
||||||
Reference in New Issue
Block a user