""" Tests for A2A Protocol implementation. Covers: - Type serialization roundtrips (Agent Card, Task, Message, Artifact, Part) - JSON-RPC envelope - Agent Card building from YAML config - Registry operations (register, list, filter) - Client/server integration (end-to-end task delegation) """ from __future__ import annotations import asyncio import json import pytest from pathlib import Path from unittest.mock import AsyncMock, patch, MagicMock from nexus.a2a.types import ( A2AError, AgentCard, AgentCapabilities, AgentInterface, AgentSkill, Artifact, DataPart, FilePart, JSONRPCError, JSONRPCRequest, JSONRPCResponse, Message, Role, Task, TaskState, TaskStatus, TextPart, part_from_dict, part_to_dict, ) from nexus.a2a.card import build_card, load_card_config from nexus.a2a.registry import LocalFileRegistry # === Type Serialization Roundtrips === class TestTextPart: def test_roundtrip(self): p = TextPart(text="hello world") d = p.to_dict() assert d == {"text": "hello world"} p2 = part_from_dict(d) assert isinstance(p2, TextPart) assert p2.text == "hello world" def test_custom_media_type(self): p = TextPart(text="data", media_type="text/markdown") d = p.to_dict() assert d["mediaType"] == "text/markdown" p2 = part_from_dict(d) assert p2.media_type == "text/markdown" class TestFilePart: def test_inline_roundtrip(self): p = FilePart(media_type="image/png", raw="base64data", filename="img.png") d = p.to_dict() assert d["raw"] == "base64data" assert d["filename"] == "img.png" p2 = part_from_dict(d) assert isinstance(p2, FilePart) assert p2.raw == "base64data" def test_url_roundtrip(self): p = FilePart(media_type="application/pdf", url="https://example.com/doc.pdf") d = p.to_dict() assert d["url"] == "https://example.com/doc.pdf" p2 = part_from_dict(d) assert isinstance(p2, FilePart) assert p2.url == "https://example.com/doc.pdf" class TestDataPart: def test_roundtrip(self): p = DataPart(data={"key": "value", "count": 42}) d = p.to_dict() assert d["data"] == {"key": "value", "count": 42} p2 = part_from_dict(d) assert isinstance(p2, DataPart) assert p2.data["count"] == 42 class TestMessage: def test_roundtrip(self): msg = Message( role=Role.USER, parts=[TextPart(text="Hello agent")], metadata={"priority": "high"}, ) d = msg.to_dict() assert d["role"] == "ROLE_USER" assert d["parts"] == [{"text": "Hello agent"}] assert d["metadata"]["priority"] == "high" msg2 = Message.from_dict(d) assert msg2.role == Role.USER assert isinstance(msg2.parts[0], TextPart) assert msg2.parts[0].text == "Hello agent" assert msg2.metadata["priority"] == "high" def test_multi_part(self): msg = Message( role=Role.AGENT, parts=[ TextPart(text="Here's the report"), DataPart(data={"status": "healthy"}), ], ) d = msg.to_dict() assert len(d["parts"]) == 2 msg2 = Message.from_dict(d) assert len(msg2.parts) == 2 assert isinstance(msg2.parts[0], TextPart) assert isinstance(msg2.parts[1], DataPart) class TestArtifact: def test_roundtrip(self): art = Artifact( parts=[TextPart(text="result data")], name="report", description="CI health report", ) d = art.to_dict() assert d["name"] == "report" assert d["description"] == "CI health report" art2 = Artifact.from_dict(d) assert art2.name == "report" assert isinstance(art2.parts[0], TextPart) assert art2.parts[0].text == "result data" class TestTask: def test_roundtrip(self): task = Task( id="test-123", status=TaskStatus(state=TaskState.WORKING), history=[ Message(role=Role.USER, parts=[TextPart(text="Do X")]), ], ) d = task.to_dict() assert d["id"] == "test-123" assert d["status"]["state"] == "TASK_STATE_WORKING" task2 = Task.from_dict(d) assert task2.id == "test-123" assert task2.status.state == TaskState.WORKING assert len(task2.history) == 1 def test_with_artifacts(self): task = Task( id="art-task", status=TaskStatus(state=TaskState.COMPLETED), artifacts=[ Artifact( parts=[TextPart(text="42")], name="answer", ) ], ) d = task.to_dict() assert len(d["artifacts"]) == 1 task2 = Task.from_dict(d) assert task2.artifacts[0].name == "answer" def test_terminal_states(self): for state in [ TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED, TaskState.REJECTED, ]: assert state.terminal is True for state in [ TaskState.SUBMITTED, TaskState.WORKING, TaskState.INPUT_REQUIRED, TaskState.AUTH_REQUIRED, ]: assert state.terminal is False class TestAgentCard: def test_roundtrip(self): card = AgentCard( name="TestAgent", description="A test agent", version="1.0.0", supported_interfaces=[ AgentInterface(url="http://localhost:8080/a2a/v1") ], capabilities=AgentCapabilities(streaming=True), skills=[ AgentSkill( id="test-skill", name="Test Skill", description="Does tests", tags=["test"], ) ], ) d = card.to_dict() assert d["name"] == "TestAgent" assert d["capabilities"]["streaming"] is True assert len(d["skills"]) == 1 assert d["skills"][0]["id"] == "test-skill" card2 = AgentCard.from_dict(d) assert card2.name == "TestAgent" assert card2.skills[0].id == "test-skill" assert card2.capabilities.streaming is True class TestJSONRPC: def test_request_roundtrip(self): req = JSONRPCRequest( method="SendMessage", params={"message": {"text": "hello"}}, ) d = req.to_dict() assert d["jsonrpc"] == "2.0" assert d["method"] == "SendMessage" def test_response_success(self): resp = JSONRPCResponse( id="req-1", result={"task": {"id": "t1"}}, ) d = resp.to_dict() assert "error" not in d assert d["result"]["task"]["id"] == "t1" def test_response_error(self): resp = JSONRPCResponse( id="req-1", error=A2AError.TASK_NOT_FOUND, ) d = resp.to_dict() assert "result" not in d assert d["error"]["code"] == -32001 # === Agent Card Building === class TestBuildCard: def test_basic_config(self): config = { "name": "Bezalel", "description": "CI/CD specialist", "version": "2.0.0", "url": "https://bezalel.example.com", "skills": [ { "id": "ci-health", "name": "CI Health", "description": "Check CI", "tags": ["ci"], }, { "id": "deploy", "name": "Deploy", "description": "Deploy services", "tags": ["ops"], }, ], } card = build_card(config) assert card.name == "Bezalel" assert card.version == "2.0.0" assert len(card.skills) == 2 assert card.skills[0].id == "ci-health" assert card.supported_interfaces[0].url == "https://bezalel.example.com" def test_bearer_auth(self): config = { "name": "Test", "description": "Test", "auth": {"scheme": "bearer", "token_env": "MY_TOKEN"}, } card = build_card(config) assert "bearerAuth" in card.security_schemes assert card.security_requirements[0]["schemes"]["bearerAuth"] == {"list": []} def test_api_key_auth(self): config = { "name": "Test", "description": "Test", "auth": {"scheme": "api_key", "key_name": "X-Custom-Key"}, } card = build_card(config) assert "apiKeyAuth" in card.security_schemes # === Registry === class TestLocalFileRegistry: def _make_card(self, name: str, skills: list[dict] | None = None) -> AgentCard: return AgentCard( name=name, description=f"Agent {name}", supported_interfaces=[ AgentInterface(url=f"http://{name}:8080/a2a/v1") ], skills=[ AgentSkill( id=s["id"], name=s.get("name", s["id"]), description=s.get("description", ""), tags=s.get("tags", []), ) for s in (skills or []) ], ) def test_register_and_list(self, tmp_path): registry = LocalFileRegistry(tmp_path / "agents.json") registry.register(self._make_card("ezra")) registry.register(self._make_card("allegro")) agents = registry.list_agents() assert len(agents) == 2 names = {a.name for a in agents} assert names == {"ezra", "allegro"} def test_filter_by_skill(self, tmp_path): registry = LocalFileRegistry(tmp_path / "agents.json") registry.register( self._make_card("ezra", [{"id": "ci-health", "tags": ["ci"]}]) ) registry.register( self._make_card("allegro", [{"id": "research", "tags": ["research"]}]) ) ci_agents = registry.list_agents(skill="ci-health") assert len(ci_agents) == 1 assert ci_agents[0].name == "ezra" def test_filter_by_tag(self, tmp_path): registry = LocalFileRegistry(tmp_path / "agents.json") registry.register( self._make_card("ezra", [{"id": "ci", "tags": ["devops", "ci"]}]) ) registry.register( self._make_card("allegro", [{"id": "research", "tags": ["research"]}]) ) devops_agents = registry.list_agents(tag="devops") assert len(devops_agents) == 1 assert devops_agents[0].name == "ezra" def test_persistence(self, tmp_path): path = tmp_path / "agents.json" reg1 = LocalFileRegistry(path) reg1.register(self._make_card("ezra")) # Load fresh from disk reg2 = LocalFileRegistry(path) agents = reg2.list_agents() assert len(agents) == 1 assert agents[0].name == "ezra" def test_unregister(self, tmp_path): registry = LocalFileRegistry(tmp_path / "agents.json") registry.register(self._make_card("ezra")) assert len(registry.list_agents()) == 1 assert registry.unregister("ezra") is True assert len(registry.list_agents()) == 0 assert registry.unregister("nonexistent") is False def test_get_endpoint(self, tmp_path): registry = LocalFileRegistry(tmp_path / "agents.json") registry.register(self._make_card("ezra")) url = registry.get_endpoint("ezra") assert url == "http://ezra:8080/a2a/v1" # === Server Integration (FastAPI required) === try: from fastapi.testclient import TestClient HAS_TEST_CLIENT = True except ImportError: HAS_TEST_CLIENT = False @pytest.mark.skipif(not HAS_TEST_CLIENT, reason="fastapi not installed") class TestA2AServerIntegration: """End-to-end tests using FastAPI TestClient.""" def _make_server(self, auth_token: str = ""): from nexus.a2a.server import A2AServer, echo_handler card = AgentCard( name="TestAgent", description="Test agent for A2A", supported_interfaces=[ AgentInterface(url="http://localhost:8080/a2a/v1") ], capabilities=AgentCapabilities(streaming=False), skills=[ AgentSkill( id="echo", name="Echo", description="Echo back messages", tags=["test"], ) ], ) server = A2AServer(card=card, auth_token=auth_token) server.register_handler("echo", echo_handler) server.set_default_handler(echo_handler) return server def test_agent_card_well_known(self): server = self._make_server() client = TestClient(server.app) resp = client.get("/.well-known/agent-card.json") assert resp.status_code == 200 data = resp.json() assert data["name"] == "TestAgent" assert len(data["skills"]) == 1 def test_agent_card_fallback(self): server = self._make_server() client = TestClient(server.app) resp = client.get("/agent.json") assert resp.status_code == 200 assert resp.json()["name"] == "TestAgent" def test_send_message(self): server = self._make_server() client = TestClient(server.app) rpc_request = { "jsonrpc": "2.0", "id": "test-1", "method": "SendMessage", "params": { "message": { "messageId": "msg-1", "role": "ROLE_USER", "parts": [{"text": "Hello from test"}], }, "configuration": { "acceptedOutputModes": ["text/plain"], "historyLength": 10, "returnImmediately": False, }, }, } resp = client.post("/a2a/v1", json=rpc_request) assert resp.status_code == 200 data = resp.json() assert "result" in data assert "task" in data["result"] task = data["result"]["task"] assert task["status"]["state"] == "TASK_STATE_COMPLETED" assert len(task["artifacts"]) == 1 assert "Echo" in task["artifacts"][0]["parts"][0]["text"] def test_get_task(self): server = self._make_server() client = TestClient(server.app) # Create a task first send_req = { "jsonrpc": "2.0", "id": "s1", "method": "SendMessage", "params": { "message": { "messageId": "m1", "role": "ROLE_USER", "parts": [{"text": "get me"}], }, "configuration": {}, }, } send_resp = client.post("/a2a/v1", json=send_req) task_id = send_resp.json()["result"]["task"]["id"] # Now fetch it get_req = { "jsonrpc": "2.0", "id": "g1", "method": "GetTask", "params": {"id": task_id}, } get_resp = client.post("/a2a/v1", json=get_req) assert get_resp.status_code == 200 assert get_resp.json()["result"]["id"] == task_id def test_get_nonexistent_task(self): server = self._make_server() client = TestClient(server.app) req = { "jsonrpc": "2.0", "id": "g2", "method": "GetTask", "params": {"id": "nonexistent"}, } resp = client.post("/a2a/v1", json=req) assert resp.status_code == 400 data = resp.json() assert "error" in data def test_list_tasks(self): server = self._make_server() client = TestClient(server.app) # Create two tasks for i in range(2): req = { "jsonrpc": "2.0", "id": f"s{i}", "method": "SendMessage", "params": { "message": { "messageId": f"m{i}", "role": "ROLE_USER", "parts": [{"text": f"task {i}"}], }, "configuration": {}, }, } client.post("/a2a/v1", json=req) list_req = { "jsonrpc": "2.0", "id": "l1", "method": "ListTasks", "params": {"pageSize": 10}, } resp = client.post("/a2a/v1", json=list_req) assert resp.status_code == 200 tasks = resp.json()["result"]["tasks"] assert len(tasks) >= 2 def test_cancel_task(self): from nexus.a2a.server import A2AServer # Create a server with a slow handler so task stays WORKING async def slow_handler(task, card): import asyncio await asyncio.sleep(10) # never reached in test task.status = TaskStatus(state=TaskState.COMPLETED) return task card = AgentCard(name="SlowAgent", description="Slow test agent") server = A2AServer(card=card) server.set_default_handler(slow_handler) client = TestClient(server.app) # Create a task (but we need to intercept before handler runs) # Instead, manually insert a task and test cancel on it task = Task( id="cancel-me", status=TaskStatus(state=TaskState.WORKING), history=[ Message(role=Role.USER, parts=[TextPart(text="cancel me")]) ], ) server._tasks[task.id] = task # Cancel it cancel_req = { "jsonrpc": "2.0", "id": "c2", "method": "CancelTask", "params": {"id": "cancel-me"}, } cancel_resp = client.post("/a2a/v1", json=cancel_req) assert cancel_resp.status_code == 200 assert cancel_resp.json()["result"]["status"]["state"] == "TASK_STATE_CANCELED" def test_auth_required(self): server = self._make_server(auth_token="secret123") client = TestClient(server.app) # No auth header — should get 401 req = { "jsonrpc": "2.0", "id": "a1", "method": "SendMessage", "params": { "message": { "messageId": "am1", "role": "ROLE_USER", "parts": [{"text": "hello"}], }, "configuration": {}, }, } resp = client.post("/a2a/v1", json=req) assert resp.status_code == 401 def test_auth_success(self): server = self._make_server(auth_token="secret123") client = TestClient(server.app) req = { "jsonrpc": "2.0", "id": "a2", "method": "SendMessage", "params": { "message": { "messageId": "am2", "role": "ROLE_USER", "parts": [{"text": "authenticated"}], }, "configuration": {}, }, } resp = client.post( "/a2a/v1", json=req, headers={"Authorization": "Bearer secret123"}, ) assert resp.status_code == 200 assert resp.json()["result"]["task"]["status"]["state"] == "TASK_STATE_COMPLETED" def test_unknown_method(self): server = self._make_server() client = TestClient(server.app) req = { "jsonrpc": "2.0", "id": "u1", "method": "NonExistentMethod", "params": {}, } resp = client.post("/a2a/v1", json=req) assert resp.status_code == 400 assert resp.json()["error"]["code"] == -32602 def test_audit_log(self): server = self._make_server() client = TestClient(server.app) req = { "jsonrpc": "2.0", "id": "au1", "method": "SendMessage", "params": { "message": { "messageId": "aum1", "role": "ROLE_USER", "parts": [{"text": "audit me"}], }, "configuration": {}, }, } client.post("/a2a/v1", json=req) client.post("/a2a/v1", json=req) log = server.get_audit_log() assert len(log) == 2 assert all(entry["method"] == "SendMessage" for entry in log) # === Custom Handler Test === @pytest.mark.skipif(not HAS_TEST_CLIENT, reason="fastapi not installed") class TestCustomHandlers: """Test custom task handlers.""" def test_skill_routing(self): from nexus.a2a.server import A2AServer from nexus.a2a.types import Task, AgentCard async def ci_handler(task: Task, card: AgentCard) -> Task: task.artifacts.append( Artifact( parts=[TextPart(text="CI pipeline healthy: 5/5 passed")], name="ci_report", ) ) task.status = TaskStatus(state=TaskState.COMPLETED) return task card = AgentCard( name="CI Agent", description="CI specialist", skills=[AgentSkill(id="ci-health", name="CI Health", description="Check CI", tags=["ci"])], ) server = A2AServer(card=card) server.register_handler("ci-health", ci_handler) client = TestClient(server.app) req = { "jsonrpc": "2.0", "id": "h1", "method": "SendMessage", "params": { "message": { "messageId": "hm1", "role": "ROLE_USER", "parts": [{"text": "Check CI"}], "metadata": {"targetSkill": "ci-health"}, }, "configuration": {}, }, } resp = client.post("/a2a/v1", json=req) task_data = resp.json()["result"]["task"] assert task_data["status"]["state"] == "TASK_STATE_COMPLETED" assert "5/5 passed" in task_data["artifacts"][0]["parts"][0]["text"] def test_handler_error(self): from nexus.a2a.server import A2AServer from nexus.a2a.types import Task, AgentCard async def failing_handler(task: Task, card: AgentCard) -> Task: raise RuntimeError("Handler blew up") card = AgentCard(name="Fail Agent", description="Fails") server = A2AServer(card=card) server.set_default_handler(failing_handler) client = TestClient(server.app) req = { "jsonrpc": "2.0", "id": "f1", "method": "SendMessage", "params": { "message": { "messageId": "fm1", "role": "ROLE_USER", "parts": [{"text": "break"}], }, "configuration": {}, }, } resp = client.post("/a2a/v1", json=req) task_data = resp.json()["result"]["task"] assert task_data["status"]["state"] == "TASK_STATE_FAILED" assert "blew up" in task_data["status"]["message"]["parts"][0]["text"].lower()