""" A2A Server — receive and process tasks from other agents. Provides a FastAPI router that serves: - GET /.well-known/agent-card.json — Agent Card discovery - GET /agent.json — Agent Card fallback - POST /a2a/v1 — JSON-RPC endpoint (SendMessage, GetTask, etc.) - POST /a2a/v1/rpc — JSON-RPC endpoint (alias) Task routing: registered handlers are matched by skill ID or receive all tasks via a default handler. Usage: server = A2AServer(card=my_card, auth_token="secret") server.register_handler("ci-health", my_ci_handler) await server.start(host="0.0.0.0", port=8080) """ from __future__ import annotations import asyncio import json import logging import time import uuid from datetime import datetime, timezone from typing import Any, Callable, Awaitable, Optional try: from fastapi import FastAPI, Request, Response, HTTPException, Header from fastapi.responses import JSONResponse import uvicorn HAS_FASTAPI = True except ImportError: HAS_FASTAPI = False from nexus.a2a.types import ( A2AError, AgentCard, Artifact, JSONRPCError, JSONRPCResponse, Message, Role, Task, TaskState, TaskStatus, TextPart, ) logger = logging.getLogger("nexus.a2a.server") # Type for task handlers TaskHandler = Callable[[Task, AgentCard], Awaitable[Task]] class A2AServer: """ A2A protocol server for receiving agent-to-agent task delegation. Supports: - Agent Card serving at /.well-known/agent-card.json - JSON-RPC task lifecycle (SendMessage, GetTask, CancelTask, ListTasks) - Pluggable task handlers (by skill ID or default) - Bearer / API key authentication - Audit logging """ def __init__( self, card: AgentCard, auth_token: str = "", auth_scheme: str = "bearer", ): if not HAS_FASTAPI: raise ImportError( "fastapi and uvicorn are required for A2AServer. " "Install with: pip install fastapi uvicorn" ) self.card = card self.auth_token = auth_token self.auth_scheme = auth_scheme # Task store (in-memory; swap for SQLite/Redis in production) self._tasks: dict[str, Task] = {} # Handlers keyed by skill ID self._handlers: dict[str, TaskHandler] = {} # Default handler for unmatched skills self._default_handler: Optional[TaskHandler] = None # Audit log self._audit_log: list[dict] = [] self.app = FastAPI( title=f"A2A — {card.name}", description=card.description, version=card.version, ) self._register_routes() def register_handler(self, skill_id: str, handler: TaskHandler): """Register a handler for a specific skill ID.""" self._handlers[skill_id] = handler logger.info(f"Registered handler for skill: {skill_id}") def set_default_handler(self, handler: TaskHandler): """Set the fallback handler for tasks without a matching skill.""" self._default_handler = handler def _verify_auth(self, authorization: Optional[str]) -> bool: """Check authentication header.""" if not self.auth_token: return True # No auth configured if not authorization: return False if self.auth_scheme == "bearer": expected = f"Bearer {self.auth_token}" return authorization == expected return False def _register_routes(self): """Wire up FastAPI routes.""" @self.app.get("/.well-known/agent-card.json") async def agent_card_well_known(): return JSONResponse(self.card.to_dict()) @self.app.get("/agent.json") async def agent_card_fallback(): return JSONResponse(self.card.to_dict()) @self.app.post("/a2a/v1") @self.app.post("/a2a/v1/rpc") async def rpc_endpoint(request: Request): return await self._handle_rpc(request) @self.app.get("/a2a/v1/tasks") @self.app.get("/a2a/v1/tasks/{task_id}") async def rest_get_task(task_id: Optional[str] = None): if task_id: task = self._tasks.get(task_id) if not task: return JSONRPCResponse( id="", error=A2AError.TASK_NOT_FOUND, ).to_dict() return JSONResponse(task.to_dict()) else: return JSONResponse( {"tasks": [t.to_dict() for t in self._tasks.values()]} ) async def _handle_rpc(self, request: Request) -> JSONResponse: """Handle JSON-RPC requests.""" # Auth check auth_header = request.headers.get("authorization") if not self._verify_auth(auth_header): return JSONResponse( status_code=401, content={"error": "Unauthorized"}, ) # Parse JSON-RPC try: body = await request.json() except json.JSONDecodeError: return JSONResponse( JSONRPCResponse( id="", error=A2AError.PARSE ).to_dict(), status_code=400, ) method = body.get("method", "") request_id = body.get("id", str(uuid.uuid4())) params = body.get("params", {}) # Audit self._audit_log.append({ "timestamp": time.time(), "method": method, "request_id": request_id, "source": request.client.host if request.client else "unknown", }) try: result = await self._dispatch_rpc(method, params, request_id) return JSONResponse( JSONRPCResponse(id=request_id, result=result).to_dict() ) except ValueError as e: return JSONResponse( JSONRPCResponse( id=request_id, error=JSONRPCError(-32602, str(e)), ).to_dict(), status_code=400, ) except Exception as e: logger.exception(f"Error handling {method}: {e}") return JSONResponse( JSONRPCResponse( id=request_id, error=JSONRPCError(-32603, str(e)), ).to_dict(), status_code=500, ) async def _dispatch_rpc( self, method: str, params: dict, request_id: str ) -> Any: """Route JSON-RPC method to handler.""" if method == "SendMessage": return await self._rpc_send_message(params) elif method == "GetTask": return await self._rpc_get_task(params) elif method == "ListTasks": return await self._rpc_list_tasks(params) elif method == "CancelTask": return await self._rpc_cancel_task(params) elif method == "GetAgentCard": return self.card.to_dict() else: raise ValueError(f"Unknown method: {method}") async def _rpc_send_message(self, params: dict) -> dict: """Handle SendMessage — create a task and route to handler.""" msg_data = params.get("message", {}) message = Message.from_dict(msg_data) # Determine target skill from metadata target_skill = message.metadata.get("targetSkill", "") # Create task task = Task( context_id=message.context_id, status=TaskStatus(state=TaskState.SUBMITTED), history=[message], metadata={"targetSkill": target_skill} if target_skill else {}, ) # Store immediately self._tasks[task.id] = task # Dispatch to handler handler = self._handlers.get(target_skill) or self._default_handler if handler is None: task.status = TaskStatus( state=TaskState.FAILED, message=Message( role=Role.AGENT, parts=[TextPart(text="No handler available for this task")], ), ) return {"task": task.to_dict()} try: # Mark as working task.status = TaskStatus(state=TaskState.WORKING) self._tasks[task.id] = task # Execute handler result_task = await handler(task, self.card) # Store result self._tasks[result_task.id] = result_task return {"task": result_task.to_dict()} except Exception as e: task.status = TaskStatus( state=TaskState.FAILED, message=Message( role=Role.AGENT, parts=[TextPart(text=f"Handler error: {str(e)}")], ), ) self._tasks[task.id] = task return {"task": task.to_dict()} async def _rpc_get_task(self, params: dict) -> dict: """Handle GetTask.""" task_id = params.get("id", "") task = self._tasks.get(task_id) if not task: raise ValueError(f"Task not found: {task_id}") return task.to_dict() async def _rpc_list_tasks(self, params: dict) -> dict: """Handle ListTasks with cursor-based pagination.""" page_size = params.get("pageSize", 20) page_token = params.get("pageToken", "") tasks = sorted( self._tasks.values(), key=lambda t: t.status.timestamp, reverse=True, ) # Simple cursor: find index by token start_idx = 0 if page_token: for i, t in enumerate(tasks): if t.id == page_token: start_idx = i + 1 break page = tasks[start_idx : start_idx + page_size] next_token = "" if start_idx + page_size < len(tasks): next_token = tasks[start_idx + page_size - 1].id return { "tasks": [t.to_dict() for t in page], "nextPageToken": next_token, } async def _rpc_cancel_task(self, params: dict) -> dict: """Handle CancelTask.""" task_id = params.get("id", "") task = self._tasks.get(task_id) if not task: raise ValueError(f"Task not found: {task_id}") if task.status.state.terminal: raise ValueError( f"Task {task_id} is already terminal " f"({task.status.state.value})" ) task.status = TaskStatus(state=TaskState.CANCELED) self._tasks[task_id] = task return task.to_dict() def get_audit_log(self) -> list[dict]: """Return audit log of all received requests.""" return list(self._audit_log) async def start( self, host: str = "0.0.0.0", port: int = 8080, ): """Start the A2A server with uvicorn.""" logger.info( f"Starting A2A server for {self.card.name} on " f"{host}:{port}" ) logger.info( f"Agent Card at " f"http://{host}:{port}/.well-known/agent-card.json" ) config = uvicorn.Config( self.app, host=host, port=port, log_level="info", ) server = uvicorn.Server(config) await server.serve() # --- Default Handler Factory --- async def echo_handler(task: Task, card: AgentCard) -> Task: """ Simple echo handler for testing. Returns the user's message as an artifact. """ if task.history: last_msg = task.history[-1] text_parts = [p for p in last_msg.parts if isinstance(p, TextPart)] if text_parts: response_text = f"[{card.name}] Echo: {text_parts[0].text}" task.artifacts.append( Artifact( parts=[TextPart(text=response_text)], name="echo_response", ) ) task.status = TaskStatus(state=TaskState.COMPLETED) return task