mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 01:49:23 +08:00
Introduce unit tests (#248)
* Proof-of-concept: unittests * Add pytest-asyncio dep * Import Body from stream * Add test for allow_extra_tracks * Cleanup decorators * add test to linting * fix ruff issues * Run formatter * fix * Dont test every python version --------- Co-authored-by: Marcus Valtonen Örnhag <marcus.valtonen.ornhag@ericsson.com> Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
committed by
GitHub
parent
0767030997
commit
2331079c0f
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
308
test/test_webrtc_connection_mixin.py
Normal file
308
test/test_webrtc_connection_mixin.py
Normal file
@@ -0,0 +1,308 @@
|
||||
import asyncio
|
||||
from typing import (
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from aiortc import (
|
||||
AudioStreamTrack,
|
||||
RTCPeerConnection,
|
||||
RTCSessionDescription,
|
||||
VideoStreamTrack,
|
||||
)
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from fastrtc.stream import Body
|
||||
from fastrtc.tracks import HandlerType
|
||||
from fastrtc.webrtc_connection_mixin import WebRTCConnectionMixin
|
||||
|
||||
|
||||
class MinimalTestStream(WebRTCConnectionMixin):
|
||||
def __init__(
|
||||
self,
|
||||
handler: HandlerType = lambda x: x,
|
||||
*,
|
||||
mode: Literal["send-receive", "receive", "send"] = "send-receive",
|
||||
modality: Literal["video", "audio", "audio-video"] = "video",
|
||||
concurrency_limit: int | None | Literal["default"] = "default",
|
||||
time_limit: float | None = None,
|
||||
allow_extra_tracks: bool = False,
|
||||
):
|
||||
WebRTCConnectionMixin.__init__(self)
|
||||
self.mode = mode
|
||||
self.modality = modality
|
||||
self.event_handler = handler
|
||||
self.concurrency_limit = cast(
|
||||
(int),
|
||||
1 if concurrency_limit in ["default", None] else concurrency_limit,
|
||||
)
|
||||
self.time_limit = time_limit
|
||||
self.allow_extra_tracks = allow_extra_tracks
|
||||
|
||||
def mount(self, app: FastAPI, path: str = ""):
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(prefix=path)
|
||||
router.post("/webrtc/offer")(self.offer)
|
||||
app.include_router(router)
|
||||
|
||||
async def offer(self, body: Body):
|
||||
return await self.handle_offer(
|
||||
body.model_dump(), set_outputs=self.set_additional_outputs(body.webrtc_id)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_client_and_stream(request):
|
||||
app = FastAPI()
|
||||
params = request.param if hasattr(request, "param") else {}
|
||||
stream = MinimalTestStream(**params)
|
||||
stream.mount(app)
|
||||
test_client = TestClient(app)
|
||||
yield test_client, stream
|
||||
|
||||
|
||||
class TestWebRTCConnectionMixin:
|
||||
@staticmethod
|
||||
async def setup_peer_connection(audio=False, video=False):
|
||||
pc = RTCPeerConnection()
|
||||
channel = pc.createDataChannel("test-data-channel")
|
||||
if audio:
|
||||
audio_track = AudioStreamTrack()
|
||||
pc.addTrack(audio_track)
|
||||
if video:
|
||||
video_track = VideoStreamTrack()
|
||||
pc.addTrack(video_track)
|
||||
|
||||
await pc.setLocalDescription(await pc.createOffer())
|
||||
return pc, channel
|
||||
|
||||
@staticmethod
|
||||
async def send_offer(
|
||||
pc,
|
||||
client,
|
||||
audio=False,
|
||||
video=False,
|
||||
webrtc_id="test_id",
|
||||
response_code=200,
|
||||
return_status_and_metadata=False,
|
||||
):
|
||||
body = {
|
||||
"sdp": pc.localDescription.sdp,
|
||||
"type": pc.localDescription.type,
|
||||
}
|
||||
if webrtc_id is not None:
|
||||
body["webrtc_id"] = webrtc_id
|
||||
response = client.post(
|
||||
"/webrtc/offer",
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=body,
|
||||
)
|
||||
assert response.status_code == response_code
|
||||
if not response_code == 200:
|
||||
return
|
||||
out = response.json()
|
||||
if return_status_and_metadata:
|
||||
assert "status" in out and "meta" in out
|
||||
return out["status"], out["meta"]
|
||||
assert "type" in out and out["type"] == "answer"
|
||||
assert "webrtc-datachannel" in out["sdp"]
|
||||
if audio:
|
||||
assert "m=audio" in out["sdp"]
|
||||
if video:
|
||||
assert "m=video" in out["sdp"]
|
||||
|
||||
await pc.setRemoteDescription(RTCSessionDescription(out["sdp"], out["type"]))
|
||||
|
||||
# Allow data to stream
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
@staticmethod
|
||||
async def close_peer_connection(pc):
|
||||
await pc.close()
|
||||
assert pc.connectionState == "closed"
|
||||
assert pc.iceConnectionState == "closed"
|
||||
assert pc.signalingState == "closed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_client_and_stream", [{"modality": "audio"}], indirect=True
|
||||
)
|
||||
@pytest.mark.parametrize("audio", [True, False])
|
||||
async def test_successful_connection_audio(self, test_client_and_stream, audio):
|
||||
test_client, stream = test_client_and_stream
|
||||
pc, channel = await self.setup_peer_connection(audio)
|
||||
await self.send_offer(pc, test_client, audio)
|
||||
# TODO: Test stream? E.g., when no audio is not part of the offer...
|
||||
await self.close_peer_connection(pc)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_client_and_stream", [{"modality": "video"}], indirect=True
|
||||
)
|
||||
@pytest.mark.parametrize("video", [True, False])
|
||||
async def test_successful_connection_video(self, test_client_and_stream, video):
|
||||
test_client, stream = test_client_and_stream
|
||||
pc, channel = await self.setup_peer_connection(video=video)
|
||||
await self.send_offer(pc, test_client, video=video)
|
||||
# TODO: Test stream? E.g., when no video is not part of the offer...
|
||||
await self.close_peer_connection(pc)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_client_and_stream", [{"modality": "audio"}], indirect=True
|
||||
)
|
||||
@pytest.mark.parametrize("audio", [True, False])
|
||||
async def test_unsuccessful_connection_audio(self, test_client_and_stream, audio):
|
||||
test_client, stream = test_client_and_stream
|
||||
pc, channel = await self.setup_peer_connection(audio=audio, video=True)
|
||||
with pytest.raises(ValueError, match=r"Unsupported track kind .*"):
|
||||
await self.send_offer(pc, test_client, audio=audio, video=True)
|
||||
await self.close_peer_connection(pc)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_client_and_stream", [{"modality": "video"}], indirect=True
|
||||
)
|
||||
@pytest.mark.parametrize("video", [True, False])
|
||||
async def test_unsuccessful_connection_video(self, test_client_and_stream, video):
|
||||
test_client, stream = test_client_and_stream
|
||||
pc, channel = await self.setup_peer_connection(audio=True, video=video)
|
||||
with pytest.raises(ValueError, match=r"Unsupported track kind .*"):
|
||||
await self.send_offer(pc, test_client, audio=True, video=video)
|
||||
await self.close_peer_connection(pc)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsuccessful_webrtc_offer_no_webrtc_id(self, test_client_and_stream):
|
||||
test_client, stream = test_client_and_stream
|
||||
pc, channel = await self.setup_peer_connection()
|
||||
await self.send_offer(
|
||||
pc,
|
||||
test_client,
|
||||
webrtc_id=None,
|
||||
response_code=422,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_client_and_stream", [{"modality": "dummy"}], indirect=True
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"audio, video",
|
||||
[
|
||||
(True, False),
|
||||
(False, True),
|
||||
(True, True),
|
||||
],
|
||||
)
|
||||
async def test_incorrect_modality(self, test_client_and_stream, audio, video):
|
||||
test_client, stream = test_client_and_stream
|
||||
pc, channel = await self.setup_peer_connection(audio=audio, video=video)
|
||||
with pytest.raises(ValueError, match=r"Modality must be .*"):
|
||||
await self.send_offer(pc, test_client, audio=audio, video=video)
|
||||
await self.close_peer_connection(pc)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_client_and_stream", [{"concurrency_limit": 1}], indirect=True
|
||||
)
|
||||
async def test_concurrency_limit_reached_two_peers(self, test_client_and_stream):
|
||||
test_client, stream = test_client_and_stream
|
||||
pc1, channel = await self.setup_peer_connection(video=True)
|
||||
pc2, channel = await self.setup_peer_connection(video=True)
|
||||
await self.send_offer(pc1, test_client)
|
||||
status, metadata = await self.send_offer(
|
||||
pc2, test_client, return_status_and_metadata=True
|
||||
)
|
||||
assert status == "failed"
|
||||
assert metadata == {"error": "connection_already_exists"}
|
||||
|
||||
await self.close_peer_connection(pc1)
|
||||
await self.close_peer_connection(pc2)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_client_and_stream", [{"concurrency_limit": 2}], indirect=True
|
||||
)
|
||||
async def test_concurrency_limit_reached_three_peers_same_id(
|
||||
self, test_client_and_stream
|
||||
):
|
||||
test_client, stream = test_client_and_stream
|
||||
pc1, channel = await self.setup_peer_connection(video=True)
|
||||
pc2, channel = await self.setup_peer_connection(video=True)
|
||||
pc3, channel = await self.setup_peer_connection(video=True)
|
||||
await self.send_offer(pc1, test_client)
|
||||
status, metadata = await self.send_offer(
|
||||
pc2, test_client, return_status_and_metadata=True
|
||||
)
|
||||
assert status == "failed"
|
||||
assert metadata == {"error": "connection_already_exists"}
|
||||
status, metadata = await self.send_offer(
|
||||
pc3, test_client, return_status_and_metadata=True
|
||||
)
|
||||
assert status == "failed"
|
||||
assert metadata == {"error": "connection_already_exists"}
|
||||
|
||||
await self.close_peer_connection(pc1)
|
||||
await self.close_peer_connection(pc2)
|
||||
await self.close_peer_connection(pc3)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_client_and_stream", [{"concurrency_limit": 2}], indirect=True
|
||||
)
|
||||
async def test_concurrency_limit_reached_three_peers(self, test_client_and_stream):
|
||||
test_client, stream = test_client_and_stream
|
||||
pc1, channel = await self.setup_peer_connection(video=True)
|
||||
pc2, channel = await self.setup_peer_connection(video=True)
|
||||
pc3, channel = await self.setup_peer_connection(video=True)
|
||||
await self.send_offer(pc1, test_client, webrtc_id="foo")
|
||||
await self.send_offer(pc2, test_client, webrtc_id="bar")
|
||||
status, metadata = await self.send_offer(
|
||||
pc3, test_client, webrtc_id="baz", return_status_and_metadata=True
|
||||
)
|
||||
|
||||
assert status == "failed"
|
||||
assert metadata == {"error": "concurrency_limit_reached", "limit": 2}
|
||||
|
||||
await self.close_peer_connection(pc1)
|
||||
await self.close_peer_connection(pc2)
|
||||
await self.close_peer_connection(pc3)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_client_and_stream", [{"concurrency_limit": 1}], indirect=True
|
||||
)
|
||||
async def test_concurrency_limit_reached_peers_with_no_mediastreams(
|
||||
self, test_client_and_stream
|
||||
):
|
||||
test_client, stream = test_client_and_stream
|
||||
pc1, channel = await self.setup_peer_connection()
|
||||
pc2, channel = await self.setup_peer_connection()
|
||||
await self.send_offer(pc1, test_client, webrtc_id="foo")
|
||||
status, metadata = await self.send_offer(
|
||||
pc2, test_client, webrtc_id="bar", return_status_and_metadata=True
|
||||
)
|
||||
assert status == "failed"
|
||||
assert metadata == {"error": "concurrency_limit_reached", "limit": 1}
|
||||
|
||||
await self.close_peer_connection(pc1)
|
||||
await self.close_peer_connection(pc2)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_client_and_stream",
|
||||
[
|
||||
{"allow_extra_tracks": True, "modality": "audio"},
|
||||
{"allow_extra_tracks": True, "modality": "video"},
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_successful_connection_allow_extra_tracks(
|
||||
self, test_client_and_stream
|
||||
):
|
||||
test_client, stream = test_client_and_stream
|
||||
pc, channel = await self.setup_peer_connection(audio=True, video=True)
|
||||
await self.send_offer(pc, test_client, audio=True, video=True)
|
||||
await self.close_peer_connection(pc)
|
||||
Reference in New Issue
Block a user