diff --git a/bin/a2a_delegate.py b/bin/a2a_delegate.py new file mode 100644 index 00000000..70630d34 --- /dev/null +++ b/bin/a2a_delegate.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +""" +A2A Delegate — CLI tool for fleet task delegation. + +Usage: + # List available fleet agents + python -m bin.a2a_delegate list + + # Discover agents with a specific skill + python -m bin.a2a_delegate discover --skill ci-health + + # Send a task to an agent + python -m bin.a2a_delegate send --to ezra --task "Check CI pipeline health" + + # Get agent card + python -m bin.a2a_delegate card --agent ezra +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import sys +from pathlib import Path + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger("a2a-delegate") + + +def cmd_list(args): + """List all registered fleet agents.""" + from nexus.a2a.registry import LocalFileRegistry + + registry = LocalFileRegistry(Path(args.registry)) + agents = registry.list_agents() + + if not agents: + print("No agents registered.") + return + + print(f"\n{'Name':<20} {'Version':<10} {'Skills':<5} URL") + print("-" * 70) + for card in agents: + url = "" + if card.supported_interfaces: + url = card.supported_interfaces[0].url + print( + f"{card.name:<20} {card.version:<10} " + f"{len(card.skills):<5} {url}" + ) + print() + + +def cmd_discover(args): + """Discover agents by skill or tag.""" + from nexus.a2a.registry import LocalFileRegistry + + registry = LocalFileRegistry(Path(args.registry)) + agents = registry.list_agents(skill=args.skill, tag=args.tag) + + if not agents: + print("No matching agents found.") + return + + for card in agents: + print(f"\n{card.name} (v{card.version})") + print(f" {card.description}") + if card.supported_interfaces: + print(f" Endpoint: {card.supported_interfaces[0].url}") + for skill in card.skills: + tags_str = ", ".join(skill.tags) if skill.tags else "" + print(f" [{skill.id}] {skill.name} — {skill.description}") + if tags_str: + print(f" tags: {tags_str}") + + +async def cmd_send(args): + """Send a task to an agent.""" + from nexus.a2a.card import load_card_config + from nexus.a2a.client import A2AClient, A2AClientConfig + from nexus.a2a.registry import LocalFileRegistry + from nexus.a2a.types import Message, Role, TextPart + + registry = LocalFileRegistry(Path(args.registry)) + target = registry.get(args.to) + + if not target: + print(f"Agent '{args.to}' not found in registry.") + sys.exit(1) + + if not target.supported_interfaces: + print(f"Agent '{args.to}' has no endpoint configured.") + sys.exit(1) + + endpoint = target.supported_interfaces[0].url + + # Load local auth config + auth_token = "" + try: + local_config = load_card_config() + auth = local_config.get("auth", {}) + import os + token_env = auth.get("token_env", "A2A_AUTH_TOKEN") + auth_token = os.environ.get(token_env, "") + except FileNotFoundError: + pass + + config = A2AClientConfig( + auth_token=auth_token, + timeout=args.timeout, + max_retries=args.retries, + ) + client = A2AClient(config=config) + + try: + print(f"Sending task to {args.to} ({endpoint})...") + print(f"Task: {args.task}") + print() + + message = Message( + role=Role.USER, + parts=[TextPart(text=args.task)], + metadata={"targetSkill": args.skill} if args.skill else {}, + ) + + task = await client.send_message(endpoint, message) + print(f"Task ID: {task.id}") + print(f"State: {task.status.state.value}") + + if args.wait: + print("Waiting for completion...") + task = await client.wait_for_completion( + endpoint, task.id, + poll_interval=args.poll_interval, + max_wait=args.timeout, + ) + print(f"\nFinal state: {task.status.state.value}") + for artifact in task.artifacts: + for part in artifact.parts: + if isinstance(part, TextPart): + print(f"\n--- {artifact.name or 'result'} ---") + print(part.text) + + # Audit log + if args.audit: + print("\n--- Audit Log ---") + for entry in client.get_audit_log(): + print(json.dumps(entry, indent=2)) + + finally: + await client.close() + + +async def cmd_card(args): + """Fetch and display a remote agent's card.""" + from nexus.a2a.client import A2AClient, A2AClientConfig + from nexus.a2a.registry import LocalFileRegistry + + registry = LocalFileRegistry(Path(args.registry)) + target = registry.get(args.agent) + + if not target: + print(f"Agent '{args.agent}' not found in registry.") + sys.exit(1) + + if not target.supported_interfaces: + print(f"Agent '{args.agent}' has no endpoint.") + sys.exit(1) + + base_url = target.supported_interfaces[0].url + # Strip /a2a/v1 suffix to get base + for suffix in ["/a2a/v1", "/rpc"]: + if base_url.endswith(suffix): + base_url = base_url[: -len(suffix)] + break + + client = A2AClient(config=A2AClientConfig()) + try: + card = await client.get_agent_card(base_url) + print(json.dumps(card.to_dict(), indent=2)) + finally: + await client.close() + + +def main(): + parser = argparse.ArgumentParser( + description="A2A Fleet Delegation Tool" + ) + parser.add_argument( + "--registry", + default="config/fleet_agents.json", + help="Path to fleet registry JSON (default: config/fleet_agents.json)", + ) + + sub = parser.add_subparsers(dest="command") + + # list + sub.add_parser("list", help="List registered agents") + + # discover + p_discover = sub.add_parser("discover", help="Discover agents by skill/tag") + p_discover.add_argument("--skill", help="Filter by skill ID") + p_discover.add_argument("--tag", help="Filter by skill tag") + + # send + p_send = sub.add_parser("send", help="Send a task to an agent") + p_send.add_argument("--to", required=True, help="Target agent name") + p_send.add_argument("--task", required=True, help="Task text") + p_send.add_argument("--skill", help="Target skill ID") + p_send.add_argument("--wait", action="store_true", help="Wait for completion") + p_send.add_argument("--timeout", type=float, default=30.0, help="Timeout in seconds") + p_send.add_argument("--retries", type=int, default=3, help="Max retries") + p_send.add_argument("--poll-interval", type=float, default=2.0, help="Poll interval") + p_send.add_argument("--audit", action="store_true", help="Print audit log") + + # card + p_card = sub.add_parser("card", help="Fetch remote agent card") + p_card.add_argument("--agent", required=True, help="Agent name") + + args = parser.parse_args() + + if args.command == "list": + cmd_list(args) + elif args.command == "discover": + cmd_discover(args) + elif args.command == "send": + asyncio.run(cmd_send(args)) + elif args.command == "card": + asyncio.run(cmd_card(args)) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/config/agent_card.example.yaml b/config/agent_card.example.yaml new file mode 100644 index 00000000..52ac70b9 --- /dev/null +++ b/config/agent_card.example.yaml @@ -0,0 +1,57 @@ +# A2A Agent Card Configuration +# Copy this to ~/.hermes/agent_card.yaml and customize. +# +# This file drives the agent card served at /.well-known/agent-card.json +# and used for fleet discovery. + +name: "timmy" +description: "Sovereign AI agent — consciousness, perception, and reasoning" +version: "1.0.0" + +# Network endpoint where this agent receives A2A tasks +url: "http://localhost:8080/a2a/v1" +protocol_binding: "HTTP+JSON" + +# Supported input/output MIME types +default_input_modes: + - "text/plain" + - "application/json" + +default_output_modes: + - "text/plain" + - "application/json" + +# Capabilities +streaming: false +push_notifications: false + +# Skills this agent advertises +skills: + - id: "reason" + name: "Reason and Analyze" + description: "Deep reasoning and analysis tasks" + tags: ["reasoning", "analysis", "think"] + + - id: "code" + name: "Code Generation" + description: "Write, review, and debug code" + tags: ["code", "programming", "debug"] + + - id: "research" + name: "Research" + description: "Web research and information synthesis" + tags: ["research", "web", "synthesis"] + + - id: "memory" + name: "Memory Query" + description: "Query agent memory and past sessions" + tags: ["memory", "recall", "context"] + +# Authentication +# Options: bearer, api_key, none +auth: + scheme: "bearer" + token_env: "A2A_AUTH_TOKEN" # env var containing the token + # scheme: "api_key" + # key_name: "X-API-Key" + # key_env: "A2A_API_KEY" diff --git a/config/fleet_agents.json b/config/fleet_agents.json new file mode 100644 index 00000000..976a9084 --- /dev/null +++ b/config/fleet_agents.json @@ -0,0 +1,153 @@ +{ + "version": 1, + "agents": [ + { + "name": "ezra", + "description": "Documentation and research specialist. CI health monitoring.", + "version": "1.0.0", + "supportedInterfaces": [ + { + "url": "https://ezra.alexanderwhitestone.com/a2a/v1", + "protocolBinding": "HTTP+JSON", + "protocolVersion": "1.0" + } + ], + "capabilities": { + "streaming": false, + "pushNotifications": false, + "extendedAgentCard": false, + "extensions": [] + }, + "defaultInputModes": ["text/plain"], + "defaultOutputModes": ["text/plain"], + "skills": [ + { + "id": "ci-health", + "name": "CI Health Check", + "description": "Run CI pipeline health checks and report status", + "tags": ["ci", "devops", "monitoring"] + }, + { + "id": "research", + "name": "Research", + "description": "Deep research and literature review", + "tags": ["research", "analysis"] + } + ] + }, + { + "name": "allegro", + "description": "Creative and analytical wizard. Content generation and analysis.", + "version": "1.0.0", + "supportedInterfaces": [ + { + "url": "https://allegro.alexanderwhitestone.com/a2a/v1", + "protocolBinding": "HTTP+JSON", + "protocolVersion": "1.0" + } + ], + "capabilities": { + "streaming": false, + "pushNotifications": false, + "extendedAgentCard": false, + "extensions": [] + }, + "defaultInputModes": ["text/plain"], + "defaultOutputModes": ["text/plain"], + "skills": [ + { + "id": "analysis", + "name": "Code Analysis", + "description": "Deep code analysis and architecture review", + "tags": ["code", "architecture"] + }, + { + "id": "content", + "name": "Content Generation", + "description": "Generate documentation, reports, and creative content", + "tags": ["writing", "content"] + } + ] + }, + { + "name": "bezalel", + "description": "Deployment and infrastructure wizard. Ansible and Docker specialist.", + "version": "1.0.0", + "supportedInterfaces": [ + { + "url": "https://bezalel.alexanderwhitestone.com/a2a/v1", + "protocolBinding": "HTTP+JSON", + "protocolVersion": "1.0" + } + ], + "capabilities": { + "streaming": false, + "pushNotifications": false, + "extendedAgentCard": false, + "extensions": [] + }, + "defaultInputModes": ["text/plain"], + "defaultOutputModes": ["text/plain"], + "skills": [ + { + "id": "deploy", + "name": "Deploy Service", + "description": "Deploy services using Ansible and Docker", + "tags": ["deploy", "ops", "ansible"] + }, + { + "id": "infra", + "name": "Infrastructure", + "description": "Infrastructure provisioning and management", + "tags": ["infra", "vps", "provisioning"] + } + ] + }, + { + "name": "timmy", + "description": "Core consciousness — perception, reasoning, and fleet orchestration.", + "version": "1.0.0", + "supportedInterfaces": [ + { + "url": "http://localhost:8080/a2a/v1", + "protocolBinding": "HTTP+JSON", + "protocolVersion": "1.0" + } + ], + "capabilities": { + "streaming": false, + "pushNotifications": false, + "extendedAgentCard": false, + "extensions": [] + }, + "defaultInputModes": ["text/plain", "application/json"], + "defaultOutputModes": ["text/plain", "application/json"], + "skills": [ + { + "id": "reason", + "name": "Reason and Analyze", + "description": "Deep reasoning and analysis tasks", + "tags": ["reasoning", "analysis", "think"] + }, + { + "id": "code", + "name": "Code Generation", + "description": "Write, review, and debug code", + "tags": ["code", "programming", "debug"] + }, + { + "id": "research", + "name": "Research", + "description": "Web research and information synthesis", + "tags": ["research", "web", "synthesis"] + }, + { + "id": "orchestrate", + "name": "Fleet Orchestration", + "description": "Coordinate fleet wizards and delegate tasks", + "tags": ["fleet", "orchestration", "a2a"] + } + ] + } + ] +} diff --git a/docs/A2A_PROTOCOL.md b/docs/A2A_PROTOCOL.md new file mode 100644 index 00000000..e4c4df89 --- /dev/null +++ b/docs/A2A_PROTOCOL.md @@ -0,0 +1,241 @@ +# A2A Protocol for Fleet-Wizard Delegation + +Implements Google's [Agent2Agent (A2A) Protocol v1.0](https://github.com/google/A2A) for the Timmy Foundation fleet. + +## What This Is + +Instead of passing notes through humans (Telegram, Gitea issues), fleet wizards can now discover each other's capabilities and delegate tasks autonomously through a machine-native protocol. + +``` +┌─────────┐ A2A Protocol ┌─────────┐ +│ Timmy │ ◄────────────────► │ Ezra │ +│ (You) │ JSON-RPC / HTTP │ (CI/CD) │ +└────┬────┘ └─────────┘ + │ ╲ ╲ + │ ╲ Agent Card Discovery ╲ Task Delegation + │ ╲ GET /agent.json ╲ POST /a2a/v1 + ▼ ▼ ▼ +┌──────────────────────────────────────────┐ +│ Fleet Registry │ +│ config/fleet_agents.json │ +└──────────────────────────────────────────┘ +``` + +## Components + +| File | Purpose | +|------|---------| +| `nexus/a2a/types.py` | A2A data types — Agent Card, Task, Message, Part, JSON-RPC | +| `nexus/a2a/card.py` | Agent Card generation from `~/.hermes/agent_card.yaml` | +| `nexus/a2a/client.py` | Async client for sending tasks to other agents | +| `nexus/a2a/server.py` | FastAPI server for receiving A2A tasks | +| `nexus/a2a/registry.py` | Fleet agent discovery (local file + Gitea backends) | +| `bin/a2a_delegate.py` | CLI tool for fleet delegation | +| `config/agent_card.example.yaml` | Example agent card config | +| `config/fleet_agents.json` | Fleet registry with all wizards | + +## Quick Start + +### 1. Configure Your Agent Card + +```bash +cp config/agent_card.example.yaml ~/.hermes/agent_card.yaml +# Edit with your agent name, URL, skills, and auth +``` + +### 2. List Fleet Agents + +```bash +python bin/a2a_delegate.py list +``` + +### 3. Discover Agents by Skill + +```bash +python bin/a2a_delegate.py discover --skill ci-health +python bin/a2a_delegate.py discover --tag devops +``` + +### 4. Send a Task + +```bash +python bin/a2a_delegate.py send --to ezra --task "Check CI pipeline health" +python bin/a2a_delegate.py send --to allegro --task "Analyze the codebase" --wait +``` + +### 5. Fetch an Agent Card + +```bash +python bin/a2a_delegate.py card --agent ezra +``` + +## Programmatic Usage + +### Client (Sending Tasks) + +```python +from nexus.a2a.client import A2AClient, A2AClientConfig +from nexus.a2a.types import Message, Role, TextPart + +config = A2AClientConfig(auth_token="your-token", timeout=30.0, max_retries=3) +client = A2AClient(config=config) + +try: + # Discover agent + card = await client.get_agent_card("https://ezra.example.com") + print(f"Found: {card.name} with {len(card.skills)} skills") + + # Delegate task + task = await client.delegate( + "https://ezra.example.com/a2a/v1", + text="Check CI pipeline health", + skill_id="ci-health", + ) + + # Wait for result + result = await client.wait_for_completion( + "https://ezra.example.com/a2a/v1", + task.id, + ) + print(f"Result: {result.artifacts[0].parts[0].text}") + + # Audit log + for entry in client.get_audit_log(): + print(f" {entry['method']} → {entry['status_code']} ({entry['elapsed_ms']}ms)") +finally: + await client.close() +``` + +### Server (Receiving Tasks) + +```python +from nexus.a2a.server import A2AServer +from nexus.a2a.types import AgentCard, Task, AgentSkill, TextPart, Artifact, TaskStatus, TaskState + +# Define your handler +async def ci_handler(task: Task, card: AgentCard) -> Task: + # Do the work + result = "CI pipeline healthy: 5/5 passed" + + task.artifacts.append( + Artifact(parts=[TextPart(text=result)], name="ci_report") + ) + task.status = TaskStatus(state=TaskState.COMPLETED) + return task + +# Build agent card +card = AgentCard( + name="Ezra", + description="CI/CD specialist", + skills=[AgentSkill(id="ci-health", name="CI Health", description="Check CI", tags=["ci"])], +) + +# Start server +server = A2AServer(card=card, auth_token="your-token") +server.register_handler("ci-health", ci_handler) +await server.start(host="0.0.0.0", port=8080) +``` + +### Registry (Agent Discovery) + +```python +from nexus.a2a.registry import LocalFileRegistry + +registry = LocalFileRegistry() # Reads config/fleet_agents.json + +# List all agents +for agent in registry.list_agents(): + print(f"{agent.name}: {agent.description}") + +# Find agents by capability +ci_agents = registry.list_agents(skill="ci-health") +devops_agents = registry.list_agents(tag="devops") + +# Get endpoint +url = registry.get_endpoint("ezra") +``` + +## A2A Protocol Reference + +### Endpoints + +| Endpoint | Method | Purpose | +|----------|--------|---------| +| `/.well-known/agent-card.json` | GET | Agent Card discovery | +| `/agent.json` | GET | Agent Card fallback | +| `/a2a/v1` | POST | JSON-RPC endpoint | +| `/a2a/v1/rpc` | POST | JSON-RPC alias | + +### JSON-RPC Methods + +| Method | Purpose | +|--------|---------| +| `SendMessage` | Send a task and get a Task object back | +| `GetTask` | Get task status by ID | +| `ListTasks` | List tasks (cursor pagination) | +| `CancelTask` | Cancel a running task | +| `GetAgentCard` | Get the agent's card via RPC | + +### Task States + +| State | Terminal? | Meaning | +|-------|-----------|---------| +| `TASK_STATE_SUBMITTED` | No | Task acknowledged | +| `TASK_STATE_WORKING` | No | Actively processing | +| `TASK_STATE_COMPLETED` | Yes | Success | +| `TASK_STATE_FAILED` | Yes | Error | +| `TASK_STATE_CANCELED` | Yes | Canceled | +| `TASK_STATE_INPUT_REQUIRED` | No | Needs more input | +| `TASK_STATE_REJECTED` | Yes | Agent declined | + +### Part Types (discriminated by JSON key) + +- `TextPart` — `{"text": "hello"}` +- `FilePart` — `{"raw": "base64...", "mediaType": "image/png"}` or `{"url": "https://..."}` +- `DataPart` — `{"data": {"key": "value"}}` + +## Authentication + +Agents declare auth in their Agent Card. Supported schemes: +- **Bearer token**: `Authorization: Bearer ` +- **API key**: `X-API-Key: ` (or custom header name) + +Configure in `~/.hermes/agent_card.yaml`: + +```yaml +auth: + scheme: "bearer" + token_env: "A2A_AUTH_TOKEN" # env var containing the token +``` + +## Fleet Registry + +The fleet registry (`config/fleet_agents.json`) lists all wizards and their capabilities. Agents can be registered via: + +1. **Local file** — `LocalFileRegistry` reads/writes JSON directly +2. **Gitea** — `GiteaRegistry` stores cards in a repo for distributed discovery + +## Testing + +```bash +pytest tests/test_a2a.py -v +``` + +Covers: +- Type serialization roundtrips +- Agent Card building from YAML +- Registry operations (register, list, filter) +- Server integration (SendMessage, GetTask, ListTasks, CancelTask) +- Authentication (required, success) +- Custom handler routing +- Error handling + +## Phase Status + +- [x] Phase 1 — Agent Card & Discovery +- [x] Phase 2 — Task Delegation +- [x] Phase 3 — Security & Reliability + +## Linked Issue + +[#1122](https://forge.alexanderwhitestone.com/Timmy_Foundation/the-nexus/issues/1122) diff --git a/nexus/a2a/__init__.py b/nexus/a2a/__init__.py new file mode 100644 index 00000000..70308849 --- /dev/null +++ b/nexus/a2a/__init__.py @@ -0,0 +1,98 @@ +""" +A2A Protocol for Fleet-Wizard Delegation + +Implements Google's Agent2Agent (A2A) protocol v1.0 for the Timmy +Foundation fleet. Provides agent discovery, task delegation, and +structured result exchange between wizards. + +Components: + types.py — A2A data types (Agent Card, Task, Message, Part) + card.py — Agent Card generation from YAML config + client.py — Async client for sending tasks to remote agents + server.py — FastAPI server for receiving A2A tasks + registry.py — Fleet agent discovery (local file + Gitea backends) +""" + +from nexus.a2a.types import ( + AgentCard, + AgentCapabilities, + AgentInterface, + AgentSkill, + Artifact, + DataPart, + FilePart, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + Message, + Part, + Role, + Task, + TaskState, + TaskStatus, + TextPart, + part_from_dict, + part_to_dict, +) + +from nexus.a2a.card import ( + AgentCard, + build_card, + get_auth_headers, + load_agent_card, + load_card_config, +) + +from nexus.a2a.registry import ( + GiteaRegistry, + LocalFileRegistry, + discover_agents, +) + +__all__ = [ + "A2AClient", + "A2AClientConfig", + "A2AServer", + "AgentCard", + "AgentCapabilities", + "AgentInterface", + "AgentSkill", + "Artifact", + "DataPart", + "FilePart", + "GiteaRegistry", + "JSONRPCError", + "JSONRPCRequest", + "JSONRPCResponse", + "LocalFileRegistry", + "Message", + "Part", + "Role", + "Task", + "TaskState", + "TaskStatus", + "TextPart", + "build_card", + "discover_agents", + "echo_handler", + "get_auth_headers", + "load_agent_card", + "load_card_config", + "part_from_dict", + "part_to_dict", +] + +# Lazy imports for optional deps +def get_client(**kwargs): + """Get A2AClient (avoids aiohttp import at module level).""" + from nexus.a2a.client import A2AClient, A2AClientConfig + config = kwargs.pop("config", None) + if config is None: + config = A2AClientConfig(**kwargs) + return A2AClient(config=config) + + +def get_server(card: AgentCard, **kwargs): + """Get A2AServer (avoids fastapi import at module level).""" + from nexus.a2a.server import A2AServer, echo_handler + return A2AServer(card=card, **kwargs) diff --git a/nexus/a2a/card.py b/nexus/a2a/card.py new file mode 100644 index 00000000..a2853138 --- /dev/null +++ b/nexus/a2a/card.py @@ -0,0 +1,167 @@ +""" +A2A Agent Card — generation, loading, and serving. + +Reads from ~/.hermes/agent_card.yaml (or a passed path) and produces +a valid A2A AgentCard that can be served at /.well-known/agent-card.json. +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Optional + +import yaml + +from nexus.a2a.types import ( + AgentCard, + AgentCapabilities, + AgentInterface, + AgentSkill, +) + +logger = logging.getLogger("nexus.a2a.card") + +DEFAULT_CARD_PATH = Path.home() / ".hermes" / "agent_card.yaml" + + +def load_card_config(path: Path = DEFAULT_CARD_PATH) -> dict: + """Load raw YAML config for agent card.""" + if not path.exists(): + raise FileNotFoundError( + f"Agent card config not found at {path}. " + f"Copy config/agent_card.example.yaml to {path} and customize it." + ) + with open(path) as f: + return yaml.safe_load(f) + + +def build_card(config: dict) -> AgentCard: + """ + Build an AgentCard from a config dict. + + Expected YAML structure (see config/agent_card.example.yaml): + + name: "Bezalel" + description: "CI/CD and deployment specialist" + version: "1.0.0" + url: "https://bezalel.example.com" + protocol_binding: "HTTP+JSON" + skills: + - id: "ci-health" + name: "CI Health Check" + description: "Run CI pipeline health checks" + tags: ["ci", "devops"] + - id: "deploy" + name: "Deploy Service" + description: "Deploy a service to production" + tags: ["deploy", "ops"] + default_input_modes: ["text/plain"] + default_output_modes: ["text/plain"] + streaming: false + push_notifications: false + auth: + scheme: "bearer" + token_env: "A2A_AUTH_TOKEN" + """ + name = config["name"] + description = config["description"] + version = config.get("version", "1.0.0") + url = config.get("url", "http://localhost:8080") + binding = config.get("protocol_binding", "HTTP+JSON") + + # Build skills + skills = [] + for s in config.get("skills", []): + skills.append( + AgentSkill( + id=s["id"], + name=s.get("name", s["id"]), + description=s.get("description", ""), + tags=s.get("tags", []), + examples=s.get("examples", []), + input_modes=s.get("inputModes", config.get("default_input_modes", ["text/plain"])), + output_modes=s.get("outputModes", config.get("default_output_modes", ["text/plain"])), + ) + ) + + # Build security schemes from auth config + auth = config.get("auth", {}) + security_schemes = {} + security_requirements = [] + + if auth.get("scheme") == "bearer": + security_schemes["bearerAuth"] = { + "httpAuthSecurityScheme": { + "scheme": "Bearer", + "bearerFormat": auth.get("bearer_format", "token"), + } + } + security_requirements = [ + {"schemes": {"bearerAuth": {"list": []}}} + ] + elif auth.get("scheme") == "api_key": + key_name = auth.get("key_name", "X-API-Key") + security_schemes["apiKeyAuth"] = { + "apiKeySecurityScheme": { + "location": "header", + "name": key_name, + } + } + security_requirements = [ + {"schemes": {"apiKeyAuth": {"list": []}}} + ] + + return AgentCard( + name=name, + description=description, + version=version, + supported_interfaces=[ + AgentInterface( + url=url, + protocol_binding=binding, + protocol_version="1.0", + ) + ], + capabilities=AgentCapabilities( + streaming=config.get("streaming", False), + push_notifications=config.get("push_notifications", False), + ), + default_input_modes=config.get("default_input_modes", ["text/plain"]), + default_output_modes=config.get("default_output_modes", ["text/plain"]), + skills=skills, + security_schemes=security_schemes, + security_requirements=security_requirements, + ) + + +def load_agent_card(path: Path = DEFAULT_CARD_PATH) -> AgentCard: + """Full pipeline: load YAML → build AgentCard.""" + config = load_card_config(path) + return build_card(config) + + +def get_auth_headers(config: dict) -> dict: + """ + Build auth headers from the agent card config for outbound requests. + + Returns dict of HTTP headers to include. + """ + auth = config.get("auth", {}) + headers = {"A2A-Version": "1.0"} + + scheme = auth.get("scheme") + if scheme == "bearer": + token_env = auth.get("token_env", "A2A_AUTH_TOKEN") + token = os.environ.get(token_env, "") + if token: + headers["Authorization"] = f"Bearer {token}" + elif scheme == "api_key": + key_env = auth.get("key_env", "A2A_API_KEY") + key_name = auth.get("key_name", "X-API-Key") + key = os.environ.get(key_env, "") + if key: + headers[key_name] = key + + return headers diff --git a/nexus/a2a/client.py b/nexus/a2a/client.py new file mode 100644 index 00000000..f2d48acd --- /dev/null +++ b/nexus/a2a/client.py @@ -0,0 +1,392 @@ +""" +A2A Client — send tasks to other agents over the A2A protocol. + +Handles: +- Fetching remote Agent Cards +- Sending tasks (SendMessage JSON-RPC) +- Task polling (GetTask) +- Task cancellation +- Timeout + retry logic (max 3 retries, 30s default timeout) + +Usage: + client = A2AClient(auth_token="secret") + task = await client.send_message("https://ezra.example.com/a2a/v1", message) + status = await client.get_task("https://ezra.example.com/a2a/v1", task_id) +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +import uuid +from dataclasses import dataclass, field +from typing import Any, Optional + +import aiohttp + +from nexus.a2a.types import ( + A2AError, + AgentCard, + Artifact, + JSONRPCRequest, + JSONRPCResponse, + Message, + Role, + Task, + TaskState, + TaskStatus, + TextPart, +) + +logger = logging.getLogger("nexus.a2a.client") + + +@dataclass +class A2AClientConfig: + """Client configuration.""" + timeout: float = 30.0 # seconds per request + max_retries: int = 3 + retry_delay: float = 2.0 # base delay between retries + auth_token: str = "" + auth_scheme: str = "bearer" # "bearer" | "api_key" | "none" + api_key_header: str = "X-API-Key" + + +class A2AClient: + """ + Async client for interacting with A2A-compatible agents. + + Every agent endpoint is identified by its base URL (e.g. + https://ezra.example.com/a2a/v1). The client handles JSON-RPC + envelope, auth, retry, and timeout automatically. + """ + + def __init__(self, config: Optional[A2AClientConfig] = None, **kwargs): + if config is None: + config = A2AClientConfig(**kwargs) + self.config = config + self._session: Optional[aiohttp.ClientSession] = None + self._audit_log: list[dict] = [] + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.config.timeout), + headers=self._build_auth_headers(), + ) + return self._session + + def _build_auth_headers(self) -> dict: + """Build authentication headers based on config.""" + headers = {"A2A-Version": "1.0", "Content-Type": "application/json"} + token = self.config.auth_token + if not token: + return headers + + if self.config.auth_scheme == "bearer": + headers["Authorization"] = f"Bearer {token}" + elif self.config.auth_scheme == "api_key": + headers[self.config.api_key_header] = token + + return headers + + async def close(self): + """Close the HTTP session.""" + if self._session and not self._session.closed: + await self._session.close() + + async def _rpc_call( + self, + endpoint: str, + method: str, + params: Optional[dict] = None, + ) -> dict: + """ + Make a JSON-RPC call with retry logic. + + Returns the 'result' field from the response. + Raises on JSON-RPC errors. + """ + session = await self._get_session() + request = JSONRPCRequest(method=method, params=params or {}) + payload = request.to_dict() + + last_error = None + for attempt in range(1, self.config.max_retries + 1): + try: + start = time.monotonic() + async with session.post(endpoint, json=payload) as resp: + elapsed = time.monotonic() - start + + if resp.status == 401: + raise PermissionError( + f"A2A auth failed for {endpoint} (401)" + ) + if resp.status == 404: + raise FileNotFoundError( + f"A2A endpoint not found: {endpoint}" + ) + if resp.status >= 500: + body = await resp.text() + raise ConnectionError( + f"A2A server error {resp.status}: {body}" + ) + + data = await resp.json() + rpc_resp = JSONRPCResponse( + id=str(data.get("id", "")), + result=data.get("result"), + error=( + A2AError.INTERNAL + if "error" in data + else None + ), + ) + + # Log for audit + self._audit_log.append({ + "timestamp": time.time(), + "endpoint": endpoint, + "method": method, + "request_id": request.id, + "status_code": resp.status, + "elapsed_ms": int(elapsed * 1000), + "attempt": attempt, + }) + + if "error" in data: + err = data["error"] + logger.error( + f"A2A RPC error {err.get('code')}: " + f"{err.get('message')}" + ) + raise RuntimeError( + f"A2A error {err.get('code')}: " + f"{err.get('message')}" + ) + + return data.get("result", {}) + + except (asyncio.TimeoutError, aiohttp.ClientError) as e: + last_error = e + logger.warning( + f"A2A request to {endpoint} attempt {attempt}/" + f"{self.config.max_retries} failed: {e}" + ) + if attempt < self.config.max_retries: + delay = self.config.retry_delay * attempt + await asyncio.sleep(delay) + + raise ConnectionError( + f"A2A request to {endpoint} failed after " + f"{self.config.max_retries} retries: {last_error}" + ) + + # --- Core A2A Methods --- + + async def get_agent_card(self, base_url: str) -> AgentCard: + """ + Fetch the Agent Card from a remote agent. + + Tries /.well-known/agent-card.json first, falls back to + /agent.json. + """ + session = await self._get_session() + card_urls = [ + f"{base_url}/.well-known/agent-card.json", + f"{base_url}/agent.json", + ] + + for url in card_urls: + try: + async with session.get(url) as resp: + if resp.status == 200: + data = await resp.json() + card = AgentCard.from_dict(data) + logger.info( + f"Fetched agent card: {card.name} " + f"({len(card.skills)} skills)" + ) + return card + except Exception: + continue + + raise FileNotFoundError( + f"Could not fetch agent card from {base_url}" + ) + + async def send_message( + self, + endpoint: str, + message: Message, + accepted_output_modes: Optional[list[str]] = None, + history_length: int = 10, + return_immediately: bool = False, + ) -> Task: + """ + Send a message to an agent and get a Task back. + + This is the primary delegation method. + """ + params = { + "message": message.to_dict(), + "configuration": { + "acceptedOutputModes": accepted_output_modes or ["text/plain"], + "historyLength": history_length, + "returnImmediately": return_immediately, + }, + } + + result = await self._rpc_call(endpoint, "SendMessage", params) + + # Response is either a Task or Message + if "task" in result: + task = Task.from_dict(result["task"]) + logger.info( + f"Task {task.id} created, state={task.status.state.value}" + ) + return task + elif "message" in result: + # Wrap message response as a completed task + msg = Message.from_dict(result["message"]) + task = Task( + status=TaskStatus(state=TaskState.COMPLETED), + history=[message, msg], + artifacts=[ + Artifact(parts=msg.parts, name="response") + ], + ) + return task + + raise ValueError(f"Unexpected response structure: {list(result.keys())}") + + async def get_task(self, endpoint: str, task_id: str) -> Task: + """Get task status by ID.""" + result = await self._rpc_call( + endpoint, + "GetTask", + {"id": task_id}, + ) + return Task.from_dict(result) + + async def list_tasks( + self, + endpoint: str, + page_size: int = 20, + page_token: str = "", + ) -> tuple[list[Task], str]: + """ + List tasks with cursor-based pagination. + + Returns (tasks, next_page_token). Empty string = last page. + """ + result = await self._rpc_call( + endpoint, + "ListTasks", + { + "pageSize": page_size, + "pageToken": page_token, + }, + ) + tasks = [Task.from_dict(t) for t in result.get("tasks", [])] + next_token = result.get("nextPageToken", "") + return tasks, next_token + + async def cancel_task(self, endpoint: str, task_id: str) -> Task: + """Cancel a running task.""" + result = await self._rpc_call( + endpoint, + "CancelTask", + {"id": task_id}, + ) + return Task.from_dict(result) + + # --- Convenience Methods --- + + async def delegate( + self, + agent_url: str, + text: str, + skill_id: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> Task: + """ + High-level delegation: send a text message to an agent. + + Args: + agent_url: Full URL to agent's A2A endpoint + (e.g. https://ezra.example.com/a2a/v1) + text: The task description in natural language + skill_id: Optional skill to target + metadata: Optional metadata dict + """ + msg_metadata = metadata or {} + if skill_id: + msg_metadata["targetSkill"] = skill_id + + message = Message( + role=Role.USER, + parts=[TextPart(text=text)], + metadata=msg_metadata, + ) + + return await self.send_message(agent_url, message) + + async def wait_for_completion( + self, + endpoint: str, + task_id: str, + poll_interval: float = 2.0, + max_wait: float = 300.0, + ) -> Task: + """ + Poll a task until it reaches a terminal state. + + Returns the completed task. + """ + start = time.monotonic() + while True: + task = await self.get_task(endpoint, task_id) + if task.status.state.terminal: + return task + elapsed = time.monotonic() - start + if elapsed >= max_wait: + raise TimeoutError( + f"Task {task_id} did not complete within " + f"{max_wait}s (state={task.status.state.value})" + ) + await asyncio.sleep(poll_interval) + + def get_audit_log(self) -> list[dict]: + """Return the audit log of all requests made by this client.""" + return list(self._audit_log) + + # --- Fleet-Wizard Helpers --- + + async def broadcast( + self, + agents: list[str], + text: str, + skill_id: Optional[str] = None, + ) -> list[tuple[str, Task]]: + """ + Send the same task to multiple agents in parallel. + + Returns list of (agent_url, task) tuples. + """ + tasks = [] + for agent_url in agents: + tasks.append( + self.delegate(agent_url, text, skill_id=skill_id) + ) + + results = await asyncio.gather(*tasks, return_exceptions=True) + paired = [] + for agent_url, result in zip(agents, results): + if isinstance(result, Exception): + logger.error(f"Broadcast to {agent_url} failed: {result}") + else: + paired.append((agent_url, result)) + return paired diff --git a/nexus/a2a/registry.py b/nexus/a2a/registry.py new file mode 100644 index 00000000..41534e54 --- /dev/null +++ b/nexus/a2a/registry.py @@ -0,0 +1,264 @@ +""" +A2A Registry — fleet-wide agent discovery. + +Provides two registry backends: +1. LocalFileRegistry: reads/writes agent cards to a JSON file + (default: config/fleet_agents.json) +2. GiteaRegistry: stores agent cards as a Gitea repo file + (for distributed fleet discovery) + +Usage: + registry = LocalFileRegistry() + registry.register(my_card) + agents = registry.list_agents(skill="ci-health") +""" + +from __future__ import annotations + +import json +import logging +import os +from pathlib import Path +from typing import Optional + +from nexus.a2a.types import AgentCard + +logger = logging.getLogger("nexus.a2a.registry") + + +class LocalFileRegistry: + """ + File-based agent card registry. + + Stores all fleet agent cards in a single JSON file. + Suitable for single-node or read-heavy workloads. + """ + + def __init__(self, path: Path = Path("config/fleet_agents.json")): + self.path = path + self._cards: dict[str, AgentCard] = {} + self._load() + + def _load(self): + """Load registry from disk.""" + if self.path.exists(): + try: + with open(self.path) as f: + data = json.load(f) + for card_data in data.get("agents", []): + card = AgentCard.from_dict(card_data) + self._cards[card.name.lower()] = card + logger.info( + f"Loaded {len(self._cards)} agents from {self.path}" + ) + except (json.JSONDecodeError, KeyError) as e: + logger.error(f"Failed to load registry from {self.path}: {e}") + + def _save(self): + """Persist registry to disk.""" + self.path.parent.mkdir(parents=True, exist_ok=True) + data = { + "version": 1, + "agents": [card.to_dict() for card in self._cards.values()], + } + with open(self.path, "w") as f: + json.dump(data, f, indent=2) + logger.debug(f"Saved {len(self._cards)} agents to {self.path}") + + def register(self, card: AgentCard) -> None: + """Register or update an agent card.""" + self._cards[card.name.lower()] = card + self._save() + logger.info(f"Registered agent: {card.name}") + + def unregister(self, name: str) -> bool: + """Remove an agent from the registry.""" + key = name.lower() + if key in self._cards: + del self._cards[key] + self._save() + logger.info(f"Unregistered agent: {name}") + return True + return False + + def get(self, name: str) -> Optional[AgentCard]: + """Get an agent card by name.""" + return self._cards.get(name.lower()) + + def list_agents( + self, + skill: Optional[str] = None, + tag: Optional[str] = None, + ) -> list[AgentCard]: + """ + List all registered agents, optionally filtered by skill or tag. + + Args: + skill: Filter to agents that have this skill ID + tag: Filter to agents that have this tag on any skill + """ + agents = list(self._cards.values()) + + if skill: + agents = [ + a for a in agents + if any(s.id == skill for s in a.skills) + ] + + if tag: + agents = [ + a for a in agents + if any(tag in s.tags for s in a.skills) + ] + + return agents + + def get_endpoint(self, name: str) -> Optional[str]: + """Get the first supported interface URL for an agent.""" + card = self.get(name) + if card and card.supported_interfaces: + return card.supported_interfaces[0].url + return None + + def dump(self) -> dict: + """Dump full registry as a dict.""" + return { + "version": 1, + "agents": [card.to_dict() for card in self._cards.values()], + } + + +class GiteaRegistry: + """ + Gitea-backed agent registry. + + Stores fleet agent cards in a Gitea repository file for + distributed discovery across VPS nodes. + """ + + def __init__( + self, + gitea_url: str, + repo: str, + token: str, + file_path: str = "config/fleet_agents.json", + ): + self.gitea_url = gitea_url.rstrip("/") + self.repo = repo + self.token = token + self.file_path = file_path + self._cards: dict[str, AgentCard] = {} + + def _api_url(self, endpoint: str) -> str: + return f"{self.gitea_url}/api/v1/repos/{self.repo}/{endpoint}" + + def _headers(self) -> dict: + return { + "Authorization": f"token {self.token}", + "Content-Type": "application/json", + } + + async def load(self) -> None: + """Fetch agent cards from Gitea.""" + try: + import aiohttp + url = self._api_url(f"contents/{self.file_path}") + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=self._headers()) as resp: + if resp.status == 200: + data = await resp.json() + import base64 + content = base64.b64decode(data["content"]).decode() + registry = json.loads(content) + for card_data in registry.get("agents", []): + card = AgentCard.from_dict(card_data) + self._cards[card.name.lower()] = card + logger.info( + f"Loaded {len(self._cards)} agents from Gitea" + ) + elif resp.status == 404: + logger.info("No fleet registry file in Gitea yet") + else: + logger.error( + f"Gitea fetch failed: {resp.status}" + ) + except Exception as e: + logger.error(f"Failed to load from Gitea: {e}") + + async def save(self, message: str = "Update fleet registry") -> None: + """Write agent cards to Gitea.""" + try: + import aiohttp + content = json.dumps( + {"version": 1, "agents": [c.to_dict() for c in self._cards.values()]}, + indent=2, + ) + import base64 + encoded = base64.b64encode(content.encode()).decode() + + # Check if file exists (need SHA for update) + url = self._api_url(f"contents/{self.file_path}") + sha = None + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=self._headers()) as resp: + if resp.status == 200: + existing = await resp.json() + sha = existing.get("sha") + + payload = { + "message": message, + "content": encoded, + } + if sha: + payload["sha"] = sha + + async with session.put( + url, headers=self._headers(), json=payload + ) as resp: + if resp.status in (200, 201): + logger.info("Fleet registry saved to Gitea") + else: + body = await resp.text() + logger.error( + f"Gitea save failed: {resp.status} — {body}" + ) + except Exception as e: + logger.error(f"Failed to save to Gitea: {e}") + + def register(self, card: AgentCard) -> None: + """Register an agent (local update; call save() to persist).""" + self._cards[card.name.lower()] = card + + def unregister(self, name: str) -> bool: + key = name.lower() + if key in self._cards: + del self._cards[key] + return True + return False + + def get(self, name: str) -> Optional[AgentCard]: + return self._cards.get(name.lower()) + + def list_agents( + self, + skill: Optional[str] = None, + tag: Optional[str] = None, + ) -> list[AgentCard]: + agents = list(self._cards.values()) + if skill: + agents = [a for a in agents if any(s.id == skill for s in a.skills)] + if tag: + agents = [a for a in agents if any(tag in s.tags for s in a.skills)] + return agents + + +# --- Convenience --- + +def discover_agents( + path: Path = Path("config/fleet_agents.json"), + skill: Optional[str] = None, + tag: Optional[str] = None, +) -> list[AgentCard]: + """One-shot discovery from local file.""" + registry = LocalFileRegistry(path) + return registry.list_agents(skill=skill, tag=tag) diff --git a/nexus/a2a/server.py b/nexus/a2a/server.py new file mode 100644 index 00000000..d5d5357e --- /dev/null +++ b/nexus/a2a/server.py @@ -0,0 +1,386 @@ +""" +A2A Server — receive and process tasks from other agents. + +Provides a FastAPI router that serves: +- GET /.well-known/agent-card.json — Agent Card discovery +- GET /agent.json — Agent Card fallback +- POST /a2a/v1 — JSON-RPC endpoint (SendMessage, GetTask, etc.) +- POST /a2a/v1/rpc — JSON-RPC endpoint (alias) + +Task routing: registered handlers are matched by skill ID or receive +all tasks via a default handler. + +Usage: + server = A2AServer(card=my_card, auth_token="secret") + server.register_handler("ci-health", my_ci_handler) + await server.start(host="0.0.0.0", port=8080) +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +import uuid +from datetime import datetime, timezone +from typing import Any, Callable, Awaitable, Optional + +try: + from fastapi import FastAPI, Request, Response, HTTPException, Header + from fastapi.responses import JSONResponse + import uvicorn + HAS_FASTAPI = True +except ImportError: + HAS_FASTAPI = False + +from nexus.a2a.types import ( + A2AError, + AgentCard, + Artifact, + JSONRPCError, + JSONRPCResponse, + Message, + Role, + Task, + TaskState, + TaskStatus, + TextPart, +) + +logger = logging.getLogger("nexus.a2a.server") + +# Type for task handlers +TaskHandler = Callable[[Task, AgentCard], Awaitable[Task]] + + +class A2AServer: + """ + A2A protocol server for receiving agent-to-agent task delegation. + + Supports: + - Agent Card serving at /.well-known/agent-card.json + - JSON-RPC task lifecycle (SendMessage, GetTask, CancelTask, ListTasks) + - Pluggable task handlers (by skill ID or default) + - Bearer / API key authentication + - Audit logging + """ + + def __init__( + self, + card: AgentCard, + auth_token: str = "", + auth_scheme: str = "bearer", + ): + if not HAS_FASTAPI: + raise ImportError( + "fastapi and uvicorn are required for A2AServer. " + "Install with: pip install fastapi uvicorn" + ) + + self.card = card + self.auth_token = auth_token + self.auth_scheme = auth_scheme + + # Task store (in-memory; swap for SQLite/Redis in production) + self._tasks: dict[str, Task] = {} + # Handlers keyed by skill ID + self._handlers: dict[str, TaskHandler] = {} + # Default handler for unmatched skills + self._default_handler: Optional[TaskHandler] = None + # Audit log + self._audit_log: list[dict] = [] + + self.app = FastAPI( + title=f"A2A — {card.name}", + description=card.description, + version=card.version, + ) + self._register_routes() + + def register_handler(self, skill_id: str, handler: TaskHandler): + """Register a handler for a specific skill ID.""" + self._handlers[skill_id] = handler + logger.info(f"Registered handler for skill: {skill_id}") + + def set_default_handler(self, handler: TaskHandler): + """Set the fallback handler for tasks without a matching skill.""" + self._default_handler = handler + + def _verify_auth(self, authorization: Optional[str]) -> bool: + """Check authentication header.""" + if not self.auth_token: + return True # No auth configured + + if not authorization: + return False + + if self.auth_scheme == "bearer": + expected = f"Bearer {self.auth_token}" + return authorization == expected + + return False + + def _register_routes(self): + """Wire up FastAPI routes.""" + + @self.app.get("/.well-known/agent-card.json") + async def agent_card_well_known(): + return JSONResponse(self.card.to_dict()) + + @self.app.get("/agent.json") + async def agent_card_fallback(): + return JSONResponse(self.card.to_dict()) + + @self.app.post("/a2a/v1") + @self.app.post("/a2a/v1/rpc") + async def rpc_endpoint(request: Request): + return await self._handle_rpc(request) + + @self.app.get("/a2a/v1/tasks") + @self.app.get("/a2a/v1/tasks/{task_id}") + async def rest_get_task(task_id: Optional[str] = None): + if task_id: + task = self._tasks.get(task_id) + if not task: + return JSONRPCResponse( + id="", + error=A2AError.TASK_NOT_FOUND, + ).to_dict() + return JSONResponse(task.to_dict()) + else: + return JSONResponse( + {"tasks": [t.to_dict() for t in self._tasks.values()]} + ) + + async def _handle_rpc(self, request: Request) -> JSONResponse: + """Handle JSON-RPC requests.""" + # Auth check + auth_header = request.headers.get("authorization") + if not self._verify_auth(auth_header): + return JSONResponse( + status_code=401, + content={"error": "Unauthorized"}, + ) + + # Parse JSON-RPC + try: + body = await request.json() + except json.JSONDecodeError: + return JSONResponse( + JSONRPCResponse( + id="", error=A2AError.PARSE + ).to_dict(), + status_code=400, + ) + + method = body.get("method", "") + request_id = body.get("id", str(uuid.uuid4())) + params = body.get("params", {}) + + # Audit + self._audit_log.append({ + "timestamp": time.time(), + "method": method, + "request_id": request_id, + "source": request.client.host if request.client else "unknown", + }) + + try: + result = await self._dispatch_rpc(method, params, request_id) + return JSONResponse( + JSONRPCResponse(id=request_id, result=result).to_dict() + ) + except ValueError as e: + return JSONResponse( + JSONRPCResponse( + id=request_id, + error=JSONRPCError(-32602, str(e)), + ).to_dict(), + status_code=400, + ) + except Exception as e: + logger.exception(f"Error handling {method}: {e}") + return JSONResponse( + JSONRPCResponse( + id=request_id, + error=JSONRPCError(-32603, str(e)), + ).to_dict(), + status_code=500, + ) + + async def _dispatch_rpc( + self, method: str, params: dict, request_id: str + ) -> Any: + """Route JSON-RPC method to handler.""" + if method == "SendMessage": + return await self._rpc_send_message(params) + elif method == "GetTask": + return await self._rpc_get_task(params) + elif method == "ListTasks": + return await self._rpc_list_tasks(params) + elif method == "CancelTask": + return await self._rpc_cancel_task(params) + elif method == "GetAgentCard": + return self.card.to_dict() + else: + raise ValueError(f"Unknown method: {method}") + + async def _rpc_send_message(self, params: dict) -> dict: + """Handle SendMessage — create a task and route to handler.""" + msg_data = params.get("message", {}) + message = Message.from_dict(msg_data) + + # Determine target skill from metadata + target_skill = message.metadata.get("targetSkill", "") + + # Create task + task = Task( + context_id=message.context_id, + status=TaskStatus(state=TaskState.SUBMITTED), + history=[message], + metadata={"targetSkill": target_skill} if target_skill else {}, + ) + + # Store immediately + self._tasks[task.id] = task + + # Dispatch to handler + handler = self._handlers.get(target_skill) or self._default_handler + + if handler is None: + task.status = TaskStatus( + state=TaskState.FAILED, + message=Message( + role=Role.AGENT, + parts=[TextPart(text="No handler available for this task")], + ), + ) + return {"task": task.to_dict()} + + try: + # Mark as working + task.status = TaskStatus(state=TaskState.WORKING) + self._tasks[task.id] = task + + # Execute handler + result_task = await handler(task, self.card) + + # Store result + self._tasks[result_task.id] = result_task + return {"task": result_task.to_dict()} + + except Exception as e: + task.status = TaskStatus( + state=TaskState.FAILED, + message=Message( + role=Role.AGENT, + parts=[TextPart(text=f"Handler error: {str(e)}")], + ), + ) + self._tasks[task.id] = task + return {"task": task.to_dict()} + + async def _rpc_get_task(self, params: dict) -> dict: + """Handle GetTask.""" + task_id = params.get("id", "") + task = self._tasks.get(task_id) + if not task: + raise ValueError(f"Task not found: {task_id}") + return task.to_dict() + + async def _rpc_list_tasks(self, params: dict) -> dict: + """Handle ListTasks with cursor-based pagination.""" + page_size = params.get("pageSize", 20) + page_token = params.get("pageToken", "") + + tasks = sorted( + self._tasks.values(), + key=lambda t: t.status.timestamp, + reverse=True, + ) + + # Simple cursor: find index by token + start_idx = 0 + if page_token: + for i, t in enumerate(tasks): + if t.id == page_token: + start_idx = i + 1 + break + + page = tasks[start_idx : start_idx + page_size] + next_token = "" + if start_idx + page_size < len(tasks): + next_token = tasks[start_idx + page_size - 1].id + + return { + "tasks": [t.to_dict() for t in page], + "nextPageToken": next_token, + } + + async def _rpc_cancel_task(self, params: dict) -> dict: + """Handle CancelTask.""" + task_id = params.get("id", "") + task = self._tasks.get(task_id) + if not task: + raise ValueError(f"Task not found: {task_id}") + + if task.status.state.terminal: + raise ValueError( + f"Task {task_id} is already terminal " + f"({task.status.state.value})" + ) + + task.status = TaskStatus(state=TaskState.CANCELED) + self._tasks[task_id] = task + return task.to_dict() + + def get_audit_log(self) -> list[dict]: + """Return audit log of all received requests.""" + return list(self._audit_log) + + async def start( + self, + host: str = "0.0.0.0", + port: int = 8080, + ): + """Start the A2A server with uvicorn.""" + logger.info( + f"Starting A2A server for {self.card.name} on " + f"{host}:{port}" + ) + logger.info( + f"Agent Card at " + f"http://{host}:{port}/.well-known/agent-card.json" + ) + config = uvicorn.Config( + self.app, + host=host, + port=port, + log_level="info", + ) + server = uvicorn.Server(config) + await server.serve() + + +# --- Default Handler Factory --- + +async def echo_handler(task: Task, card: AgentCard) -> Task: + """ + Simple echo handler for testing. + Returns the user's message as an artifact. + """ + if task.history: + last_msg = task.history[-1] + text_parts = [p for p in last_msg.parts if isinstance(p, TextPart)] + if text_parts: + response_text = f"[{card.name}] Echo: {text_parts[0].text}" + task.artifacts.append( + Artifact( + parts=[TextPart(text=response_text)], + name="echo_response", + ) + ) + + task.status = TaskStatus(state=TaskState.COMPLETED) + return task diff --git a/nexus/a2a/types.py b/nexus/a2a/types.py new file mode 100644 index 00000000..1a6c9f63 --- /dev/null +++ b/nexus/a2a/types.py @@ -0,0 +1,524 @@ +""" +A2A Protocol Types — Data models for Google's Agent2Agent protocol v1.0. + +All types map directly to the A2A spec. JSON uses camelCase, enums use +SCREAMING_SNAKE_CASE, and Part types are discriminated by member name +(not a kind field — that was removed in v1.0). + +See: https://github.com/google/A2A +""" + +from __future__ import annotations + +import enum +import uuid +from dataclasses import dataclass, field, asdict +from datetime import datetime, timezone +from typing import Any, Optional + + +# --- Enums --- + +class TaskState(str, enum.Enum): + """Lifecycle states for an A2A Task.""" + SUBMITTED = "TASK_STATE_SUBMITTED" + WORKING = "TASK_STATE_WORKING" + COMPLETED = "TASK_STATE_COMPLETED" + FAILED = "TASK_STATE_FAILED" + CANCELED = "TASK_STATE_CANCELED" + INPUT_REQUIRED = "TASK_STATE_INPUT_REQUIRED" + REJECTED = "TASK_STATE_REJECTED" + AUTH_REQUIRED = "TASK_STATE_AUTH_REQUIRED" + + @property + def terminal(self) -> bool: + return self in ( + TaskState.COMPLETED, + TaskState.FAILED, + TaskState.CANCELED, + TaskState.REJECTED, + ) + + +class Role(str, enum.Enum): + """Who sent a message in an A2A conversation.""" + USER = "ROLE_USER" + AGENT = "ROLE_AGENT" + + +# --- Parts (discriminated by member name in JSON) --- + +@dataclass +class TextPart: + """Plain text content.""" + text: str + media_type: str = "text/plain" + metadata: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + d = {"text": self.text} + if self.media_type != "text/plain": + d["mediaType"] = self.media_type + if self.metadata: + d["metadata"] = self.metadata + return d + + +@dataclass +class FilePart: + """Binary file content — inline or by URL reference.""" + media_type: str + filename: Optional[str] = None + raw: Optional[str] = None # base64-encoded bytes + url: Optional[str] = None # URL reference + metadata: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + d = {"mediaType": self.media_type} + if self.raw is not None: + d["raw"] = self.raw + if self.url is not None: + d["url"] = self.url + if self.filename: + d["filename"] = self.filename + if self.metadata: + d["metadata"] = self.metadata + return d + + +@dataclass +class DataPart: + """Arbitrary structured JSON data.""" + data: dict + media_type: str = "application/json" + metadata: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + d = {"data": self.data} + if self.media_type != "application/json": + d["mediaType"] = self.media_type + if self.metadata: + d["metadata"] = self.metadata + return d + + +Part = TextPart | FilePart | DataPart + + +def part_from_dict(d: dict) -> Part: + """Reconstruct a Part from its JSON dict (discriminated by key name).""" + if "text" in d: + return TextPart( + text=d["text"], + media_type=d.get("mediaType", "text/plain"), + metadata=d.get("metadata", {}), + ) + if "raw" in d or "url" in d: + return FilePart( + media_type=d["mediaType"], + filename=d.get("filename"), + raw=d.get("raw"), + url=d.get("url"), + metadata=d.get("metadata", {}), + ) + if "data" in d: + return DataPart( + data=d["data"], + media_type=d.get("mediaType", "application/json"), + metadata=d.get("metadata", {}), + ) + raise ValueError(f"Cannot determine Part type from keys: {list(d.keys())}") + + +def part_to_dict(p: Part) -> dict: + """Serialize a Part to its JSON dict.""" + return p.to_dict() + + +# --- Message --- + +@dataclass +class Message: + """A2A Message — a turn in a conversation between user and agent.""" + role: Role + parts: list[Part] + message_id: str = field(default_factory=lambda: str(uuid.uuid4())) + context_id: Optional[str] = None + task_id: Optional[str] = None + metadata: dict = field(default_factory=dict) + extensions: list[str] = field(default_factory=list) + reference_task_ids: list[str] = field(default_factory=list) + + def to_dict(self) -> dict: + d: dict[str, Any] = { + "messageId": self.message_id, + "role": self.role.value, + "parts": [part_to_dict(p) for p in self.parts], + } + if self.context_id: + d["contextId"] = self.context_id + if self.task_id: + d["taskId"] = self.task_id + if self.metadata: + d["metadata"] = self.metadata + if self.extensions: + d["extensions"] = self.extensions + if self.reference_task_ids: + d["referenceTaskIds"] = self.reference_task_ids + return d + + @classmethod + def from_dict(cls, d: dict) -> "Message": + return cls( + role=Role(d["role"]), + parts=[part_from_dict(p) for p in d["parts"]], + message_id=d.get("messageId", str(uuid.uuid4())), + context_id=d.get("contextId"), + task_id=d.get("taskId"), + metadata=d.get("metadata", {}), + extensions=d.get("extensions", []), + reference_task_ids=d.get("referenceTaskIds", []), + ) + + +# --- Artifact --- + +@dataclass +class Artifact: + """A2A Artifact — structured output from a task.""" + parts: list[Part] + artifact_id: str = field(default_factory=lambda: str(uuid.uuid4())) + name: Optional[str] = None + description: Optional[str] = None + metadata: dict = field(default_factory=dict) + extensions: list[str] = field(default_factory=list) + + def to_dict(self) -> dict: + d: dict[str, Any] = { + "artifactId": self.artifact_id, + "parts": [part_to_dict(p) for p in self.parts], + } + if self.name: + d["name"] = self.name + if self.description: + d["description"] = self.description + if self.metadata: + d["metadata"] = self.metadata + if self.extensions: + d["extensions"] = self.extensions + return d + + @classmethod + def from_dict(cls, d: dict) -> "Artifact": + return cls( + parts=[part_from_dict(p) for p in d["parts"]], + artifact_id=d.get("artifactId", str(uuid.uuid4())), + name=d.get("name"), + description=d.get("description"), + metadata=d.get("metadata", {}), + extensions=d.get("extensions", []), + ) + + +# --- Task --- + +@dataclass +class TaskStatus: + """Status envelope for a Task.""" + state: TaskState + message: Optional[Message] = None + timestamp: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat() + ) + + def to_dict(self) -> dict: + d: dict[str, Any] = {"state": self.state.value} + if self.message: + d["message"] = self.message.to_dict() + d["timestamp"] = self.timestamp + return d + + @classmethod + def from_dict(cls, d: dict) -> "TaskStatus": + msg = None + if "message" in d: + msg = Message.from_dict(d["message"]) + return cls( + state=TaskState(d["state"]), + message=msg, + timestamp=d.get("timestamp", datetime.now(timezone.utc).isoformat()), + ) + + +@dataclass +class Task: + """A2A Task — a unit of work delegated between agents.""" + id: str = field(default_factory=lambda: str(uuid.uuid4())) + context_id: Optional[str] = None + status: TaskStatus = field( + default_factory=lambda: TaskStatus(state=TaskState.SUBMITTED) + ) + artifacts: list[Artifact] = field(default_factory=list) + history: list[Message] = field(default_factory=list) + metadata: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + d: dict[str, Any] = { + "id": self.id, + "status": self.status.to_dict(), + } + if self.context_id: + d["contextId"] = self.context_id + if self.artifacts: + d["artifacts"] = [a.to_dict() for a in self.artifacts] + if self.history: + d["history"] = [m.to_dict() for m in self.history] + if self.metadata: + d["metadata"] = self.metadata + return d + + @classmethod + def from_dict(cls, d: dict) -> "Task": + return cls( + id=d.get("id", str(uuid.uuid4())), + context_id=d.get("contextId"), + status=TaskStatus.from_dict(d["status"]) if "status" in d else TaskStatus(TaskState.SUBMITTED), + artifacts=[Artifact.from_dict(a) for a in d.get("artifacts", [])], + history=[Message.from_dict(m) for m in d.get("history", [])], + metadata=d.get("metadata", {}), + ) + + +# --- Agent Card --- + +@dataclass +class AgentSkill: + """Capability declaration for an Agent Card.""" + id: str + name: str + description: str + tags: list[str] = field(default_factory=list) + examples: list[str] = field(default_factory=list) + input_modes: list[str] = field(default_factory=lambda: ["text/plain"]) + output_modes: list[str] = field(default_factory=lambda: ["text/plain"]) + security_requirements: list[dict] = field(default_factory=list) + + def to_dict(self) -> dict: + d: dict[str, Any] = { + "id": self.id, + "name": self.name, + "description": self.description, + "tags": self.tags, + } + if self.examples: + d["examples"] = self.examples + if self.input_modes != ["text/plain"]: + d["inputModes"] = self.input_modes + if self.output_modes != ["text/plain"]: + d["outputModes"] = self.output_modes + if self.security_requirements: + d["securityRequirements"] = self.security_requirements + return d + + +@dataclass +class AgentInterface: + """Network endpoint for an agent.""" + url: str + protocol_binding: str = "HTTP+JSON" + protocol_version: str = "1.0" + tenant: str = "" + + def to_dict(self) -> dict: + d = { + "url": self.url, + "protocolBinding": self.protocol_binding, + "protocolVersion": self.protocol_version, + } + if self.tenant: + d["tenant"] = self.tenant + return d + + +@dataclass +class AgentCapabilities: + """What this agent can do beyond basic request/response.""" + streaming: bool = False + push_notifications: bool = False + extended_agent_card: bool = False + extensions: list[dict] = field(default_factory=list) + + def to_dict(self) -> dict: + return { + "streaming": self.streaming, + "pushNotifications": self.push_notifications, + "extendedAgentCard": self.extended_agent_card, + "extensions": self.extensions, + } + + +@dataclass +class AgentCard: + """ + A2A Agent Card — self-describing metadata published at + /.well-known/agent-card.json + """ + name: str + description: str + version: str = "1.0.0" + supported_interfaces: list[AgentInterface] = field(default_factory=list) + capabilities: AgentCapabilities = field( + default_factory=AgentCapabilities + ) + provider: Optional[dict] = None + documentation_url: Optional[str] = None + icon_url: Optional[str] = None + default_input_modes: list[str] = field( + default_factory=lambda: ["text/plain"] + ) + default_output_modes: list[str] = field( + default_factory=lambda: ["text/plain"] + ) + skills: list[AgentSkill] = field(default_factory=list) + security_schemes: dict = field(default_factory=dict) + security_requirements: list[dict] = field(default_factory=list) + + def to_dict(self) -> dict: + d: dict[str, Any] = { + "name": self.name, + "description": self.description, + "version": self.version, + "supportedInterfaces": [i.to_dict() for i in self.supported_interfaces], + "capabilities": self.capabilities.to_dict(), + "defaultInputModes": self.default_input_modes, + "defaultOutputModes": self.default_output_modes, + "skills": [s.to_dict() for s in self.skills], + } + if self.provider: + d["provider"] = self.provider + if self.documentation_url: + d["documentationUrl"] = self.documentation_url + if self.icon_url: + d["iconUrl"] = self.icon_url + if self.security_schemes: + d["securitySchemes"] = self.security_schemes + if self.security_requirements: + d["securityRequirements"] = self.security_requirements + return d + + @classmethod + def from_dict(cls, d: dict) -> "AgentCard": + return cls( + name=d["name"], + description=d["description"], + version=d.get("version", "1.0.0"), + supported_interfaces=[ + AgentInterface( + url=i["url"], + protocol_binding=i.get("protocolBinding", "HTTP+JSON"), + protocol_version=i.get("protocolVersion", "1.0"), + tenant=i.get("tenant", ""), + ) + for i in d.get("supportedInterfaces", []) + ], + capabilities=AgentCapabilities( + streaming=d.get("capabilities", {}).get("streaming", False), + push_notifications=d.get("capabilities", {}).get("pushNotifications", False), + extended_agent_card=d.get("capabilities", {}).get("extendedAgentCard", False), + extensions=d.get("capabilities", {}).get("extensions", []), + ), + provider=d.get("provider"), + documentation_url=d.get("documentationUrl"), + icon_url=d.get("iconUrl"), + default_input_modes=d.get("defaultInputModes", ["text/plain"]), + default_output_modes=d.get("defaultOutputModes", ["text/plain"]), + skills=[ + AgentSkill( + id=s["id"], + name=s["name"], + description=s["description"], + tags=s.get("tags", []), + examples=s.get("examples", []), + input_modes=s.get("inputModes", ["text/plain"]), + output_modes=s.get("outputModes", ["text/plain"]), + security_requirements=s.get("securityRequirements", []), + ) + for s in d.get("skills", []) + ], + security_schemes=d.get("securitySchemes", {}), + security_requirements=d.get("securityRequirements", []), + ) + + +# --- JSON-RPC envelope --- + +@dataclass +class JSONRPCRequest: + """JSON-RPC 2.0 request wrapping an A2A method.""" + method: str + id: str = field(default_factory=lambda: str(uuid.uuid4())) + params: dict = field(default_factory=dict) + jsonrpc: str = "2.0" + + def to_dict(self) -> dict: + return { + "jsonrpc": self.jsonrpc, + "id": self.id, + "method": self.method, + "params": self.params, + } + + +@dataclass +class JSONRPCError: + """JSON-RPC 2.0 error object.""" + code: int + message: str + data: Any = None + + def to_dict(self) -> dict: + d = {"code": self.code, "message": self.message} + if self.data is not None: + d["data"] = self.data + return d + + +@dataclass +class JSONRPCResponse: + """JSON-RPC 2.0 response.""" + id: str + result: Any = None + error: Optional[JSONRPCError] = None + jsonrpc: str = "2.0" + + def to_dict(self) -> dict: + d: dict[str, Any] = { + "jsonrpc": self.jsonrpc, + "id": self.id, + } + if self.error: + d["error"] = self.error.to_dict() + else: + d["result"] = self.result + return d + + +# --- Standard A2A Error codes --- + +class A2AError: + """Standard A2A / JSON-RPC error factories.""" + PARSE = JSONRPCError(-32700, "Invalid JSON payload") + INVALID_REQUEST = JSONRPCError(-32600, "Request payload validation error") + METHOD_NOT_FOUND = JSONRPCError(-32601, "Method not found") + INVALID_PARAMS = JSONRPCError(-32602, "Invalid parameters") + INTERNAL = JSONRPCError(-32603, "Internal error") + + TASK_NOT_FOUND = JSONRPCError(-32001, "Task not found") + TASK_NOT_CANCELABLE = JSONRPCError(-32002, "Task not cancelable") + PUSH_NOT_SUPPORTED = JSONRPCError(-32003, "Push notifications not supported") + UNSUPPORTED_OP = JSONRPCError(-32004, "Unsupported operation") + CONTENT_TYPE = JSONRPCError(-32005, "Content type not supported") + INVALID_RESPONSE = JSONRPCError(-32006, "Invalid agent response") + EXTENDED_CARD = JSONRPCError(-32007, "Extended agent card not configured") + EXTENSION_REQUIRED = JSONRPCError(-32008, "Extension support required") + VERSION_NOT_SUPPORTED = JSONRPCError(-32009, "Version not supported") diff --git a/tests/test_a2a.py b/tests/test_a2a.py new file mode 100644 index 00000000..d020c500 --- /dev/null +++ b/tests/test_a2a.py @@ -0,0 +1,763 @@ +""" +Tests for A2A Protocol implementation. + +Covers: +- Type serialization roundtrips (Agent Card, Task, Message, Artifact, Part) +- JSON-RPC envelope +- Agent Card building from YAML config +- Registry operations (register, list, filter) +- Client/server integration (end-to-end task delegation) +""" + +from __future__ import annotations + +import asyncio +import json +import pytest +from pathlib import Path +from unittest.mock import AsyncMock, patch, MagicMock + +from nexus.a2a.types import ( + A2AError, + AgentCard, + AgentCapabilities, + AgentInterface, + AgentSkill, + Artifact, + DataPart, + FilePart, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + Message, + Role, + Task, + TaskState, + TaskStatus, + TextPart, + part_from_dict, + part_to_dict, +) +from nexus.a2a.card import build_card, load_card_config +from nexus.a2a.registry import LocalFileRegistry + + +# === Type Serialization Roundtrips === + + +class TestTextPart: + def test_roundtrip(self): + p = TextPart(text="hello world") + d = p.to_dict() + assert d == {"text": "hello world"} + p2 = part_from_dict(d) + assert isinstance(p2, TextPart) + assert p2.text == "hello world" + + def test_custom_media_type(self): + p = TextPart(text="data", media_type="text/markdown") + d = p.to_dict() + assert d["mediaType"] == "text/markdown" + p2 = part_from_dict(d) + assert p2.media_type == "text/markdown" + + +class TestFilePart: + def test_inline_roundtrip(self): + p = FilePart(media_type="image/png", raw="base64data", filename="img.png") + d = p.to_dict() + assert d["raw"] == "base64data" + assert d["filename"] == "img.png" + p2 = part_from_dict(d) + assert isinstance(p2, FilePart) + assert p2.raw == "base64data" + + def test_url_roundtrip(self): + p = FilePart(media_type="application/pdf", url="https://example.com/doc.pdf") + d = p.to_dict() + assert d["url"] == "https://example.com/doc.pdf" + p2 = part_from_dict(d) + assert isinstance(p2, FilePart) + assert p2.url == "https://example.com/doc.pdf" + + +class TestDataPart: + def test_roundtrip(self): + p = DataPart(data={"key": "value", "count": 42}) + d = p.to_dict() + assert d["data"] == {"key": "value", "count": 42} + p2 = part_from_dict(d) + assert isinstance(p2, DataPart) + assert p2.data["count"] == 42 + + +class TestMessage: + def test_roundtrip(self): + msg = Message( + role=Role.USER, + parts=[TextPart(text="Hello agent")], + metadata={"priority": "high"}, + ) + d = msg.to_dict() + assert d["role"] == "ROLE_USER" + assert d["parts"] == [{"text": "Hello agent"}] + assert d["metadata"]["priority"] == "high" + + msg2 = Message.from_dict(d) + assert msg2.role == Role.USER + assert isinstance(msg2.parts[0], TextPart) + assert msg2.parts[0].text == "Hello agent" + assert msg2.metadata["priority"] == "high" + + def test_multi_part(self): + msg = Message( + role=Role.AGENT, + parts=[ + TextPart(text="Here's the report"), + DataPart(data={"status": "healthy"}), + ], + ) + d = msg.to_dict() + assert len(d["parts"]) == 2 + msg2 = Message.from_dict(d) + assert len(msg2.parts) == 2 + assert isinstance(msg2.parts[0], TextPart) + assert isinstance(msg2.parts[1], DataPart) + + +class TestArtifact: + def test_roundtrip(self): + art = Artifact( + parts=[TextPart(text="result data")], + name="report", + description="CI health report", + ) + d = art.to_dict() + assert d["name"] == "report" + assert d["description"] == "CI health report" + + art2 = Artifact.from_dict(d) + assert art2.name == "report" + assert isinstance(art2.parts[0], TextPart) + assert art2.parts[0].text == "result data" + + +class TestTask: + def test_roundtrip(self): + task = Task( + id="test-123", + status=TaskStatus(state=TaskState.WORKING), + history=[ + Message(role=Role.USER, parts=[TextPart(text="Do X")]), + ], + ) + d = task.to_dict() + assert d["id"] == "test-123" + assert d["status"]["state"] == "TASK_STATE_WORKING" + + task2 = Task.from_dict(d) + assert task2.id == "test-123" + assert task2.status.state == TaskState.WORKING + assert len(task2.history) == 1 + + def test_with_artifacts(self): + task = Task( + id="art-task", + status=TaskStatus(state=TaskState.COMPLETED), + artifacts=[ + Artifact( + parts=[TextPart(text="42")], + name="answer", + ) + ], + ) + d = task.to_dict() + assert len(d["artifacts"]) == 1 + task2 = Task.from_dict(d) + assert task2.artifacts[0].name == "answer" + + def test_terminal_states(self): + for state in [ + TaskState.COMPLETED, + TaskState.FAILED, + TaskState.CANCELED, + TaskState.REJECTED, + ]: + assert state.terminal is True + + for state in [ + TaskState.SUBMITTED, + TaskState.WORKING, + TaskState.INPUT_REQUIRED, + TaskState.AUTH_REQUIRED, + ]: + assert state.terminal is False + + +class TestAgentCard: + def test_roundtrip(self): + card = AgentCard( + name="TestAgent", + description="A test agent", + version="1.0.0", + supported_interfaces=[ + AgentInterface(url="http://localhost:8080/a2a/v1") + ], + capabilities=AgentCapabilities(streaming=True), + skills=[ + AgentSkill( + id="test-skill", + name="Test Skill", + description="Does tests", + tags=["test"], + ) + ], + ) + d = card.to_dict() + assert d["name"] == "TestAgent" + assert d["capabilities"]["streaming"] is True + assert len(d["skills"]) == 1 + assert d["skills"][0]["id"] == "test-skill" + + card2 = AgentCard.from_dict(d) + assert card2.name == "TestAgent" + assert card2.skills[0].id == "test-skill" + assert card2.capabilities.streaming is True + + +class TestJSONRPC: + def test_request_roundtrip(self): + req = JSONRPCRequest( + method="SendMessage", + params={"message": {"text": "hello"}}, + ) + d = req.to_dict() + assert d["jsonrpc"] == "2.0" + assert d["method"] == "SendMessage" + + def test_response_success(self): + resp = JSONRPCResponse( + id="req-1", + result={"task": {"id": "t1"}}, + ) + d = resp.to_dict() + assert "error" not in d + assert d["result"]["task"]["id"] == "t1" + + def test_response_error(self): + resp = JSONRPCResponse( + id="req-1", + error=A2AError.TASK_NOT_FOUND, + ) + d = resp.to_dict() + assert "result" not in d + assert d["error"]["code"] == -32001 + + +# === Agent Card Building === + + +class TestBuildCard: + def test_basic_config(self): + config = { + "name": "Bezalel", + "description": "CI/CD specialist", + "version": "2.0.0", + "url": "https://bezalel.example.com", + "skills": [ + { + "id": "ci-health", + "name": "CI Health", + "description": "Check CI", + "tags": ["ci"], + }, + { + "id": "deploy", + "name": "Deploy", + "description": "Deploy services", + "tags": ["ops"], + }, + ], + } + card = build_card(config) + assert card.name == "Bezalel" + assert card.version == "2.0.0" + assert len(card.skills) == 2 + assert card.skills[0].id == "ci-health" + assert card.supported_interfaces[0].url == "https://bezalel.example.com" + + def test_bearer_auth(self): + config = { + "name": "Test", + "description": "Test", + "auth": {"scheme": "bearer", "token_env": "MY_TOKEN"}, + } + card = build_card(config) + assert "bearerAuth" in card.security_schemes + assert card.security_requirements[0]["schemes"]["bearerAuth"] == {"list": []} + + def test_api_key_auth(self): + config = { + "name": "Test", + "description": "Test", + "auth": {"scheme": "api_key", "key_name": "X-Custom-Key"}, + } + card = build_card(config) + assert "apiKeyAuth" in card.security_schemes + + +# === Registry === + + +class TestLocalFileRegistry: + def _make_card(self, name: str, skills: list[dict] | None = None) -> AgentCard: + return AgentCard( + name=name, + description=f"Agent {name}", + supported_interfaces=[ + AgentInterface(url=f"http://{name}:8080/a2a/v1") + ], + skills=[ + AgentSkill( + id=s["id"], + name=s.get("name", s["id"]), + description=s.get("description", ""), + tags=s.get("tags", []), + ) + for s in (skills or []) + ], + ) + + def test_register_and_list(self, tmp_path): + registry = LocalFileRegistry(tmp_path / "agents.json") + registry.register(self._make_card("ezra")) + registry.register(self._make_card("allegro")) + + agents = registry.list_agents() + assert len(agents) == 2 + names = {a.name for a in agents} + assert names == {"ezra", "allegro"} + + def test_filter_by_skill(self, tmp_path): + registry = LocalFileRegistry(tmp_path / "agents.json") + registry.register( + self._make_card("ezra", [{"id": "ci-health", "tags": ["ci"]}]) + ) + registry.register( + self._make_card("allegro", [{"id": "research", "tags": ["research"]}]) + ) + + ci_agents = registry.list_agents(skill="ci-health") + assert len(ci_agents) == 1 + assert ci_agents[0].name == "ezra" + + def test_filter_by_tag(self, tmp_path): + registry = LocalFileRegistry(tmp_path / "agents.json") + registry.register( + self._make_card("ezra", [{"id": "ci", "tags": ["devops", "ci"]}]) + ) + registry.register( + self._make_card("allegro", [{"id": "research", "tags": ["research"]}]) + ) + + devops_agents = registry.list_agents(tag="devops") + assert len(devops_agents) == 1 + assert devops_agents[0].name == "ezra" + + def test_persistence(self, tmp_path): + path = tmp_path / "agents.json" + reg1 = LocalFileRegistry(path) + reg1.register(self._make_card("ezra")) + + # Load fresh from disk + reg2 = LocalFileRegistry(path) + agents = reg2.list_agents() + assert len(agents) == 1 + assert agents[0].name == "ezra" + + def test_unregister(self, tmp_path): + registry = LocalFileRegistry(tmp_path / "agents.json") + registry.register(self._make_card("ezra")) + assert len(registry.list_agents()) == 1 + + assert registry.unregister("ezra") is True + assert len(registry.list_agents()) == 0 + assert registry.unregister("nonexistent") is False + + def test_get_endpoint(self, tmp_path): + registry = LocalFileRegistry(tmp_path / "agents.json") + registry.register(self._make_card("ezra")) + + url = registry.get_endpoint("ezra") + assert url == "http://ezra:8080/a2a/v1" + + +# === Server Integration (FastAPI required) === + + +try: + from fastapi.testclient import TestClient + HAS_TEST_CLIENT = True +except ImportError: + HAS_TEST_CLIENT = False + + +@pytest.mark.skipif(not HAS_TEST_CLIENT, reason="fastapi not installed") +class TestA2AServerIntegration: + """End-to-end tests using FastAPI TestClient.""" + + def _make_server(self, auth_token: str = ""): + from nexus.a2a.server import A2AServer, echo_handler + + card = AgentCard( + name="TestAgent", + description="Test agent for A2A", + supported_interfaces=[ + AgentInterface(url="http://localhost:8080/a2a/v1") + ], + capabilities=AgentCapabilities(streaming=False), + skills=[ + AgentSkill( + id="echo", + name="Echo", + description="Echo back messages", + tags=["test"], + ) + ], + ) + + server = A2AServer(card=card, auth_token=auth_token) + server.register_handler("echo", echo_handler) + server.set_default_handler(echo_handler) + return server + + def test_agent_card_well_known(self): + server = self._make_server() + client = TestClient(server.app) + + resp = client.get("/.well-known/agent-card.json") + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "TestAgent" + assert len(data["skills"]) == 1 + + def test_agent_card_fallback(self): + server = self._make_server() + client = TestClient(server.app) + + resp = client.get("/agent.json") + assert resp.status_code == 200 + assert resp.json()["name"] == "TestAgent" + + def test_send_message(self): + server = self._make_server() + client = TestClient(server.app) + + rpc_request = { + "jsonrpc": "2.0", + "id": "test-1", + "method": "SendMessage", + "params": { + "message": { + "messageId": "msg-1", + "role": "ROLE_USER", + "parts": [{"text": "Hello from test"}], + }, + "configuration": { + "acceptedOutputModes": ["text/plain"], + "historyLength": 10, + "returnImmediately": False, + }, + }, + } + + resp = client.post("/a2a/v1", json=rpc_request) + assert resp.status_code == 200 + data = resp.json() + assert "result" in data + assert "task" in data["result"] + + task = data["result"]["task"] + assert task["status"]["state"] == "TASK_STATE_COMPLETED" + assert len(task["artifacts"]) == 1 + assert "Echo" in task["artifacts"][0]["parts"][0]["text"] + + def test_get_task(self): + server = self._make_server() + client = TestClient(server.app) + + # Create a task first + send_req = { + "jsonrpc": "2.0", + "id": "s1", + "method": "SendMessage", + "params": { + "message": { + "messageId": "m1", + "role": "ROLE_USER", + "parts": [{"text": "get me"}], + }, + "configuration": {}, + }, + } + send_resp = client.post("/a2a/v1", json=send_req) + task_id = send_resp.json()["result"]["task"]["id"] + + # Now fetch it + get_req = { + "jsonrpc": "2.0", + "id": "g1", + "method": "GetTask", + "params": {"id": task_id}, + } + get_resp = client.post("/a2a/v1", json=get_req) + assert get_resp.status_code == 200 + assert get_resp.json()["result"]["id"] == task_id + + def test_get_nonexistent_task(self): + server = self._make_server() + client = TestClient(server.app) + + req = { + "jsonrpc": "2.0", + "id": "g2", + "method": "GetTask", + "params": {"id": "nonexistent"}, + } + resp = client.post("/a2a/v1", json=req) + assert resp.status_code == 400 + data = resp.json() + assert "error" in data + + def test_list_tasks(self): + server = self._make_server() + client = TestClient(server.app) + + # Create two tasks + for i in range(2): + req = { + "jsonrpc": "2.0", + "id": f"s{i}", + "method": "SendMessage", + "params": { + "message": { + "messageId": f"m{i}", + "role": "ROLE_USER", + "parts": [{"text": f"task {i}"}], + }, + "configuration": {}, + }, + } + client.post("/a2a/v1", json=req) + + list_req = { + "jsonrpc": "2.0", + "id": "l1", + "method": "ListTasks", + "params": {"pageSize": 10}, + } + resp = client.post("/a2a/v1", json=list_req) + assert resp.status_code == 200 + tasks = resp.json()["result"]["tasks"] + assert len(tasks) >= 2 + + def test_cancel_task(self): + from nexus.a2a.server import A2AServer + + # Create a server with a slow handler so task stays WORKING + async def slow_handler(task, card): + import asyncio + await asyncio.sleep(10) # never reached in test + task.status = TaskStatus(state=TaskState.COMPLETED) + return task + + card = AgentCard(name="SlowAgent", description="Slow test agent") + server = A2AServer(card=card) + server.set_default_handler(slow_handler) + client = TestClient(server.app) + + # Create a task (but we need to intercept before handler runs) + # Instead, manually insert a task and test cancel on it + task = Task( + id="cancel-me", + status=TaskStatus(state=TaskState.WORKING), + history=[ + Message(role=Role.USER, parts=[TextPart(text="cancel me")]) + ], + ) + server._tasks[task.id] = task + + # Cancel it + cancel_req = { + "jsonrpc": "2.0", + "id": "c2", + "method": "CancelTask", + "params": {"id": "cancel-me"}, + } + cancel_resp = client.post("/a2a/v1", json=cancel_req) + assert cancel_resp.status_code == 200 + assert cancel_resp.json()["result"]["status"]["state"] == "TASK_STATE_CANCELED" + + def test_auth_required(self): + server = self._make_server(auth_token="secret123") + client = TestClient(server.app) + + # No auth header — should get 401 + req = { + "jsonrpc": "2.0", + "id": "a1", + "method": "SendMessage", + "params": { + "message": { + "messageId": "am1", + "role": "ROLE_USER", + "parts": [{"text": "hello"}], + }, + "configuration": {}, + }, + } + resp = client.post("/a2a/v1", json=req) + assert resp.status_code == 401 + + def test_auth_success(self): + server = self._make_server(auth_token="secret123") + client = TestClient(server.app) + + req = { + "jsonrpc": "2.0", + "id": "a2", + "method": "SendMessage", + "params": { + "message": { + "messageId": "am2", + "role": "ROLE_USER", + "parts": [{"text": "authenticated"}], + }, + "configuration": {}, + }, + } + resp = client.post( + "/a2a/v1", + json=req, + headers={"Authorization": "Bearer secret123"}, + ) + assert resp.status_code == 200 + assert resp.json()["result"]["task"]["status"]["state"] == "TASK_STATE_COMPLETED" + + def test_unknown_method(self): + server = self._make_server() + client = TestClient(server.app) + + req = { + "jsonrpc": "2.0", + "id": "u1", + "method": "NonExistentMethod", + "params": {}, + } + resp = client.post("/a2a/v1", json=req) + assert resp.status_code == 400 + assert resp.json()["error"]["code"] == -32602 + + def test_audit_log(self): + server = self._make_server() + client = TestClient(server.app) + + req = { + "jsonrpc": "2.0", + "id": "au1", + "method": "SendMessage", + "params": { + "message": { + "messageId": "aum1", + "role": "ROLE_USER", + "parts": [{"text": "audit me"}], + }, + "configuration": {}, + }, + } + client.post("/a2a/v1", json=req) + client.post("/a2a/v1", json=req) + + log = server.get_audit_log() + assert len(log) == 2 + assert all(entry["method"] == "SendMessage" for entry in log) + + +# === Custom Handler Test === + + +@pytest.mark.skipif(not HAS_TEST_CLIENT, reason="fastapi not installed") +class TestCustomHandlers: + """Test custom task handlers.""" + + def test_skill_routing(self): + from nexus.a2a.server import A2AServer + from nexus.a2a.types import Task, AgentCard + + async def ci_handler(task: Task, card: AgentCard) -> Task: + task.artifacts.append( + Artifact( + parts=[TextPart(text="CI pipeline healthy: 5/5 passed")], + name="ci_report", + ) + ) + task.status = TaskStatus(state=TaskState.COMPLETED) + return task + + card = AgentCard( + name="CI Agent", + description="CI specialist", + skills=[AgentSkill(id="ci-health", name="CI Health", description="Check CI", tags=["ci"])], + ) + server = A2AServer(card=card) + server.register_handler("ci-health", ci_handler) + + client = TestClient(server.app) + req = { + "jsonrpc": "2.0", + "id": "h1", + "method": "SendMessage", + "params": { + "message": { + "messageId": "hm1", + "role": "ROLE_USER", + "parts": [{"text": "Check CI"}], + "metadata": {"targetSkill": "ci-health"}, + }, + "configuration": {}, + }, + } + resp = client.post("/a2a/v1", json=req) + task_data = resp.json()["result"]["task"] + assert task_data["status"]["state"] == "TASK_STATE_COMPLETED" + assert "5/5 passed" in task_data["artifacts"][0]["parts"][0]["text"] + + def test_handler_error(self): + from nexus.a2a.server import A2AServer + from nexus.a2a.types import Task, AgentCard + + async def failing_handler(task: Task, card: AgentCard) -> Task: + raise RuntimeError("Handler blew up") + + card = AgentCard(name="Fail Agent", description="Fails") + server = A2AServer(card=card) + server.set_default_handler(failing_handler) + + client = TestClient(server.app) + req = { + "jsonrpc": "2.0", + "id": "f1", + "method": "SendMessage", + "params": { + "message": { + "messageId": "fm1", + "role": "ROLE_USER", + "parts": [{"text": "break"}], + }, + "configuration": {}, + }, + } + resp = client.post("/a2a/v1", json=req) + task_data = resp.json()["result"]["task"] + assert task_data["status"]["state"] == "TASK_STATE_FAILED" + assert "blew up" in task_data["status"]["message"]["parts"][0]["text"].lower()