mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 09:59:22 +08:00
Enforce modern typing (#258)
* Allow UP * Upgrade typing * test smolagents * Change to contextlib --------- Co-authored-by: Marcus Valtonen Örnhag <marcus.valtonen.ornhag@ericsson.com>
This commit is contained in:
committed by
GitHub
parent
a07e9439b6
commit
f70b27bd41
@@ -2,7 +2,6 @@ import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
@@ -102,7 +101,7 @@ class SileroVADModel:
|
||||
return h, c
|
||||
|
||||
@staticmethod
|
||||
def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray:
|
||||
def collect_chunks(audio: np.ndarray, chunks: list[AudioChunk]) -> np.ndarray:
|
||||
"""Collects and concatenates audio chunks."""
|
||||
if not chunks:
|
||||
return np.array([], dtype=np.float32)
|
||||
@@ -116,7 +115,7 @@ class SileroVADModel:
|
||||
audio: np.ndarray,
|
||||
vad_options: SileroVadOptions,
|
||||
**kwargs,
|
||||
) -> List[AudioChunk]:
|
||||
) -> list[AudioChunk]:
|
||||
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from collections.abc import AsyncGenerator, Callable, Generator
|
||||
from dataclasses import dataclass, field
|
||||
from logging import getLogger
|
||||
from threading import Event
|
||||
from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, Literal
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
cast,
|
||||
)
|
||||
@@ -29,8 +28,8 @@ curr_dir = Path(__file__).parent
|
||||
|
||||
|
||||
class Body(BaseModel):
|
||||
sdp: Optional[str] = None
|
||||
candidate: Optional[dict[str, Any]] = None
|
||||
sdp: str | None = None
|
||||
candidate: dict[str, Any] | None = None
|
||||
type: str
|
||||
webrtc_id: str
|
||||
|
||||
@@ -253,7 +252,7 @@ class Stream(WebRTCConnectionMixin):
|
||||
return wrapper
|
||||
|
||||
def _inject_startup_message(
|
||||
self, lifespan: Callable[[FastAPI], AsyncContextManager] | None = None
|
||||
self, lifespan: Callable[[FastAPI], AbstractAsyncContextManager] | None = None
|
||||
):
|
||||
"""
|
||||
Create a FastAPI lifespan context manager to print startup messages and check environment.
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import re
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import AsyncGenerator, Generator, Literal, Protocol, TypeVar
|
||||
from typing import Literal, Protocol, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
@@ -12,15 +12,12 @@ import time
|
||||
import traceback
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Generator,
|
||||
Literal,
|
||||
Tuple,
|
||||
TypeAlias,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -51,11 +48,11 @@ from fastrtc.utils import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VideoNDArray: TypeAlias = Union[
|
||||
np.ndarray[Any, np.dtype[np.uint8]],
|
||||
np.ndarray[Any, np.dtype[np.uint16]],
|
||||
np.ndarray[Any, np.dtype[np.float32]],
|
||||
]
|
||||
VideoNDArray: TypeAlias = (
|
||||
np.ndarray[Any, np.dtype[np.uint8]]
|
||||
| np.ndarray[Any, np.dtype[np.uint16]]
|
||||
| np.ndarray[Any, np.dtype[np.float32]]
|
||||
)
|
||||
|
||||
VideoEmitType = (
|
||||
VideoNDArray
|
||||
@@ -219,7 +216,7 @@ class VideoCallback(VideoStreamTrack):
|
||||
else:
|
||||
raise WebRTCError(str(e)) from e
|
||||
|
||||
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
|
||||
async def next_timestamp(self) -> tuple[int, fractions.Fraction]:
|
||||
"""Override to control frame rate"""
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
@@ -906,7 +903,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.latest_args = list(args)
|
||||
self.args_set.set()
|
||||
|
||||
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
|
||||
async def next_timestamp(self) -> tuple[int, fractions.Fraction]:
|
||||
"""Override to control frame rate"""
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
@@ -7,9 +7,10 @@ import json
|
||||
import logging
|
||||
import tempfile
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
|
||||
from typing import Any, Literal, Protocol, TypedDict, cast
|
||||
|
||||
import av
|
||||
import librosa
|
||||
@@ -136,7 +137,7 @@ def split_output(
|
||||
raise ValueError(
|
||||
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
||||
)
|
||||
if not isinstance(data[-1], (AdditionalOutputs, CloseStream)):
|
||||
if not isinstance(data[-1], AdditionalOutputs | CloseStream):
|
||||
raise ValueError(
|
||||
"The last element of the tuple must be an instance of AdditionalOutputs."
|
||||
)
|
||||
|
||||
@@ -3,15 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Concatenate,
|
||||
Iterable,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -6,10 +6,9 @@ import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
AsyncGenerator,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
|
||||
@@ -2,7 +2,8 @@ import asyncio
|
||||
import audioop
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Optional, cast
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import anyio
|
||||
import librosa
|
||||
@@ -57,9 +58,9 @@ class WebSocketHandler:
|
||||
):
|
||||
self.stream_handler = stream_handler
|
||||
self.stream_handler._clear_queue = self._clear_queue
|
||||
self.websocket: Optional[WebSocket] = None
|
||||
self._emit_task: Optional[asyncio.Task] = None
|
||||
self.stream_id: Optional[str] = None
|
||||
self.websocket: WebSocket | None = None
|
||||
self._emit_task: asyncio.Task | None = None
|
||||
self.stream_id: str | None = None
|
||||
self.set_additional_outputs_factory = additional_outputs_factory
|
||||
self.set_additional_outputs: Callable[[AdditionalOutputs], None]
|
||||
self.set_handler = set_handler
|
||||
@@ -67,8 +68,8 @@ class WebSocketHandler:
|
||||
self.clean_up = clean_up
|
||||
self.queue = asyncio.Queue()
|
||||
self.playing_durations = [] # Track durations of frames being played
|
||||
self._frame_cleanup_task: Optional[asyncio.Task] = None
|
||||
self._graceful_shutdown_task: Optional[asyncio.Task] = None
|
||||
self._frame_cleanup_task: asyncio.Task | None = None
|
||||
self._graceful_shutdown_task: asyncio.Task | None = None
|
||||
|
||||
def _clear_queue(self):
|
||||
old_queue = self.queue
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from functools import lru_cache
|
||||
from typing import Generator, Literal
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Literal
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
@@ -17,7 +18,7 @@ from numpy.typing import NDArray
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@cache
|
||||
def load_moonshine(
|
||||
model_name: Literal["moonshine/base", "moonshine/tiny"],
|
||||
) -> MoonshineOnnxModel:
|
||||
|
||||
@@ -3,7 +3,8 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
from typing import AsyncGenerator, Literal
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Literal
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastrtc import (
|
||||
@@ -22,7 +21,7 @@ stt_model = get_stt_model()
|
||||
tts_model = get_tts_model()
|
||||
|
||||
# Conversation state to maintain history
|
||||
conversation_state: List[Dict[str, str]] = []
|
||||
conversation_state: list[dict[str, str]] = []
|
||||
|
||||
# System prompt for agent
|
||||
system_prompt = """You are a helpful assistant that can helps with finding places to
|
||||
@@ -78,9 +77,7 @@ def process_response(audio):
|
||||
response_content = agent.run(input_text)
|
||||
|
||||
# Convert response to audio using TTS model
|
||||
for audio_chunk in tts_model.stream_tts_sync(response_content or ""):
|
||||
# Yield the audio chunk
|
||||
yield audio_chunk
|
||||
yield from tts_model.stream_tts_sync(response_content or "")
|
||||
|
||||
|
||||
stream = Stream(
|
||||
|
||||
@@ -105,12 +105,12 @@ select = [
|
||||
"W",
|
||||
"Q",
|
||||
"I",
|
||||
"UP",
|
||||
]
|
||||
|
||||
# These can be turned on when the framework is more mature (Too many errors right now)
|
||||
exclude = [
|
||||
"D",
|
||||
"UP"
|
||||
]
|
||||
|
||||
# Avoid enforcing line-length violations (`E501`)
|
||||
|
||||
Reference in New Issue
Block a user