""" A2A Client — send tasks to other agents over the A2A protocol. Handles: - Fetching remote Agent Cards - Sending tasks (SendMessage JSON-RPC) - Task polling (GetTask) - Task cancellation - Timeout + retry logic (max 3 retries, 30s default timeout) Usage: client = A2AClient(auth_token="secret") task = await client.send_message("https://ezra.example.com/a2a/v1", message) status = await client.get_task("https://ezra.example.com/a2a/v1", task_id) """ from __future__ import annotations import asyncio import json import logging import time import uuid from dataclasses import dataclass, field from typing import Any, Optional import aiohttp from nexus.a2a.types import ( A2AError, AgentCard, Artifact, JSONRPCRequest, JSONRPCResponse, Message, Role, Task, TaskState, TaskStatus, TextPart, ) logger = logging.getLogger("nexus.a2a.client") @dataclass class A2AClientConfig: """Client configuration.""" timeout: float = 30.0 # seconds per request max_retries: int = 3 retry_delay: float = 2.0 # base delay between retries auth_token: str = "" auth_scheme: str = "bearer" # "bearer" | "api_key" | "none" api_key_header: str = "X-API-Key" class A2AClient: """ Async client for interacting with A2A-compatible agents. Every agent endpoint is identified by its base URL (e.g. https://ezra.example.com/a2a/v1). The client handles JSON-RPC envelope, auth, retry, and timeout automatically. """ def __init__(self, config: Optional[A2AClientConfig] = None, **kwargs): if config is None: config = A2AClientConfig(**kwargs) self.config = config self._session: Optional[aiohttp.ClientSession] = None self._audit_log: list[dict] = [] async def _get_session(self) -> aiohttp.ClientSession: if self._session is None or self._session.closed: self._session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=self.config.timeout), headers=self._build_auth_headers(), ) return self._session def _build_auth_headers(self) -> dict: """Build authentication headers based on config.""" headers = {"A2A-Version": "1.0", "Content-Type": "application/json"} token = self.config.auth_token if not token: return headers if self.config.auth_scheme == "bearer": headers["Authorization"] = f"Bearer {token}" elif self.config.auth_scheme == "api_key": headers[self.config.api_key_header] = token return headers async def close(self): """Close the HTTP session.""" if self._session and not self._session.closed: await self._session.close() async def _rpc_call( self, endpoint: str, method: str, params: Optional[dict] = None, ) -> dict: """ Make a JSON-RPC call with retry logic. Returns the 'result' field from the response. Raises on JSON-RPC errors. """ session = await self._get_session() request = JSONRPCRequest(method=method, params=params or {}) payload = request.to_dict() last_error = None for attempt in range(1, self.config.max_retries + 1): try: start = time.monotonic() async with session.post(endpoint, json=payload) as resp: elapsed = time.monotonic() - start if resp.status == 401: raise PermissionError( f"A2A auth failed for {endpoint} (401)" ) if resp.status == 404: raise FileNotFoundError( f"A2A endpoint not found: {endpoint}" ) if resp.status >= 500: body = await resp.text() raise ConnectionError( f"A2A server error {resp.status}: {body}" ) data = await resp.json() rpc_resp = JSONRPCResponse( id=str(data.get("id", "")), result=data.get("result"), error=( A2AError.INTERNAL if "error" in data else None ), ) # Log for audit self._audit_log.append({ "timestamp": time.time(), "endpoint": endpoint, "method": method, "request_id": request.id, "status_code": resp.status, "elapsed_ms": int(elapsed * 1000), "attempt": attempt, }) if "error" in data: err = data["error"] logger.error( f"A2A RPC error {err.get('code')}: " f"{err.get('message')}" ) raise RuntimeError( f"A2A error {err.get('code')}: " f"{err.get('message')}" ) return data.get("result", {}) except (asyncio.TimeoutError, aiohttp.ClientError) as e: last_error = e logger.warning( f"A2A request to {endpoint} attempt {attempt}/" f"{self.config.max_retries} failed: {e}" ) if attempt < self.config.max_retries: delay = self.config.retry_delay * attempt await asyncio.sleep(delay) raise ConnectionError( f"A2A request to {endpoint} failed after " f"{self.config.max_retries} retries: {last_error}" ) # --- Core A2A Methods --- async def get_agent_card(self, base_url: str) -> AgentCard: """ Fetch the Agent Card from a remote agent. Tries /.well-known/agent-card.json first, falls back to /agent.json. """ session = await self._get_session() card_urls = [ f"{base_url}/.well-known/agent-card.json", f"{base_url}/agent.json", ] for url in card_urls: try: async with session.get(url) as resp: if resp.status == 200: data = await resp.json() card = AgentCard.from_dict(data) logger.info( f"Fetched agent card: {card.name} " f"({len(card.skills)} skills)" ) return card except Exception: continue raise FileNotFoundError( f"Could not fetch agent card from {base_url}" ) async def send_message( self, endpoint: str, message: Message, accepted_output_modes: Optional[list[str]] = None, history_length: int = 10, return_immediately: bool = False, ) -> Task: """ Send a message to an agent and get a Task back. This is the primary delegation method. """ params = { "message": message.to_dict(), "configuration": { "acceptedOutputModes": accepted_output_modes or ["text/plain"], "historyLength": history_length, "returnImmediately": return_immediately, }, } result = await self._rpc_call(endpoint, "SendMessage", params) # Response is either a Task or Message if "task" in result: task = Task.from_dict(result["task"]) logger.info( f"Task {task.id} created, state={task.status.state.value}" ) return task elif "message" in result: # Wrap message response as a completed task msg = Message.from_dict(result["message"]) task = Task( status=TaskStatus(state=TaskState.COMPLETED), history=[message, msg], artifacts=[ Artifact(parts=msg.parts, name="response") ], ) return task raise ValueError(f"Unexpected response structure: {list(result.keys())}") async def get_task(self, endpoint: str, task_id: str) -> Task: """Get task status by ID.""" result = await self._rpc_call( endpoint, "GetTask", {"id": task_id}, ) return Task.from_dict(result) async def list_tasks( self, endpoint: str, page_size: int = 20, page_token: str = "", ) -> tuple[list[Task], str]: """ List tasks with cursor-based pagination. Returns (tasks, next_page_token). Empty string = last page. """ result = await self._rpc_call( endpoint, "ListTasks", { "pageSize": page_size, "pageToken": page_token, }, ) tasks = [Task.from_dict(t) for t in result.get("tasks", [])] next_token = result.get("nextPageToken", "") return tasks, next_token async def cancel_task(self, endpoint: str, task_id: str) -> Task: """Cancel a running task.""" result = await self._rpc_call( endpoint, "CancelTask", {"id": task_id}, ) return Task.from_dict(result) # --- Convenience Methods --- async def delegate( self, agent_url: str, text: str, skill_id: Optional[str] = None, metadata: Optional[dict] = None, ) -> Task: """ High-level delegation: send a text message to an agent. Args: agent_url: Full URL to agent's A2A endpoint (e.g. https://ezra.example.com/a2a/v1) text: The task description in natural language skill_id: Optional skill to target metadata: Optional metadata dict """ msg_metadata = metadata or {} if skill_id: msg_metadata["targetSkill"] = skill_id message = Message( role=Role.USER, parts=[TextPart(text=text)], metadata=msg_metadata, ) return await self.send_message(agent_url, message) async def wait_for_completion( self, endpoint: str, task_id: str, poll_interval: float = 2.0, max_wait: float = 300.0, ) -> Task: """ Poll a task until it reaches a terminal state. Returns the completed task. """ start = time.monotonic() while True: task = await self.get_task(endpoint, task_id) if task.status.state.terminal: return task elapsed = time.monotonic() - start if elapsed >= max_wait: raise TimeoutError( f"Task {task_id} did not complete within " f"{max_wait}s (state={task.status.state.value})" ) await asyncio.sleep(poll_interval) def get_audit_log(self) -> list[dict]: """Return the audit log of all requests made by this client.""" return list(self._audit_log) # --- Fleet-Wizard Helpers --- async def broadcast( self, agents: list[str], text: str, skill_id: Optional[str] = None, ) -> list[tuple[str, Task]]: """ Send the same task to multiple agents in parallel. Returns list of (agent_url, task) tuples. """ tasks = [] for agent_url in agents: tasks.append( self.delegate(agent_url, text, skill_id=skill_id) ) results = await asyncio.gather(*tasks, return_exceptions=True) paired = [] for agent_url, result in zip(agents, results): if isinstance(result, Exception): logger.error(f"Broadcast to {agent_url} failed: {result}") else: paired.append((agent_url, result)) return paired