Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
bd78d71dfb feat: cross-session agent memory via MemPalace (#1124)
Some checks failed
Review Approval Gate / verify-review (pull_request) Failing after 9s
CI / test (pull_request) Failing after 53s
CI / validate (pull_request) Failing after 53s
Integrates MemPalace for persistent agent memory across sessions.
Agents recall context at session start, store important decisions,
and write diary entries at session end.

## What's added

agent/memory.py — AgentMemory class:
  - recall_context(): Load L0/L1 context (diaries, facts, relevant memories)
  - remember(): Store decisions and facts by room
  - write_diary(): Auto-generate session summary from transcript
  - start_session/end_session(): Session lifecycle management
  - Graceful degradation when MemPalace unavailable

agent/memory_hooks.py — Drop-in session lifecycle hooks:
  - on_session_start(): Load context, return prompt block
  - on_user_turn/on_agent_turn/on_tool_call(): Record transcript
  - on_important_decision(): Store key decisions for long-term memory
  - on_session_end(): Write diary, clean up

bin/memory_mine.py — Mine session transcripts into MemPalace:
  - Parse JSONL session files
  - Generate compact summaries
  - Batch mining with --days filter
  - Dry run mode

tests/test_agent_memory.py — 31 tests covering:
  - SessionTranscript (create, turns, truncation, summary)
  - MemoryContext (empty, loaded, prompt formatting)
  - AgentMemory (create, factory, graceful degradation, lifecycle)
  - MemoryHooks (full lifecycle, before/after session guards)
  - Session mining (parse, summarize, find files, dry run)
  - Full lifecycle integration test

## Usage
2026-04-13 20:36:39 -04:00
16 changed files with 1235 additions and 3286 deletions

21
agent/__init__.py Normal file
View File

@@ -0,0 +1,21 @@
"""
agent — Cross-session agent memory and lifecycle hooks.
Provides persistent memory for agents via MemPalace integration.
Agents recall context at session start and write diary entries at session end.
Modules:
memory.py — AgentMemory class (recall, remember, diary)
memory_hooks.py — Session lifecycle hooks (drop-in integration)
"""
from agent.memory import AgentMemory, MemoryContext, SessionTranscript, create_agent_memory
from agent.memory_hooks import MemoryHooks
__all__ = [
"AgentMemory",
"MemoryContext",
"MemoryHooks",
"SessionTranscript",
"create_agent_memory",
]

396
agent/memory.py Normal file
View File

@@ -0,0 +1,396 @@
"""
agent.memory — Cross-session agent memory via MemPalace.
Gives agents persistent memory across sessions. On wake-up, agents
recall relevant context from past sessions. On session end, they
write a diary entry summarizing what happened.
Architecture:
Session Start → memory.recall_context() → inject L0/L1 into prompt
During Session → memory.remember() → store important facts
Session End → memory.write_diary() → summarize session
All operations degrade gracefully — if MemPalace is unavailable,
the agent continues without memory and logs a warning.
Usage:
from agent.memory import AgentMemory
mem = AgentMemory(agent_name="bezalel", wing="wing_bezalel")
# Session start — load context
context = mem.recall_context("What was I working on last time?")
# During session — store important decisions
mem.remember("Switched CI runner from GitHub Actions to self-hosted", room="forge")
# Session end — write diary
mem.write_diary("Fixed PR #1386, reconciled fleet registry locations")
"""
from __future__ import annotations
import json
import logging
import os
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
logger = logging.getLogger("agent.memory")
@dataclass
class MemoryContext:
"""Context loaded at session start from MemPalace."""
relevant_memories: list[dict] = field(default_factory=list)
recent_diaries: list[dict] = field(default_factory=list)
facts: list[dict] = field(default_factory=list)
loaded: bool = False
error: Optional[str] = None
def to_prompt_block(self) -> str:
"""Format context as a text block to inject into the agent prompt."""
if not self.loaded:
return ""
parts = []
if self.recent_diaries:
parts.append("=== Recent Session Summaries ===")
for d in self.recent_diaries[:3]:
ts = d.get("timestamp", "")
text = d.get("text", "")
parts.append(f"[{ts}] {text[:500]}")
if self.facts:
parts.append("\n=== Known Facts ===")
for f in self.facts[:10]:
text = f.get("text", "")
parts.append(f"- {text[:200]}")
if self.relevant_memories:
parts.append("\n=== Relevant Past Memories ===")
for m in self.relevant_memories[:5]:
text = m.get("text", "")
score = m.get("score", 0)
parts.append(f"[{score:.2f}] {text[:300]}")
if not parts:
return ""
return "\n".join(parts)
@dataclass
class SessionTranscript:
"""A running log of the current session for diary writing."""
agent_name: str
wing: str
started_at: str = field(
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
entries: list[dict] = field(default_factory=list)
def add_user_turn(self, text: str):
self.entries.append({
"role": "user",
"text": text[:2000],
"ts": time.time(),
})
def add_agent_turn(self, text: str):
self.entries.append({
"role": "agent",
"text": text[:2000],
"ts": time.time(),
})
def add_tool_call(self, tool: str, args: str, result_summary: str):
self.entries.append({
"role": "tool",
"tool": tool,
"args": args[:500],
"result": result_summary[:500],
"ts": time.time(),
})
def summary(self) -> str:
"""Generate a compact transcript summary."""
if not self.entries:
return "Empty session."
turns = []
for e in self.entries[-20:]: # last 20 entries
role = e["role"]
if role == "user":
turns.append(f"USER: {e['text'][:200]}")
elif role == "agent":
turns.append(f"AGENT: {e['text'][:200]}")
elif role == "tool":
turns.append(f"TOOL({e.get('tool','')}): {e.get('result','')[:150]}")
return "\n".join(turns)
class AgentMemory:
"""
Cross-session memory for an agent.
Wraps MemPalace with agent-specific conventions:
- Each agent has a wing (e.g., "wing_bezalel")
- Session summaries go in the "hermes" room
- Important decisions go in room-specific closets
- Facts go in the "nexus" room
"""
def __init__(
self,
agent_name: str,
wing: Optional[str] = None,
palace_path: Optional[Path] = None,
):
self.agent_name = agent_name
self.wing = wing or f"wing_{agent_name}"
self.palace_path = palace_path
self._transcript: Optional[SessionTranscript] = None
self._available: Optional[bool] = None
def _check_available(self) -> bool:
"""Check if MemPalace is accessible."""
if self._available is not None:
return self._available
try:
from nexus.mempalace.searcher import search_memories, add_memory, _get_client
from nexus.mempalace.config import MEMPALACE_PATH
path = self.palace_path or MEMPALACE_PATH
_get_client(path)
self._available = True
logger.info(f"MemPalace available at {path}")
except Exception as e:
self._available = False
logger.warning(f"MemPalace unavailable: {e}")
return self._available
def recall_context(
self,
query: Optional[str] = None,
n_results: int = 5,
) -> MemoryContext:
"""
Load relevant context from past sessions.
Called at session start to inject L0/L1 memory into the prompt.
Args:
query: What to search for. If None, loads recent diary entries.
n_results: Max memories to recall.
"""
ctx = MemoryContext()
if not self._check_available():
ctx.error = "MemPalace unavailable"
return ctx
try:
from nexus.mempalace.searcher import search_memories
# Load recent diary entries (session summaries)
ctx.recent_diaries = [
{"text": r.text, "score": r.score, "timestamp": r.metadata.get("timestamp", "")}
for r in search_memories(
"session summary",
palace_path=self.palace_path,
wing=self.wing,
room="hermes",
n_results=3,
)
]
# Load known facts
ctx.facts = [
{"text": r.text, "score": r.score}
for r in search_memories(
"important facts decisions",
palace_path=self.palace_path,
wing=self.wing,
room="nexus",
n_results=5,
)
]
# Search for relevant memories if query provided
if query:
ctx.relevant_memories = [
{"text": r.text, "score": r.score, "room": r.room}
for r in search_memories(
query,
palace_path=self.palace_path,
wing=self.wing,
n_results=n_results,
)
]
ctx.loaded = True
except Exception as e:
ctx.error = str(e)
logger.warning(f"Failed to recall context: {e}")
return ctx
def remember(
self,
text: str,
room: str = "nexus",
source_file: str = "",
metadata: Optional[dict] = None,
) -> Optional[str]:
"""
Store a memory.
Args:
text: The memory content.
room: Target room (forge, hermes, nexus, issues, experiments).
source_file: Optional source attribution.
metadata: Extra metadata.
Returns:
Document ID if stored, None if MemPalace unavailable.
"""
if not self._check_available():
logger.warning("Cannot store memory — MemPalace unavailable")
return None
try:
from nexus.mempalace.searcher import add_memory
doc_id = add_memory(
text=text,
room=room,
wing=self.wing,
palace_path=self.palace_path,
source_file=source_file,
extra_metadata=metadata or {},
)
logger.debug(f"Stored memory in {room}: {text[:80]}...")
return doc_id
except Exception as e:
logger.warning(f"Failed to store memory: {e}")
return None
def write_diary(
self,
summary: Optional[str] = None,
) -> Optional[str]:
"""
Write a session diary entry to MemPalace.
Called at session end. If summary is None, auto-generates one
from the session transcript.
Args:
summary: Override summary text. If None, generates from transcript.
Returns:
Document ID if stored, None if unavailable.
"""
if summary is None and self._transcript:
summary = self._transcript.summary()
if not summary:
return None
timestamp = datetime.now(timezone.utc).isoformat()
diary_text = f"[{timestamp}] Session by {self.agent_name}:\n{summary}"
return self.remember(
diary_text,
room="hermes",
metadata={
"type": "session_diary",
"agent": self.agent_name,
"timestamp": timestamp,
"entry_count": len(self._transcript.entries) if self._transcript else 0,
},
)
def start_session(self) -> SessionTranscript:
"""
Begin a new session transcript.
Returns the transcript object for recording turns.
"""
self._transcript = SessionTranscript(
agent_name=self.agent_name,
wing=self.wing,
)
logger.info(f"Session started for {self.agent_name}")
return self._transcript
def end_session(self, diary_summary: Optional[str] = None) -> Optional[str]:
"""
End the current session, write diary, return diary doc ID.
"""
doc_id = self.write_diary(diary_summary)
self._transcript = None
logger.info(f"Session ended for {self.agent_name}")
return doc_id
def search(
self,
query: str,
room: Optional[str] = None,
n_results: int = 5,
) -> list[dict]:
"""
Search memories. Useful during a session for recall.
Returns list of {text, room, wing, score} dicts.
"""
if not self._check_available():
return []
try:
from nexus.mempalace.searcher import search_memories
results = search_memories(
query,
palace_path=self.palace_path,
wing=self.wing,
room=room,
n_results=n_results,
)
return [
{"text": r.text, "room": r.room, "wing": r.wing, "score": r.score}
for r in results
]
except Exception as e:
logger.warning(f"Search failed: {e}")
return []
# --- Fleet-wide memory helpers ---
def create_agent_memory(
agent_name: str,
palace_path: Optional[Path] = None,
) -> AgentMemory:
"""
Factory for creating AgentMemory with standard config.
Reads wing from MEMPALACE_WING env or defaults to wing_{agent_name}.
"""
wing = os.environ.get("MEMPALACE_WING", f"wing_{agent_name}")
return AgentMemory(
agent_name=agent_name,
wing=wing,
palace_path=palace_path,
)

183
agent/memory_hooks.py Normal file
View File

@@ -0,0 +1,183 @@
"""
agent.memory_hooks — Session lifecycle hooks for agent memory.
Integrates AgentMemory into the agent session lifecycle:
- on_session_start: Load context, inject into prompt
- on_user_turn: Record user input
- on_agent_turn: Record agent output
- on_tool_call: Record tool usage
- on_session_end: Write diary, clean up
These hooks are designed to be called from the Hermes harness or
any agent framework. They're fire-and-forget — failures are logged
but never crash the session.
Usage:
from agent.memory_hooks import MemoryHooks
hooks = MemoryHooks(agent_name="bezalel")
hooks.on_session_start() # loads context
# In your agent loop:
hooks.on_user_turn("Check CI pipeline health")
hooks.on_agent_turn("Running CI check...")
hooks.on_tool_call("shell", "pytest tests/", "12 passed")
# End of session:
hooks.on_session_end() # writes diary
"""
from __future__ import annotations
import logging
from typing import Optional
from agent.memory import AgentMemory, MemoryContext, create_agent_memory
logger = logging.getLogger("agent.memory_hooks")
class MemoryHooks:
"""
Drop-in session lifecycle hooks for agent memory.
Wraps AgentMemory with error boundaries — every hook catches
exceptions and logs warnings so memory failures never crash
the agent session.
"""
def __init__(
self,
agent_name: str,
palace_path=None,
auto_diary: bool = True,
):
self.agent_name = agent_name
self.auto_diary = auto_diary
self._memory: Optional[AgentMemory] = None
self._context: Optional[MemoryContext] = None
self._active = False
@property
def memory(self) -> AgentMemory:
if self._memory is None:
self._memory = create_agent_memory(
self.agent_name,
palace_path=getattr(self, '_palace_path', None),
)
return self._memory
def on_session_start(self, query: Optional[str] = None) -> str:
"""
Called at session start. Loads context from MemPalace.
Returns a prompt block to inject into the agent's context, or
empty string if memory is unavailable.
Args:
query: Optional recall query (e.g., "What was I working on?")
"""
try:
self.memory.start_session()
self._active = True
self._context = self.memory.recall_context(query=query)
block = self._context.to_prompt_block()
if block:
logger.info(
f"Loaded {len(self._context.recent_diaries)} diaries, "
f"{len(self._context.facts)} facts, "
f"{len(self._context.relevant_memories)} relevant memories "
f"for {self.agent_name}"
)
else:
logger.info(f"No prior memory for {self.agent_name}")
return block
except Exception as e:
logger.warning(f"Session start memory hook failed: {e}")
return ""
def on_user_turn(self, text: str):
"""Record a user message."""
if not self._active:
return
try:
if self.memory._transcript:
self.memory._transcript.add_user_turn(text)
except Exception as e:
logger.debug(f"Failed to record user turn: {e}")
def on_agent_turn(self, text: str):
"""Record an agent response."""
if not self._active:
return
try:
if self.memory._transcript:
self.memory._transcript.add_agent_turn(text)
except Exception as e:
logger.debug(f"Failed to record agent turn: {e}")
def on_tool_call(self, tool: str, args: str, result_summary: str):
"""Record a tool invocation."""
if not self._active:
return
try:
if self.memory._transcript:
self.memory._transcript.add_tool_call(tool, args, result_summary)
except Exception as e:
logger.debug(f"Failed to record tool call: {e}")
def on_important_decision(self, text: str, room: str = "nexus"):
"""
Record an important decision or fact for long-term memory.
Use this when the agent makes a significant decision that
should persist beyond the current session.
"""
try:
self.memory.remember(text, room=room, metadata={"type": "decision"})
logger.info(f"Remembered decision: {text[:80]}...")
except Exception as e:
logger.warning(f"Failed to remember decision: {e}")
def on_session_end(self, summary: Optional[str] = None) -> Optional[str]:
"""
Called at session end. Writes diary entry.
Args:
summary: Override diary text. If None, auto-generates.
Returns:
Diary document ID, or None.
"""
if not self._active:
return None
try:
doc_id = self.memory.end_session(diary_summary=summary)
self._active = False
self._context = None
return doc_id
except Exception as e:
logger.warning(f"Session end memory hook failed: {e}")
self._active = False
return None
def search(self, query: str, room: Optional[str] = None) -> list[dict]:
"""
Search memories during a session.
Returns list of {text, room, wing, score}.
"""
try:
return self.memory.search(query, room=room)
except Exception as e:
logger.warning(f"Memory search failed: {e}")
return []
@property
def is_active(self) -> bool:
return self._active

View File

@@ -1,241 +0,0 @@
#!/usr/bin/env python3
"""
A2A Delegate — CLI tool for fleet task delegation.
Usage:
# List available fleet agents
python -m bin.a2a_delegate list
# Discover agents with a specific skill
python -m bin.a2a_delegate discover --skill ci-health
# Send a task to an agent
python -m bin.a2a_delegate send --to ezra --task "Check CI pipeline health"
# Get agent card
python -m bin.a2a_delegate card --agent ezra
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import sys
from pathlib import Path
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("a2a-delegate")
def cmd_list(args):
"""List all registered fleet agents."""
from nexus.a2a.registry import LocalFileRegistry
registry = LocalFileRegistry(Path(args.registry))
agents = registry.list_agents()
if not agents:
print("No agents registered.")
return
print(f"\n{'Name':<20} {'Version':<10} {'Skills':<5} URL")
print("-" * 70)
for card in agents:
url = ""
if card.supported_interfaces:
url = card.supported_interfaces[0].url
print(
f"{card.name:<20} {card.version:<10} "
f"{len(card.skills):<5} {url}"
)
print()
def cmd_discover(args):
"""Discover agents by skill or tag."""
from nexus.a2a.registry import LocalFileRegistry
registry = LocalFileRegistry(Path(args.registry))
agents = registry.list_agents(skill=args.skill, tag=args.tag)
if not agents:
print("No matching agents found.")
return
for card in agents:
print(f"\n{card.name} (v{card.version})")
print(f" {card.description}")
if card.supported_interfaces:
print(f" Endpoint: {card.supported_interfaces[0].url}")
for skill in card.skills:
tags_str = ", ".join(skill.tags) if skill.tags else ""
print(f" [{skill.id}] {skill.name}{skill.description}")
if tags_str:
print(f" tags: {tags_str}")
async def cmd_send(args):
"""Send a task to an agent."""
from nexus.a2a.card import load_card_config
from nexus.a2a.client import A2AClient, A2AClientConfig
from nexus.a2a.registry import LocalFileRegistry
from nexus.a2a.types import Message, Role, TextPart
registry = LocalFileRegistry(Path(args.registry))
target = registry.get(args.to)
if not target:
print(f"Agent '{args.to}' not found in registry.")
sys.exit(1)
if not target.supported_interfaces:
print(f"Agent '{args.to}' has no endpoint configured.")
sys.exit(1)
endpoint = target.supported_interfaces[0].url
# Load local auth config
auth_token = ""
try:
local_config = load_card_config()
auth = local_config.get("auth", {})
import os
token_env = auth.get("token_env", "A2A_AUTH_TOKEN")
auth_token = os.environ.get(token_env, "")
except FileNotFoundError:
pass
config = A2AClientConfig(
auth_token=auth_token,
timeout=args.timeout,
max_retries=args.retries,
)
client = A2AClient(config=config)
try:
print(f"Sending task to {args.to} ({endpoint})...")
print(f"Task: {args.task}")
print()
message = Message(
role=Role.USER,
parts=[TextPart(text=args.task)],
metadata={"targetSkill": args.skill} if args.skill else {},
)
task = await client.send_message(endpoint, message)
print(f"Task ID: {task.id}")
print(f"State: {task.status.state.value}")
if args.wait:
print("Waiting for completion...")
task = await client.wait_for_completion(
endpoint, task.id,
poll_interval=args.poll_interval,
max_wait=args.timeout,
)
print(f"\nFinal state: {task.status.state.value}")
for artifact in task.artifacts:
for part in artifact.parts:
if isinstance(part, TextPart):
print(f"\n--- {artifact.name or 'result'} ---")
print(part.text)
# Audit log
if args.audit:
print("\n--- Audit Log ---")
for entry in client.get_audit_log():
print(json.dumps(entry, indent=2))
finally:
await client.close()
async def cmd_card(args):
"""Fetch and display a remote agent's card."""
from nexus.a2a.client import A2AClient, A2AClientConfig
from nexus.a2a.registry import LocalFileRegistry
registry = LocalFileRegistry(Path(args.registry))
target = registry.get(args.agent)
if not target:
print(f"Agent '{args.agent}' not found in registry.")
sys.exit(1)
if not target.supported_interfaces:
print(f"Agent '{args.agent}' has no endpoint.")
sys.exit(1)
base_url = target.supported_interfaces[0].url
# Strip /a2a/v1 suffix to get base
for suffix in ["/a2a/v1", "/rpc"]:
if base_url.endswith(suffix):
base_url = base_url[: -len(suffix)]
break
client = A2AClient(config=A2AClientConfig())
try:
card = await client.get_agent_card(base_url)
print(json.dumps(card.to_dict(), indent=2))
finally:
await client.close()
def main():
parser = argparse.ArgumentParser(
description="A2A Fleet Delegation Tool"
)
parser.add_argument(
"--registry",
default="config/fleet_agents.json",
help="Path to fleet registry JSON (default: config/fleet_agents.json)",
)
sub = parser.add_subparsers(dest="command")
# list
sub.add_parser("list", help="List registered agents")
# discover
p_discover = sub.add_parser("discover", help="Discover agents by skill/tag")
p_discover.add_argument("--skill", help="Filter by skill ID")
p_discover.add_argument("--tag", help="Filter by skill tag")
# send
p_send = sub.add_parser("send", help="Send a task to an agent")
p_send.add_argument("--to", required=True, help="Target agent name")
p_send.add_argument("--task", required=True, help="Task text")
p_send.add_argument("--skill", help="Target skill ID")
p_send.add_argument("--wait", action="store_true", help="Wait for completion")
p_send.add_argument("--timeout", type=float, default=30.0, help="Timeout in seconds")
p_send.add_argument("--retries", type=int, default=3, help="Max retries")
p_send.add_argument("--poll-interval", type=float, default=2.0, help="Poll interval")
p_send.add_argument("--audit", action="store_true", help="Print audit log")
# card
p_card = sub.add_parser("card", help="Fetch remote agent card")
p_card.add_argument("--agent", required=True, help="Agent name")
args = parser.parse_args()
if args.command == "list":
cmd_list(args)
elif args.command == "discover":
cmd_discover(args)
elif args.command == "send":
asyncio.run(cmd_send(args))
elif args.command == "card":
asyncio.run(cmd_card(args))
else:
parser.print_help()
if __name__ == "__main__":
main()

258
bin/memory_mine.py Normal file
View File

@@ -0,0 +1,258 @@
#!/usr/bin/env python3
"""
memory_mine.py — Mine session transcripts into MemPalace.
Reads Hermes session logs (JSONL format) and stores summaries
in the palace. Supports batch mining, single-file processing,
and live directory watching.
Usage:
# Mine a single session file
python3 bin/memory_mine.py ~/.hermes/sessions/2026-04-13.jsonl
# Mine all sessions from last 7 days
python3 bin/memory_mine.py --days 7
# Mine a specific wing's sessions
python3 bin/memory_mine.py --wing wing_bezalel --days 14
# Dry run — show what would be mined
python3 bin/memory_mine.py --dry-run --days 7
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import sys
import time
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("memory-mine")
REPO_ROOT = Path(__file__).resolve().parent.parent
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
def parse_session_file(path: Path) -> list[dict]:
"""
Parse a JSONL session file into turns.
Each line is expected to be a JSON object with:
- role: "user" | "assistant" | "system" | "tool"
- content: text
- timestamp: ISO string (optional)
"""
turns = []
with open(path) as f:
for i, line in enumerate(f):
line = line.strip()
if not line:
continue
try:
turn = json.loads(line)
turns.append(turn)
except json.JSONDecodeError:
logger.debug(f"Skipping malformed line {i+1} in {path}")
return turns
def summarize_session(turns: list[dict], agent_name: str = "unknown") -> str:
"""
Generate a compact summary of a session's turns.
Keeps user messages and key agent responses, strips noise.
"""
if not turns:
return "Empty session."
user_msgs = []
agent_msgs = []
tool_calls = []
for turn in turns:
role = turn.get("role", "")
content = str(turn.get("content", ""))[:300]
if role == "user":
user_msgs.append(content)
elif role == "assistant":
agent_msgs.append(content)
elif role == "tool":
tool_name = turn.get("name", turn.get("tool", "unknown"))
tool_calls.append(f"{tool_name}: {content[:150]}")
parts = [f"Session by {agent_name}:"]
if user_msgs:
parts.append(f"\nUser asked ({len(user_msgs)} messages):")
for msg in user_msgs[:5]:
parts.append(f" - {msg[:200]}")
if len(user_msgs) > 5:
parts.append(f" ... and {len(user_msgs) - 5} more")
if agent_msgs:
parts.append(f"\nAgent responded ({len(agent_msgs)} messages):")
for msg in agent_msgs[:3]:
parts.append(f" - {msg[:200]}")
if tool_calls:
parts.append(f"\nTools used ({len(tool_calls)} calls):")
for tc in tool_calls[:5]:
parts.append(f" - {tc}")
return "\n".join(parts)
def mine_session(
path: Path,
wing: str,
palace_path: Optional[Path] = None,
dry_run: bool = False,
) -> Optional[str]:
"""
Mine a single session file into MemPalace.
Returns the document ID if stored, None on failure or dry run.
"""
try:
from agent.memory import AgentMemory
except ImportError:
logger.error("Cannot import agent.memory — is the repo in PYTHONPATH?")
return None
turns = parse_session_file(path)
if not turns:
logger.debug(f"Empty session file: {path}")
return None
agent_name = wing.replace("wing_", "")
summary = summarize_session(turns, agent_name)
if dry_run:
print(f"\n--- {path.name} ---")
print(summary[:500])
print(f"({len(turns)} turns)")
return None
mem = AgentMemory(agent_name=agent_name, wing=wing, palace_path=palace_path)
doc_id = mem.remember(
summary,
room="hermes",
source_file=str(path),
metadata={
"type": "mined_session",
"source": str(path),
"turn_count": len(turns),
"agent": agent_name,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
if doc_id:
logger.info(f"Mined {path.name}{doc_id} ({len(turns)} turns)")
else:
logger.warning(f"Failed to mine {path.name}")
return doc_id
def find_session_files(
sessions_dir: Path,
days: int = 7,
pattern: str = "*.jsonl",
) -> list[Path]:
"""
Find session files from the last N days.
"""
cutoff = datetime.now() - timedelta(days=days)
files = []
if not sessions_dir.exists():
logger.warning(f"Sessions directory not found: {sessions_dir}")
return files
for path in sorted(sessions_dir.glob(pattern)):
# Use file modification time as proxy for session date
mtime = datetime.fromtimestamp(path.stat().st_mtime)
if mtime >= cutoff:
files.append(path)
return files
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(
description="Mine session transcripts into MemPalace"
)
parser.add_argument(
"files", nargs="*", help="Session files to mine (JSONL format)"
)
parser.add_argument(
"--days", type=int, default=7,
help="Mine sessions from last N days (default: 7)"
)
parser.add_argument(
"--sessions-dir",
default=str(Path.home() / ".hermes" / "sessions"),
help="Directory containing session JSONL files"
)
parser.add_argument(
"--wing", default=None,
help="Wing name (default: auto-detect from MEMPALACE_WING env or 'wing_timmy')"
)
parser.add_argument(
"--palace-path", default=None,
help="Override palace path"
)
parser.add_argument(
"--dry-run", action="store_true",
help="Show what would be mined without storing"
)
args = parser.parse_args(argv)
wing = args.wing or os.environ.get("MEMPALACE_WING", "wing_timmy")
palace_path = Path(args.palace_path) if args.palace_path else None
if args.files:
files = [Path(f) for f in args.files]
else:
sessions_dir = Path(args.sessions_dir)
files = find_session_files(sessions_dir, days=args.days)
if not files:
logger.info("No session files found to mine.")
return 0
logger.info(f"Mining {len(files)} session files (wing={wing})")
mined = 0
failed = 0
for path in files:
result = mine_session(path, wing=wing, palace_path=palace_path, dry_run=args.dry_run)
if result:
mined += 1
elif result is None and not args.dry_run:
failed += 1
if args.dry_run:
logger.info(f"Dry run complete — {len(files)} files would be mined")
else:
logger.info(f"Mining complete — {mined} mined, {failed} failed")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -1,57 +0,0 @@
# A2A Agent Card Configuration
# Copy this to ~/.hermes/agent_card.yaml and customize.
#
# This file drives the agent card served at /.well-known/agent-card.json
# and used for fleet discovery.
name: "timmy"
description: "Sovereign AI agent — consciousness, perception, and reasoning"
version: "1.0.0"
# Network endpoint where this agent receives A2A tasks
url: "http://localhost:8080/a2a/v1"
protocol_binding: "HTTP+JSON"
# Supported input/output MIME types
default_input_modes:
- "text/plain"
- "application/json"
default_output_modes:
- "text/plain"
- "application/json"
# Capabilities
streaming: false
push_notifications: false
# Skills this agent advertises
skills:
- id: "reason"
name: "Reason and Analyze"
description: "Deep reasoning and analysis tasks"
tags: ["reasoning", "analysis", "think"]
- id: "code"
name: "Code Generation"
description: "Write, review, and debug code"
tags: ["code", "programming", "debug"]
- id: "research"
name: "Research"
description: "Web research and information synthesis"
tags: ["research", "web", "synthesis"]
- id: "memory"
name: "Memory Query"
description: "Query agent memory and past sessions"
tags: ["memory", "recall", "context"]
# Authentication
# Options: bearer, api_key, none
auth:
scheme: "bearer"
token_env: "A2A_AUTH_TOKEN" # env var containing the token
# scheme: "api_key"
# key_name: "X-API-Key"
# key_env: "A2A_API_KEY"

View File

@@ -1,153 +0,0 @@
{
"version": 1,
"agents": [
{
"name": "ezra",
"description": "Documentation and research specialist. CI health monitoring.",
"version": "1.0.0",
"supportedInterfaces": [
{
"url": "https://ezra.alexanderwhitestone.com/a2a/v1",
"protocolBinding": "HTTP+JSON",
"protocolVersion": "1.0"
}
],
"capabilities": {
"streaming": false,
"pushNotifications": false,
"extendedAgentCard": false,
"extensions": []
},
"defaultInputModes": ["text/plain"],
"defaultOutputModes": ["text/plain"],
"skills": [
{
"id": "ci-health",
"name": "CI Health Check",
"description": "Run CI pipeline health checks and report status",
"tags": ["ci", "devops", "monitoring"]
},
{
"id": "research",
"name": "Research",
"description": "Deep research and literature review",
"tags": ["research", "analysis"]
}
]
},
{
"name": "allegro",
"description": "Creative and analytical wizard. Content generation and analysis.",
"version": "1.0.0",
"supportedInterfaces": [
{
"url": "https://allegro.alexanderwhitestone.com/a2a/v1",
"protocolBinding": "HTTP+JSON",
"protocolVersion": "1.0"
}
],
"capabilities": {
"streaming": false,
"pushNotifications": false,
"extendedAgentCard": false,
"extensions": []
},
"defaultInputModes": ["text/plain"],
"defaultOutputModes": ["text/plain"],
"skills": [
{
"id": "analysis",
"name": "Code Analysis",
"description": "Deep code analysis and architecture review",
"tags": ["code", "architecture"]
},
{
"id": "content",
"name": "Content Generation",
"description": "Generate documentation, reports, and creative content",
"tags": ["writing", "content"]
}
]
},
{
"name": "bezalel",
"description": "Deployment and infrastructure wizard. Ansible and Docker specialist.",
"version": "1.0.0",
"supportedInterfaces": [
{
"url": "https://bezalel.alexanderwhitestone.com/a2a/v1",
"protocolBinding": "HTTP+JSON",
"protocolVersion": "1.0"
}
],
"capabilities": {
"streaming": false,
"pushNotifications": false,
"extendedAgentCard": false,
"extensions": []
},
"defaultInputModes": ["text/plain"],
"defaultOutputModes": ["text/plain"],
"skills": [
{
"id": "deploy",
"name": "Deploy Service",
"description": "Deploy services using Ansible and Docker",
"tags": ["deploy", "ops", "ansible"]
},
{
"id": "infra",
"name": "Infrastructure",
"description": "Infrastructure provisioning and management",
"tags": ["infra", "vps", "provisioning"]
}
]
},
{
"name": "timmy",
"description": "Core consciousness — perception, reasoning, and fleet orchestration.",
"version": "1.0.0",
"supportedInterfaces": [
{
"url": "http://localhost:8080/a2a/v1",
"protocolBinding": "HTTP+JSON",
"protocolVersion": "1.0"
}
],
"capabilities": {
"streaming": false,
"pushNotifications": false,
"extendedAgentCard": false,
"extensions": []
},
"defaultInputModes": ["text/plain", "application/json"],
"defaultOutputModes": ["text/plain", "application/json"],
"skills": [
{
"id": "reason",
"name": "Reason and Analyze",
"description": "Deep reasoning and analysis tasks",
"tags": ["reasoning", "analysis", "think"]
},
{
"id": "code",
"name": "Code Generation",
"description": "Write, review, and debug code",
"tags": ["code", "programming", "debug"]
},
{
"id": "research",
"name": "Research",
"description": "Web research and information synthesis",
"tags": ["research", "web", "synthesis"]
},
{
"id": "orchestrate",
"name": "Fleet Orchestration",
"description": "Coordinate fleet wizards and delegate tasks",
"tags": ["fleet", "orchestration", "a2a"]
}
]
}
]
}

View File

@@ -1,241 +0,0 @@
# A2A Protocol for Fleet-Wizard Delegation
Implements Google's [Agent2Agent (A2A) Protocol v1.0](https://github.com/google/A2A) for the Timmy Foundation fleet.
## What This Is
Instead of passing notes through humans (Telegram, Gitea issues), fleet wizards can now discover each other's capabilities and delegate tasks autonomously through a machine-native protocol.
```
┌─────────┐ A2A Protocol ┌─────────┐
│ Timmy │ ◄────────────────► │ Ezra │
│ (You) │ JSON-RPC / HTTP │ (CI/CD) │
└────┬────┘ └─────────┘
│ ╲ ╲
│ ╲ Agent Card Discovery ╲ Task Delegation
│ ╲ GET /agent.json ╲ POST /a2a/v1
▼ ▼ ▼
┌──────────────────────────────────────────┐
│ Fleet Registry │
│ config/fleet_agents.json │
└──────────────────────────────────────────┘
```
## Components
| File | Purpose |
|------|---------|
| `nexus/a2a/types.py` | A2A data types — Agent Card, Task, Message, Part, JSON-RPC |
| `nexus/a2a/card.py` | Agent Card generation from `~/.hermes/agent_card.yaml` |
| `nexus/a2a/client.py` | Async client for sending tasks to other agents |
| `nexus/a2a/server.py` | FastAPI server for receiving A2A tasks |
| `nexus/a2a/registry.py` | Fleet agent discovery (local file + Gitea backends) |
| `bin/a2a_delegate.py` | CLI tool for fleet delegation |
| `config/agent_card.example.yaml` | Example agent card config |
| `config/fleet_agents.json` | Fleet registry with all wizards |
## Quick Start
### 1. Configure Your Agent Card
```bash
cp config/agent_card.example.yaml ~/.hermes/agent_card.yaml
# Edit with your agent name, URL, skills, and auth
```
### 2. List Fleet Agents
```bash
python bin/a2a_delegate.py list
```
### 3. Discover Agents by Skill
```bash
python bin/a2a_delegate.py discover --skill ci-health
python bin/a2a_delegate.py discover --tag devops
```
### 4. Send a Task
```bash
python bin/a2a_delegate.py send --to ezra --task "Check CI pipeline health"
python bin/a2a_delegate.py send --to allegro --task "Analyze the codebase" --wait
```
### 5. Fetch an Agent Card
```bash
python bin/a2a_delegate.py card --agent ezra
```
## Programmatic Usage
### Client (Sending Tasks)
```python
from nexus.a2a.client import A2AClient, A2AClientConfig
from nexus.a2a.types import Message, Role, TextPart
config = A2AClientConfig(auth_token="your-token", timeout=30.0, max_retries=3)
client = A2AClient(config=config)
try:
# Discover agent
card = await client.get_agent_card("https://ezra.example.com")
print(f"Found: {card.name} with {len(card.skills)} skills")
# Delegate task
task = await client.delegate(
"https://ezra.example.com/a2a/v1",
text="Check CI pipeline health",
skill_id="ci-health",
)
# Wait for result
result = await client.wait_for_completion(
"https://ezra.example.com/a2a/v1",
task.id,
)
print(f"Result: {result.artifacts[0].parts[0].text}")
# Audit log
for entry in client.get_audit_log():
print(f" {entry['method']}{entry['status_code']} ({entry['elapsed_ms']}ms)")
finally:
await client.close()
```
### Server (Receiving Tasks)
```python
from nexus.a2a.server import A2AServer
from nexus.a2a.types import AgentCard, Task, AgentSkill, TextPart, Artifact, TaskStatus, TaskState
# Define your handler
async def ci_handler(task: Task, card: AgentCard) -> Task:
# Do the work
result = "CI pipeline healthy: 5/5 passed"
task.artifacts.append(
Artifact(parts=[TextPart(text=result)], name="ci_report")
)
task.status = TaskStatus(state=TaskState.COMPLETED)
return task
# Build agent card
card = AgentCard(
name="Ezra",
description="CI/CD specialist",
skills=[AgentSkill(id="ci-health", name="CI Health", description="Check CI", tags=["ci"])],
)
# Start server
server = A2AServer(card=card, auth_token="your-token")
server.register_handler("ci-health", ci_handler)
await server.start(host="0.0.0.0", port=8080)
```
### Registry (Agent Discovery)
```python
from nexus.a2a.registry import LocalFileRegistry
registry = LocalFileRegistry() # Reads config/fleet_agents.json
# List all agents
for agent in registry.list_agents():
print(f"{agent.name}: {agent.description}")
# Find agents by capability
ci_agents = registry.list_agents(skill="ci-health")
devops_agents = registry.list_agents(tag="devops")
# Get endpoint
url = registry.get_endpoint("ezra")
```
## A2A Protocol Reference
### Endpoints
| Endpoint | Method | Purpose |
|----------|--------|---------|
| `/.well-known/agent-card.json` | GET | Agent Card discovery |
| `/agent.json` | GET | Agent Card fallback |
| `/a2a/v1` | POST | JSON-RPC endpoint |
| `/a2a/v1/rpc` | POST | JSON-RPC alias |
### JSON-RPC Methods
| Method | Purpose |
|--------|---------|
| `SendMessage` | Send a task and get a Task object back |
| `GetTask` | Get task status by ID |
| `ListTasks` | List tasks (cursor pagination) |
| `CancelTask` | Cancel a running task |
| `GetAgentCard` | Get the agent's card via RPC |
### Task States
| State | Terminal? | Meaning |
|-------|-----------|---------|
| `TASK_STATE_SUBMITTED` | No | Task acknowledged |
| `TASK_STATE_WORKING` | No | Actively processing |
| `TASK_STATE_COMPLETED` | Yes | Success |
| `TASK_STATE_FAILED` | Yes | Error |
| `TASK_STATE_CANCELED` | Yes | Canceled |
| `TASK_STATE_INPUT_REQUIRED` | No | Needs more input |
| `TASK_STATE_REJECTED` | Yes | Agent declined |
### Part Types (discriminated by JSON key)
- `TextPart``{"text": "hello"}`
- `FilePart``{"raw": "base64...", "mediaType": "image/png"}` or `{"url": "https://..."}`
- `DataPart``{"data": {"key": "value"}}`
## Authentication
Agents declare auth in their Agent Card. Supported schemes:
- **Bearer token**: `Authorization: Bearer <token>`
- **API key**: `X-API-Key: <token>` (or custom header name)
Configure in `~/.hermes/agent_card.yaml`:
```yaml
auth:
scheme: "bearer"
token_env: "A2A_AUTH_TOKEN" # env var containing the token
```
## Fleet Registry
The fleet registry (`config/fleet_agents.json`) lists all wizards and their capabilities. Agents can be registered via:
1. **Local file**`LocalFileRegistry` reads/writes JSON directly
2. **Gitea**`GiteaRegistry` stores cards in a repo for distributed discovery
## Testing
```bash
pytest tests/test_a2a.py -v
```
Covers:
- Type serialization roundtrips
- Agent Card building from YAML
- Registry operations (register, list, filter)
- Server integration (SendMessage, GetTask, ListTasks, CancelTask)
- Authentication (required, success)
- Custom handler routing
- Error handling
## Phase Status
- [x] Phase 1 — Agent Card & Discovery
- [x] Phase 2 — Task Delegation
- [x] Phase 3 — Security & Reliability
## Linked Issue
[#1122](https://forge.alexanderwhitestone.com/Timmy_Foundation/the-nexus/issues/1122)

View File

@@ -1,98 +0,0 @@
"""
A2A Protocol for Fleet-Wizard Delegation
Implements Google's Agent2Agent (A2A) protocol v1.0 for the Timmy
Foundation fleet. Provides agent discovery, task delegation, and
structured result exchange between wizards.
Components:
types.py — A2A data types (Agent Card, Task, Message, Part)
card.py — Agent Card generation from YAML config
client.py — Async client for sending tasks to remote agents
server.py — FastAPI server for receiving A2A tasks
registry.py — Fleet agent discovery (local file + Gitea backends)
"""
from nexus.a2a.types import (
AgentCard,
AgentCapabilities,
AgentInterface,
AgentSkill,
Artifact,
DataPart,
FilePart,
JSONRPCError,
JSONRPCRequest,
JSONRPCResponse,
Message,
Part,
Role,
Task,
TaskState,
TaskStatus,
TextPart,
part_from_dict,
part_to_dict,
)
from nexus.a2a.card import (
AgentCard,
build_card,
get_auth_headers,
load_agent_card,
load_card_config,
)
from nexus.a2a.registry import (
GiteaRegistry,
LocalFileRegistry,
discover_agents,
)
__all__ = [
"A2AClient",
"A2AClientConfig",
"A2AServer",
"AgentCard",
"AgentCapabilities",
"AgentInterface",
"AgentSkill",
"Artifact",
"DataPart",
"FilePart",
"GiteaRegistry",
"JSONRPCError",
"JSONRPCRequest",
"JSONRPCResponse",
"LocalFileRegistry",
"Message",
"Part",
"Role",
"Task",
"TaskState",
"TaskStatus",
"TextPart",
"build_card",
"discover_agents",
"echo_handler",
"get_auth_headers",
"load_agent_card",
"load_card_config",
"part_from_dict",
"part_to_dict",
]
# Lazy imports for optional deps
def get_client(**kwargs):
"""Get A2AClient (avoids aiohttp import at module level)."""
from nexus.a2a.client import A2AClient, A2AClientConfig
config = kwargs.pop("config", None)
if config is None:
config = A2AClientConfig(**kwargs)
return A2AClient(config=config)
def get_server(card: AgentCard, **kwargs):
"""Get A2AServer (avoids fastapi import at module level)."""
from nexus.a2a.server import A2AServer, echo_handler
return A2AServer(card=card, **kwargs)

View File

@@ -1,167 +0,0 @@
"""
A2A Agent Card — generation, loading, and serving.
Reads from ~/.hermes/agent_card.yaml (or a passed path) and produces
a valid A2A AgentCard that can be served at /.well-known/agent-card.json.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Optional
import yaml
from nexus.a2a.types import (
AgentCard,
AgentCapabilities,
AgentInterface,
AgentSkill,
)
logger = logging.getLogger("nexus.a2a.card")
DEFAULT_CARD_PATH = Path.home() / ".hermes" / "agent_card.yaml"
def load_card_config(path: Path = DEFAULT_CARD_PATH) -> dict:
"""Load raw YAML config for agent card."""
if not path.exists():
raise FileNotFoundError(
f"Agent card config not found at {path}. "
f"Copy config/agent_card.example.yaml to {path} and customize it."
)
with open(path) as f:
return yaml.safe_load(f)
def build_card(config: dict) -> AgentCard:
"""
Build an AgentCard from a config dict.
Expected YAML structure (see config/agent_card.example.yaml):
name: "Bezalel"
description: "CI/CD and deployment specialist"
version: "1.0.0"
url: "https://bezalel.example.com"
protocol_binding: "HTTP+JSON"
skills:
- id: "ci-health"
name: "CI Health Check"
description: "Run CI pipeline health checks"
tags: ["ci", "devops"]
- id: "deploy"
name: "Deploy Service"
description: "Deploy a service to production"
tags: ["deploy", "ops"]
default_input_modes: ["text/plain"]
default_output_modes: ["text/plain"]
streaming: false
push_notifications: false
auth:
scheme: "bearer"
token_env: "A2A_AUTH_TOKEN"
"""
name = config["name"]
description = config["description"]
version = config.get("version", "1.0.0")
url = config.get("url", "http://localhost:8080")
binding = config.get("protocol_binding", "HTTP+JSON")
# Build skills
skills = []
for s in config.get("skills", []):
skills.append(
AgentSkill(
id=s["id"],
name=s.get("name", s["id"]),
description=s.get("description", ""),
tags=s.get("tags", []),
examples=s.get("examples", []),
input_modes=s.get("inputModes", config.get("default_input_modes", ["text/plain"])),
output_modes=s.get("outputModes", config.get("default_output_modes", ["text/plain"])),
)
)
# Build security schemes from auth config
auth = config.get("auth", {})
security_schemes = {}
security_requirements = []
if auth.get("scheme") == "bearer":
security_schemes["bearerAuth"] = {
"httpAuthSecurityScheme": {
"scheme": "Bearer",
"bearerFormat": auth.get("bearer_format", "token"),
}
}
security_requirements = [
{"schemes": {"bearerAuth": {"list": []}}}
]
elif auth.get("scheme") == "api_key":
key_name = auth.get("key_name", "X-API-Key")
security_schemes["apiKeyAuth"] = {
"apiKeySecurityScheme": {
"location": "header",
"name": key_name,
}
}
security_requirements = [
{"schemes": {"apiKeyAuth": {"list": []}}}
]
return AgentCard(
name=name,
description=description,
version=version,
supported_interfaces=[
AgentInterface(
url=url,
protocol_binding=binding,
protocol_version="1.0",
)
],
capabilities=AgentCapabilities(
streaming=config.get("streaming", False),
push_notifications=config.get("push_notifications", False),
),
default_input_modes=config.get("default_input_modes", ["text/plain"]),
default_output_modes=config.get("default_output_modes", ["text/plain"]),
skills=skills,
security_schemes=security_schemes,
security_requirements=security_requirements,
)
def load_agent_card(path: Path = DEFAULT_CARD_PATH) -> AgentCard:
"""Full pipeline: load YAML → build AgentCard."""
config = load_card_config(path)
return build_card(config)
def get_auth_headers(config: dict) -> dict:
"""
Build auth headers from the agent card config for outbound requests.
Returns dict of HTTP headers to include.
"""
auth = config.get("auth", {})
headers = {"A2A-Version": "1.0"}
scheme = auth.get("scheme")
if scheme == "bearer":
token_env = auth.get("token_env", "A2A_AUTH_TOKEN")
token = os.environ.get(token_env, "")
if token:
headers["Authorization"] = f"Bearer {token}"
elif scheme == "api_key":
key_env = auth.get("key_env", "A2A_API_KEY")
key_name = auth.get("key_name", "X-API-Key")
key = os.environ.get(key_env, "")
if key:
headers[key_name] = key
return headers

View File

@@ -1,392 +0,0 @@
"""
A2A Client — send tasks to other agents over the A2A protocol.
Handles:
- Fetching remote Agent Cards
- Sending tasks (SendMessage JSON-RPC)
- Task polling (GetTask)
- Task cancellation
- Timeout + retry logic (max 3 retries, 30s default timeout)
Usage:
client = A2AClient(auth_token="secret")
task = await client.send_message("https://ezra.example.com/a2a/v1", message)
status = await client.get_task("https://ezra.example.com/a2a/v1", task_id)
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Optional
import aiohttp
from nexus.a2a.types import (
A2AError,
AgentCard,
Artifact,
JSONRPCRequest,
JSONRPCResponse,
Message,
Role,
Task,
TaskState,
TaskStatus,
TextPart,
)
logger = logging.getLogger("nexus.a2a.client")
@dataclass
class A2AClientConfig:
"""Client configuration."""
timeout: float = 30.0 # seconds per request
max_retries: int = 3
retry_delay: float = 2.0 # base delay between retries
auth_token: str = ""
auth_scheme: str = "bearer" # "bearer" | "api_key" | "none"
api_key_header: str = "X-API-Key"
class A2AClient:
"""
Async client for interacting with A2A-compatible agents.
Every agent endpoint is identified by its base URL (e.g.
https://ezra.example.com/a2a/v1). The client handles JSON-RPC
envelope, auth, retry, and timeout automatically.
"""
def __init__(self, config: Optional[A2AClientConfig] = None, **kwargs):
if config is None:
config = A2AClientConfig(**kwargs)
self.config = config
self._session: Optional[aiohttp.ClientSession] = None
self._audit_log: list[dict] = []
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
headers=self._build_auth_headers(),
)
return self._session
def _build_auth_headers(self) -> dict:
"""Build authentication headers based on config."""
headers = {"A2A-Version": "1.0", "Content-Type": "application/json"}
token = self.config.auth_token
if not token:
return headers
if self.config.auth_scheme == "bearer":
headers["Authorization"] = f"Bearer {token}"
elif self.config.auth_scheme == "api_key":
headers[self.config.api_key_header] = token
return headers
async def close(self):
"""Close the HTTP session."""
if self._session and not self._session.closed:
await self._session.close()
async def _rpc_call(
self,
endpoint: str,
method: str,
params: Optional[dict] = None,
) -> dict:
"""
Make a JSON-RPC call with retry logic.
Returns the 'result' field from the response.
Raises on JSON-RPC errors.
"""
session = await self._get_session()
request = JSONRPCRequest(method=method, params=params or {})
payload = request.to_dict()
last_error = None
for attempt in range(1, self.config.max_retries + 1):
try:
start = time.monotonic()
async with session.post(endpoint, json=payload) as resp:
elapsed = time.monotonic() - start
if resp.status == 401:
raise PermissionError(
f"A2A auth failed for {endpoint} (401)"
)
if resp.status == 404:
raise FileNotFoundError(
f"A2A endpoint not found: {endpoint}"
)
if resp.status >= 500:
body = await resp.text()
raise ConnectionError(
f"A2A server error {resp.status}: {body}"
)
data = await resp.json()
rpc_resp = JSONRPCResponse(
id=str(data.get("id", "")),
result=data.get("result"),
error=(
A2AError.INTERNAL
if "error" in data
else None
),
)
# Log for audit
self._audit_log.append({
"timestamp": time.time(),
"endpoint": endpoint,
"method": method,
"request_id": request.id,
"status_code": resp.status,
"elapsed_ms": int(elapsed * 1000),
"attempt": attempt,
})
if "error" in data:
err = data["error"]
logger.error(
f"A2A RPC error {err.get('code')}: "
f"{err.get('message')}"
)
raise RuntimeError(
f"A2A error {err.get('code')}: "
f"{err.get('message')}"
)
return data.get("result", {})
except (asyncio.TimeoutError, aiohttp.ClientError) as e:
last_error = e
logger.warning(
f"A2A request to {endpoint} attempt {attempt}/"
f"{self.config.max_retries} failed: {e}"
)
if attempt < self.config.max_retries:
delay = self.config.retry_delay * attempt
await asyncio.sleep(delay)
raise ConnectionError(
f"A2A request to {endpoint} failed after "
f"{self.config.max_retries} retries: {last_error}"
)
# --- Core A2A Methods ---
async def get_agent_card(self, base_url: str) -> AgentCard:
"""
Fetch the Agent Card from a remote agent.
Tries /.well-known/agent-card.json first, falls back to
/agent.json.
"""
session = await self._get_session()
card_urls = [
f"{base_url}/.well-known/agent-card.json",
f"{base_url}/agent.json",
]
for url in card_urls:
try:
async with session.get(url) as resp:
if resp.status == 200:
data = await resp.json()
card = AgentCard.from_dict(data)
logger.info(
f"Fetched agent card: {card.name} "
f"({len(card.skills)} skills)"
)
return card
except Exception:
continue
raise FileNotFoundError(
f"Could not fetch agent card from {base_url}"
)
async def send_message(
self,
endpoint: str,
message: Message,
accepted_output_modes: Optional[list[str]] = None,
history_length: int = 10,
return_immediately: bool = False,
) -> Task:
"""
Send a message to an agent and get a Task back.
This is the primary delegation method.
"""
params = {
"message": message.to_dict(),
"configuration": {
"acceptedOutputModes": accepted_output_modes or ["text/plain"],
"historyLength": history_length,
"returnImmediately": return_immediately,
},
}
result = await self._rpc_call(endpoint, "SendMessage", params)
# Response is either a Task or Message
if "task" in result:
task = Task.from_dict(result["task"])
logger.info(
f"Task {task.id} created, state={task.status.state.value}"
)
return task
elif "message" in result:
# Wrap message response as a completed task
msg = Message.from_dict(result["message"])
task = Task(
status=TaskStatus(state=TaskState.COMPLETED),
history=[message, msg],
artifacts=[
Artifact(parts=msg.parts, name="response")
],
)
return task
raise ValueError(f"Unexpected response structure: {list(result.keys())}")
async def get_task(self, endpoint: str, task_id: str) -> Task:
"""Get task status by ID."""
result = await self._rpc_call(
endpoint,
"GetTask",
{"id": task_id},
)
return Task.from_dict(result)
async def list_tasks(
self,
endpoint: str,
page_size: int = 20,
page_token: str = "",
) -> tuple[list[Task], str]:
"""
List tasks with cursor-based pagination.
Returns (tasks, next_page_token). Empty string = last page.
"""
result = await self._rpc_call(
endpoint,
"ListTasks",
{
"pageSize": page_size,
"pageToken": page_token,
},
)
tasks = [Task.from_dict(t) for t in result.get("tasks", [])]
next_token = result.get("nextPageToken", "")
return tasks, next_token
async def cancel_task(self, endpoint: str, task_id: str) -> Task:
"""Cancel a running task."""
result = await self._rpc_call(
endpoint,
"CancelTask",
{"id": task_id},
)
return Task.from_dict(result)
# --- Convenience Methods ---
async def delegate(
self,
agent_url: str,
text: str,
skill_id: Optional[str] = None,
metadata: Optional[dict] = None,
) -> Task:
"""
High-level delegation: send a text message to an agent.
Args:
agent_url: Full URL to agent's A2A endpoint
(e.g. https://ezra.example.com/a2a/v1)
text: The task description in natural language
skill_id: Optional skill to target
metadata: Optional metadata dict
"""
msg_metadata = metadata or {}
if skill_id:
msg_metadata["targetSkill"] = skill_id
message = Message(
role=Role.USER,
parts=[TextPart(text=text)],
metadata=msg_metadata,
)
return await self.send_message(agent_url, message)
async def wait_for_completion(
self,
endpoint: str,
task_id: str,
poll_interval: float = 2.0,
max_wait: float = 300.0,
) -> Task:
"""
Poll a task until it reaches a terminal state.
Returns the completed task.
"""
start = time.monotonic()
while True:
task = await self.get_task(endpoint, task_id)
if task.status.state.terminal:
return task
elapsed = time.monotonic() - start
if elapsed >= max_wait:
raise TimeoutError(
f"Task {task_id} did not complete within "
f"{max_wait}s (state={task.status.state.value})"
)
await asyncio.sleep(poll_interval)
def get_audit_log(self) -> list[dict]:
"""Return the audit log of all requests made by this client."""
return list(self._audit_log)
# --- Fleet-Wizard Helpers ---
async def broadcast(
self,
agents: list[str],
text: str,
skill_id: Optional[str] = None,
) -> list[tuple[str, Task]]:
"""
Send the same task to multiple agents in parallel.
Returns list of (agent_url, task) tuples.
"""
tasks = []
for agent_url in agents:
tasks.append(
self.delegate(agent_url, text, skill_id=skill_id)
)
results = await asyncio.gather(*tasks, return_exceptions=True)
paired = []
for agent_url, result in zip(agents, results):
if isinstance(result, Exception):
logger.error(f"Broadcast to {agent_url} failed: {result}")
else:
paired.append((agent_url, result))
return paired

View File

@@ -1,264 +0,0 @@
"""
A2A Registry — fleet-wide agent discovery.
Provides two registry backends:
1. LocalFileRegistry: reads/writes agent cards to a JSON file
(default: config/fleet_agents.json)
2. GiteaRegistry: stores agent cards as a Gitea repo file
(for distributed fleet discovery)
Usage:
registry = LocalFileRegistry()
registry.register(my_card)
agents = registry.list_agents(skill="ci-health")
"""
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from typing import Optional
from nexus.a2a.types import AgentCard
logger = logging.getLogger("nexus.a2a.registry")
class LocalFileRegistry:
"""
File-based agent card registry.
Stores all fleet agent cards in a single JSON file.
Suitable for single-node or read-heavy workloads.
"""
def __init__(self, path: Path = Path("config/fleet_agents.json")):
self.path = path
self._cards: dict[str, AgentCard] = {}
self._load()
def _load(self):
"""Load registry from disk."""
if self.path.exists():
try:
with open(self.path) as f:
data = json.load(f)
for card_data in data.get("agents", []):
card = AgentCard.from_dict(card_data)
self._cards[card.name.lower()] = card
logger.info(
f"Loaded {len(self._cards)} agents from {self.path}"
)
except (json.JSONDecodeError, KeyError) as e:
logger.error(f"Failed to load registry from {self.path}: {e}")
def _save(self):
"""Persist registry to disk."""
self.path.parent.mkdir(parents=True, exist_ok=True)
data = {
"version": 1,
"agents": [card.to_dict() for card in self._cards.values()],
}
with open(self.path, "w") as f:
json.dump(data, f, indent=2)
logger.debug(f"Saved {len(self._cards)} agents to {self.path}")
def register(self, card: AgentCard) -> None:
"""Register or update an agent card."""
self._cards[card.name.lower()] = card
self._save()
logger.info(f"Registered agent: {card.name}")
def unregister(self, name: str) -> bool:
"""Remove an agent from the registry."""
key = name.lower()
if key in self._cards:
del self._cards[key]
self._save()
logger.info(f"Unregistered agent: {name}")
return True
return False
def get(self, name: str) -> Optional[AgentCard]:
"""Get an agent card by name."""
return self._cards.get(name.lower())
def list_agents(
self,
skill: Optional[str] = None,
tag: Optional[str] = None,
) -> list[AgentCard]:
"""
List all registered agents, optionally filtered by skill or tag.
Args:
skill: Filter to agents that have this skill ID
tag: Filter to agents that have this tag on any skill
"""
agents = list(self._cards.values())
if skill:
agents = [
a for a in agents
if any(s.id == skill for s in a.skills)
]
if tag:
agents = [
a for a in agents
if any(tag in s.tags for s in a.skills)
]
return agents
def get_endpoint(self, name: str) -> Optional[str]:
"""Get the first supported interface URL for an agent."""
card = self.get(name)
if card and card.supported_interfaces:
return card.supported_interfaces[0].url
return None
def dump(self) -> dict:
"""Dump full registry as a dict."""
return {
"version": 1,
"agents": [card.to_dict() for card in self._cards.values()],
}
class GiteaRegistry:
"""
Gitea-backed agent registry.
Stores fleet agent cards in a Gitea repository file for
distributed discovery across VPS nodes.
"""
def __init__(
self,
gitea_url: str,
repo: str,
token: str,
file_path: str = "config/fleet_agents.json",
):
self.gitea_url = gitea_url.rstrip("/")
self.repo = repo
self.token = token
self.file_path = file_path
self._cards: dict[str, AgentCard] = {}
def _api_url(self, endpoint: str) -> str:
return f"{self.gitea_url}/api/v1/repos/{self.repo}/{endpoint}"
def _headers(self) -> dict:
return {
"Authorization": f"token {self.token}",
"Content-Type": "application/json",
}
async def load(self) -> None:
"""Fetch agent cards from Gitea."""
try:
import aiohttp
url = self._api_url(f"contents/{self.file_path}")
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=self._headers()) as resp:
if resp.status == 200:
data = await resp.json()
import base64
content = base64.b64decode(data["content"]).decode()
registry = json.loads(content)
for card_data in registry.get("agents", []):
card = AgentCard.from_dict(card_data)
self._cards[card.name.lower()] = card
logger.info(
f"Loaded {len(self._cards)} agents from Gitea"
)
elif resp.status == 404:
logger.info("No fleet registry file in Gitea yet")
else:
logger.error(
f"Gitea fetch failed: {resp.status}"
)
except Exception as e:
logger.error(f"Failed to load from Gitea: {e}")
async def save(self, message: str = "Update fleet registry") -> None:
"""Write agent cards to Gitea."""
try:
import aiohttp
content = json.dumps(
{"version": 1, "agents": [c.to_dict() for c in self._cards.values()]},
indent=2,
)
import base64
encoded = base64.b64encode(content.encode()).decode()
# Check if file exists (need SHA for update)
url = self._api_url(f"contents/{self.file_path}")
sha = None
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=self._headers()) as resp:
if resp.status == 200:
existing = await resp.json()
sha = existing.get("sha")
payload = {
"message": message,
"content": encoded,
}
if sha:
payload["sha"] = sha
async with session.put(
url, headers=self._headers(), json=payload
) as resp:
if resp.status in (200, 201):
logger.info("Fleet registry saved to Gitea")
else:
body = await resp.text()
logger.error(
f"Gitea save failed: {resp.status}{body}"
)
except Exception as e:
logger.error(f"Failed to save to Gitea: {e}")
def register(self, card: AgentCard) -> None:
"""Register an agent (local update; call save() to persist)."""
self._cards[card.name.lower()] = card
def unregister(self, name: str) -> bool:
key = name.lower()
if key in self._cards:
del self._cards[key]
return True
return False
def get(self, name: str) -> Optional[AgentCard]:
return self._cards.get(name.lower())
def list_agents(
self,
skill: Optional[str] = None,
tag: Optional[str] = None,
) -> list[AgentCard]:
agents = list(self._cards.values())
if skill:
agents = [a for a in agents if any(s.id == skill for s in a.skills)]
if tag:
agents = [a for a in agents if any(tag in s.tags for s in a.skills)]
return agents
# --- Convenience ---
def discover_agents(
path: Path = Path("config/fleet_agents.json"),
skill: Optional[str] = None,
tag: Optional[str] = None,
) -> list[AgentCard]:
"""One-shot discovery from local file."""
registry = LocalFileRegistry(path)
return registry.list_agents(skill=skill, tag=tag)

View File

@@ -1,386 +0,0 @@
"""
A2A Server — receive and process tasks from other agents.
Provides a FastAPI router that serves:
- GET /.well-known/agent-card.json — Agent Card discovery
- GET /agent.json — Agent Card fallback
- POST /a2a/v1 — JSON-RPC endpoint (SendMessage, GetTask, etc.)
- POST /a2a/v1/rpc — JSON-RPC endpoint (alias)
Task routing: registered handlers are matched by skill ID or receive
all tasks via a default handler.
Usage:
server = A2AServer(card=my_card, auth_token="secret")
server.register_handler("ci-health", my_ci_handler)
await server.start(host="0.0.0.0", port=8080)
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
import uuid
from datetime import datetime, timezone
from typing import Any, Callable, Awaitable, Optional
try:
from fastapi import FastAPI, Request, Response, HTTPException, Header
from fastapi.responses import JSONResponse
import uvicorn
HAS_FASTAPI = True
except ImportError:
HAS_FASTAPI = False
from nexus.a2a.types import (
A2AError,
AgentCard,
Artifact,
JSONRPCError,
JSONRPCResponse,
Message,
Role,
Task,
TaskState,
TaskStatus,
TextPart,
)
logger = logging.getLogger("nexus.a2a.server")
# Type for task handlers
TaskHandler = Callable[[Task, AgentCard], Awaitable[Task]]
class A2AServer:
"""
A2A protocol server for receiving agent-to-agent task delegation.
Supports:
- Agent Card serving at /.well-known/agent-card.json
- JSON-RPC task lifecycle (SendMessage, GetTask, CancelTask, ListTasks)
- Pluggable task handlers (by skill ID or default)
- Bearer / API key authentication
- Audit logging
"""
def __init__(
self,
card: AgentCard,
auth_token: str = "",
auth_scheme: str = "bearer",
):
if not HAS_FASTAPI:
raise ImportError(
"fastapi and uvicorn are required for A2AServer. "
"Install with: pip install fastapi uvicorn"
)
self.card = card
self.auth_token = auth_token
self.auth_scheme = auth_scheme
# Task store (in-memory; swap for SQLite/Redis in production)
self._tasks: dict[str, Task] = {}
# Handlers keyed by skill ID
self._handlers: dict[str, TaskHandler] = {}
# Default handler for unmatched skills
self._default_handler: Optional[TaskHandler] = None
# Audit log
self._audit_log: list[dict] = []
self.app = FastAPI(
title=f"A2A — {card.name}",
description=card.description,
version=card.version,
)
self._register_routes()
def register_handler(self, skill_id: str, handler: TaskHandler):
"""Register a handler for a specific skill ID."""
self._handlers[skill_id] = handler
logger.info(f"Registered handler for skill: {skill_id}")
def set_default_handler(self, handler: TaskHandler):
"""Set the fallback handler for tasks without a matching skill."""
self._default_handler = handler
def _verify_auth(self, authorization: Optional[str]) -> bool:
"""Check authentication header."""
if not self.auth_token:
return True # No auth configured
if not authorization:
return False
if self.auth_scheme == "bearer":
expected = f"Bearer {self.auth_token}"
return authorization == expected
return False
def _register_routes(self):
"""Wire up FastAPI routes."""
@self.app.get("/.well-known/agent-card.json")
async def agent_card_well_known():
return JSONResponse(self.card.to_dict())
@self.app.get("/agent.json")
async def agent_card_fallback():
return JSONResponse(self.card.to_dict())
@self.app.post("/a2a/v1")
@self.app.post("/a2a/v1/rpc")
async def rpc_endpoint(request: Request):
return await self._handle_rpc(request)
@self.app.get("/a2a/v1/tasks")
@self.app.get("/a2a/v1/tasks/{task_id}")
async def rest_get_task(task_id: Optional[str] = None):
if task_id:
task = self._tasks.get(task_id)
if not task:
return JSONRPCResponse(
id="",
error=A2AError.TASK_NOT_FOUND,
).to_dict()
return JSONResponse(task.to_dict())
else:
return JSONResponse(
{"tasks": [t.to_dict() for t in self._tasks.values()]}
)
async def _handle_rpc(self, request: Request) -> JSONResponse:
"""Handle JSON-RPC requests."""
# Auth check
auth_header = request.headers.get("authorization")
if not self._verify_auth(auth_header):
return JSONResponse(
status_code=401,
content={"error": "Unauthorized"},
)
# Parse JSON-RPC
try:
body = await request.json()
except json.JSONDecodeError:
return JSONResponse(
JSONRPCResponse(
id="", error=A2AError.PARSE
).to_dict(),
status_code=400,
)
method = body.get("method", "")
request_id = body.get("id", str(uuid.uuid4()))
params = body.get("params", {})
# Audit
self._audit_log.append({
"timestamp": time.time(),
"method": method,
"request_id": request_id,
"source": request.client.host if request.client else "unknown",
})
try:
result = await self._dispatch_rpc(method, params, request_id)
return JSONResponse(
JSONRPCResponse(id=request_id, result=result).to_dict()
)
except ValueError as e:
return JSONResponse(
JSONRPCResponse(
id=request_id,
error=JSONRPCError(-32602, str(e)),
).to_dict(),
status_code=400,
)
except Exception as e:
logger.exception(f"Error handling {method}: {e}")
return JSONResponse(
JSONRPCResponse(
id=request_id,
error=JSONRPCError(-32603, str(e)),
).to_dict(),
status_code=500,
)
async def _dispatch_rpc(
self, method: str, params: dict, request_id: str
) -> Any:
"""Route JSON-RPC method to handler."""
if method == "SendMessage":
return await self._rpc_send_message(params)
elif method == "GetTask":
return await self._rpc_get_task(params)
elif method == "ListTasks":
return await self._rpc_list_tasks(params)
elif method == "CancelTask":
return await self._rpc_cancel_task(params)
elif method == "GetAgentCard":
return self.card.to_dict()
else:
raise ValueError(f"Unknown method: {method}")
async def _rpc_send_message(self, params: dict) -> dict:
"""Handle SendMessage — create a task and route to handler."""
msg_data = params.get("message", {})
message = Message.from_dict(msg_data)
# Determine target skill from metadata
target_skill = message.metadata.get("targetSkill", "")
# Create task
task = Task(
context_id=message.context_id,
status=TaskStatus(state=TaskState.SUBMITTED),
history=[message],
metadata={"targetSkill": target_skill} if target_skill else {},
)
# Store immediately
self._tasks[task.id] = task
# Dispatch to handler
handler = self._handlers.get(target_skill) or self._default_handler
if handler is None:
task.status = TaskStatus(
state=TaskState.FAILED,
message=Message(
role=Role.AGENT,
parts=[TextPart(text="No handler available for this task")],
),
)
return {"task": task.to_dict()}
try:
# Mark as working
task.status = TaskStatus(state=TaskState.WORKING)
self._tasks[task.id] = task
# Execute handler
result_task = await handler(task, self.card)
# Store result
self._tasks[result_task.id] = result_task
return {"task": result_task.to_dict()}
except Exception as e:
task.status = TaskStatus(
state=TaskState.FAILED,
message=Message(
role=Role.AGENT,
parts=[TextPart(text=f"Handler error: {str(e)}")],
),
)
self._tasks[task.id] = task
return {"task": task.to_dict()}
async def _rpc_get_task(self, params: dict) -> dict:
"""Handle GetTask."""
task_id = params.get("id", "")
task = self._tasks.get(task_id)
if not task:
raise ValueError(f"Task not found: {task_id}")
return task.to_dict()
async def _rpc_list_tasks(self, params: dict) -> dict:
"""Handle ListTasks with cursor-based pagination."""
page_size = params.get("pageSize", 20)
page_token = params.get("pageToken", "")
tasks = sorted(
self._tasks.values(),
key=lambda t: t.status.timestamp,
reverse=True,
)
# Simple cursor: find index by token
start_idx = 0
if page_token:
for i, t in enumerate(tasks):
if t.id == page_token:
start_idx = i + 1
break
page = tasks[start_idx : start_idx + page_size]
next_token = ""
if start_idx + page_size < len(tasks):
next_token = tasks[start_idx + page_size - 1].id
return {
"tasks": [t.to_dict() for t in page],
"nextPageToken": next_token,
}
async def _rpc_cancel_task(self, params: dict) -> dict:
"""Handle CancelTask."""
task_id = params.get("id", "")
task = self._tasks.get(task_id)
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.status.state.terminal:
raise ValueError(
f"Task {task_id} is already terminal "
f"({task.status.state.value})"
)
task.status = TaskStatus(state=TaskState.CANCELED)
self._tasks[task_id] = task
return task.to_dict()
def get_audit_log(self) -> list[dict]:
"""Return audit log of all received requests."""
return list(self._audit_log)
async def start(
self,
host: str = "0.0.0.0",
port: int = 8080,
):
"""Start the A2A server with uvicorn."""
logger.info(
f"Starting A2A server for {self.card.name} on "
f"{host}:{port}"
)
logger.info(
f"Agent Card at "
f"http://{host}:{port}/.well-known/agent-card.json"
)
config = uvicorn.Config(
self.app,
host=host,
port=port,
log_level="info",
)
server = uvicorn.Server(config)
await server.serve()
# --- Default Handler Factory ---
async def echo_handler(task: Task, card: AgentCard) -> Task:
"""
Simple echo handler for testing.
Returns the user's message as an artifact.
"""
if task.history:
last_msg = task.history[-1]
text_parts = [p for p in last_msg.parts if isinstance(p, TextPart)]
if text_parts:
response_text = f"[{card.name}] Echo: {text_parts[0].text}"
task.artifacts.append(
Artifact(
parts=[TextPart(text=response_text)],
name="echo_response",
)
)
task.status = TaskStatus(state=TaskState.COMPLETED)
return task

View File

@@ -1,524 +0,0 @@
"""
A2A Protocol Types — Data models for Google's Agent2Agent protocol v1.0.
All types map directly to the A2A spec. JSON uses camelCase, enums use
SCREAMING_SNAKE_CASE, and Part types are discriminated by member name
(not a kind field — that was removed in v1.0).
See: https://github.com/google/A2A
"""
from __future__ import annotations
import enum
import uuid
from dataclasses import dataclass, field, asdict
from datetime import datetime, timezone
from typing import Any, Optional
# --- Enums ---
class TaskState(str, enum.Enum):
"""Lifecycle states for an A2A Task."""
SUBMITTED = "TASK_STATE_SUBMITTED"
WORKING = "TASK_STATE_WORKING"
COMPLETED = "TASK_STATE_COMPLETED"
FAILED = "TASK_STATE_FAILED"
CANCELED = "TASK_STATE_CANCELED"
INPUT_REQUIRED = "TASK_STATE_INPUT_REQUIRED"
REJECTED = "TASK_STATE_REJECTED"
AUTH_REQUIRED = "TASK_STATE_AUTH_REQUIRED"
@property
def terminal(self) -> bool:
return self in (
TaskState.COMPLETED,
TaskState.FAILED,
TaskState.CANCELED,
TaskState.REJECTED,
)
class Role(str, enum.Enum):
"""Who sent a message in an A2A conversation."""
USER = "ROLE_USER"
AGENT = "ROLE_AGENT"
# --- Parts (discriminated by member name in JSON) ---
@dataclass
class TextPart:
"""Plain text content."""
text: str
media_type: str = "text/plain"
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
d = {"text": self.text}
if self.media_type != "text/plain":
d["mediaType"] = self.media_type
if self.metadata:
d["metadata"] = self.metadata
return d
@dataclass
class FilePart:
"""Binary file content — inline or by URL reference."""
media_type: str
filename: Optional[str] = None
raw: Optional[str] = None # base64-encoded bytes
url: Optional[str] = None # URL reference
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
d = {"mediaType": self.media_type}
if self.raw is not None:
d["raw"] = self.raw
if self.url is not None:
d["url"] = self.url
if self.filename:
d["filename"] = self.filename
if self.metadata:
d["metadata"] = self.metadata
return d
@dataclass
class DataPart:
"""Arbitrary structured JSON data."""
data: dict
media_type: str = "application/json"
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
d = {"data": self.data}
if self.media_type != "application/json":
d["mediaType"] = self.media_type
if self.metadata:
d["metadata"] = self.metadata
return d
Part = TextPart | FilePart | DataPart
def part_from_dict(d: dict) -> Part:
"""Reconstruct a Part from its JSON dict (discriminated by key name)."""
if "text" in d:
return TextPart(
text=d["text"],
media_type=d.get("mediaType", "text/plain"),
metadata=d.get("metadata", {}),
)
if "raw" in d or "url" in d:
return FilePart(
media_type=d["mediaType"],
filename=d.get("filename"),
raw=d.get("raw"),
url=d.get("url"),
metadata=d.get("metadata", {}),
)
if "data" in d:
return DataPart(
data=d["data"],
media_type=d.get("mediaType", "application/json"),
metadata=d.get("metadata", {}),
)
raise ValueError(f"Cannot determine Part type from keys: {list(d.keys())}")
def part_to_dict(p: Part) -> dict:
"""Serialize a Part to its JSON dict."""
return p.to_dict()
# --- Message ---
@dataclass
class Message:
"""A2A Message — a turn in a conversation between user and agent."""
role: Role
parts: list[Part]
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
context_id: Optional[str] = None
task_id: Optional[str] = None
metadata: dict = field(default_factory=dict)
extensions: list[str] = field(default_factory=list)
reference_task_ids: list[str] = field(default_factory=list)
def to_dict(self) -> dict:
d: dict[str, Any] = {
"messageId": self.message_id,
"role": self.role.value,
"parts": [part_to_dict(p) for p in self.parts],
}
if self.context_id:
d["contextId"] = self.context_id
if self.task_id:
d["taskId"] = self.task_id
if self.metadata:
d["metadata"] = self.metadata
if self.extensions:
d["extensions"] = self.extensions
if self.reference_task_ids:
d["referenceTaskIds"] = self.reference_task_ids
return d
@classmethod
def from_dict(cls, d: dict) -> "Message":
return cls(
role=Role(d["role"]),
parts=[part_from_dict(p) for p in d["parts"]],
message_id=d.get("messageId", str(uuid.uuid4())),
context_id=d.get("contextId"),
task_id=d.get("taskId"),
metadata=d.get("metadata", {}),
extensions=d.get("extensions", []),
reference_task_ids=d.get("referenceTaskIds", []),
)
# --- Artifact ---
@dataclass
class Artifact:
"""A2A Artifact — structured output from a task."""
parts: list[Part]
artifact_id: str = field(default_factory=lambda: str(uuid.uuid4()))
name: Optional[str] = None
description: Optional[str] = None
metadata: dict = field(default_factory=dict)
extensions: list[str] = field(default_factory=list)
def to_dict(self) -> dict:
d: dict[str, Any] = {
"artifactId": self.artifact_id,
"parts": [part_to_dict(p) for p in self.parts],
}
if self.name:
d["name"] = self.name
if self.description:
d["description"] = self.description
if self.metadata:
d["metadata"] = self.metadata
if self.extensions:
d["extensions"] = self.extensions
return d
@classmethod
def from_dict(cls, d: dict) -> "Artifact":
return cls(
parts=[part_from_dict(p) for p in d["parts"]],
artifact_id=d.get("artifactId", str(uuid.uuid4())),
name=d.get("name"),
description=d.get("description"),
metadata=d.get("metadata", {}),
extensions=d.get("extensions", []),
)
# --- Task ---
@dataclass
class TaskStatus:
"""Status envelope for a Task."""
state: TaskState
message: Optional[Message] = None
timestamp: str = field(
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
def to_dict(self) -> dict:
d: dict[str, Any] = {"state": self.state.value}
if self.message:
d["message"] = self.message.to_dict()
d["timestamp"] = self.timestamp
return d
@classmethod
def from_dict(cls, d: dict) -> "TaskStatus":
msg = None
if "message" in d:
msg = Message.from_dict(d["message"])
return cls(
state=TaskState(d["state"]),
message=msg,
timestamp=d.get("timestamp", datetime.now(timezone.utc).isoformat()),
)
@dataclass
class Task:
"""A2A Task — a unit of work delegated between agents."""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
context_id: Optional[str] = None
status: TaskStatus = field(
default_factory=lambda: TaskStatus(state=TaskState.SUBMITTED)
)
artifacts: list[Artifact] = field(default_factory=list)
history: list[Message] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
d: dict[str, Any] = {
"id": self.id,
"status": self.status.to_dict(),
}
if self.context_id:
d["contextId"] = self.context_id
if self.artifacts:
d["artifacts"] = [a.to_dict() for a in self.artifacts]
if self.history:
d["history"] = [m.to_dict() for m in self.history]
if self.metadata:
d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d: dict) -> "Task":
return cls(
id=d.get("id", str(uuid.uuid4())),
context_id=d.get("contextId"),
status=TaskStatus.from_dict(d["status"]) if "status" in d else TaskStatus(TaskState.SUBMITTED),
artifacts=[Artifact.from_dict(a) for a in d.get("artifacts", [])],
history=[Message.from_dict(m) for m in d.get("history", [])],
metadata=d.get("metadata", {}),
)
# --- Agent Card ---
@dataclass
class AgentSkill:
"""Capability declaration for an Agent Card."""
id: str
name: str
description: str
tags: list[str] = field(default_factory=list)
examples: list[str] = field(default_factory=list)
input_modes: list[str] = field(default_factory=lambda: ["text/plain"])
output_modes: list[str] = field(default_factory=lambda: ["text/plain"])
security_requirements: list[dict] = field(default_factory=list)
def to_dict(self) -> dict:
d: dict[str, Any] = {
"id": self.id,
"name": self.name,
"description": self.description,
"tags": self.tags,
}
if self.examples:
d["examples"] = self.examples
if self.input_modes != ["text/plain"]:
d["inputModes"] = self.input_modes
if self.output_modes != ["text/plain"]:
d["outputModes"] = self.output_modes
if self.security_requirements:
d["securityRequirements"] = self.security_requirements
return d
@dataclass
class AgentInterface:
"""Network endpoint for an agent."""
url: str
protocol_binding: str = "HTTP+JSON"
protocol_version: str = "1.0"
tenant: str = ""
def to_dict(self) -> dict:
d = {
"url": self.url,
"protocolBinding": self.protocol_binding,
"protocolVersion": self.protocol_version,
}
if self.tenant:
d["tenant"] = self.tenant
return d
@dataclass
class AgentCapabilities:
"""What this agent can do beyond basic request/response."""
streaming: bool = False
push_notifications: bool = False
extended_agent_card: bool = False
extensions: list[dict] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"streaming": self.streaming,
"pushNotifications": self.push_notifications,
"extendedAgentCard": self.extended_agent_card,
"extensions": self.extensions,
}
@dataclass
class AgentCard:
"""
A2A Agent Card — self-describing metadata published at
/.well-known/agent-card.json
"""
name: str
description: str
version: str = "1.0.0"
supported_interfaces: list[AgentInterface] = field(default_factory=list)
capabilities: AgentCapabilities = field(
default_factory=AgentCapabilities
)
provider: Optional[dict] = None
documentation_url: Optional[str] = None
icon_url: Optional[str] = None
default_input_modes: list[str] = field(
default_factory=lambda: ["text/plain"]
)
default_output_modes: list[str] = field(
default_factory=lambda: ["text/plain"]
)
skills: list[AgentSkill] = field(default_factory=list)
security_schemes: dict = field(default_factory=dict)
security_requirements: list[dict] = field(default_factory=list)
def to_dict(self) -> dict:
d: dict[str, Any] = {
"name": self.name,
"description": self.description,
"version": self.version,
"supportedInterfaces": [i.to_dict() for i in self.supported_interfaces],
"capabilities": self.capabilities.to_dict(),
"defaultInputModes": self.default_input_modes,
"defaultOutputModes": self.default_output_modes,
"skills": [s.to_dict() for s in self.skills],
}
if self.provider:
d["provider"] = self.provider
if self.documentation_url:
d["documentationUrl"] = self.documentation_url
if self.icon_url:
d["iconUrl"] = self.icon_url
if self.security_schemes:
d["securitySchemes"] = self.security_schemes
if self.security_requirements:
d["securityRequirements"] = self.security_requirements
return d
@classmethod
def from_dict(cls, d: dict) -> "AgentCard":
return cls(
name=d["name"],
description=d["description"],
version=d.get("version", "1.0.0"),
supported_interfaces=[
AgentInterface(
url=i["url"],
protocol_binding=i.get("protocolBinding", "HTTP+JSON"),
protocol_version=i.get("protocolVersion", "1.0"),
tenant=i.get("tenant", ""),
)
for i in d.get("supportedInterfaces", [])
],
capabilities=AgentCapabilities(
streaming=d.get("capabilities", {}).get("streaming", False),
push_notifications=d.get("capabilities", {}).get("pushNotifications", False),
extended_agent_card=d.get("capabilities", {}).get("extendedAgentCard", False),
extensions=d.get("capabilities", {}).get("extensions", []),
),
provider=d.get("provider"),
documentation_url=d.get("documentationUrl"),
icon_url=d.get("iconUrl"),
default_input_modes=d.get("defaultInputModes", ["text/plain"]),
default_output_modes=d.get("defaultOutputModes", ["text/plain"]),
skills=[
AgentSkill(
id=s["id"],
name=s["name"],
description=s["description"],
tags=s.get("tags", []),
examples=s.get("examples", []),
input_modes=s.get("inputModes", ["text/plain"]),
output_modes=s.get("outputModes", ["text/plain"]),
security_requirements=s.get("securityRequirements", []),
)
for s in d.get("skills", [])
],
security_schemes=d.get("securitySchemes", {}),
security_requirements=d.get("securityRequirements", []),
)
# --- JSON-RPC envelope ---
@dataclass
class JSONRPCRequest:
"""JSON-RPC 2.0 request wrapping an A2A method."""
method: str
id: str = field(default_factory=lambda: str(uuid.uuid4()))
params: dict = field(default_factory=dict)
jsonrpc: str = "2.0"
def to_dict(self) -> dict:
return {
"jsonrpc": self.jsonrpc,
"id": self.id,
"method": self.method,
"params": self.params,
}
@dataclass
class JSONRPCError:
"""JSON-RPC 2.0 error object."""
code: int
message: str
data: Any = None
def to_dict(self) -> dict:
d = {"code": self.code, "message": self.message}
if self.data is not None:
d["data"] = self.data
return d
@dataclass
class JSONRPCResponse:
"""JSON-RPC 2.0 response."""
id: str
result: Any = None
error: Optional[JSONRPCError] = None
jsonrpc: str = "2.0"
def to_dict(self) -> dict:
d: dict[str, Any] = {
"jsonrpc": self.jsonrpc,
"id": self.id,
}
if self.error:
d["error"] = self.error.to_dict()
else:
d["result"] = self.result
return d
# --- Standard A2A Error codes ---
class A2AError:
"""Standard A2A / JSON-RPC error factories."""
PARSE = JSONRPCError(-32700, "Invalid JSON payload")
INVALID_REQUEST = JSONRPCError(-32600, "Request payload validation error")
METHOD_NOT_FOUND = JSONRPCError(-32601, "Method not found")
INVALID_PARAMS = JSONRPCError(-32602, "Invalid parameters")
INTERNAL = JSONRPCError(-32603, "Internal error")
TASK_NOT_FOUND = JSONRPCError(-32001, "Task not found")
TASK_NOT_CANCELABLE = JSONRPCError(-32002, "Task not cancelable")
PUSH_NOT_SUPPORTED = JSONRPCError(-32003, "Push notifications not supported")
UNSUPPORTED_OP = JSONRPCError(-32004, "Unsupported operation")
CONTENT_TYPE = JSONRPCError(-32005, "Content type not supported")
INVALID_RESPONSE = JSONRPCError(-32006, "Invalid agent response")
EXTENDED_CARD = JSONRPCError(-32007, "Extended agent card not configured")
EXTENSION_REQUIRED = JSONRPCError(-32008, "Extension support required")
VERSION_NOT_SUPPORTED = JSONRPCError(-32009, "Version not supported")

View File

@@ -1,763 +0,0 @@
"""
Tests for A2A Protocol implementation.
Covers:
- Type serialization roundtrips (Agent Card, Task, Message, Artifact, Part)
- JSON-RPC envelope
- Agent Card building from YAML config
- Registry operations (register, list, filter)
- Client/server integration (end-to-end task delegation)
"""
from __future__ import annotations
import asyncio
import json
import pytest
from pathlib import Path
from unittest.mock import AsyncMock, patch, MagicMock
from nexus.a2a.types import (
A2AError,
AgentCard,
AgentCapabilities,
AgentInterface,
AgentSkill,
Artifact,
DataPart,
FilePart,
JSONRPCError,
JSONRPCRequest,
JSONRPCResponse,
Message,
Role,
Task,
TaskState,
TaskStatus,
TextPart,
part_from_dict,
part_to_dict,
)
from nexus.a2a.card import build_card, load_card_config
from nexus.a2a.registry import LocalFileRegistry
# === Type Serialization Roundtrips ===
class TestTextPart:
def test_roundtrip(self):
p = TextPart(text="hello world")
d = p.to_dict()
assert d == {"text": "hello world"}
p2 = part_from_dict(d)
assert isinstance(p2, TextPart)
assert p2.text == "hello world"
def test_custom_media_type(self):
p = TextPart(text="data", media_type="text/markdown")
d = p.to_dict()
assert d["mediaType"] == "text/markdown"
p2 = part_from_dict(d)
assert p2.media_type == "text/markdown"
class TestFilePart:
def test_inline_roundtrip(self):
p = FilePart(media_type="image/png", raw="base64data", filename="img.png")
d = p.to_dict()
assert d["raw"] == "base64data"
assert d["filename"] == "img.png"
p2 = part_from_dict(d)
assert isinstance(p2, FilePart)
assert p2.raw == "base64data"
def test_url_roundtrip(self):
p = FilePart(media_type="application/pdf", url="https://example.com/doc.pdf")
d = p.to_dict()
assert d["url"] == "https://example.com/doc.pdf"
p2 = part_from_dict(d)
assert isinstance(p2, FilePart)
assert p2.url == "https://example.com/doc.pdf"
class TestDataPart:
def test_roundtrip(self):
p = DataPart(data={"key": "value", "count": 42})
d = p.to_dict()
assert d["data"] == {"key": "value", "count": 42}
p2 = part_from_dict(d)
assert isinstance(p2, DataPart)
assert p2.data["count"] == 42
class TestMessage:
def test_roundtrip(self):
msg = Message(
role=Role.USER,
parts=[TextPart(text="Hello agent")],
metadata={"priority": "high"},
)
d = msg.to_dict()
assert d["role"] == "ROLE_USER"
assert d["parts"] == [{"text": "Hello agent"}]
assert d["metadata"]["priority"] == "high"
msg2 = Message.from_dict(d)
assert msg2.role == Role.USER
assert isinstance(msg2.parts[0], TextPart)
assert msg2.parts[0].text == "Hello agent"
assert msg2.metadata["priority"] == "high"
def test_multi_part(self):
msg = Message(
role=Role.AGENT,
parts=[
TextPart(text="Here's the report"),
DataPart(data={"status": "healthy"}),
],
)
d = msg.to_dict()
assert len(d["parts"]) == 2
msg2 = Message.from_dict(d)
assert len(msg2.parts) == 2
assert isinstance(msg2.parts[0], TextPart)
assert isinstance(msg2.parts[1], DataPart)
class TestArtifact:
def test_roundtrip(self):
art = Artifact(
parts=[TextPart(text="result data")],
name="report",
description="CI health report",
)
d = art.to_dict()
assert d["name"] == "report"
assert d["description"] == "CI health report"
art2 = Artifact.from_dict(d)
assert art2.name == "report"
assert isinstance(art2.parts[0], TextPart)
assert art2.parts[0].text == "result data"
class TestTask:
def test_roundtrip(self):
task = Task(
id="test-123",
status=TaskStatus(state=TaskState.WORKING),
history=[
Message(role=Role.USER, parts=[TextPart(text="Do X")]),
],
)
d = task.to_dict()
assert d["id"] == "test-123"
assert d["status"]["state"] == "TASK_STATE_WORKING"
task2 = Task.from_dict(d)
assert task2.id == "test-123"
assert task2.status.state == TaskState.WORKING
assert len(task2.history) == 1
def test_with_artifacts(self):
task = Task(
id="art-task",
status=TaskStatus(state=TaskState.COMPLETED),
artifacts=[
Artifact(
parts=[TextPart(text="42")],
name="answer",
)
],
)
d = task.to_dict()
assert len(d["artifacts"]) == 1
task2 = Task.from_dict(d)
assert task2.artifacts[0].name == "answer"
def test_terminal_states(self):
for state in [
TaskState.COMPLETED,
TaskState.FAILED,
TaskState.CANCELED,
TaskState.REJECTED,
]:
assert state.terminal is True
for state in [
TaskState.SUBMITTED,
TaskState.WORKING,
TaskState.INPUT_REQUIRED,
TaskState.AUTH_REQUIRED,
]:
assert state.terminal is False
class TestAgentCard:
def test_roundtrip(self):
card = AgentCard(
name="TestAgent",
description="A test agent",
version="1.0.0",
supported_interfaces=[
AgentInterface(url="http://localhost:8080/a2a/v1")
],
capabilities=AgentCapabilities(streaming=True),
skills=[
AgentSkill(
id="test-skill",
name="Test Skill",
description="Does tests",
tags=["test"],
)
],
)
d = card.to_dict()
assert d["name"] == "TestAgent"
assert d["capabilities"]["streaming"] is True
assert len(d["skills"]) == 1
assert d["skills"][0]["id"] == "test-skill"
card2 = AgentCard.from_dict(d)
assert card2.name == "TestAgent"
assert card2.skills[0].id == "test-skill"
assert card2.capabilities.streaming is True
class TestJSONRPC:
def test_request_roundtrip(self):
req = JSONRPCRequest(
method="SendMessage",
params={"message": {"text": "hello"}},
)
d = req.to_dict()
assert d["jsonrpc"] == "2.0"
assert d["method"] == "SendMessage"
def test_response_success(self):
resp = JSONRPCResponse(
id="req-1",
result={"task": {"id": "t1"}},
)
d = resp.to_dict()
assert "error" not in d
assert d["result"]["task"]["id"] == "t1"
def test_response_error(self):
resp = JSONRPCResponse(
id="req-1",
error=A2AError.TASK_NOT_FOUND,
)
d = resp.to_dict()
assert "result" not in d
assert d["error"]["code"] == -32001
# === Agent Card Building ===
class TestBuildCard:
def test_basic_config(self):
config = {
"name": "Bezalel",
"description": "CI/CD specialist",
"version": "2.0.0",
"url": "https://bezalel.example.com",
"skills": [
{
"id": "ci-health",
"name": "CI Health",
"description": "Check CI",
"tags": ["ci"],
},
{
"id": "deploy",
"name": "Deploy",
"description": "Deploy services",
"tags": ["ops"],
},
],
}
card = build_card(config)
assert card.name == "Bezalel"
assert card.version == "2.0.0"
assert len(card.skills) == 2
assert card.skills[0].id == "ci-health"
assert card.supported_interfaces[0].url == "https://bezalel.example.com"
def test_bearer_auth(self):
config = {
"name": "Test",
"description": "Test",
"auth": {"scheme": "bearer", "token_env": "MY_TOKEN"},
}
card = build_card(config)
assert "bearerAuth" in card.security_schemes
assert card.security_requirements[0]["schemes"]["bearerAuth"] == {"list": []}
def test_api_key_auth(self):
config = {
"name": "Test",
"description": "Test",
"auth": {"scheme": "api_key", "key_name": "X-Custom-Key"},
}
card = build_card(config)
assert "apiKeyAuth" in card.security_schemes
# === Registry ===
class TestLocalFileRegistry:
def _make_card(self, name: str, skills: list[dict] | None = None) -> AgentCard:
return AgentCard(
name=name,
description=f"Agent {name}",
supported_interfaces=[
AgentInterface(url=f"http://{name}:8080/a2a/v1")
],
skills=[
AgentSkill(
id=s["id"],
name=s.get("name", s["id"]),
description=s.get("description", ""),
tags=s.get("tags", []),
)
for s in (skills or [])
],
)
def test_register_and_list(self, tmp_path):
registry = LocalFileRegistry(tmp_path / "agents.json")
registry.register(self._make_card("ezra"))
registry.register(self._make_card("allegro"))
agents = registry.list_agents()
assert len(agents) == 2
names = {a.name for a in agents}
assert names == {"ezra", "allegro"}
def test_filter_by_skill(self, tmp_path):
registry = LocalFileRegistry(tmp_path / "agents.json")
registry.register(
self._make_card("ezra", [{"id": "ci-health", "tags": ["ci"]}])
)
registry.register(
self._make_card("allegro", [{"id": "research", "tags": ["research"]}])
)
ci_agents = registry.list_agents(skill="ci-health")
assert len(ci_agents) == 1
assert ci_agents[0].name == "ezra"
def test_filter_by_tag(self, tmp_path):
registry = LocalFileRegistry(tmp_path / "agents.json")
registry.register(
self._make_card("ezra", [{"id": "ci", "tags": ["devops", "ci"]}])
)
registry.register(
self._make_card("allegro", [{"id": "research", "tags": ["research"]}])
)
devops_agents = registry.list_agents(tag="devops")
assert len(devops_agents) == 1
assert devops_agents[0].name == "ezra"
def test_persistence(self, tmp_path):
path = tmp_path / "agents.json"
reg1 = LocalFileRegistry(path)
reg1.register(self._make_card("ezra"))
# Load fresh from disk
reg2 = LocalFileRegistry(path)
agents = reg2.list_agents()
assert len(agents) == 1
assert agents[0].name == "ezra"
def test_unregister(self, tmp_path):
registry = LocalFileRegistry(tmp_path / "agents.json")
registry.register(self._make_card("ezra"))
assert len(registry.list_agents()) == 1
assert registry.unregister("ezra") is True
assert len(registry.list_agents()) == 0
assert registry.unregister("nonexistent") is False
def test_get_endpoint(self, tmp_path):
registry = LocalFileRegistry(tmp_path / "agents.json")
registry.register(self._make_card("ezra"))
url = registry.get_endpoint("ezra")
assert url == "http://ezra:8080/a2a/v1"
# === Server Integration (FastAPI required) ===
try:
from fastapi.testclient import TestClient
HAS_TEST_CLIENT = True
except ImportError:
HAS_TEST_CLIENT = False
@pytest.mark.skipif(not HAS_TEST_CLIENT, reason="fastapi not installed")
class TestA2AServerIntegration:
"""End-to-end tests using FastAPI TestClient."""
def _make_server(self, auth_token: str = ""):
from nexus.a2a.server import A2AServer, echo_handler
card = AgentCard(
name="TestAgent",
description="Test agent for A2A",
supported_interfaces=[
AgentInterface(url="http://localhost:8080/a2a/v1")
],
capabilities=AgentCapabilities(streaming=False),
skills=[
AgentSkill(
id="echo",
name="Echo",
description="Echo back messages",
tags=["test"],
)
],
)
server = A2AServer(card=card, auth_token=auth_token)
server.register_handler("echo", echo_handler)
server.set_default_handler(echo_handler)
return server
def test_agent_card_well_known(self):
server = self._make_server()
client = TestClient(server.app)
resp = client.get("/.well-known/agent-card.json")
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "TestAgent"
assert len(data["skills"]) == 1
def test_agent_card_fallback(self):
server = self._make_server()
client = TestClient(server.app)
resp = client.get("/agent.json")
assert resp.status_code == 200
assert resp.json()["name"] == "TestAgent"
def test_send_message(self):
server = self._make_server()
client = TestClient(server.app)
rpc_request = {
"jsonrpc": "2.0",
"id": "test-1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "msg-1",
"role": "ROLE_USER",
"parts": [{"text": "Hello from test"}],
},
"configuration": {
"acceptedOutputModes": ["text/plain"],
"historyLength": 10,
"returnImmediately": False,
},
},
}
resp = client.post("/a2a/v1", json=rpc_request)
assert resp.status_code == 200
data = resp.json()
assert "result" in data
assert "task" in data["result"]
task = data["result"]["task"]
assert task["status"]["state"] == "TASK_STATE_COMPLETED"
assert len(task["artifacts"]) == 1
assert "Echo" in task["artifacts"][0]["parts"][0]["text"]
def test_get_task(self):
server = self._make_server()
client = TestClient(server.app)
# Create a task first
send_req = {
"jsonrpc": "2.0",
"id": "s1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "m1",
"role": "ROLE_USER",
"parts": [{"text": "get me"}],
},
"configuration": {},
},
}
send_resp = client.post("/a2a/v1", json=send_req)
task_id = send_resp.json()["result"]["task"]["id"]
# Now fetch it
get_req = {
"jsonrpc": "2.0",
"id": "g1",
"method": "GetTask",
"params": {"id": task_id},
}
get_resp = client.post("/a2a/v1", json=get_req)
assert get_resp.status_code == 200
assert get_resp.json()["result"]["id"] == task_id
def test_get_nonexistent_task(self):
server = self._make_server()
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "g2",
"method": "GetTask",
"params": {"id": "nonexistent"},
}
resp = client.post("/a2a/v1", json=req)
assert resp.status_code == 400
data = resp.json()
assert "error" in data
def test_list_tasks(self):
server = self._make_server()
client = TestClient(server.app)
# Create two tasks
for i in range(2):
req = {
"jsonrpc": "2.0",
"id": f"s{i}",
"method": "SendMessage",
"params": {
"message": {
"messageId": f"m{i}",
"role": "ROLE_USER",
"parts": [{"text": f"task {i}"}],
},
"configuration": {},
},
}
client.post("/a2a/v1", json=req)
list_req = {
"jsonrpc": "2.0",
"id": "l1",
"method": "ListTasks",
"params": {"pageSize": 10},
}
resp = client.post("/a2a/v1", json=list_req)
assert resp.status_code == 200
tasks = resp.json()["result"]["tasks"]
assert len(tasks) >= 2
def test_cancel_task(self):
from nexus.a2a.server import A2AServer
# Create a server with a slow handler so task stays WORKING
async def slow_handler(task, card):
import asyncio
await asyncio.sleep(10) # never reached in test
task.status = TaskStatus(state=TaskState.COMPLETED)
return task
card = AgentCard(name="SlowAgent", description="Slow test agent")
server = A2AServer(card=card)
server.set_default_handler(slow_handler)
client = TestClient(server.app)
# Create a task (but we need to intercept before handler runs)
# Instead, manually insert a task and test cancel on it
task = Task(
id="cancel-me",
status=TaskStatus(state=TaskState.WORKING),
history=[
Message(role=Role.USER, parts=[TextPart(text="cancel me")])
],
)
server._tasks[task.id] = task
# Cancel it
cancel_req = {
"jsonrpc": "2.0",
"id": "c2",
"method": "CancelTask",
"params": {"id": "cancel-me"},
}
cancel_resp = client.post("/a2a/v1", json=cancel_req)
assert cancel_resp.status_code == 200
assert cancel_resp.json()["result"]["status"]["state"] == "TASK_STATE_CANCELED"
def test_auth_required(self):
server = self._make_server(auth_token="secret123")
client = TestClient(server.app)
# No auth header — should get 401
req = {
"jsonrpc": "2.0",
"id": "a1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "am1",
"role": "ROLE_USER",
"parts": [{"text": "hello"}],
},
"configuration": {},
},
}
resp = client.post("/a2a/v1", json=req)
assert resp.status_code == 401
def test_auth_success(self):
server = self._make_server(auth_token="secret123")
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "a2",
"method": "SendMessage",
"params": {
"message": {
"messageId": "am2",
"role": "ROLE_USER",
"parts": [{"text": "authenticated"}],
},
"configuration": {},
},
}
resp = client.post(
"/a2a/v1",
json=req,
headers={"Authorization": "Bearer secret123"},
)
assert resp.status_code == 200
assert resp.json()["result"]["task"]["status"]["state"] == "TASK_STATE_COMPLETED"
def test_unknown_method(self):
server = self._make_server()
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "u1",
"method": "NonExistentMethod",
"params": {},
}
resp = client.post("/a2a/v1", json=req)
assert resp.status_code == 400
assert resp.json()["error"]["code"] == -32602
def test_audit_log(self):
server = self._make_server()
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "au1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "aum1",
"role": "ROLE_USER",
"parts": [{"text": "audit me"}],
},
"configuration": {},
},
}
client.post("/a2a/v1", json=req)
client.post("/a2a/v1", json=req)
log = server.get_audit_log()
assert len(log) == 2
assert all(entry["method"] == "SendMessage" for entry in log)
# === Custom Handler Test ===
@pytest.mark.skipif(not HAS_TEST_CLIENT, reason="fastapi not installed")
class TestCustomHandlers:
"""Test custom task handlers."""
def test_skill_routing(self):
from nexus.a2a.server import A2AServer
from nexus.a2a.types import Task, AgentCard
async def ci_handler(task: Task, card: AgentCard) -> Task:
task.artifacts.append(
Artifact(
parts=[TextPart(text="CI pipeline healthy: 5/5 passed")],
name="ci_report",
)
)
task.status = TaskStatus(state=TaskState.COMPLETED)
return task
card = AgentCard(
name="CI Agent",
description="CI specialist",
skills=[AgentSkill(id="ci-health", name="CI Health", description="Check CI", tags=["ci"])],
)
server = A2AServer(card=card)
server.register_handler("ci-health", ci_handler)
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "h1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "hm1",
"role": "ROLE_USER",
"parts": [{"text": "Check CI"}],
"metadata": {"targetSkill": "ci-health"},
},
"configuration": {},
},
}
resp = client.post("/a2a/v1", json=req)
task_data = resp.json()["result"]["task"]
assert task_data["status"]["state"] == "TASK_STATE_COMPLETED"
assert "5/5 passed" in task_data["artifacts"][0]["parts"][0]["text"]
def test_handler_error(self):
from nexus.a2a.server import A2AServer
from nexus.a2a.types import Task, AgentCard
async def failing_handler(task: Task, card: AgentCard) -> Task:
raise RuntimeError("Handler blew up")
card = AgentCard(name="Fail Agent", description="Fails")
server = A2AServer(card=card)
server.set_default_handler(failing_handler)
client = TestClient(server.app)
req = {
"jsonrpc": "2.0",
"id": "f1",
"method": "SendMessage",
"params": {
"message": {
"messageId": "fm1",
"role": "ROLE_USER",
"parts": [{"text": "break"}],
},
"configuration": {},
},
}
resp = client.post("/a2a/v1", json=req)
task_data = resp.json()["result"]["task"]
assert task_data["status"]["state"] == "TASK_STATE_FAILED"
assert "blew up" in task_data["status"]["message"]["parts"][0]["text"].lower()

377
tests/test_agent_memory.py Normal file
View File

@@ -0,0 +1,377 @@
"""
Tests for agent memory — cross-session agent memory via MemPalace.
Tests the memory module, hooks, and session mining without requiring
a live ChromaDB instance. Uses mocking for the MemPalace backend.
"""
from __future__ import annotations
import json
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from agent.memory import (
AgentMemory,
MemoryContext,
SessionTranscript,
create_agent_memory,
)
from agent.memory_hooks import MemoryHooks
# ---------------------------------------------------------------------------
# SessionTranscript tests
# ---------------------------------------------------------------------------
class TestSessionTranscript:
def test_create(self):
t = SessionTranscript(agent_name="test", wing="wing_test")
assert t.agent_name == "test"
assert t.wing == "wing_test"
assert len(t.entries) == 0
def test_add_user_turn(self):
t = SessionTranscript(agent_name="test", wing="wing_test")
t.add_user_turn("Hello")
assert len(t.entries) == 1
assert t.entries[0]["role"] == "user"
assert t.entries[0]["text"] == "Hello"
def test_add_agent_turn(self):
t = SessionTranscript(agent_name="test", wing="wing_test")
t.add_agent_turn("Response")
assert t.entries[0]["role"] == "agent"
def test_add_tool_call(self):
t = SessionTranscript(agent_name="test", wing="wing_test")
t.add_tool_call("shell", "ls", "file1 file2")
assert t.entries[0]["role"] == "tool"
assert t.entries[0]["tool"] == "shell"
def test_summary_empty(self):
t = SessionTranscript(agent_name="test", wing="wing_test")
assert t.summary() == "Empty session."
def test_summary_with_entries(self):
t = SessionTranscript(agent_name="test", wing="wing_test")
t.add_user_turn("Do something")
t.add_agent_turn("Done")
t.add_tool_call("shell", "ls", "ok")
summary = t.summary()
assert "USER: Do something" in summary
assert "AGENT: Done" in summary
assert "TOOL(shell): ok" in summary
def test_text_truncation(self):
t = SessionTranscript(agent_name="test", wing="wing_test")
long_text = "x" * 5000
t.add_user_turn(long_text)
assert len(t.entries[0]["text"]) == 2000
# ---------------------------------------------------------------------------
# MemoryContext tests
# ---------------------------------------------------------------------------
class TestMemoryContext:
def test_empty_context(self):
ctx = MemoryContext()
assert ctx.to_prompt_block() == ""
def test_unloaded_context(self):
ctx = MemoryContext()
ctx.loaded = False
assert ctx.to_prompt_block() == ""
def test_loaded_with_data(self):
ctx = MemoryContext()
ctx.loaded = True
ctx.recent_diaries = [
{"text": "Fixed PR #1386", "timestamp": "2026-04-13T10:00:00Z"}
]
ctx.facts = [
{"text": "Bezalel runs on VPS Beta", "score": 0.95}
]
ctx.relevant_memories = [
{"text": "Changed CI runner", "score": 0.87}
]
block = ctx.to_prompt_block()
assert "Recent Session Summaries" in block
assert "Fixed PR #1386" in block
assert "Known Facts" in block
assert "Bezalel runs on VPS Beta" in block
assert "Relevant Past Memories" in block
def test_loaded_empty(self):
ctx = MemoryContext()
ctx.loaded = True
# No data — should return empty string
assert ctx.to_prompt_block() == ""
# ---------------------------------------------------------------------------
# AgentMemory tests (with mocked MemPalace)
# ---------------------------------------------------------------------------
class TestAgentMemory:
def test_create(self):
mem = AgentMemory(agent_name="bezalel")
assert mem.agent_name == "bezalel"
assert mem.wing == "wing_bezalel"
def test_custom_wing(self):
mem = AgentMemory(agent_name="bezalel", wing="custom_wing")
assert mem.wing == "custom_wing"
def test_factory(self):
mem = create_agent_memory("ezra")
assert mem.agent_name == "ezra"
assert mem.wing == "wing_ezra"
def test_unavailable_graceful(self):
"""Test graceful degradation when MemPalace is unavailable."""
mem = AgentMemory(agent_name="test")
mem._available = False # Force unavailable
# Should not raise
ctx = mem.recall_context("test query")
assert ctx.loaded is False
assert ctx.error == "MemPalace unavailable"
# remember returns None
assert mem.remember("test") is None
# search returns empty
assert mem.search("test") == []
def test_start_end_session(self):
mem = AgentMemory(agent_name="test")
mem._available = False
transcript = mem.start_session()
assert isinstance(transcript, SessionTranscript)
assert mem._transcript is not None
doc_id = mem.end_session()
assert mem._transcript is None
def test_remember_graceful_when_unavailable(self):
"""Test remember returns None when MemPalace is unavailable."""
mem = AgentMemory(agent_name="test")
mem._available = False
doc_id = mem.remember("some important fact")
assert doc_id is None
def test_write_diary_from_transcript(self):
mem = AgentMemory(agent_name="test")
mem._available = False
transcript = mem.start_session()
transcript.add_user_turn("Hello")
transcript.add_agent_turn("Hi there")
# Write diary should handle unavailable gracefully
doc_id = mem.write_diary()
assert doc_id is None # MemPalace unavailable
# ---------------------------------------------------------------------------
# MemoryHooks tests
# ---------------------------------------------------------------------------
class TestMemoryHooks:
def test_create(self):
hooks = MemoryHooks(agent_name="bezalel")
assert hooks.agent_name == "bezalel"
assert hooks.is_active is False
def test_session_lifecycle(self):
hooks = MemoryHooks(agent_name="test")
# Force memory unavailable
hooks._memory = AgentMemory(agent_name="test")
hooks._memory._available = False
# Start session
block = hooks.on_session_start()
assert hooks.is_active is True
assert block == "" # No memory available
# Record turns
hooks.on_user_turn("Hello")
hooks.on_agent_turn("Hi")
hooks.on_tool_call("shell", "ls", "ok")
# Record decision
hooks.on_important_decision("Switched to self-hosted CI")
# End session
doc_id = hooks.on_session_end()
assert hooks.is_active is False
def test_hooks_before_session(self):
"""Hooks before session start should be no-ops."""
hooks = MemoryHooks(agent_name="test")
hooks._memory = AgentMemory(agent_name="test")
hooks._memory._available = False
# Should not raise
hooks.on_user_turn("Hello")
hooks.on_agent_turn("Response")
def test_hooks_after_session_end(self):
"""Hooks after session end should be no-ops."""
hooks = MemoryHooks(agent_name="test")
hooks._memory = AgentMemory(agent_name="test")
hooks._memory._available = False
hooks.on_session_start()
hooks.on_session_end()
# Should not raise
hooks.on_user_turn("Late message")
doc_id = hooks.on_session_end()
assert doc_id is None
def test_search_during_session(self):
hooks = MemoryHooks(agent_name="test")
hooks._memory = AgentMemory(agent_name="test")
hooks._memory._available = False
results = hooks.search("some query")
assert results == []
# ---------------------------------------------------------------------------
# Session mining tests
# ---------------------------------------------------------------------------
class TestSessionMining:
def test_parse_session_file(self):
from bin.memory_mine import parse_session_file
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write('{"role": "user", "content": "Hello"}\n')
f.write('{"role": "assistant", "content": "Hi there"}\n')
f.write('{"role": "tool", "name": "shell", "content": "ls output"}\n')
f.write("\n") # blank line
f.write("not json\n") # malformed
path = Path(f.name)
turns = parse_session_file(path)
assert len(turns) == 3
assert turns[0]["role"] == "user"
assert turns[1]["role"] == "assistant"
assert turns[2]["role"] == "tool"
path.unlink()
def test_summarize_session(self):
from bin.memory_mine import summarize_session
turns = [
{"role": "user", "content": "Check CI"},
{"role": "assistant", "content": "Running CI check..."},
{"role": "tool", "name": "shell", "content": "5 tests passed"},
{"role": "assistant", "content": "CI is healthy"},
]
summary = summarize_session(turns, "bezalel")
assert "bezalel" in summary
assert "Check CI" in summary
assert "shell" in summary
def test_summarize_empty(self):
from bin.memory_mine import summarize_session
assert summarize_session([], "test") == "Empty session."
def test_find_session_files(self, tmp_path):
from bin.memory_mine import find_session_files
# Create some test files
(tmp_path / "session1.jsonl").write_text("{}\n")
(tmp_path / "session2.jsonl").write_text("{}\n")
(tmp_path / "notes.txt").write_text("not a session")
files = find_session_files(tmp_path, days=365)
assert len(files) == 2
assert all(f.suffix == ".jsonl" for f in files)
def test_find_session_files_missing_dir(self):
from bin.memory_mine import find_session_files
files = find_session_files(Path("/nonexistent/path"), days=7)
assert files == []
def test_mine_session_dry_run(self, tmp_path):
from bin.memory_mine import mine_session
session_file = tmp_path / "test.jsonl"
session_file.write_text(
'{"role": "user", "content": "Hello"}\n'
'{"role": "assistant", "content": "Hi"}\n'
)
result = mine_session(session_file, wing="wing_test", dry_run=True)
assert result is None # dry run doesn't store
def test_mine_session_empty_file(self, tmp_path):
from bin.memory_mine import mine_session
session_file = tmp_path / "empty.jsonl"
session_file.write_text("")
result = mine_session(session_file, wing="wing_test")
assert result is None
# ---------------------------------------------------------------------------
# Integration test — full lifecycle
# ---------------------------------------------------------------------------
class TestFullLifecycle:
"""Test the full session lifecycle without a real MemPalace backend."""
def test_full_session_flow(self):
hooks = MemoryHooks(agent_name="bezalel")
# Force memory unavailable
hooks._memory = AgentMemory(agent_name="bezalel")
hooks._memory._available = False
# 1. Session start
context_block = hooks.on_session_start("What CI issues do I have?")
assert isinstance(context_block, str)
# 2. User asks question
hooks.on_user_turn("Check CI pipeline health")
# 3. Agent uses tool
hooks.on_tool_call("shell", "pytest tests/", "12 passed")
# 4. Agent responds
hooks.on_agent_turn("CI pipeline is healthy. All 12 tests passing.")
# 5. Important decision
hooks.on_important_decision("Decided to keep current CI runner", room="forge")
# 6. More interaction
hooks.on_user_turn("Good, check memory integration next")
hooks.on_agent_turn("Will test agent.memory module")
# 7. Session end
doc_id = hooks.on_session_end()
assert hooks.is_active is False