Files
the-nexus/nexus/a2a/client.py
Alexander Whitestone bb9758c4d2
Some checks failed
CI / test (pull_request) Failing after 31s
Review Approval Gate / verify-review (pull_request) Failing after 4s
CI / validate (pull_request) Failing after 30s
feat: implement A2A protocol for fleet-wizard delegation (#1122)
Implements Google Agent2Agent Protocol v1.0 with full fleet integration:

## Phase 1 - Agent Card & Discovery
- Agent Card types with JSON serialization (camelCase, Part discrimination by key)
- Card generation from YAML config (~/.hermes/agent_card.yaml)
- Fleet registry with LocalFileRegistry + GiteaRegistry backends
- Discovery by skill ID or tag

## Phase 2 - Task Delegation
- Async A2A client with JSON-RPC SendMessage/GetTask/ListTasks/CancelTask
- FastAPI server with pluggable task handlers (skill-routed)
- CLI tool (bin/a2a_delegate.py) for fleet delegation
- Broadcast to multiple agents in parallel

## Phase 3 - Security & Reliability
- Bearer token + API key auth (configurable per agent)
- Retry logic (max 3 retries, 30s timeout)
- Audit logging for all inter-agent requests
- Error handling per A2A spec (-32001 to -32009 codes)

## Test Coverage
- 37 tests covering types, card building, registry, server integration
- Auth (required + success), handler routing, error handling

Files:
- nexus/a2a/ (types.py, card.py, client.py, server.py, registry.py)
- bin/a2a_delegate.py (CLI)
- config/ (agent_card.example.yaml, fleet_agents.json)
- docs/A2A_PROTOCOL.md
- tests/test_a2a.py (37 tests, all passing)
2026-04-13 18:31:05 -04:00

393 lines
12 KiB
Python

"""
A2A Client — send tasks to other agents over the A2A protocol.
Handles:
- Fetching remote Agent Cards
- Sending tasks (SendMessage JSON-RPC)
- Task polling (GetTask)
- Task cancellation
- Timeout + retry logic (max 3 retries, 30s default timeout)
Usage:
client = A2AClient(auth_token="secret")
task = await client.send_message("https://ezra.example.com/a2a/v1", message)
status = await client.get_task("https://ezra.example.com/a2a/v1", task_id)
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Optional
import aiohttp
from nexus.a2a.types import (
A2AError,
AgentCard,
Artifact,
JSONRPCRequest,
JSONRPCResponse,
Message,
Role,
Task,
TaskState,
TaskStatus,
TextPart,
)
logger = logging.getLogger("nexus.a2a.client")
@dataclass
class A2AClientConfig:
"""Client configuration."""
timeout: float = 30.0 # seconds per request
max_retries: int = 3
retry_delay: float = 2.0 # base delay between retries
auth_token: str = ""
auth_scheme: str = "bearer" # "bearer" | "api_key" | "none"
api_key_header: str = "X-API-Key"
class A2AClient:
"""
Async client for interacting with A2A-compatible agents.
Every agent endpoint is identified by its base URL (e.g.
https://ezra.example.com/a2a/v1). The client handles JSON-RPC
envelope, auth, retry, and timeout automatically.
"""
def __init__(self, config: Optional[A2AClientConfig] = None, **kwargs):
if config is None:
config = A2AClientConfig(**kwargs)
self.config = config
self._session: Optional[aiohttp.ClientSession] = None
self._audit_log: list[dict] = []
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
headers=self._build_auth_headers(),
)
return self._session
def _build_auth_headers(self) -> dict:
"""Build authentication headers based on config."""
headers = {"A2A-Version": "1.0", "Content-Type": "application/json"}
token = self.config.auth_token
if not token:
return headers
if self.config.auth_scheme == "bearer":
headers["Authorization"] = f"Bearer {token}"
elif self.config.auth_scheme == "api_key":
headers[self.config.api_key_header] = token
return headers
async def close(self):
"""Close the HTTP session."""
if self._session and not self._session.closed:
await self._session.close()
async def _rpc_call(
self,
endpoint: str,
method: str,
params: Optional[dict] = None,
) -> dict:
"""
Make a JSON-RPC call with retry logic.
Returns the 'result' field from the response.
Raises on JSON-RPC errors.
"""
session = await self._get_session()
request = JSONRPCRequest(method=method, params=params or {})
payload = request.to_dict()
last_error = None
for attempt in range(1, self.config.max_retries + 1):
try:
start = time.monotonic()
async with session.post(endpoint, json=payload) as resp:
elapsed = time.monotonic() - start
if resp.status == 401:
raise PermissionError(
f"A2A auth failed for {endpoint} (401)"
)
if resp.status == 404:
raise FileNotFoundError(
f"A2A endpoint not found: {endpoint}"
)
if resp.status >= 500:
body = await resp.text()
raise ConnectionError(
f"A2A server error {resp.status}: {body}"
)
data = await resp.json()
rpc_resp = JSONRPCResponse(
id=str(data.get("id", "")),
result=data.get("result"),
error=(
A2AError.INTERNAL
if "error" in data
else None
),
)
# Log for audit
self._audit_log.append({
"timestamp": time.time(),
"endpoint": endpoint,
"method": method,
"request_id": request.id,
"status_code": resp.status,
"elapsed_ms": int(elapsed * 1000),
"attempt": attempt,
})
if "error" in data:
err = data["error"]
logger.error(
f"A2A RPC error {err.get('code')}: "
f"{err.get('message')}"
)
raise RuntimeError(
f"A2A error {err.get('code')}: "
f"{err.get('message')}"
)
return data.get("result", {})
except (asyncio.TimeoutError, aiohttp.ClientError) as e:
last_error = e
logger.warning(
f"A2A request to {endpoint} attempt {attempt}/"
f"{self.config.max_retries} failed: {e}"
)
if attempt < self.config.max_retries:
delay = self.config.retry_delay * attempt
await asyncio.sleep(delay)
raise ConnectionError(
f"A2A request to {endpoint} failed after "
f"{self.config.max_retries} retries: {last_error}"
)
# --- Core A2A Methods ---
async def get_agent_card(self, base_url: str) -> AgentCard:
"""
Fetch the Agent Card from a remote agent.
Tries /.well-known/agent-card.json first, falls back to
/agent.json.
"""
session = await self._get_session()
card_urls = [
f"{base_url}/.well-known/agent-card.json",
f"{base_url}/agent.json",
]
for url in card_urls:
try:
async with session.get(url) as resp:
if resp.status == 200:
data = await resp.json()
card = AgentCard.from_dict(data)
logger.info(
f"Fetched agent card: {card.name} "
f"({len(card.skills)} skills)"
)
return card
except Exception:
continue
raise FileNotFoundError(
f"Could not fetch agent card from {base_url}"
)
async def send_message(
self,
endpoint: str,
message: Message,
accepted_output_modes: Optional[list[str]] = None,
history_length: int = 10,
return_immediately: bool = False,
) -> Task:
"""
Send a message to an agent and get a Task back.
This is the primary delegation method.
"""
params = {
"message": message.to_dict(),
"configuration": {
"acceptedOutputModes": accepted_output_modes or ["text/plain"],
"historyLength": history_length,
"returnImmediately": return_immediately,
},
}
result = await self._rpc_call(endpoint, "SendMessage", params)
# Response is either a Task or Message
if "task" in result:
task = Task.from_dict(result["task"])
logger.info(
f"Task {task.id} created, state={task.status.state.value}"
)
return task
elif "message" in result:
# Wrap message response as a completed task
msg = Message.from_dict(result["message"])
task = Task(
status=TaskStatus(state=TaskState.COMPLETED),
history=[message, msg],
artifacts=[
Artifact(parts=msg.parts, name="response")
],
)
return task
raise ValueError(f"Unexpected response structure: {list(result.keys())}")
async def get_task(self, endpoint: str, task_id: str) -> Task:
"""Get task status by ID."""
result = await self._rpc_call(
endpoint,
"GetTask",
{"id": task_id},
)
return Task.from_dict(result)
async def list_tasks(
self,
endpoint: str,
page_size: int = 20,
page_token: str = "",
) -> tuple[list[Task], str]:
"""
List tasks with cursor-based pagination.
Returns (tasks, next_page_token). Empty string = last page.
"""
result = await self._rpc_call(
endpoint,
"ListTasks",
{
"pageSize": page_size,
"pageToken": page_token,
},
)
tasks = [Task.from_dict(t) for t in result.get("tasks", [])]
next_token = result.get("nextPageToken", "")
return tasks, next_token
async def cancel_task(self, endpoint: str, task_id: str) -> Task:
"""Cancel a running task."""
result = await self._rpc_call(
endpoint,
"CancelTask",
{"id": task_id},
)
return Task.from_dict(result)
# --- Convenience Methods ---
async def delegate(
self,
agent_url: str,
text: str,
skill_id: Optional[str] = None,
metadata: Optional[dict] = None,
) -> Task:
"""
High-level delegation: send a text message to an agent.
Args:
agent_url: Full URL to agent's A2A endpoint
(e.g. https://ezra.example.com/a2a/v1)
text: The task description in natural language
skill_id: Optional skill to target
metadata: Optional metadata dict
"""
msg_metadata = metadata or {}
if skill_id:
msg_metadata["targetSkill"] = skill_id
message = Message(
role=Role.USER,
parts=[TextPart(text=text)],
metadata=msg_metadata,
)
return await self.send_message(agent_url, message)
async def wait_for_completion(
self,
endpoint: str,
task_id: str,
poll_interval: float = 2.0,
max_wait: float = 300.0,
) -> Task:
"""
Poll a task until it reaches a terminal state.
Returns the completed task.
"""
start = time.monotonic()
while True:
task = await self.get_task(endpoint, task_id)
if task.status.state.terminal:
return task
elapsed = time.monotonic() - start
if elapsed >= max_wait:
raise TimeoutError(
f"Task {task_id} did not complete within "
f"{max_wait}s (state={task.status.state.value})"
)
await asyncio.sleep(poll_interval)
def get_audit_log(self) -> list[dict]:
"""Return the audit log of all requests made by this client."""
return list(self._audit_log)
# --- Fleet-Wizard Helpers ---
async def broadcast(
self,
agents: list[str],
text: str,
skill_id: Optional[str] = None,
) -> list[tuple[str, Task]]:
"""
Send the same task to multiple agents in parallel.
Returns list of (agent_url, task) tuples.
"""
tasks = []
for agent_url in agents:
tasks.append(
self.delegate(agent_url, text, skill_id=skill_id)
)
results = await asyncio.gather(*tasks, return_exceptions=True)
paired = []
for agent_url, result in zip(agents, results):
if isinstance(result, Exception):
logger.error(f"Broadcast to {agent_url} failed: {result}")
else:
paired.append((agent_url, result))
return paired