mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Enforce modern typing (#258)
* Allow UP * Upgrade typing * test smolagents * Change to contextlib --------- Co-authored-by: Marcus Valtonen Örnhag <marcus.valtonen.ornhag@ericsson.com>
This commit is contained in:
committed by
GitHub
parent
a07e9439b6
commit
f70b27bd41
@@ -2,7 +2,6 @@ import logging
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -102,7 +101,7 @@ class SileroVADModel:
|
|||||||
return h, c
|
return h, c
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def collect_chunks(audio: np.ndarray, chunks: List[AudioChunk]) -> np.ndarray:
|
def collect_chunks(audio: np.ndarray, chunks: list[AudioChunk]) -> np.ndarray:
|
||||||
"""Collects and concatenates audio chunks."""
|
"""Collects and concatenates audio chunks."""
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return np.array([], dtype=np.float32)
|
return np.array([], dtype=np.float32)
|
||||||
@@ -116,7 +115,7 @@ class SileroVADModel:
|
|||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
vad_options: SileroVadOptions,
|
vad_options: SileroVadOptions,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[AudioChunk]:
|
) -> list[AudioChunk]:
|
||||||
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
from collections.abc import AsyncGenerator, Callable, Generator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from typing import Any, AsyncGenerator, Callable, Generator, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Literal
|
from collections.abc import Callable
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import AbstractAsyncContextManager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncContextManager,
|
|
||||||
Callable,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
TypedDict,
|
TypedDict,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@@ -29,8 +28,8 @@ curr_dir = Path(__file__).parent
|
|||||||
|
|
||||||
|
|
||||||
class Body(BaseModel):
|
class Body(BaseModel):
|
||||||
sdp: Optional[str] = None
|
sdp: str | None = None
|
||||||
candidate: Optional[dict[str, Any]] = None
|
candidate: dict[str, Any] | None = None
|
||||||
type: str
|
type: str
|
||||||
webrtc_id: str
|
webrtc_id: str
|
||||||
|
|
||||||
@@ -253,7 +252,7 @@ class Stream(WebRTCConnectionMixin):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def _inject_startup_message(
|
def _inject_startup_message(
|
||||||
self, lifespan: Callable[[FastAPI], AsyncContextManager] | None = None
|
self, lifespan: Callable[[FastAPI], AbstractAsyncContextManager] | None = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a FastAPI lifespan context manager to print startup messages and check environment.
|
Create a FastAPI lifespan context manager to print startup messages and check environment.
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
from collections.abc import AsyncGenerator, Generator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import AsyncGenerator, Generator, Literal, Protocol, TypeVar
|
from typing import Literal, Protocol, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|||||||
@@ -12,15 +12,12 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Generator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Generator,
|
|
||||||
Literal,
|
Literal,
|
||||||
Tuple,
|
|
||||||
TypeAlias,
|
TypeAlias,
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -51,11 +48,11 @@ from fastrtc.utils import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
VideoNDArray: TypeAlias = Union[
|
VideoNDArray: TypeAlias = (
|
||||||
np.ndarray[Any, np.dtype[np.uint8]],
|
np.ndarray[Any, np.dtype[np.uint8]]
|
||||||
np.ndarray[Any, np.dtype[np.uint16]],
|
| np.ndarray[Any, np.dtype[np.uint16]]
|
||||||
np.ndarray[Any, np.dtype[np.float32]],
|
| np.ndarray[Any, np.dtype[np.float32]]
|
||||||
]
|
)
|
||||||
|
|
||||||
VideoEmitType = (
|
VideoEmitType = (
|
||||||
VideoNDArray
|
VideoNDArray
|
||||||
@@ -219,7 +216,7 @@ class VideoCallback(VideoStreamTrack):
|
|||||||
else:
|
else:
|
||||||
raise WebRTCError(str(e)) from e
|
raise WebRTCError(str(e)) from e
|
||||||
|
|
||||||
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
|
async def next_timestamp(self) -> tuple[int, fractions.Fraction]:
|
||||||
"""Override to control frame rate"""
|
"""Override to control frame rate"""
|
||||||
if self.readyState != "live":
|
if self.readyState != "live":
|
||||||
raise MediaStreamError
|
raise MediaStreamError
|
||||||
@@ -906,7 +903,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
|||||||
self.latest_args = list(args)
|
self.latest_args = list(args)
|
||||||
self.args_set.set()
|
self.args_set.set()
|
||||||
|
|
||||||
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
|
async def next_timestamp(self) -> tuple[int, fractions.Fraction]:
|
||||||
"""Override to control frame rate"""
|
"""Override to control frame rate"""
|
||||||
if self.readyState != "live":
|
if self.readyState != "live":
|
||||||
raise MediaStreamError
|
raise MediaStreamError
|
||||||
|
|||||||
@@ -7,9 +7,10 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
|
from collections.abc import Callable
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Literal, Protocol, TypedDict, cast
|
from typing import Any, Literal, Protocol, TypedDict, cast
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import librosa
|
import librosa
|
||||||
@@ -136,7 +137,7 @@ def split_output(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
"The tuple must have exactly two elements: the data and an instance of AdditionalOutputs."
|
||||||
)
|
)
|
||||||
if not isinstance(data[-1], (AdditionalOutputs, CloseStream)):
|
if not isinstance(data[-1], AdditionalOutputs | CloseStream):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The last element of the tuple must be an instance of AdditionalOutputs."
|
"The last element of the tuple must be an instance of AdditionalOutputs."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,15 +3,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Concatenate,
|
Concatenate,
|
||||||
Iterable,
|
|
||||||
Literal,
|
Literal,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,10 +6,9 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (
|
from typing import (
|
||||||
AsyncGenerator,
|
|
||||||
Literal,
|
Literal,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ import asyncio
|
|||||||
import audioop
|
import audioop
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Awaitable, Callable, Optional, cast
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import librosa
|
import librosa
|
||||||
@@ -57,9 +58,9 @@ class WebSocketHandler:
|
|||||||
):
|
):
|
||||||
self.stream_handler = stream_handler
|
self.stream_handler = stream_handler
|
||||||
self.stream_handler._clear_queue = self._clear_queue
|
self.stream_handler._clear_queue = self._clear_queue
|
||||||
self.websocket: Optional[WebSocket] = None
|
self.websocket: WebSocket | None = None
|
||||||
self._emit_task: Optional[asyncio.Task] = None
|
self._emit_task: asyncio.Task | None = None
|
||||||
self.stream_id: Optional[str] = None
|
self.stream_id: str | None = None
|
||||||
self.set_additional_outputs_factory = additional_outputs_factory
|
self.set_additional_outputs_factory = additional_outputs_factory
|
||||||
self.set_additional_outputs: Callable[[AdditionalOutputs], None]
|
self.set_additional_outputs: Callable[[AdditionalOutputs], None]
|
||||||
self.set_handler = set_handler
|
self.set_handler = set_handler
|
||||||
@@ -67,8 +68,8 @@ class WebSocketHandler:
|
|||||||
self.clean_up = clean_up
|
self.clean_up = clean_up
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
self.playing_durations = [] # Track durations of frames being played
|
self.playing_durations = [] # Track durations of frames being played
|
||||||
self._frame_cleanup_task: Optional[asyncio.Task] = None
|
self._frame_cleanup_task: asyncio.Task | None = None
|
||||||
self._graceful_shutdown_task: Optional[asyncio.Task] = None
|
self._graceful_shutdown_task: asyncio.Task | None = None
|
||||||
|
|
||||||
def _clear_queue(self):
|
def _clear_queue(self):
|
||||||
old_queue = self.queue
|
old_queue = self.queue
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncGenerator
|
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from functools import lru_cache
|
from collections.abc import Generator
|
||||||
from typing import Generator, Literal
|
from functools import cache
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -17,7 +18,7 @@ from numpy.typing import NDArray
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@cache
|
||||||
def load_moonshine(
|
def load_moonshine(
|
||||||
model_name: Literal["moonshine/base", "moonshine/tiny"],
|
model_name: Literal["moonshine/base", "moonshine/tiny"],
|
||||||
) -> MoonshineOnnxModel:
|
) -> MoonshineOnnxModel:
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import AsyncGenerator, Literal
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastrtc import (
|
from fastrtc import (
|
||||||
@@ -22,7 +21,7 @@ stt_model = get_stt_model()
|
|||||||
tts_model = get_tts_model()
|
tts_model = get_tts_model()
|
||||||
|
|
||||||
# Conversation state to maintain history
|
# Conversation state to maintain history
|
||||||
conversation_state: List[Dict[str, str]] = []
|
conversation_state: list[dict[str, str]] = []
|
||||||
|
|
||||||
# System prompt for agent
|
# System prompt for agent
|
||||||
system_prompt = """You are a helpful assistant that can helps with finding places to
|
system_prompt = """You are a helpful assistant that can helps with finding places to
|
||||||
@@ -78,9 +77,7 @@ def process_response(audio):
|
|||||||
response_content = agent.run(input_text)
|
response_content = agent.run(input_text)
|
||||||
|
|
||||||
# Convert response to audio using TTS model
|
# Convert response to audio using TTS model
|
||||||
for audio_chunk in tts_model.stream_tts_sync(response_content or ""):
|
yield from tts_model.stream_tts_sync(response_content or "")
|
||||||
# Yield the audio chunk
|
|
||||||
yield audio_chunk
|
|
||||||
|
|
||||||
|
|
||||||
stream = Stream(
|
stream = Stream(
|
||||||
|
|||||||
@@ -105,12 +105,12 @@ select = [
|
|||||||
"W",
|
"W",
|
||||||
"Q",
|
"Q",
|
||||||
"I",
|
"I",
|
||||||
|
"UP",
|
||||||
]
|
]
|
||||||
|
|
||||||
# These can be turned on when the framework is more mature (Too many errors right now)
|
# These can be turned on when the framework is more mature (Too many errors right now)
|
||||||
exclude = [
|
exclude = [
|
||||||
"D",
|
"D",
|
||||||
"UP"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Avoid enforcing line-length violations (`E501`)
|
# Avoid enforcing line-length violations (`E501`)
|
||||||
|
|||||||
Reference in New Issue
Block a user