mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 01:49:23 +08:00
* Demo * patient intake * cartesia * Add cartesia * Fix * lint * Move test * Fix * Fix * Fix * Fix
310 lines
11 KiB
Python
310 lines
11 KiB
Python
import asyncio
|
|
from typing import (
|
|
Literal,
|
|
cast,
|
|
)
|
|
|
|
import pytest
|
|
from aiortc import (
|
|
AudioStreamTrack,
|
|
RTCPeerConnection,
|
|
RTCSessionDescription,
|
|
VideoStreamTrack,
|
|
)
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
from fastrtc.stream import Body
|
|
from fastrtc.tracks import HandlerType
|
|
from fastrtc.webrtc_connection_mixin import WebRTCConnectionMixin
|
|
|
|
|
|
class MinimalTestStream(WebRTCConnectionMixin):
|
|
def __init__(
|
|
self,
|
|
handler: HandlerType = lambda x: x,
|
|
*,
|
|
mode: Literal["send-receive", "receive", "send"] = "send-receive",
|
|
modality: Literal["video", "audio", "audio-video"] = "video",
|
|
concurrency_limit: int | None | Literal["default"] = "default",
|
|
time_limit: float | None = None,
|
|
allow_extra_tracks: bool = False,
|
|
):
|
|
WebRTCConnectionMixin.__init__(self)
|
|
self.mode = mode
|
|
self.modality = modality
|
|
self.event_handler = handler
|
|
self.concurrency_limit = cast(
|
|
(int),
|
|
1 if concurrency_limit in ["default", None] else concurrency_limit,
|
|
)
|
|
self.time_limit = time_limit
|
|
self.allow_extra_tracks = allow_extra_tracks
|
|
self.server_rtc_configuration = None
|
|
|
|
def mount(self, app: FastAPI, path: str = ""):
|
|
from fastapi import APIRouter
|
|
|
|
router = APIRouter(prefix=path)
|
|
router.post("/webrtc/offer")(self.offer)
|
|
app.include_router(router)
|
|
|
|
async def offer(self, body: Body):
|
|
return await self.handle_offer(
|
|
body.model_dump(), set_outputs=self.set_additional_outputs(body.webrtc_id)
|
|
)
|
|
|
|
|
|
@pytest.fixture()
|
|
def test_client_and_stream(request):
|
|
app = FastAPI()
|
|
params = request.param if hasattr(request, "param") else {}
|
|
stream = MinimalTestStream(**params)
|
|
stream.mount(app)
|
|
test_client = TestClient(app)
|
|
yield test_client, stream
|
|
|
|
|
|
class TestWebRTCConnectionMixin:
|
|
@staticmethod
|
|
async def setup_peer_connection(audio=False, video=False):
|
|
pc = RTCPeerConnection()
|
|
channel = pc.createDataChannel("test-data-channel")
|
|
if audio:
|
|
audio_track = AudioStreamTrack()
|
|
pc.addTrack(audio_track)
|
|
if video:
|
|
video_track = VideoStreamTrack()
|
|
pc.addTrack(video_track)
|
|
|
|
await pc.setLocalDescription(await pc.createOffer())
|
|
return pc, channel
|
|
|
|
@staticmethod
|
|
async def send_offer(
|
|
pc,
|
|
client,
|
|
audio=False,
|
|
video=False,
|
|
webrtc_id="test_id",
|
|
response_code=200,
|
|
return_status_and_metadata=False,
|
|
):
|
|
body = {
|
|
"sdp": pc.localDescription.sdp,
|
|
"type": pc.localDescription.type,
|
|
}
|
|
if webrtc_id is not None:
|
|
body["webrtc_id"] = webrtc_id
|
|
response = client.post(
|
|
"/webrtc/offer",
|
|
headers={"Content-Type": "application/json"},
|
|
json=body,
|
|
)
|
|
assert response.status_code == response_code
|
|
if not response_code == 200:
|
|
return
|
|
out = response.json()
|
|
if return_status_and_metadata:
|
|
assert "status" in out and "meta" in out
|
|
return out["status"], out["meta"]
|
|
assert "type" in out and out["type"] == "answer"
|
|
assert "webrtc-datachannel" in out["sdp"]
|
|
if audio:
|
|
assert "m=audio" in out["sdp"]
|
|
if video:
|
|
assert "m=video" in out["sdp"]
|
|
|
|
await pc.setRemoteDescription(RTCSessionDescription(out["sdp"], out["type"]))
|
|
|
|
# Allow data to stream
|
|
await asyncio.sleep(0.5)
|
|
|
|
@staticmethod
|
|
async def close_peer_connection(pc):
|
|
await pc.close()
|
|
assert pc.connectionState == "closed"
|
|
assert pc.iceConnectionState == "closed"
|
|
assert pc.signalingState == "closed"
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"test_client_and_stream", [{"modality": "audio"}], indirect=True
|
|
)
|
|
@pytest.mark.parametrize("audio", [True, False])
|
|
async def test_successful_connection_audio(self, test_client_and_stream, audio):
|
|
test_client, stream = test_client_and_stream
|
|
pc, channel = await self.setup_peer_connection(audio)
|
|
await self.send_offer(pc, test_client, audio)
|
|
# TODO: Test stream? E.g., when no audio is not part of the offer...
|
|
await self.close_peer_connection(pc)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"test_client_and_stream", [{"modality": "video"}], indirect=True
|
|
)
|
|
@pytest.mark.parametrize("video", [True, False])
|
|
async def test_successful_connection_video(self, test_client_and_stream, video):
|
|
test_client, stream = test_client_and_stream
|
|
pc, channel = await self.setup_peer_connection(video=video)
|
|
await self.send_offer(pc, test_client, video=video)
|
|
# TODO: Test stream? E.g., when no video is not part of the offer...
|
|
await self.close_peer_connection(pc)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"test_client_and_stream", [{"modality": "audio"}], indirect=True
|
|
)
|
|
@pytest.mark.parametrize("audio", [True, False])
|
|
async def test_unsuccessful_connection_audio(self, test_client_and_stream, audio):
|
|
test_client, stream = test_client_and_stream
|
|
pc, channel = await self.setup_peer_connection(audio=audio, video=True)
|
|
with pytest.raises(ValueError, match=r"Unsupported track kind .*"):
|
|
await self.send_offer(pc, test_client, audio=audio, video=True)
|
|
await self.close_peer_connection(pc)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"test_client_and_stream", [{"modality": "video"}], indirect=True
|
|
)
|
|
@pytest.mark.parametrize("video", [True, False])
|
|
async def test_unsuccessful_connection_video(self, test_client_and_stream, video):
|
|
test_client, stream = test_client_and_stream
|
|
pc, channel = await self.setup_peer_connection(audio=True, video=video)
|
|
with pytest.raises(ValueError, match=r"Unsupported track kind .*"):
|
|
await self.send_offer(pc, test_client, audio=True, video=video)
|
|
await self.close_peer_connection(pc)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unsuccessful_webrtc_offer_no_webrtc_id(self, test_client_and_stream):
|
|
test_client, stream = test_client_and_stream
|
|
pc, channel = await self.setup_peer_connection()
|
|
await self.send_offer(
|
|
pc,
|
|
test_client,
|
|
webrtc_id=None,
|
|
response_code=422,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"test_client_and_stream", [{"modality": "dummy"}], indirect=True
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"audio, video",
|
|
[
|
|
(True, False),
|
|
(False, True),
|
|
(True, True),
|
|
],
|
|
)
|
|
async def test_incorrect_modality(self, test_client_and_stream, audio, video):
|
|
test_client, stream = test_client_and_stream
|
|
pc, channel = await self.setup_peer_connection(audio=audio, video=video)
|
|
with pytest.raises(ValueError, match=r"Modality must be .*"):
|
|
await self.send_offer(pc, test_client, audio=audio, video=video)
|
|
await self.close_peer_connection(pc)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"test_client_and_stream", [{"concurrency_limit": 1}], indirect=True
|
|
)
|
|
async def test_concurrency_limit_reached_two_peers(self, test_client_and_stream):
|
|
test_client, stream = test_client_and_stream
|
|
pc1, channel = await self.setup_peer_connection(video=True)
|
|
pc2, channel = await self.setup_peer_connection(video=True)
|
|
await self.send_offer(pc1, test_client)
|
|
status, metadata = await self.send_offer(
|
|
pc2, test_client, return_status_and_metadata=True
|
|
)
|
|
assert status == "failed"
|
|
assert metadata == {"error": "connection_already_exists"}
|
|
|
|
await self.close_peer_connection(pc1)
|
|
await self.close_peer_connection(pc2)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"test_client_and_stream", [{"concurrency_limit": 2}], indirect=True
|
|
)
|
|
async def test_concurrency_limit_reached_three_peers_same_id(
|
|
self, test_client_and_stream
|
|
):
|
|
test_client, stream = test_client_and_stream
|
|
pc1, channel = await self.setup_peer_connection(video=True)
|
|
pc2, channel = await self.setup_peer_connection(video=True)
|
|
pc3, channel = await self.setup_peer_connection(video=True)
|
|
await self.send_offer(pc1, test_client)
|
|
status, metadata = await self.send_offer(
|
|
pc2, test_client, return_status_and_metadata=True
|
|
)
|
|
assert status == "failed"
|
|
assert metadata == {"error": "connection_already_exists"}
|
|
status, metadata = await self.send_offer(
|
|
pc3, test_client, return_status_and_metadata=True
|
|
)
|
|
assert status == "failed"
|
|
assert metadata == {"error": "connection_already_exists"}
|
|
|
|
await self.close_peer_connection(pc1)
|
|
await self.close_peer_connection(pc2)
|
|
await self.close_peer_connection(pc3)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"test_client_and_stream", [{"concurrency_limit": 2}], indirect=True
|
|
)
|
|
async def test_concurrency_limit_reached_three_peers(self, test_client_and_stream):
|
|
test_client, stream = test_client_and_stream
|
|
pc1, channel = await self.setup_peer_connection(video=True)
|
|
pc2, channel = await self.setup_peer_connection(video=True)
|
|
pc3, channel = await self.setup_peer_connection(video=True)
|
|
await self.send_offer(pc1, test_client, webrtc_id="foo")
|
|
await self.send_offer(pc2, test_client, webrtc_id="bar")
|
|
status, metadata = await self.send_offer(
|
|
pc3, test_client, webrtc_id="baz", return_status_and_metadata=True
|
|
)
|
|
|
|
assert status == "failed"
|
|
assert metadata == {"error": "concurrency_limit_reached", "limit": 2}
|
|
|
|
await self.close_peer_connection(pc1)
|
|
await self.close_peer_connection(pc2)
|
|
await self.close_peer_connection(pc3)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"test_client_and_stream", [{"concurrency_limit": 1}], indirect=True
|
|
)
|
|
async def test_concurrency_limit_reached_peers_with_no_mediastreams(
|
|
self, test_client_and_stream
|
|
):
|
|
test_client, stream = test_client_and_stream
|
|
pc1, channel = await self.setup_peer_connection()
|
|
pc2, channel = await self.setup_peer_connection()
|
|
await self.send_offer(pc1, test_client, webrtc_id="foo")
|
|
status, metadata = await self.send_offer(
|
|
pc2, test_client, webrtc_id="bar", return_status_and_metadata=True
|
|
)
|
|
assert status == "failed"
|
|
assert metadata == {"error": "concurrency_limit_reached", "limit": 1}
|
|
|
|
await self.close_peer_connection(pc1)
|
|
await self.close_peer_connection(pc2)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"test_client_and_stream",
|
|
[
|
|
{"allow_extra_tracks": True, "modality": "audio"},
|
|
{"allow_extra_tracks": True, "modality": "video"},
|
|
],
|
|
indirect=True,
|
|
)
|
|
async def test_successful_connection_allow_extra_tracks(
|
|
self, test_client_and_stream
|
|
):
|
|
test_client, stream = test_client_and_stream
|
|
pc, channel = await self.setup_peer_connection(audio=True, video=True)
|
|
await self.send_offer(pc, test_client, audio=True, video=True)
|
|
await self.close_peer_connection(pc)
|