Introduce unit tests (#248)

* Proof-of-concept: unittests

* Add pytest-asyncio dep

* Import Body from stream

* Add test for allow_extra_tracks

* Cleanup decorators

* add test to linting

* fix ruff issues

* Run formatter

* fix

* Dont test every python version

---------

Co-authored-by: Marcus Valtonen Örnhag <marcus.valtonen.ornhag@ericsson.com>
Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
Marcus Valtonen Örnhag
2025-04-07 23:35:25 +02:00
committed by GitHub
parent 0767030997
commit 2331079c0f
8 changed files with 1246 additions and 898 deletions

0
test/__init__.py Normal file
View File

View File

@@ -0,0 +1,308 @@
import asyncio
from typing import (
Literal,
cast,
)
import pytest
from aiortc import (
AudioStreamTrack,
RTCPeerConnection,
RTCSessionDescription,
VideoStreamTrack,
)
from fastapi import FastAPI
from fastapi.testclient import TestClient
from fastrtc.stream import Body
from fastrtc.tracks import HandlerType
from fastrtc.webrtc_connection_mixin import WebRTCConnectionMixin
class MinimalTestStream(WebRTCConnectionMixin):
def __init__(
self,
handler: HandlerType = lambda x: x,
*,
mode: Literal["send-receive", "receive", "send"] = "send-receive",
modality: Literal["video", "audio", "audio-video"] = "video",
concurrency_limit: int | None | Literal["default"] = "default",
time_limit: float | None = None,
allow_extra_tracks: bool = False,
):
WebRTCConnectionMixin.__init__(self)
self.mode = mode
self.modality = modality
self.event_handler = handler
self.concurrency_limit = cast(
(int),
1 if concurrency_limit in ["default", None] else concurrency_limit,
)
self.time_limit = time_limit
self.allow_extra_tracks = allow_extra_tracks
def mount(self, app: FastAPI, path: str = ""):
from fastapi import APIRouter
router = APIRouter(prefix=path)
router.post("/webrtc/offer")(self.offer)
app.include_router(router)
async def offer(self, body: Body):
return await self.handle_offer(
body.model_dump(), set_outputs=self.set_additional_outputs(body.webrtc_id)
)
@pytest.fixture()
def test_client_and_stream(request):
app = FastAPI()
params = request.param if hasattr(request, "param") else {}
stream = MinimalTestStream(**params)
stream.mount(app)
test_client = TestClient(app)
yield test_client, stream
class TestWebRTCConnectionMixin:
@staticmethod
async def setup_peer_connection(audio=False, video=False):
pc = RTCPeerConnection()
channel = pc.createDataChannel("test-data-channel")
if audio:
audio_track = AudioStreamTrack()
pc.addTrack(audio_track)
if video:
video_track = VideoStreamTrack()
pc.addTrack(video_track)
await pc.setLocalDescription(await pc.createOffer())
return pc, channel
@staticmethod
async def send_offer(
pc,
client,
audio=False,
video=False,
webrtc_id="test_id",
response_code=200,
return_status_and_metadata=False,
):
body = {
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type,
}
if webrtc_id is not None:
body["webrtc_id"] = webrtc_id
response = client.post(
"/webrtc/offer",
headers={"Content-Type": "application/json"},
json=body,
)
assert response.status_code == response_code
if not response_code == 200:
return
out = response.json()
if return_status_and_metadata:
assert "status" in out and "meta" in out
return out["status"], out["meta"]
assert "type" in out and out["type"] == "answer"
assert "webrtc-datachannel" in out["sdp"]
if audio:
assert "m=audio" in out["sdp"]
if video:
assert "m=video" in out["sdp"]
await pc.setRemoteDescription(RTCSessionDescription(out["sdp"], out["type"]))
# Allow data to stream
await asyncio.sleep(0.5)
@staticmethod
async def close_peer_connection(pc):
await pc.close()
assert pc.connectionState == "closed"
assert pc.iceConnectionState == "closed"
assert pc.signalingState == "closed"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_client_and_stream", [{"modality": "audio"}], indirect=True
)
@pytest.mark.parametrize("audio", [True, False])
async def test_successful_connection_audio(self, test_client_and_stream, audio):
test_client, stream = test_client_and_stream
pc, channel = await self.setup_peer_connection(audio)
await self.send_offer(pc, test_client, audio)
# TODO: Test stream? E.g., when no audio is not part of the offer...
await self.close_peer_connection(pc)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_client_and_stream", [{"modality": "video"}], indirect=True
)
@pytest.mark.parametrize("video", [True, False])
async def test_successful_connection_video(self, test_client_and_stream, video):
test_client, stream = test_client_and_stream
pc, channel = await self.setup_peer_connection(video=video)
await self.send_offer(pc, test_client, video=video)
# TODO: Test stream? E.g., when no video is not part of the offer...
await self.close_peer_connection(pc)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_client_and_stream", [{"modality": "audio"}], indirect=True
)
@pytest.mark.parametrize("audio", [True, False])
async def test_unsuccessful_connection_audio(self, test_client_and_stream, audio):
test_client, stream = test_client_and_stream
pc, channel = await self.setup_peer_connection(audio=audio, video=True)
with pytest.raises(ValueError, match=r"Unsupported track kind .*"):
await self.send_offer(pc, test_client, audio=audio, video=True)
await self.close_peer_connection(pc)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_client_and_stream", [{"modality": "video"}], indirect=True
)
@pytest.mark.parametrize("video", [True, False])
async def test_unsuccessful_connection_video(self, test_client_and_stream, video):
test_client, stream = test_client_and_stream
pc, channel = await self.setup_peer_connection(audio=True, video=video)
with pytest.raises(ValueError, match=r"Unsupported track kind .*"):
await self.send_offer(pc, test_client, audio=True, video=video)
await self.close_peer_connection(pc)
@pytest.mark.asyncio
async def test_unsuccessful_webrtc_offer_no_webrtc_id(self, test_client_and_stream):
test_client, stream = test_client_and_stream
pc, channel = await self.setup_peer_connection()
await self.send_offer(
pc,
test_client,
webrtc_id=None,
response_code=422,
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_client_and_stream", [{"modality": "dummy"}], indirect=True
)
@pytest.mark.parametrize(
"audio, video",
[
(True, False),
(False, True),
(True, True),
],
)
async def test_incorrect_modality(self, test_client_and_stream, audio, video):
test_client, stream = test_client_and_stream
pc, channel = await self.setup_peer_connection(audio=audio, video=video)
with pytest.raises(ValueError, match=r"Modality must be .*"):
await self.send_offer(pc, test_client, audio=audio, video=video)
await self.close_peer_connection(pc)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_client_and_stream", [{"concurrency_limit": 1}], indirect=True
)
async def test_concurrency_limit_reached_two_peers(self, test_client_and_stream):
test_client, stream = test_client_and_stream
pc1, channel = await self.setup_peer_connection(video=True)
pc2, channel = await self.setup_peer_connection(video=True)
await self.send_offer(pc1, test_client)
status, metadata = await self.send_offer(
pc2, test_client, return_status_and_metadata=True
)
assert status == "failed"
assert metadata == {"error": "connection_already_exists"}
await self.close_peer_connection(pc1)
await self.close_peer_connection(pc2)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_client_and_stream", [{"concurrency_limit": 2}], indirect=True
)
async def test_concurrency_limit_reached_three_peers_same_id(
self, test_client_and_stream
):
test_client, stream = test_client_and_stream
pc1, channel = await self.setup_peer_connection(video=True)
pc2, channel = await self.setup_peer_connection(video=True)
pc3, channel = await self.setup_peer_connection(video=True)
await self.send_offer(pc1, test_client)
status, metadata = await self.send_offer(
pc2, test_client, return_status_and_metadata=True
)
assert status == "failed"
assert metadata == {"error": "connection_already_exists"}
status, metadata = await self.send_offer(
pc3, test_client, return_status_and_metadata=True
)
assert status == "failed"
assert metadata == {"error": "connection_already_exists"}
await self.close_peer_connection(pc1)
await self.close_peer_connection(pc2)
await self.close_peer_connection(pc3)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_client_and_stream", [{"concurrency_limit": 2}], indirect=True
)
async def test_concurrency_limit_reached_three_peers(self, test_client_and_stream):
test_client, stream = test_client_and_stream
pc1, channel = await self.setup_peer_connection(video=True)
pc2, channel = await self.setup_peer_connection(video=True)
pc3, channel = await self.setup_peer_connection(video=True)
await self.send_offer(pc1, test_client, webrtc_id="foo")
await self.send_offer(pc2, test_client, webrtc_id="bar")
status, metadata = await self.send_offer(
pc3, test_client, webrtc_id="baz", return_status_and_metadata=True
)
assert status == "failed"
assert metadata == {"error": "concurrency_limit_reached", "limit": 2}
await self.close_peer_connection(pc1)
await self.close_peer_connection(pc2)
await self.close_peer_connection(pc3)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_client_and_stream", [{"concurrency_limit": 1}], indirect=True
)
async def test_concurrency_limit_reached_peers_with_no_mediastreams(
self, test_client_and_stream
):
test_client, stream = test_client_and_stream
pc1, channel = await self.setup_peer_connection()
pc2, channel = await self.setup_peer_connection()
await self.send_offer(pc1, test_client, webrtc_id="foo")
status, metadata = await self.send_offer(
pc2, test_client, webrtc_id="bar", return_status_and_metadata=True
)
assert status == "failed"
assert metadata == {"error": "concurrency_limit_reached", "limit": 1}
await self.close_peer_connection(pc1)
await self.close_peer_connection(pc2)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_client_and_stream",
[
{"allow_extra_tracks": True, "modality": "audio"},
{"allow_extra_tracks": True, "modality": "video"},
],
indirect=True,
)
async def test_successful_connection_allow_extra_tracks(
self, test_client_and_stream
):
test_client, stream = test_client_and_stream
pc, channel = await self.setup_peer_connection(audio=True, video=True)
await self.send_offer(pc, test_client, audio=True, video=True)
await self.close_peer_connection(pc)