mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-04 17:39:23 +08:00
Introduce unit tests (#248)
* Proof-of-concept: unittests * Add pytest-asyncio dep * Import Body from stream * Add test for allow_extra_tracks * Cleanup decorators * add test to linting * fix ruff issues * Run formatter * fix * Dont test every python version --------- Co-authored-by: Marcus Valtonen Örnhag <marcus.valtonen.ornhag@ericsson.com> Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
committed by
GitHub
parent
0767030997
commit
2331079c0f
22
.github/workflows/tests.yml
vendored
22
.github/workflows/tests.yml
vendored
@@ -16,4 +16,24 @@ jobs:
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
ruff check .
|
ruff check .
|
||||||
ruff format --check --diff .
|
ruff format --check --diff .
|
||||||
pyright
|
pyright
|
||||||
|
test:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python:
|
||||||
|
- '3.10'
|
||||||
|
- '3.13'
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
python -m pip install -U pip
|
||||||
|
pip install .[dev]
|
||||||
|
python -m pytest -s test
|
||||||
|
shell: bash
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -15,6 +15,5 @@ demo/scratch
|
|||||||
.gradio
|
.gradio
|
||||||
.vscode
|
.vscode
|
||||||
.DS_Store
|
.DS_Store
|
||||||
test/
|
|
||||||
.venv*
|
.venv*
|
||||||
.env
|
.env
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -106,6 +106,7 @@ class WebRTCConnectionMixin:
|
|||||||
def clean_up(self, webrtc_id: str):
|
def clean_up(self, webrtc_id: str):
|
||||||
self.handlers.pop(webrtc_id, None)
|
self.handlers.pop(webrtc_id, None)
|
||||||
self.connection_timeouts.pop(webrtc_id, None)
|
self.connection_timeouts.pop(webrtc_id, None)
|
||||||
|
self.pcs.pop(webrtc_id, None)
|
||||||
connection = self.connections.pop(webrtc_id, [])
|
connection = self.connections.pop(webrtc_id, [])
|
||||||
for conn in connection:
|
for conn in connection:
|
||||||
if isinstance(conn, AudioCallback):
|
if isinstance(conn, AudioCallback):
|
||||||
@@ -229,7 +230,18 @@ class WebRTCConnectionMixin:
|
|||||||
content={"status": "failed", "meta": {"error": "connection_closed"}},
|
content={"status": "failed", "meta": {"error": "connection_closed"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(self.connections) >= cast(int, self.concurrency_limit):
|
if body["webrtc_id"] in self.connections:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"status": "failed",
|
||||||
|
"meta": {
|
||||||
|
"error": "connection_already_exists",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self.pcs) >= cast(int, self.concurrency_limit):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=200,
|
status_code=200,
|
||||||
content={
|
content={
|
||||||
|
|||||||
@@ -73,6 +73,11 @@
|
|||||||
"error",
|
"error",
|
||||||
`Too many concurrent connections. Please try again later!`,
|
`Too many concurrent connections. Please try again later!`,
|
||||||
);
|
);
|
||||||
|
} else if (
|
||||||
|
msg.status === "failed" &&
|
||||||
|
msg.meta?.error === "connection_already_exists"
|
||||||
|
) {
|
||||||
|
gradio.dispatch("error", "Connection already exists");
|
||||||
} else {
|
} else {
|
||||||
gradio.dispatch("error", "Unexpected server error");
|
gradio.dispatch("error", "Unexpected server error");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ issues = "https://github.com/freddyaboulton/gradio-webrtc/issues"
|
|||||||
Documentation = "https://freddyaboulton.github.io/gradio-webrtc/cookbook/"
|
Documentation = "https://freddyaboulton.github.io/gradio-webrtc/cookbook/"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = ["build", "twine"]
|
dev = ["build", "twine", "httpx", "pytest"]
|
||||||
vad = ["onnxruntime>=1.20.1"]
|
vad = ["onnxruntime>=1.20.1"]
|
||||||
tts = ["kokoro-onnx"]
|
tts = ["kokoro-onnx"]
|
||||||
stopword = ["fastrtc-moonshine-onnx", "onnxruntime>=1.20.1"]
|
stopword = ["fastrtc-moonshine-onnx", "onnxruntime>=1.20.1"]
|
||||||
@@ -82,8 +82,12 @@ artifacts = ["/backend/fastrtc/templates", "*.pyi"]
|
|||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["/backend/fastrtc"]
|
packages = ["/backend/fastrtc"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope="function"
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
src = ["demo", "backend/fastrtc"]
|
src = ["demo", "backend/fastrtc", "test"]
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
extend-exclude = ["demo/phonic_chat", "demo/nextjs_voice_chat"]
|
extend-exclude = ["demo/phonic_chat", "demo/nextjs_voice_chat"]
|
||||||
|
|
||||||
|
|||||||
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
308
test/test_webrtc_connection_mixin.py
Normal file
308
test/test_webrtc_connection_mixin.py
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import (
|
||||||
|
Literal,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from aiortc import (
|
||||||
|
AudioStreamTrack,
|
||||||
|
RTCPeerConnection,
|
||||||
|
RTCSessionDescription,
|
||||||
|
VideoStreamTrack,
|
||||||
|
)
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from fastrtc.stream import Body
|
||||||
|
from fastrtc.tracks import HandlerType
|
||||||
|
from fastrtc.webrtc_connection_mixin import WebRTCConnectionMixin
|
||||||
|
|
||||||
|
|
||||||
|
class MinimalTestStream(WebRTCConnectionMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
handler: HandlerType = lambda x: x,
|
||||||
|
*,
|
||||||
|
mode: Literal["send-receive", "receive", "send"] = "send-receive",
|
||||||
|
modality: Literal["video", "audio", "audio-video"] = "video",
|
||||||
|
concurrency_limit: int | None | Literal["default"] = "default",
|
||||||
|
time_limit: float | None = None,
|
||||||
|
allow_extra_tracks: bool = False,
|
||||||
|
):
|
||||||
|
WebRTCConnectionMixin.__init__(self)
|
||||||
|
self.mode = mode
|
||||||
|
self.modality = modality
|
||||||
|
self.event_handler = handler
|
||||||
|
self.concurrency_limit = cast(
|
||||||
|
(int),
|
||||||
|
1 if concurrency_limit in ["default", None] else concurrency_limit,
|
||||||
|
)
|
||||||
|
self.time_limit = time_limit
|
||||||
|
self.allow_extra_tracks = allow_extra_tracks
|
||||||
|
|
||||||
|
def mount(self, app: FastAPI, path: str = ""):
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
router = APIRouter(prefix=path)
|
||||||
|
router.post("/webrtc/offer")(self.offer)
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
async def offer(self, body: Body):
|
||||||
|
return await self.handle_offer(
|
||||||
|
body.model_dump(), set_outputs=self.set_additional_outputs(body.webrtc_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def test_client_and_stream(request):
|
||||||
|
app = FastAPI()
|
||||||
|
params = request.param if hasattr(request, "param") else {}
|
||||||
|
stream = MinimalTestStream(**params)
|
||||||
|
stream.mount(app)
|
||||||
|
test_client = TestClient(app)
|
||||||
|
yield test_client, stream
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebRTCConnectionMixin:
|
||||||
|
@staticmethod
|
||||||
|
async def setup_peer_connection(audio=False, video=False):
|
||||||
|
pc = RTCPeerConnection()
|
||||||
|
channel = pc.createDataChannel("test-data-channel")
|
||||||
|
if audio:
|
||||||
|
audio_track = AudioStreamTrack()
|
||||||
|
pc.addTrack(audio_track)
|
||||||
|
if video:
|
||||||
|
video_track = VideoStreamTrack()
|
||||||
|
pc.addTrack(video_track)
|
||||||
|
|
||||||
|
await pc.setLocalDescription(await pc.createOffer())
|
||||||
|
return pc, channel
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def send_offer(
|
||||||
|
pc,
|
||||||
|
client,
|
||||||
|
audio=False,
|
||||||
|
video=False,
|
||||||
|
webrtc_id="test_id",
|
||||||
|
response_code=200,
|
||||||
|
return_status_and_metadata=False,
|
||||||
|
):
|
||||||
|
body = {
|
||||||
|
"sdp": pc.localDescription.sdp,
|
||||||
|
"type": pc.localDescription.type,
|
||||||
|
}
|
||||||
|
if webrtc_id is not None:
|
||||||
|
body["webrtc_id"] = webrtc_id
|
||||||
|
response = client.post(
|
||||||
|
"/webrtc/offer",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json=body,
|
||||||
|
)
|
||||||
|
assert response.status_code == response_code
|
||||||
|
if not response_code == 200:
|
||||||
|
return
|
||||||
|
out = response.json()
|
||||||
|
if return_status_and_metadata:
|
||||||
|
assert "status" in out and "meta" in out
|
||||||
|
return out["status"], out["meta"]
|
||||||
|
assert "type" in out and out["type"] == "answer"
|
||||||
|
assert "webrtc-datachannel" in out["sdp"]
|
||||||
|
if audio:
|
||||||
|
assert "m=audio" in out["sdp"]
|
||||||
|
if video:
|
||||||
|
assert "m=video" in out["sdp"]
|
||||||
|
|
||||||
|
await pc.setRemoteDescription(RTCSessionDescription(out["sdp"], out["type"]))
|
||||||
|
|
||||||
|
# Allow data to stream
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def close_peer_connection(pc):
|
||||||
|
await pc.close()
|
||||||
|
assert pc.connectionState == "closed"
|
||||||
|
assert pc.iceConnectionState == "closed"
|
||||||
|
assert pc.signalingState == "closed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_client_and_stream", [{"modality": "audio"}], indirect=True
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("audio", [True, False])
|
||||||
|
async def test_successful_connection_audio(self, test_client_and_stream, audio):
|
||||||
|
test_client, stream = test_client_and_stream
|
||||||
|
pc, channel = await self.setup_peer_connection(audio)
|
||||||
|
await self.send_offer(pc, test_client, audio)
|
||||||
|
# TODO: Test stream? E.g., when no audio is not part of the offer...
|
||||||
|
await self.close_peer_connection(pc)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_client_and_stream", [{"modality": "video"}], indirect=True
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("video", [True, False])
|
||||||
|
async def test_successful_connection_video(self, test_client_and_stream, video):
|
||||||
|
test_client, stream = test_client_and_stream
|
||||||
|
pc, channel = await self.setup_peer_connection(video=video)
|
||||||
|
await self.send_offer(pc, test_client, video=video)
|
||||||
|
# TODO: Test stream? E.g., when no video is not part of the offer...
|
||||||
|
await self.close_peer_connection(pc)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_client_and_stream", [{"modality": "audio"}], indirect=True
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("audio", [True, False])
|
||||||
|
async def test_unsuccessful_connection_audio(self, test_client_and_stream, audio):
|
||||||
|
test_client, stream = test_client_and_stream
|
||||||
|
pc, channel = await self.setup_peer_connection(audio=audio, video=True)
|
||||||
|
with pytest.raises(ValueError, match=r"Unsupported track kind .*"):
|
||||||
|
await self.send_offer(pc, test_client, audio=audio, video=True)
|
||||||
|
await self.close_peer_connection(pc)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_client_and_stream", [{"modality": "video"}], indirect=True
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("video", [True, False])
|
||||||
|
async def test_unsuccessful_connection_video(self, test_client_and_stream, video):
|
||||||
|
test_client, stream = test_client_and_stream
|
||||||
|
pc, channel = await self.setup_peer_connection(audio=True, video=video)
|
||||||
|
with pytest.raises(ValueError, match=r"Unsupported track kind .*"):
|
||||||
|
await self.send_offer(pc, test_client, audio=True, video=video)
|
||||||
|
await self.close_peer_connection(pc)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsuccessful_webrtc_offer_no_webrtc_id(self, test_client_and_stream):
|
||||||
|
test_client, stream = test_client_and_stream
|
||||||
|
pc, channel = await self.setup_peer_connection()
|
||||||
|
await self.send_offer(
|
||||||
|
pc,
|
||||||
|
test_client,
|
||||||
|
webrtc_id=None,
|
||||||
|
response_code=422,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_client_and_stream", [{"modality": "dummy"}], indirect=True
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"audio, video",
|
||||||
|
[
|
||||||
|
(True, False),
|
||||||
|
(False, True),
|
||||||
|
(True, True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_incorrect_modality(self, test_client_and_stream, audio, video):
|
||||||
|
test_client, stream = test_client_and_stream
|
||||||
|
pc, channel = await self.setup_peer_connection(audio=audio, video=video)
|
||||||
|
with pytest.raises(ValueError, match=r"Modality must be .*"):
|
||||||
|
await self.send_offer(pc, test_client, audio=audio, video=video)
|
||||||
|
await self.close_peer_connection(pc)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_client_and_stream", [{"concurrency_limit": 1}], indirect=True
|
||||||
|
)
|
||||||
|
async def test_concurrency_limit_reached_two_peers(self, test_client_and_stream):
|
||||||
|
test_client, stream = test_client_and_stream
|
||||||
|
pc1, channel = await self.setup_peer_connection(video=True)
|
||||||
|
pc2, channel = await self.setup_peer_connection(video=True)
|
||||||
|
await self.send_offer(pc1, test_client)
|
||||||
|
status, metadata = await self.send_offer(
|
||||||
|
pc2, test_client, return_status_and_metadata=True
|
||||||
|
)
|
||||||
|
assert status == "failed"
|
||||||
|
assert metadata == {"error": "connection_already_exists"}
|
||||||
|
|
||||||
|
await self.close_peer_connection(pc1)
|
||||||
|
await self.close_peer_connection(pc2)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_client_and_stream", [{"concurrency_limit": 2}], indirect=True
|
||||||
|
)
|
||||||
|
async def test_concurrency_limit_reached_three_peers_same_id(
|
||||||
|
self, test_client_and_stream
|
||||||
|
):
|
||||||
|
test_client, stream = test_client_and_stream
|
||||||
|
pc1, channel = await self.setup_peer_connection(video=True)
|
||||||
|
pc2, channel = await self.setup_peer_connection(video=True)
|
||||||
|
pc3, channel = await self.setup_peer_connection(video=True)
|
||||||
|
await self.send_offer(pc1, test_client)
|
||||||
|
status, metadata = await self.send_offer(
|
||||||
|
pc2, test_client, return_status_and_metadata=True
|
||||||
|
)
|
||||||
|
assert status == "failed"
|
||||||
|
assert metadata == {"error": "connection_already_exists"}
|
||||||
|
status, metadata = await self.send_offer(
|
||||||
|
pc3, test_client, return_status_and_metadata=True
|
||||||
|
)
|
||||||
|
assert status == "failed"
|
||||||
|
assert metadata == {"error": "connection_already_exists"}
|
||||||
|
|
||||||
|
await self.close_peer_connection(pc1)
|
||||||
|
await self.close_peer_connection(pc2)
|
||||||
|
await self.close_peer_connection(pc3)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_client_and_stream", [{"concurrency_limit": 2}], indirect=True
|
||||||
|
)
|
||||||
|
async def test_concurrency_limit_reached_three_peers(self, test_client_and_stream):
|
||||||
|
test_client, stream = test_client_and_stream
|
||||||
|
pc1, channel = await self.setup_peer_connection(video=True)
|
||||||
|
pc2, channel = await self.setup_peer_connection(video=True)
|
||||||
|
pc3, channel = await self.setup_peer_connection(video=True)
|
||||||
|
await self.send_offer(pc1, test_client, webrtc_id="foo")
|
||||||
|
await self.send_offer(pc2, test_client, webrtc_id="bar")
|
||||||
|
status, metadata = await self.send_offer(
|
||||||
|
pc3, test_client, webrtc_id="baz", return_status_and_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert status == "failed"
|
||||||
|
assert metadata == {"error": "concurrency_limit_reached", "limit": 2}
|
||||||
|
|
||||||
|
await self.close_peer_connection(pc1)
|
||||||
|
await self.close_peer_connection(pc2)
|
||||||
|
await self.close_peer_connection(pc3)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_client_and_stream", [{"concurrency_limit": 1}], indirect=True
|
||||||
|
)
|
||||||
|
async def test_concurrency_limit_reached_peers_with_no_mediastreams(
|
||||||
|
self, test_client_and_stream
|
||||||
|
):
|
||||||
|
test_client, stream = test_client_and_stream
|
||||||
|
pc1, channel = await self.setup_peer_connection()
|
||||||
|
pc2, channel = await self.setup_peer_connection()
|
||||||
|
await self.send_offer(pc1, test_client, webrtc_id="foo")
|
||||||
|
status, metadata = await self.send_offer(
|
||||||
|
pc2, test_client, webrtc_id="bar", return_status_and_metadata=True
|
||||||
|
)
|
||||||
|
assert status == "failed"
|
||||||
|
assert metadata == {"error": "concurrency_limit_reached", "limit": 1}
|
||||||
|
|
||||||
|
await self.close_peer_connection(pc1)
|
||||||
|
await self.close_peer_connection(pc2)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_client_and_stream",
|
||||||
|
[
|
||||||
|
{"allow_extra_tracks": True, "modality": "audio"},
|
||||||
|
{"allow_extra_tracks": True, "modality": "video"},
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
async def test_successful_connection_allow_extra_tracks(
|
||||||
|
self, test_client_and_stream
|
||||||
|
):
|
||||||
|
test_client, stream = test_client_and_stream
|
||||||
|
pc, channel = await self.setup_peer_connection(audio=True, video=True)
|
||||||
|
await self.send_offer(pc, test_client, audio=True, video=True)
|
||||||
|
await self.close_peer_connection(pc)
|
||||||
Reference in New Issue
Block a user