Files
the-nexus/tests/test_a2a.py
Alexander Whitestone bb9758c4d2
Some checks failed
CI / test (pull_request) Failing after 31s
Review Approval Gate / verify-review (pull_request) Failing after 4s
CI / validate (pull_request) Failing after 30s
feat: implement A2A protocol for fleet-wizard delegation (#1122)
Implements Google Agent2Agent Protocol v1.0 with full fleet integration:

## Phase 1 - Agent Card & Discovery
- Agent Card types with JSON serialization (camelCase, Part discrimination by key)
- Card generation from YAML config (~/.hermes/agent_card.yaml)
- Fleet registry with LocalFileRegistry + GiteaRegistry backends
- Discovery by skill ID or tag

## Phase 2 - Task Delegation
- Async A2A client with JSON-RPC SendMessage/GetTask/ListTasks/CancelTask
- FastAPI server with pluggable task handlers (skill-routed)
- CLI tool (bin/a2a_delegate.py) for fleet delegation
- Broadcast to multiple agents in parallel

## Phase 3 - Security & Reliability
- Bearer token + API key auth (configurable per agent)
- Retry logic (max 3 retries, 30s timeout)
- Audit logging for all inter-agent requests
- Error handling per A2A spec (-32001 to -32009 codes)

## Test Coverage
- 37 tests covering types, card building, registry, server integration
- Auth (required + success), handler routing, error handling

Files:
- nexus/a2a/ (types.py, card.py, client.py, server.py, registry.py)
- bin/a2a_delegate.py (CLI)
- config/ (agent_card.example.yaml, fleet_agents.json)
- docs/A2A_PROTOCOL.md
- tests/test_a2a.py (37 tests, all passing)
2026-04-13 18:31:05 -04:00

764 lines
23 KiB
Python

"""
Tests for A2A Protocol implementation.
Covers:
- Type serialization roundtrips (Agent Card, Task, Message, Artifact, Part)
- JSON-RPC envelope
- Agent Card building from YAML config
- Registry operations (register, list, filter)
- Client/server integration (end-to-end task delegation)
"""
from __future__ import annotations
import asyncio
import json
import pytest
from pathlib import Path
from unittest.mock import AsyncMock, patch, MagicMock
from nexus.a2a.types import (
A2AError,
AgentCard,
AgentCapabilities,
AgentInterface,
AgentSkill,
Artifact,
DataPart,
FilePart,
JSONRPCError,
JSONRPCRequest,
JSONRPCResponse,
Message,
Role,
Task,
TaskState,
TaskStatus,
TextPart,
part_from_dict,
part_to_dict,
)
from nexus.a2a.card import build_card, load_card_config
from nexus.a2a.registry import LocalFileRegistry
# === Type Serialization Roundtrips ===
class TestTextPart:
def test_roundtrip(self):
p = TextPart(text="hello world")
d = p.to_dict()
assert d == {"text": "hello world"}
p2 = part_from_dict(d)
assert isinstance(p2, TextPart)
assert p2.text == "hello world"
def test_custom_media_type(self):
p = TextPart(text="data", media_type="text/markdown")
d = p.to_dict()
assert d["mediaType"] == "text/markdown"
p2 = part_from_dict(d)
assert p2.media_type == "text/markdown"
class TestFilePart:
def test_inline_roundtrip(self):
p = FilePart(media_type="image/png", raw="base64data", filename="img.png")
d = p.to_dict()
assert d["raw"] == "base64data"
assert d["filename"] == "img.png"
p2 = part_from_dict(d)
assert isinstance(p2, FilePart)
assert p2.raw == "base64data"
def test_url_roundtrip(self):
p = FilePart(media_type="application/pdf", url="https://example.com/doc.pdf")
d = p.to_dict()
assert d["url"] == "https://example.com/doc.pdf"
p2 = part_from_dict(d)
assert isinstance(p2, FilePart)
assert p2.url == "https://example.com/doc.pdf"
class TestDataPart:
def test_roundtrip(self):
p = DataPart(data={"key": "value", "count": 42})
d = p.to_dict()
assert d["data"] == {"key": "value", "count": 42}
p2 = part_from_dict(d)
assert isinstance(p2, DataPart)
assert p2.data["count"] == 42
class TestMessage:
def test_roundtrip(self):
msg = Message(
role=Role.USER,
parts=[TextPart(text="Hello agent")],
metadata={"priority": "high"},
)
d = msg.to_dict()
assert d["role"] == "ROLE_USER"
assert d["parts"] == [{"text": "Hello agent"}]
assert d["metadata"]["priority"] == "high"
msg2 = Message.from_dict(d)
assert msg2.role == Role.USER
assert isinstance(msg2.parts[0], TextPart)
assert msg2.parts[0].text == "Hello agent"
assert msg2.metadata["priority"] == "high"
def test_multi_part(self):
msg = Message(
role=Role.AGENT,
parts=[
TextPart(text="Here's the report"),
DataPart(data={"status": "healthy"}),
],
)
d = msg.to_dict()
assert len(d["parts"]) == 2
msg2 = Message.from_dict(d)
assert len(msg2.parts) == 2
assert isinstance(msg2.parts[0], TextPart)
assert isinstance(msg2.parts[1], DataPart)
class TestArtifact:
def test_roundtrip(self):
art = Artifact(
parts=[TextPart(text="result data")],
name="report",
description="CI health report",
)
d = art.to_dict()
assert d["name"] == "report"
assert d["description"] == "CI health report"
art2 = Artifact.from_dict(d)
assert art2.name == "report"
assert isinstance(art2.parts[0], TextPart)
assert art2.parts[0].text == "result data"
class TestTask:
def test_roundtrip(self):
task = Task(
id="test-123",
status=TaskStatus(state=TaskState.WORKING),
history=[
Message(role=Role.USER, parts=[TextPart(text="Do X")]),
],
)
d = task.to_dict()
assert d["id"] == "test-123"
assert d["status"]["state"] == "TASK_STATE_WORKING"
task2 = Task.from_dict(d)
assert task2.id == "test-123"
assert task2.status.state == TaskState.WORKING
assert len(task2.history) == 1
def test_with_artifacts(self):
task = Task(
id="art-task",
status=TaskStatus(state=TaskState.COMPLETED),
artifacts=[
Artifact(
parts=[TextPart(text="42")],
name="answer",
)
],
)
d = task.to_dict()
assert len(d["artifacts"]) == 1
task2 = Task.from_dict(d)
assert task2.artifacts[0].name == "answer"
def test_terminal_states(self):
for state in [
TaskState.COMPLETED,
TaskState.FAILED,
TaskState.CANCELED,
TaskState.REJECTED,
]:
assert state.terminal is True
for state in [
TaskState.SUBMITTED,
TaskState.WORKING,
TaskState.INPUT_REQUIRED,
TaskState.AUTH_REQUIRED,
]:
assert state.terminal is False
class TestAgentCard:
def test_roundtrip(self):
card = AgentCard(
name="TestAgent",
description="A test agent",
version="1.0.0",
supported_interfaces=[
AgentInterface(url="http://localhost:8080/a2a/v1")
],
capabilities=AgentCapabilities(streaming=True),
skills=[
AgentSkill(
id="test-skill",
name="Test Skill",
description="Does tests",
tags=["test"],
)
],
)
d = card.to_dict()
assert d["name"] == "TestAgent"
assert d["capabilities"]["streaming"] is True
assert len(d["skills"]) == 1
assert d["skills"][0]["id"] == "test-skill"
card2 = AgentCard.from_dict(d)
assert card2.name == "TestAgent"
assert card2.skills[0].id == "test-skill"
assert card2.capabilities.streaming is True
class TestJSONRPC:
def test_request_roundtrip(self):
req = JSONRPCRequest(
method="SendMessage",
params={"message": {"text": "hello"}},
)
d = req.to_dict()
assert d["jsonrpc"] == "2.0"
assert d["method"] == "SendMessage"
def test_response_success(self):
resp = JSONRPCResponse(
id="req-1",
result={"task": {"id": "t1"}},
)
d = resp.to_dict()
assert "error" not in d
assert d["result"]["task"]["id"] == "t1"
def test_response_error(self):
resp = JSONRPCResponse(
id="req-1",
error=A2AError.TASK_NOT_FOUND,
)
d = resp.to_dict()
assert "result" not in d
assert d["error"]["code"] == -32001
# === Agent Card Building ===
class TestBuildCard:
def test_basic_config(self):
config = {
"name": "Bezalel",
"description": "CI/CD specialist",
"version": "2.0.0",
"url": "https://bezalel.example.com",
"skills": [
{
"id": "ci-health",
"name": "CI Health",
"description": "Check CI",
"tags": ["ci"],
},
{
"id": "deploy",
"name": "Deploy",
"description": "Deploy services",
"tags": ["ops"],
},
],
}
card = build_card(config)
assert card.name == "Bezalel"
assert card.version == "2.0.0"
assert len(card.skills) == 2
assert card.skills[0].id == "ci-health"
assert card.supported_interfaces[0].url == "https://bezalel.example.com"
def test_bearer_auth(self):
config = {
"name": "Test",
"description": "Test",
"auth": {"scheme": "bearer", "token_env": "MY_TOKEN"},
}
card = build_card(config)
assert "bearerAuth" in card.security_schemes
assert card.security_requirements[0]["schemes"]["bearerAuth"] == {"list": []}
def test_api_key_auth(self):
config = {
"name": "Test",
"description": "Test",
"auth": {"scheme": "api_key", "key_name": "X-Custom-Key"},
}
card = build_card(config)
assert "apiKeyAuth" in card.security_schemes
# === Registry ===
class TestLocalFileRegistry:
def _make_card(self, name: str, skills: list[dict] | None = None) -> AgentCard:
return AgentCard(
name=name,
description=f"Agent {name}",
supported_interfaces=[
AgentInterface(url=f"http://{name}:8080/a2a/v1")
],
skills=[
AgentSkill(
id=s["id"],
name=s.get("name", s["id"]),
description=s.get("description", ""),
tags=s.get("tags", []),
)
for s in (skills or [])
],
)
def test_register_and_list(self, tmp_path):
registry = LocalFileRegistry(tmp_path / "agents.json")
registry.register(self._make_card("ezra"))
registry.register(self._make_card("allegro"))
agents = registry.list_agents()
assert len(agents) == 2
names = {a.name for a in agents}
assert names == {"ezra", "allegro"}
def test_filter_by_skill(self, tmp_path):
registry = LocalFileRegistry(tmp_path / "agents.json")
registry.register(
self._make_card("ezra", [{"id": "ci-health", "tags": ["ci"]}])
)
registry.register(
self._make_card("allegro", [{"id": "research", "tags": ["research"]}])
)
ci_agents = registry.list_agents(skill="ci-health")
assert len(ci_agents) == 1
assert ci_agents[0].name == "ezra"
def test_filter_by_tag(self, tmp_path):
registry = LocalFileRegistry(tmp_path / "agents.json")
registry.register(
self._make_card("ezra", [{"id": "ci", "tags": ["devops", "ci"]}])
)
registry.register(
self._make_card("allegro", [{"id": "research", "tags": ["research"]}])
)
devops_agents = registry.list_agents(tag="devops")
assert len(devops_agents) == 1
assert devops_agents[0].name == "ezra"
def test_persistence(self, tmp_path):
path = tmp_path / "agents.json"
reg1 = LocalFileRegistry(path)
reg1.register(self._make_card("ezra"))
# Load fresh from disk
reg2 = LocalFileRegistry(path)
agents = reg2.list_agents()
assert len(agents) == 1
assert agents[0].name == "ezra"
def test_unregister(self, tmp_path):
registry = LocalFileRegistry(tmp_path / "agents.json")
registry.register(self._make_card("ezra"))
assert len(registry.list_agents()) == 1
assert registry.unregister("ezra") is True
assert len(registry.list_agents()) == 0
assert registry.unregister("nonexistent") is False
def test_get_endpoint(self, tmp_path):
registry = LocalFileRegistry(tmp_path / "agents.json")
registry.register(self._make_card("ezra"))
url = registry.get_endpoint("ezra")
assert url == "http://ezra:8080/a2a/v1"
# === Server Integration (FastAPI required) ===
try:
from fastapi.testclient import TestClient
HAS_TEST_CLIENT = True
except ImportError:
HAS_TEST_CLIENT = False
@pytest.mark.skipif(not HAS_TEST_CLIENT, reason="fastapi not installed")
class TestA2AServerIntegration:
"""End-to-end tests using FastAPI TestClient."""
def _make_server(self, auth_token: str = ""):
from nexus.a2a.server import A2AServer, echo_handler
card = AgentCard(
name="TestAgent",
description="Test agent for A2A",
supported_interfaces=[
AgentInterface(url="http://localhost:8080/a2a/v1")
],
capabilities=AgentCapabilities(streaming=False),
skills=[
AgentSkill(
id="echo",
name="Echo",
description="Echo back messages",
tags=["test"],
)
],
)
server = A2AServer(card=card, auth_token=auth_token)
server.register_handler("echo", echo_handler)
server.set_default_handler(echo_handler)
return server
def test_agent_card_well_known(self):
server = self._make_server()
client = TestClient(server.app)
resp = client.get("/.well-known/agent-card.json")
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "TestAgent"
assert len(data["skills"]) == 1
def test_agent_card_fallback(self):
server = self._make_server()
client = TestClient(server.app)
resp = client.get("/agent.json")
assert resp.status_code == 200
assert resp.json()["name"] == "TestAgent"
def test_send_message(self):
server = self._make_server()
client = TestClient(server.app)
rpc_request = {
"jsonrpc": "2.0",
"id": "test-1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "msg-1",
"role": "ROLE_USER",
"parts": [{"text": "Hello from test"}],
},
"configuration": {
"acceptedOutputModes": ["text/plain"],
"historyLength": 10,
"returnImmediately": False,
},
},
}
resp = client.post("/a2a/v1", json=rpc_request)
assert resp.status_code == 200
data = resp.json()
assert "result" in data
assert "task" in data["result"]
task = data["result"]["task"]
assert task["status"]["state"] == "TASK_STATE_COMPLETED"
assert len(task["artifacts"]) == 1
assert "Echo" in task["artifacts"][0]["parts"][0]["text"]
def test_get_task(self):
server = self._make_server()
client = TestClient(server.app)
# Create a task first
send_req = {
"jsonrpc": "2.0",
"id": "s1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "m1",
"role": "ROLE_USER",
"parts": [{"text": "get me"}],
},
"configuration": {},
},
}
send_resp = client.post("/a2a/v1", json=send_req)
task_id = send_resp.json()["result"]["task"]["id"]
# Now fetch it
get_req = {
"jsonrpc": "2.0",
"id": "g1",
"method": "GetTask",
"params": {"id": task_id},
}
get_resp = client.post("/a2a/v1", json=get_req)
assert get_resp.status_code == 200
assert get_resp.json()["result"]["id"] == task_id
def test_get_nonexistent_task(self):
server = self._make_server()
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "g2",
"method": "GetTask",
"params": {"id": "nonexistent"},
}
resp = client.post("/a2a/v1", json=req)
assert resp.status_code == 400
data = resp.json()
assert "error" in data
def test_list_tasks(self):
server = self._make_server()
client = TestClient(server.app)
# Create two tasks
for i in range(2):
req = {
"jsonrpc": "2.0",
"id": f"s{i}",
"method": "SendMessage",
"params": {
"message": {
"messageId": f"m{i}",
"role": "ROLE_USER",
"parts": [{"text": f"task {i}"}],
},
"configuration": {},
},
}
client.post("/a2a/v1", json=req)
list_req = {
"jsonrpc": "2.0",
"id": "l1",
"method": "ListTasks",
"params": {"pageSize": 10},
}
resp = client.post("/a2a/v1", json=list_req)
assert resp.status_code == 200
tasks = resp.json()["result"]["tasks"]
assert len(tasks) >= 2
def test_cancel_task(self):
from nexus.a2a.server import A2AServer
# Create a server with a slow handler so task stays WORKING
async def slow_handler(task, card):
import asyncio
await asyncio.sleep(10) # never reached in test
task.status = TaskStatus(state=TaskState.COMPLETED)
return task
card = AgentCard(name="SlowAgent", description="Slow test agent")
server = A2AServer(card=card)
server.set_default_handler(slow_handler)
client = TestClient(server.app)
# Create a task (but we need to intercept before handler runs)
# Instead, manually insert a task and test cancel on it
task = Task(
id="cancel-me",
status=TaskStatus(state=TaskState.WORKING),
history=[
Message(role=Role.USER, parts=[TextPart(text="cancel me")])
],
)
server._tasks[task.id] = task
# Cancel it
cancel_req = {
"jsonrpc": "2.0",
"id": "c2",
"method": "CancelTask",
"params": {"id": "cancel-me"},
}
cancel_resp = client.post("/a2a/v1", json=cancel_req)
assert cancel_resp.status_code == 200
assert cancel_resp.json()["result"]["status"]["state"] == "TASK_STATE_CANCELED"
def test_auth_required(self):
server = self._make_server(auth_token="secret123")
client = TestClient(server.app)
# No auth header — should get 401
req = {
"jsonrpc": "2.0",
"id": "a1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "am1",
"role": "ROLE_USER",
"parts": [{"text": "hello"}],
},
"configuration": {},
},
}
resp = client.post("/a2a/v1", json=req)
assert resp.status_code == 401
def test_auth_success(self):
server = self._make_server(auth_token="secret123")
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "a2",
"method": "SendMessage",
"params": {
"message": {
"messageId": "am2",
"role": "ROLE_USER",
"parts": [{"text": "authenticated"}],
},
"configuration": {},
},
}
resp = client.post(
"/a2a/v1",
json=req,
headers={"Authorization": "Bearer secret123"},
)
assert resp.status_code == 200
assert resp.json()["result"]["task"]["status"]["state"] == "TASK_STATE_COMPLETED"
def test_unknown_method(self):
server = self._make_server()
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "u1",
"method": "NonExistentMethod",
"params": {},
}
resp = client.post("/a2a/v1", json=req)
assert resp.status_code == 400
assert resp.json()["error"]["code"] == -32602
def test_audit_log(self):
server = self._make_server()
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "au1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "aum1",
"role": "ROLE_USER",
"parts": [{"text": "audit me"}],
},
"configuration": {},
},
}
client.post("/a2a/v1", json=req)
client.post("/a2a/v1", json=req)
log = server.get_audit_log()
assert len(log) == 2
assert all(entry["method"] == "SendMessage" for entry in log)
# === Custom Handler Test ===
@pytest.mark.skipif(not HAS_TEST_CLIENT, reason="fastapi not installed")
class TestCustomHandlers:
"""Test custom task handlers."""
def test_skill_routing(self):
from nexus.a2a.server import A2AServer
from nexus.a2a.types import Task, AgentCard
async def ci_handler(task: Task, card: AgentCard) -> Task:
task.artifacts.append(
Artifact(
parts=[TextPart(text="CI pipeline healthy: 5/5 passed")],
name="ci_report",
)
)
task.status = TaskStatus(state=TaskState.COMPLETED)
return task
card = AgentCard(
name="CI Agent",
description="CI specialist",
skills=[AgentSkill(id="ci-health", name="CI Health", description="Check CI", tags=["ci"])],
)
server = A2AServer(card=card)
server.register_handler("ci-health", ci_handler)
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "h1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "hm1",
"role": "ROLE_USER",
"parts": [{"text": "Check CI"}],
"metadata": {"targetSkill": "ci-health"},
},
"configuration": {},
},
}
resp = client.post("/a2a/v1", json=req)
task_data = resp.json()["result"]["task"]
assert task_data["status"]["state"] == "TASK_STATE_COMPLETED"
assert "5/5 passed" in task_data["artifacts"][0]["parts"][0]["text"]
def test_handler_error(self):
from nexus.a2a.server import A2AServer
from nexus.a2a.types import Task, AgentCard
async def failing_handler(task: Task, card: AgentCard) -> Task:
raise RuntimeError("Handler blew up")
card = AgentCard(name="Fail Agent", description="Fails")
server = A2AServer(card=card)
server.set_default_handler(failing_handler)
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "f1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "fm1",
"role": "ROLE_USER",
"parts": [{"text": "break"}],
},
"configuration": {},
},
}
resp = client.post("/a2a/v1", json=req)
task_data = resp.json()["result"]["task"]
assert task_data["status"]["state"] == "TASK_STATE_FAILED"
assert "blew up" in task_data["status"]["message"]["parts"][0]["text"].lower()