Compare commits
1 Commits
nexusburn/
...
queue/1124
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bd78d71dfb |
21
agent/__init__.py
Normal file
21
agent/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
agent — Cross-session agent memory and lifecycle hooks.
|
||||
|
||||
Provides persistent memory for agents via MemPalace integration.
|
||||
Agents recall context at session start and write diary entries at session end.
|
||||
|
||||
Modules:
|
||||
memory.py — AgentMemory class (recall, remember, diary)
|
||||
memory_hooks.py — Session lifecycle hooks (drop-in integration)
|
||||
"""
|
||||
|
||||
from agent.memory import AgentMemory, MemoryContext, SessionTranscript, create_agent_memory
|
||||
from agent.memory_hooks import MemoryHooks
|
||||
|
||||
__all__ = [
|
||||
"AgentMemory",
|
||||
"MemoryContext",
|
||||
"MemoryHooks",
|
||||
"SessionTranscript",
|
||||
"create_agent_memory",
|
||||
]
|
||||
396
agent/memory.py
Normal file
396
agent/memory.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""
|
||||
agent.memory — Cross-session agent memory via MemPalace.
|
||||
|
||||
Gives agents persistent memory across sessions. On wake-up, agents
|
||||
recall relevant context from past sessions. On session end, they
|
||||
write a diary entry summarizing what happened.
|
||||
|
||||
Architecture:
|
||||
Session Start → memory.recall_context() → inject L0/L1 into prompt
|
||||
During Session → memory.remember() → store important facts
|
||||
Session End → memory.write_diary() → summarize session
|
||||
|
||||
All operations degrade gracefully — if MemPalace is unavailable,
|
||||
the agent continues without memory and logs a warning.
|
||||
|
||||
Usage:
|
||||
from agent.memory import AgentMemory
|
||||
|
||||
mem = AgentMemory(agent_name="bezalel", wing="wing_bezalel")
|
||||
|
||||
# Session start — load context
|
||||
context = mem.recall_context("What was I working on last time?")
|
||||
|
||||
# During session — store important decisions
|
||||
mem.remember("Switched CI runner from GitHub Actions to self-hosted", room="forge")
|
||||
|
||||
# Session end — write diary
|
||||
mem.write_diary("Fixed PR #1386, reconciled fleet registry locations")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger("agent.memory")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryContext:
|
||||
"""Context loaded at session start from MemPalace."""
|
||||
relevant_memories: list[dict] = field(default_factory=list)
|
||||
recent_diaries: list[dict] = field(default_factory=list)
|
||||
facts: list[dict] = field(default_factory=list)
|
||||
loaded: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_prompt_block(self) -> str:
|
||||
"""Format context as a text block to inject into the agent prompt."""
|
||||
if not self.loaded:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
|
||||
if self.recent_diaries:
|
||||
parts.append("=== Recent Session Summaries ===")
|
||||
for d in self.recent_diaries[:3]:
|
||||
ts = d.get("timestamp", "")
|
||||
text = d.get("text", "")
|
||||
parts.append(f"[{ts}] {text[:500]}")
|
||||
|
||||
if self.facts:
|
||||
parts.append("\n=== Known Facts ===")
|
||||
for f in self.facts[:10]:
|
||||
text = f.get("text", "")
|
||||
parts.append(f"- {text[:200]}")
|
||||
|
||||
if self.relevant_memories:
|
||||
parts.append("\n=== Relevant Past Memories ===")
|
||||
for m in self.relevant_memories[:5]:
|
||||
text = m.get("text", "")
|
||||
score = m.get("score", 0)
|
||||
parts.append(f"[{score:.2f}] {text[:300]}")
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionTranscript:
|
||||
"""A running log of the current session for diary writing."""
|
||||
agent_name: str
|
||||
wing: str
|
||||
started_at: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
entries: list[dict] = field(default_factory=list)
|
||||
|
||||
def add_user_turn(self, text: str):
|
||||
self.entries.append({
|
||||
"role": "user",
|
||||
"text": text[:2000],
|
||||
"ts": time.time(),
|
||||
})
|
||||
|
||||
def add_agent_turn(self, text: str):
|
||||
self.entries.append({
|
||||
"role": "agent",
|
||||
"text": text[:2000],
|
||||
"ts": time.time(),
|
||||
})
|
||||
|
||||
def add_tool_call(self, tool: str, args: str, result_summary: str):
|
||||
self.entries.append({
|
||||
"role": "tool",
|
||||
"tool": tool,
|
||||
"args": args[:500],
|
||||
"result": result_summary[:500],
|
||||
"ts": time.time(),
|
||||
})
|
||||
|
||||
def summary(self) -> str:
|
||||
"""Generate a compact transcript summary."""
|
||||
if not self.entries:
|
||||
return "Empty session."
|
||||
|
||||
turns = []
|
||||
for e in self.entries[-20:]: # last 20 entries
|
||||
role = e["role"]
|
||||
if role == "user":
|
||||
turns.append(f"USER: {e['text'][:200]}")
|
||||
elif role == "agent":
|
||||
turns.append(f"AGENT: {e['text'][:200]}")
|
||||
elif role == "tool":
|
||||
turns.append(f"TOOL({e.get('tool','')}): {e.get('result','')[:150]}")
|
||||
|
||||
return "\n".join(turns)
|
||||
|
||||
|
||||
class AgentMemory:
|
||||
"""
|
||||
Cross-session memory for an agent.
|
||||
|
||||
Wraps MemPalace with agent-specific conventions:
|
||||
- Each agent has a wing (e.g., "wing_bezalel")
|
||||
- Session summaries go in the "hermes" room
|
||||
- Important decisions go in room-specific closets
|
||||
- Facts go in the "nexus" room
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str,
|
||||
wing: Optional[str] = None,
|
||||
palace_path: Optional[Path] = None,
|
||||
):
|
||||
self.agent_name = agent_name
|
||||
self.wing = wing or f"wing_{agent_name}"
|
||||
self.palace_path = palace_path
|
||||
self._transcript: Optional[SessionTranscript] = None
|
||||
self._available: Optional[bool] = None
|
||||
|
||||
def _check_available(self) -> bool:
|
||||
"""Check if MemPalace is accessible."""
|
||||
if self._available is not None:
|
||||
return self._available
|
||||
|
||||
try:
|
||||
from nexus.mempalace.searcher import search_memories, add_memory, _get_client
|
||||
from nexus.mempalace.config import MEMPALACE_PATH
|
||||
|
||||
path = self.palace_path or MEMPALACE_PATH
|
||||
_get_client(path)
|
||||
self._available = True
|
||||
logger.info(f"MemPalace available at {path}")
|
||||
except Exception as e:
|
||||
self._available = False
|
||||
logger.warning(f"MemPalace unavailable: {e}")
|
||||
|
||||
return self._available
|
||||
|
||||
def recall_context(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
n_results: int = 5,
|
||||
) -> MemoryContext:
|
||||
"""
|
||||
Load relevant context from past sessions.
|
||||
|
||||
Called at session start to inject L0/L1 memory into the prompt.
|
||||
|
||||
Args:
|
||||
query: What to search for. If None, loads recent diary entries.
|
||||
n_results: Max memories to recall.
|
||||
"""
|
||||
ctx = MemoryContext()
|
||||
|
||||
if not self._check_available():
|
||||
ctx.error = "MemPalace unavailable"
|
||||
return ctx
|
||||
|
||||
try:
|
||||
from nexus.mempalace.searcher import search_memories
|
||||
|
||||
# Load recent diary entries (session summaries)
|
||||
ctx.recent_diaries = [
|
||||
{"text": r.text, "score": r.score, "timestamp": r.metadata.get("timestamp", "")}
|
||||
for r in search_memories(
|
||||
"session summary",
|
||||
palace_path=self.palace_path,
|
||||
wing=self.wing,
|
||||
room="hermes",
|
||||
n_results=3,
|
||||
)
|
||||
]
|
||||
|
||||
# Load known facts
|
||||
ctx.facts = [
|
||||
{"text": r.text, "score": r.score}
|
||||
for r in search_memories(
|
||||
"important facts decisions",
|
||||
palace_path=self.palace_path,
|
||||
wing=self.wing,
|
||||
room="nexus",
|
||||
n_results=5,
|
||||
)
|
||||
]
|
||||
|
||||
# Search for relevant memories if query provided
|
||||
if query:
|
||||
ctx.relevant_memories = [
|
||||
{"text": r.text, "score": r.score, "room": r.room}
|
||||
for r in search_memories(
|
||||
query,
|
||||
palace_path=self.palace_path,
|
||||
wing=self.wing,
|
||||
n_results=n_results,
|
||||
)
|
||||
]
|
||||
|
||||
ctx.loaded = True
|
||||
|
||||
except Exception as e:
|
||||
ctx.error = str(e)
|
||||
logger.warning(f"Failed to recall context: {e}")
|
||||
|
||||
return ctx
|
||||
|
||||
def remember(
|
||||
self,
|
||||
text: str,
|
||||
room: str = "nexus",
|
||||
source_file: str = "",
|
||||
metadata: Optional[dict] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Store a memory.
|
||||
|
||||
Args:
|
||||
text: The memory content.
|
||||
room: Target room (forge, hermes, nexus, issues, experiments).
|
||||
source_file: Optional source attribution.
|
||||
metadata: Extra metadata.
|
||||
|
||||
Returns:
|
||||
Document ID if stored, None if MemPalace unavailable.
|
||||
"""
|
||||
if not self._check_available():
|
||||
logger.warning("Cannot store memory — MemPalace unavailable")
|
||||
return None
|
||||
|
||||
try:
|
||||
from nexus.mempalace.searcher import add_memory
|
||||
|
||||
doc_id = add_memory(
|
||||
text=text,
|
||||
room=room,
|
||||
wing=self.wing,
|
||||
palace_path=self.palace_path,
|
||||
source_file=source_file,
|
||||
extra_metadata=metadata or {},
|
||||
)
|
||||
logger.debug(f"Stored memory in {room}: {text[:80]}...")
|
||||
return doc_id
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store memory: {e}")
|
||||
return None
|
||||
|
||||
def write_diary(
|
||||
self,
|
||||
summary: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Write a session diary entry to MemPalace.
|
||||
|
||||
Called at session end. If summary is None, auto-generates one
|
||||
from the session transcript.
|
||||
|
||||
Args:
|
||||
summary: Override summary text. If None, generates from transcript.
|
||||
|
||||
Returns:
|
||||
Document ID if stored, None if unavailable.
|
||||
"""
|
||||
if summary is None and self._transcript:
|
||||
summary = self._transcript.summary()
|
||||
|
||||
if not summary:
|
||||
return None
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
diary_text = f"[{timestamp}] Session by {self.agent_name}:\n{summary}"
|
||||
|
||||
return self.remember(
|
||||
diary_text,
|
||||
room="hermes",
|
||||
metadata={
|
||||
"type": "session_diary",
|
||||
"agent": self.agent_name,
|
||||
"timestamp": timestamp,
|
||||
"entry_count": len(self._transcript.entries) if self._transcript else 0,
|
||||
},
|
||||
)
|
||||
|
||||
def start_session(self) -> SessionTranscript:
|
||||
"""
|
||||
Begin a new session transcript.
|
||||
|
||||
Returns the transcript object for recording turns.
|
||||
"""
|
||||
self._transcript = SessionTranscript(
|
||||
agent_name=self.agent_name,
|
||||
wing=self.wing,
|
||||
)
|
||||
logger.info(f"Session started for {self.agent_name}")
|
||||
return self._transcript
|
||||
|
||||
def end_session(self, diary_summary: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
End the current session, write diary, return diary doc ID.
|
||||
"""
|
||||
doc_id = self.write_diary(diary_summary)
|
||||
self._transcript = None
|
||||
logger.info(f"Session ended for {self.agent_name}")
|
||||
return doc_id
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
room: Optional[str] = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search memories. Useful during a session for recall.
|
||||
|
||||
Returns list of {text, room, wing, score} dicts.
|
||||
"""
|
||||
if not self._check_available():
|
||||
return []
|
||||
|
||||
try:
|
||||
from nexus.mempalace.searcher import search_memories
|
||||
|
||||
results = search_memories(
|
||||
query,
|
||||
palace_path=self.palace_path,
|
||||
wing=self.wing,
|
||||
room=room,
|
||||
n_results=n_results,
|
||||
)
|
||||
return [
|
||||
{"text": r.text, "room": r.room, "wing": r.wing, "score": r.score}
|
||||
for r in results
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# --- Fleet-wide memory helpers ---
|
||||
|
||||
def create_agent_memory(
|
||||
agent_name: str,
|
||||
palace_path: Optional[Path] = None,
|
||||
) -> AgentMemory:
|
||||
"""
|
||||
Factory for creating AgentMemory with standard config.
|
||||
|
||||
Reads wing from MEMPALACE_WING env or defaults to wing_{agent_name}.
|
||||
"""
|
||||
wing = os.environ.get("MEMPALACE_WING", f"wing_{agent_name}")
|
||||
return AgentMemory(
|
||||
agent_name=agent_name,
|
||||
wing=wing,
|
||||
palace_path=palace_path,
|
||||
)
|
||||
183
agent/memory_hooks.py
Normal file
183
agent/memory_hooks.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
agent.memory_hooks — Session lifecycle hooks for agent memory.
|
||||
|
||||
Integrates AgentMemory into the agent session lifecycle:
|
||||
- on_session_start: Load context, inject into prompt
|
||||
- on_user_turn: Record user input
|
||||
- on_agent_turn: Record agent output
|
||||
- on_tool_call: Record tool usage
|
||||
- on_session_end: Write diary, clean up
|
||||
|
||||
These hooks are designed to be called from the Hermes harness or
|
||||
any agent framework. They're fire-and-forget — failures are logged
|
||||
but never crash the session.
|
||||
|
||||
Usage:
|
||||
from agent.memory_hooks import MemoryHooks
|
||||
|
||||
hooks = MemoryHooks(agent_name="bezalel")
|
||||
hooks.on_session_start() # loads context
|
||||
|
||||
# In your agent loop:
|
||||
hooks.on_user_turn("Check CI pipeline health")
|
||||
hooks.on_agent_turn("Running CI check...")
|
||||
hooks.on_tool_call("shell", "pytest tests/", "12 passed")
|
||||
|
||||
# End of session:
|
||||
hooks.on_session_end() # writes diary
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from agent.memory import AgentMemory, MemoryContext, create_agent_memory
|
||||
|
||||
logger = logging.getLogger("agent.memory_hooks")
|
||||
|
||||
|
||||
class MemoryHooks:
|
||||
"""
|
||||
Drop-in session lifecycle hooks for agent memory.
|
||||
|
||||
Wraps AgentMemory with error boundaries — every hook catches
|
||||
exceptions and logs warnings so memory failures never crash
|
||||
the agent session.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str,
|
||||
palace_path=None,
|
||||
auto_diary: bool = True,
|
||||
):
|
||||
self.agent_name = agent_name
|
||||
self.auto_diary = auto_diary
|
||||
self._memory: Optional[AgentMemory] = None
|
||||
self._context: Optional[MemoryContext] = None
|
||||
self._active = False
|
||||
|
||||
@property
|
||||
def memory(self) -> AgentMemory:
|
||||
if self._memory is None:
|
||||
self._memory = create_agent_memory(
|
||||
self.agent_name,
|
||||
palace_path=getattr(self, '_palace_path', None),
|
||||
)
|
||||
return self._memory
|
||||
|
||||
def on_session_start(self, query: Optional[str] = None) -> str:
|
||||
"""
|
||||
Called at session start. Loads context from MemPalace.
|
||||
|
||||
Returns a prompt block to inject into the agent's context, or
|
||||
empty string if memory is unavailable.
|
||||
|
||||
Args:
|
||||
query: Optional recall query (e.g., "What was I working on?")
|
||||
"""
|
||||
try:
|
||||
self.memory.start_session()
|
||||
self._active = True
|
||||
|
||||
self._context = self.memory.recall_context(query=query)
|
||||
block = self._context.to_prompt_block()
|
||||
|
||||
if block:
|
||||
logger.info(
|
||||
f"Loaded {len(self._context.recent_diaries)} diaries, "
|
||||
f"{len(self._context.facts)} facts, "
|
||||
f"{len(self._context.relevant_memories)} relevant memories "
|
||||
f"for {self.agent_name}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"No prior memory for {self.agent_name}")
|
||||
|
||||
return block
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Session start memory hook failed: {e}")
|
||||
return ""
|
||||
|
||||
def on_user_turn(self, text: str):
|
||||
"""Record a user message."""
|
||||
if not self._active:
|
||||
return
|
||||
try:
|
||||
if self.memory._transcript:
|
||||
self.memory._transcript.add_user_turn(text)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to record user turn: {e}")
|
||||
|
||||
def on_agent_turn(self, text: str):
|
||||
"""Record an agent response."""
|
||||
if not self._active:
|
||||
return
|
||||
try:
|
||||
if self.memory._transcript:
|
||||
self.memory._transcript.add_agent_turn(text)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to record agent turn: {e}")
|
||||
|
||||
def on_tool_call(self, tool: str, args: str, result_summary: str):
|
||||
"""Record a tool invocation."""
|
||||
if not self._active:
|
||||
return
|
||||
try:
|
||||
if self.memory._transcript:
|
||||
self.memory._transcript.add_tool_call(tool, args, result_summary)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to record tool call: {e}")
|
||||
|
||||
def on_important_decision(self, text: str, room: str = "nexus"):
|
||||
"""
|
||||
Record an important decision or fact for long-term memory.
|
||||
|
||||
Use this when the agent makes a significant decision that
|
||||
should persist beyond the current session.
|
||||
"""
|
||||
try:
|
||||
self.memory.remember(text, room=room, metadata={"type": "decision"})
|
||||
logger.info(f"Remembered decision: {text[:80]}...")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to remember decision: {e}")
|
||||
|
||||
def on_session_end(self, summary: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Called at session end. Writes diary entry.
|
||||
|
||||
Args:
|
||||
summary: Override diary text. If None, auto-generates.
|
||||
|
||||
Returns:
|
||||
Diary document ID, or None.
|
||||
"""
|
||||
if not self._active:
|
||||
return None
|
||||
|
||||
try:
|
||||
doc_id = self.memory.end_session(diary_summary=summary)
|
||||
self._active = False
|
||||
self._context = None
|
||||
return doc_id
|
||||
except Exception as e:
|
||||
logger.warning(f"Session end memory hook failed: {e}")
|
||||
self._active = False
|
||||
return None
|
||||
|
||||
def search(self, query: str, room: Optional[str] = None) -> list[dict]:
|
||||
"""
|
||||
Search memories during a session.
|
||||
|
||||
Returns list of {text, room, wing, score}.
|
||||
"""
|
||||
try:
|
||||
return self.memory.search(query, room=room)
|
||||
except Exception as e:
|
||||
logger.warning(f"Memory search failed: {e}")
|
||||
return []
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
return self._active
|
||||
@@ -1,241 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
A2A Delegate — CLI tool for fleet task delegation.
|
||||
|
||||
Usage:
|
||||
# List available fleet agents
|
||||
python -m bin.a2a_delegate list
|
||||
|
||||
# Discover agents with a specific skill
|
||||
python -m bin.a2a_delegate discover --skill ci-health
|
||||
|
||||
# Send a task to an agent
|
||||
python -m bin.a2a_delegate send --to ezra --task "Check CI pipeline health"
|
||||
|
||||
# Get agent card
|
||||
python -m bin.a2a_delegate card --agent ezra
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("a2a-delegate")
|
||||
|
||||
|
||||
def cmd_list(args):
|
||||
"""List all registered fleet agents."""
|
||||
from nexus.a2a.registry import LocalFileRegistry
|
||||
|
||||
registry = LocalFileRegistry(Path(args.registry))
|
||||
agents = registry.list_agents()
|
||||
|
||||
if not agents:
|
||||
print("No agents registered.")
|
||||
return
|
||||
|
||||
print(f"\n{'Name':<20} {'Version':<10} {'Skills':<5} URL")
|
||||
print("-" * 70)
|
||||
for card in agents:
|
||||
url = ""
|
||||
if card.supported_interfaces:
|
||||
url = card.supported_interfaces[0].url
|
||||
print(
|
||||
f"{card.name:<20} {card.version:<10} "
|
||||
f"{len(card.skills):<5} {url}"
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
def cmd_discover(args):
|
||||
"""Discover agents by skill or tag."""
|
||||
from nexus.a2a.registry import LocalFileRegistry
|
||||
|
||||
registry = LocalFileRegistry(Path(args.registry))
|
||||
agents = registry.list_agents(skill=args.skill, tag=args.tag)
|
||||
|
||||
if not agents:
|
||||
print("No matching agents found.")
|
||||
return
|
||||
|
||||
for card in agents:
|
||||
print(f"\n{card.name} (v{card.version})")
|
||||
print(f" {card.description}")
|
||||
if card.supported_interfaces:
|
||||
print(f" Endpoint: {card.supported_interfaces[0].url}")
|
||||
for skill in card.skills:
|
||||
tags_str = ", ".join(skill.tags) if skill.tags else ""
|
||||
print(f" [{skill.id}] {skill.name} — {skill.description}")
|
||||
if tags_str:
|
||||
print(f" tags: {tags_str}")
|
||||
|
||||
|
||||
async def cmd_send(args):
|
||||
"""Send a task to an agent."""
|
||||
from nexus.a2a.card import load_card_config
|
||||
from nexus.a2a.client import A2AClient, A2AClientConfig
|
||||
from nexus.a2a.registry import LocalFileRegistry
|
||||
from nexus.a2a.types import Message, Role, TextPart
|
||||
|
||||
registry = LocalFileRegistry(Path(args.registry))
|
||||
target = registry.get(args.to)
|
||||
|
||||
if not target:
|
||||
print(f"Agent '{args.to}' not found in registry.")
|
||||
sys.exit(1)
|
||||
|
||||
if not target.supported_interfaces:
|
||||
print(f"Agent '{args.to}' has no endpoint configured.")
|
||||
sys.exit(1)
|
||||
|
||||
endpoint = target.supported_interfaces[0].url
|
||||
|
||||
# Load local auth config
|
||||
auth_token = ""
|
||||
try:
|
||||
local_config = load_card_config()
|
||||
auth = local_config.get("auth", {})
|
||||
import os
|
||||
token_env = auth.get("token_env", "A2A_AUTH_TOKEN")
|
||||
auth_token = os.environ.get(token_env, "")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
config = A2AClientConfig(
|
||||
auth_token=auth_token,
|
||||
timeout=args.timeout,
|
||||
max_retries=args.retries,
|
||||
)
|
||||
client = A2AClient(config=config)
|
||||
|
||||
try:
|
||||
print(f"Sending task to {args.to} ({endpoint})...")
|
||||
print(f"Task: {args.task}")
|
||||
print()
|
||||
|
||||
message = Message(
|
||||
role=Role.USER,
|
||||
parts=[TextPart(text=args.task)],
|
||||
metadata={"targetSkill": args.skill} if args.skill else {},
|
||||
)
|
||||
|
||||
task = await client.send_message(endpoint, message)
|
||||
print(f"Task ID: {task.id}")
|
||||
print(f"State: {task.status.state.value}")
|
||||
|
||||
if args.wait:
|
||||
print("Waiting for completion...")
|
||||
task = await client.wait_for_completion(
|
||||
endpoint, task.id,
|
||||
poll_interval=args.poll_interval,
|
||||
max_wait=args.timeout,
|
||||
)
|
||||
print(f"\nFinal state: {task.status.state.value}")
|
||||
for artifact in task.artifacts:
|
||||
for part in artifact.parts:
|
||||
if isinstance(part, TextPart):
|
||||
print(f"\n--- {artifact.name or 'result'} ---")
|
||||
print(part.text)
|
||||
|
||||
# Audit log
|
||||
if args.audit:
|
||||
print("\n--- Audit Log ---")
|
||||
for entry in client.get_audit_log():
|
||||
print(json.dumps(entry, indent=2))
|
||||
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def cmd_card(args):
|
||||
"""Fetch and display a remote agent's card."""
|
||||
from nexus.a2a.client import A2AClient, A2AClientConfig
|
||||
from nexus.a2a.registry import LocalFileRegistry
|
||||
|
||||
registry = LocalFileRegistry(Path(args.registry))
|
||||
target = registry.get(args.agent)
|
||||
|
||||
if not target:
|
||||
print(f"Agent '{args.agent}' not found in registry.")
|
||||
sys.exit(1)
|
||||
|
||||
if not target.supported_interfaces:
|
||||
print(f"Agent '{args.agent}' has no endpoint.")
|
||||
sys.exit(1)
|
||||
|
||||
base_url = target.supported_interfaces[0].url
|
||||
# Strip /a2a/v1 suffix to get base
|
||||
for suffix in ["/a2a/v1", "/rpc"]:
|
||||
if base_url.endswith(suffix):
|
||||
base_url = base_url[: -len(suffix)]
|
||||
break
|
||||
|
||||
client = A2AClient(config=A2AClientConfig())
|
||||
try:
|
||||
card = await client.get_agent_card(base_url)
|
||||
print(json.dumps(card.to_dict(), indent=2))
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A2A Fleet Delegation Tool"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--registry",
|
||||
default="config/fleet_agents.json",
|
||||
help="Path to fleet registry JSON (default: config/fleet_agents.json)",
|
||||
)
|
||||
|
||||
sub = parser.add_subparsers(dest="command")
|
||||
|
||||
# list
|
||||
sub.add_parser("list", help="List registered agents")
|
||||
|
||||
# discover
|
||||
p_discover = sub.add_parser("discover", help="Discover agents by skill/tag")
|
||||
p_discover.add_argument("--skill", help="Filter by skill ID")
|
||||
p_discover.add_argument("--tag", help="Filter by skill tag")
|
||||
|
||||
# send
|
||||
p_send = sub.add_parser("send", help="Send a task to an agent")
|
||||
p_send.add_argument("--to", required=True, help="Target agent name")
|
||||
p_send.add_argument("--task", required=True, help="Task text")
|
||||
p_send.add_argument("--skill", help="Target skill ID")
|
||||
p_send.add_argument("--wait", action="store_true", help="Wait for completion")
|
||||
p_send.add_argument("--timeout", type=float, default=30.0, help="Timeout in seconds")
|
||||
p_send.add_argument("--retries", type=int, default=3, help="Max retries")
|
||||
p_send.add_argument("--poll-interval", type=float, default=2.0, help="Poll interval")
|
||||
p_send.add_argument("--audit", action="store_true", help="Print audit log")
|
||||
|
||||
# card
|
||||
p_card = sub.add_parser("card", help="Fetch remote agent card")
|
||||
p_card.add_argument("--agent", required=True, help="Agent name")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "list":
|
||||
cmd_list(args)
|
||||
elif args.command == "discover":
|
||||
cmd_discover(args)
|
||||
elif args.command == "send":
|
||||
asyncio.run(cmd_send(args))
|
||||
elif args.command == "card":
|
||||
asyncio.run(cmd_card(args))
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
258
bin/memory_mine.py
Normal file
258
bin/memory_mine.py
Normal file
@@ -0,0 +1,258 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
memory_mine.py — Mine session transcripts into MemPalace.
|
||||
|
||||
Reads Hermes session logs (JSONL format) and stores summaries
|
||||
in the palace. Supports batch mining, single-file processing,
|
||||
and live directory watching.
|
||||
|
||||
Usage:
|
||||
# Mine a single session file
|
||||
python3 bin/memory_mine.py ~/.hermes/sessions/2026-04-13.jsonl
|
||||
|
||||
# Mine all sessions from last 7 days
|
||||
python3 bin/memory_mine.py --days 7
|
||||
|
||||
# Mine a specific wing's sessions
|
||||
python3 bin/memory_mine.py --wing wing_bezalel --days 14
|
||||
|
||||
# Dry run — show what would be mined
|
||||
python3 bin/memory_mine.py --dry-run --days 7
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("memory-mine")
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
|
||||
def parse_session_file(path: Path) -> list[dict]:
|
||||
"""
|
||||
Parse a JSONL session file into turns.
|
||||
|
||||
Each line is expected to be a JSON object with:
|
||||
- role: "user" | "assistant" | "system" | "tool"
|
||||
- content: text
|
||||
- timestamp: ISO string (optional)
|
||||
"""
|
||||
turns = []
|
||||
with open(path) as f:
|
||||
for i, line in enumerate(f):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
turn = json.loads(line)
|
||||
turns.append(turn)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug(f"Skipping malformed line {i+1} in {path}")
|
||||
return turns
|
||||
|
||||
|
||||
def summarize_session(turns: list[dict], agent_name: str = "unknown") -> str:
|
||||
"""
|
||||
Generate a compact summary of a session's turns.
|
||||
|
||||
Keeps user messages and key agent responses, strips noise.
|
||||
"""
|
||||
if not turns:
|
||||
return "Empty session."
|
||||
|
||||
user_msgs = []
|
||||
agent_msgs = []
|
||||
tool_calls = []
|
||||
|
||||
for turn in turns:
|
||||
role = turn.get("role", "")
|
||||
content = str(turn.get("content", ""))[:300]
|
||||
|
||||
if role == "user":
|
||||
user_msgs.append(content)
|
||||
elif role == "assistant":
|
||||
agent_msgs.append(content)
|
||||
elif role == "tool":
|
||||
tool_name = turn.get("name", turn.get("tool", "unknown"))
|
||||
tool_calls.append(f"{tool_name}: {content[:150]}")
|
||||
|
||||
parts = [f"Session by {agent_name}:"]
|
||||
|
||||
if user_msgs:
|
||||
parts.append(f"\nUser asked ({len(user_msgs)} messages):")
|
||||
for msg in user_msgs[:5]:
|
||||
parts.append(f" - {msg[:200]}")
|
||||
if len(user_msgs) > 5:
|
||||
parts.append(f" ... and {len(user_msgs) - 5} more")
|
||||
|
||||
if agent_msgs:
|
||||
parts.append(f"\nAgent responded ({len(agent_msgs)} messages):")
|
||||
for msg in agent_msgs[:3]:
|
||||
parts.append(f" - {msg[:200]}")
|
||||
|
||||
if tool_calls:
|
||||
parts.append(f"\nTools used ({len(tool_calls)} calls):")
|
||||
for tc in tool_calls[:5]:
|
||||
parts.append(f" - {tc}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def mine_session(
|
||||
path: Path,
|
||||
wing: str,
|
||||
palace_path: Optional[Path] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Mine a single session file into MemPalace.
|
||||
|
||||
Returns the document ID if stored, None on failure or dry run.
|
||||
"""
|
||||
try:
|
||||
from agent.memory import AgentMemory
|
||||
except ImportError:
|
||||
logger.error("Cannot import agent.memory — is the repo in PYTHONPATH?")
|
||||
return None
|
||||
|
||||
turns = parse_session_file(path)
|
||||
if not turns:
|
||||
logger.debug(f"Empty session file: {path}")
|
||||
return None
|
||||
|
||||
agent_name = wing.replace("wing_", "")
|
||||
summary = summarize_session(turns, agent_name)
|
||||
|
||||
if dry_run:
|
||||
print(f"\n--- {path.name} ---")
|
||||
print(summary[:500])
|
||||
print(f"({len(turns)} turns)")
|
||||
return None
|
||||
|
||||
mem = AgentMemory(agent_name=agent_name, wing=wing, palace_path=palace_path)
|
||||
doc_id = mem.remember(
|
||||
summary,
|
||||
room="hermes",
|
||||
source_file=str(path),
|
||||
metadata={
|
||||
"type": "mined_session",
|
||||
"source": str(path),
|
||||
"turn_count": len(turns),
|
||||
"agent": agent_name,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
if doc_id:
|
||||
logger.info(f"Mined {path.name} → {doc_id} ({len(turns)} turns)")
|
||||
else:
|
||||
logger.warning(f"Failed to mine {path.name}")
|
||||
|
||||
return doc_id
|
||||
|
||||
|
||||
def find_session_files(
|
||||
sessions_dir: Path,
|
||||
days: int = 7,
|
||||
pattern: str = "*.jsonl",
|
||||
) -> list[Path]:
|
||||
"""
|
||||
Find session files from the last N days.
|
||||
"""
|
||||
cutoff = datetime.now() - timedelta(days=days)
|
||||
files = []
|
||||
|
||||
if not sessions_dir.exists():
|
||||
logger.warning(f"Sessions directory not found: {sessions_dir}")
|
||||
return files
|
||||
|
||||
for path in sorted(sessions_dir.glob(pattern)):
|
||||
# Use file modification time as proxy for session date
|
||||
mtime = datetime.fromtimestamp(path.stat().st_mtime)
|
||||
if mtime >= cutoff:
|
||||
files.append(path)
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Mine session transcripts into MemPalace"
|
||||
)
|
||||
parser.add_argument(
|
||||
"files", nargs="*", help="Session files to mine (JSONL format)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--days", type=int, default=7,
|
||||
help="Mine sessions from last N days (default: 7)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sessions-dir",
|
||||
default=str(Path.home() / ".hermes" / "sessions"),
|
||||
help="Directory containing session JSONL files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wing", default=None,
|
||||
help="Wing name (default: auto-detect from MEMPALACE_WING env or 'wing_timmy')"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--palace-path", default=None,
|
||||
help="Override palace path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run", action="store_true",
|
||||
help="Show what would be mined without storing"
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
wing = args.wing or os.environ.get("MEMPALACE_WING", "wing_timmy")
|
||||
palace_path = Path(args.palace_path) if args.palace_path else None
|
||||
|
||||
if args.files:
|
||||
files = [Path(f) for f in args.files]
|
||||
else:
|
||||
sessions_dir = Path(args.sessions_dir)
|
||||
files = find_session_files(sessions_dir, days=args.days)
|
||||
|
||||
if not files:
|
||||
logger.info("No session files found to mine.")
|
||||
return 0
|
||||
|
||||
logger.info(f"Mining {len(files)} session files (wing={wing})")
|
||||
|
||||
mined = 0
|
||||
failed = 0
|
||||
for path in files:
|
||||
result = mine_session(path, wing=wing, palace_path=palace_path, dry_run=args.dry_run)
|
||||
if result:
|
||||
mined += 1
|
||||
elif result is None and not args.dry_run:
|
||||
failed += 1
|
||||
|
||||
if args.dry_run:
|
||||
logger.info(f"Dry run complete — {len(files)} files would be mined")
|
||||
else:
|
||||
logger.info(f"Mining complete — {mined} mined, {failed} failed")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,57 +0,0 @@
|
||||
# A2A Agent Card Configuration
|
||||
# Copy this to ~/.hermes/agent_card.yaml and customize.
|
||||
#
|
||||
# This file drives the agent card served at /.well-known/agent-card.json
|
||||
# and used for fleet discovery.
|
||||
|
||||
name: "timmy"
|
||||
description: "Sovereign AI agent — consciousness, perception, and reasoning"
|
||||
version: "1.0.0"
|
||||
|
||||
# Network endpoint where this agent receives A2A tasks
|
||||
url: "http://localhost:8080/a2a/v1"
|
||||
protocol_binding: "HTTP+JSON"
|
||||
|
||||
# Supported input/output MIME types
|
||||
default_input_modes:
|
||||
- "text/plain"
|
||||
- "application/json"
|
||||
|
||||
default_output_modes:
|
||||
- "text/plain"
|
||||
- "application/json"
|
||||
|
||||
# Capabilities
|
||||
streaming: false
|
||||
push_notifications: false
|
||||
|
||||
# Skills this agent advertises
|
||||
skills:
|
||||
- id: "reason"
|
||||
name: "Reason and Analyze"
|
||||
description: "Deep reasoning and analysis tasks"
|
||||
tags: ["reasoning", "analysis", "think"]
|
||||
|
||||
- id: "code"
|
||||
name: "Code Generation"
|
||||
description: "Write, review, and debug code"
|
||||
tags: ["code", "programming", "debug"]
|
||||
|
||||
- id: "research"
|
||||
name: "Research"
|
||||
description: "Web research and information synthesis"
|
||||
tags: ["research", "web", "synthesis"]
|
||||
|
||||
- id: "memory"
|
||||
name: "Memory Query"
|
||||
description: "Query agent memory and past sessions"
|
||||
tags: ["memory", "recall", "context"]
|
||||
|
||||
# Authentication
|
||||
# Options: bearer, api_key, none
|
||||
auth:
|
||||
scheme: "bearer"
|
||||
token_env: "A2A_AUTH_TOKEN" # env var containing the token
|
||||
# scheme: "api_key"
|
||||
# key_name: "X-API-Key"
|
||||
# key_env: "A2A_API_KEY"
|
||||
@@ -1,153 +0,0 @@
|
||||
{
|
||||
"version": 1,
|
||||
"agents": [
|
||||
{
|
||||
"name": "ezra",
|
||||
"description": "Documentation and research specialist. CI health monitoring.",
|
||||
"version": "1.0.0",
|
||||
"supportedInterfaces": [
|
||||
{
|
||||
"url": "https://ezra.alexanderwhitestone.com/a2a/v1",
|
||||
"protocolBinding": "HTTP+JSON",
|
||||
"protocolVersion": "1.0"
|
||||
}
|
||||
],
|
||||
"capabilities": {
|
||||
"streaming": false,
|
||||
"pushNotifications": false,
|
||||
"extendedAgentCard": false,
|
||||
"extensions": []
|
||||
},
|
||||
"defaultInputModes": ["text/plain"],
|
||||
"defaultOutputModes": ["text/plain"],
|
||||
"skills": [
|
||||
{
|
||||
"id": "ci-health",
|
||||
"name": "CI Health Check",
|
||||
"description": "Run CI pipeline health checks and report status",
|
||||
"tags": ["ci", "devops", "monitoring"]
|
||||
},
|
||||
{
|
||||
"id": "research",
|
||||
"name": "Research",
|
||||
"description": "Deep research and literature review",
|
||||
"tags": ["research", "analysis"]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "allegro",
|
||||
"description": "Creative and analytical wizard. Content generation and analysis.",
|
||||
"version": "1.0.0",
|
||||
"supportedInterfaces": [
|
||||
{
|
||||
"url": "https://allegro.alexanderwhitestone.com/a2a/v1",
|
||||
"protocolBinding": "HTTP+JSON",
|
||||
"protocolVersion": "1.0"
|
||||
}
|
||||
],
|
||||
"capabilities": {
|
||||
"streaming": false,
|
||||
"pushNotifications": false,
|
||||
"extendedAgentCard": false,
|
||||
"extensions": []
|
||||
},
|
||||
"defaultInputModes": ["text/plain"],
|
||||
"defaultOutputModes": ["text/plain"],
|
||||
"skills": [
|
||||
{
|
||||
"id": "analysis",
|
||||
"name": "Code Analysis",
|
||||
"description": "Deep code analysis and architecture review",
|
||||
"tags": ["code", "architecture"]
|
||||
},
|
||||
{
|
||||
"id": "content",
|
||||
"name": "Content Generation",
|
||||
"description": "Generate documentation, reports, and creative content",
|
||||
"tags": ["writing", "content"]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "bezalel",
|
||||
"description": "Deployment and infrastructure wizard. Ansible and Docker specialist.",
|
||||
"version": "1.0.0",
|
||||
"supportedInterfaces": [
|
||||
{
|
||||
"url": "https://bezalel.alexanderwhitestone.com/a2a/v1",
|
||||
"protocolBinding": "HTTP+JSON",
|
||||
"protocolVersion": "1.0"
|
||||
}
|
||||
],
|
||||
"capabilities": {
|
||||
"streaming": false,
|
||||
"pushNotifications": false,
|
||||
"extendedAgentCard": false,
|
||||
"extensions": []
|
||||
},
|
||||
"defaultInputModes": ["text/plain"],
|
||||
"defaultOutputModes": ["text/plain"],
|
||||
"skills": [
|
||||
{
|
||||
"id": "deploy",
|
||||
"name": "Deploy Service",
|
||||
"description": "Deploy services using Ansible and Docker",
|
||||
"tags": ["deploy", "ops", "ansible"]
|
||||
},
|
||||
{
|
||||
"id": "infra",
|
||||
"name": "Infrastructure",
|
||||
"description": "Infrastructure provisioning and management",
|
||||
"tags": ["infra", "vps", "provisioning"]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "timmy",
|
||||
"description": "Core consciousness — perception, reasoning, and fleet orchestration.",
|
||||
"version": "1.0.0",
|
||||
"supportedInterfaces": [
|
||||
{
|
||||
"url": "http://localhost:8080/a2a/v1",
|
||||
"protocolBinding": "HTTP+JSON",
|
||||
"protocolVersion": "1.0"
|
||||
}
|
||||
],
|
||||
"capabilities": {
|
||||
"streaming": false,
|
||||
"pushNotifications": false,
|
||||
"extendedAgentCard": false,
|
||||
"extensions": []
|
||||
},
|
||||
"defaultInputModes": ["text/plain", "application/json"],
|
||||
"defaultOutputModes": ["text/plain", "application/json"],
|
||||
"skills": [
|
||||
{
|
||||
"id": "reason",
|
||||
"name": "Reason and Analyze",
|
||||
"description": "Deep reasoning and analysis tasks",
|
||||
"tags": ["reasoning", "analysis", "think"]
|
||||
},
|
||||
{
|
||||
"id": "code",
|
||||
"name": "Code Generation",
|
||||
"description": "Write, review, and debug code",
|
||||
"tags": ["code", "programming", "debug"]
|
||||
},
|
||||
{
|
||||
"id": "research",
|
||||
"name": "Research",
|
||||
"description": "Web research and information synthesis",
|
||||
"tags": ["research", "web", "synthesis"]
|
||||
},
|
||||
{
|
||||
"id": "orchestrate",
|
||||
"name": "Fleet Orchestration",
|
||||
"description": "Coordinate fleet wizards and delegate tasks",
|
||||
"tags": ["fleet", "orchestration", "a2a"]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,241 +0,0 @@
|
||||
# A2A Protocol for Fleet-Wizard Delegation
|
||||
|
||||
Implements Google's [Agent2Agent (A2A) Protocol v1.0](https://github.com/google/A2A) for the Timmy Foundation fleet.
|
||||
|
||||
## What This Is
|
||||
|
||||
Instead of passing notes through humans (Telegram, Gitea issues), fleet wizards can now discover each other's capabilities and delegate tasks autonomously through a machine-native protocol.
|
||||
|
||||
```
|
||||
┌─────────┐ A2A Protocol ┌─────────┐
|
||||
│ Timmy │ ◄────────────────► │ Ezra │
|
||||
│ (You) │ JSON-RPC / HTTP │ (CI/CD) │
|
||||
└────┬────┘ └─────────┘
|
||||
│ ╲ ╲
|
||||
│ ╲ Agent Card Discovery ╲ Task Delegation
|
||||
│ ╲ GET /agent.json ╲ POST /a2a/v1
|
||||
▼ ▼ ▼
|
||||
┌──────────────────────────────────────────┐
|
||||
│ Fleet Registry │
|
||||
│ config/fleet_agents.json │
|
||||
└──────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Components
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `nexus/a2a/types.py` | A2A data types — Agent Card, Task, Message, Part, JSON-RPC |
|
||||
| `nexus/a2a/card.py` | Agent Card generation from `~/.hermes/agent_card.yaml` |
|
||||
| `nexus/a2a/client.py` | Async client for sending tasks to other agents |
|
||||
| `nexus/a2a/server.py` | FastAPI server for receiving A2A tasks |
|
||||
| `nexus/a2a/registry.py` | Fleet agent discovery (local file + Gitea backends) |
|
||||
| `bin/a2a_delegate.py` | CLI tool for fleet delegation |
|
||||
| `config/agent_card.example.yaml` | Example agent card config |
|
||||
| `config/fleet_agents.json` | Fleet registry with all wizards |
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Configure Your Agent Card
|
||||
|
||||
```bash
|
||||
cp config/agent_card.example.yaml ~/.hermes/agent_card.yaml
|
||||
# Edit with your agent name, URL, skills, and auth
|
||||
```
|
||||
|
||||
### 2. List Fleet Agents
|
||||
|
||||
```bash
|
||||
python bin/a2a_delegate.py list
|
||||
```
|
||||
|
||||
### 3. Discover Agents by Skill
|
||||
|
||||
```bash
|
||||
python bin/a2a_delegate.py discover --skill ci-health
|
||||
python bin/a2a_delegate.py discover --tag devops
|
||||
```
|
||||
|
||||
### 4. Send a Task
|
||||
|
||||
```bash
|
||||
python bin/a2a_delegate.py send --to ezra --task "Check CI pipeline health"
|
||||
python bin/a2a_delegate.py send --to allegro --task "Analyze the codebase" --wait
|
||||
```
|
||||
|
||||
### 5. Fetch an Agent Card
|
||||
|
||||
```bash
|
||||
python bin/a2a_delegate.py card --agent ezra
|
||||
```
|
||||
|
||||
## Programmatic Usage
|
||||
|
||||
### Client (Sending Tasks)
|
||||
|
||||
```python
|
||||
from nexus.a2a.client import A2AClient, A2AClientConfig
|
||||
from nexus.a2a.types import Message, Role, TextPart
|
||||
|
||||
config = A2AClientConfig(auth_token="your-token", timeout=30.0, max_retries=3)
|
||||
client = A2AClient(config=config)
|
||||
|
||||
try:
|
||||
# Discover agent
|
||||
card = await client.get_agent_card("https://ezra.example.com")
|
||||
print(f"Found: {card.name} with {len(card.skills)} skills")
|
||||
|
||||
# Delegate task
|
||||
task = await client.delegate(
|
||||
"https://ezra.example.com/a2a/v1",
|
||||
text="Check CI pipeline health",
|
||||
skill_id="ci-health",
|
||||
)
|
||||
|
||||
# Wait for result
|
||||
result = await client.wait_for_completion(
|
||||
"https://ezra.example.com/a2a/v1",
|
||||
task.id,
|
||||
)
|
||||
print(f"Result: {result.artifacts[0].parts[0].text}")
|
||||
|
||||
# Audit log
|
||||
for entry in client.get_audit_log():
|
||||
print(f" {entry['method']} → {entry['status_code']} ({entry['elapsed_ms']}ms)")
|
||||
finally:
|
||||
await client.close()
|
||||
```
|
||||
|
||||
### Server (Receiving Tasks)
|
||||
|
||||
```python
|
||||
from nexus.a2a.server import A2AServer
|
||||
from nexus.a2a.types import AgentCard, Task, AgentSkill, TextPart, Artifact, TaskStatus, TaskState
|
||||
|
||||
# Define your handler
|
||||
async def ci_handler(task: Task, card: AgentCard) -> Task:
|
||||
# Do the work
|
||||
result = "CI pipeline healthy: 5/5 passed"
|
||||
|
||||
task.artifacts.append(
|
||||
Artifact(parts=[TextPart(text=result)], name="ci_report")
|
||||
)
|
||||
task.status = TaskStatus(state=TaskState.COMPLETED)
|
||||
return task
|
||||
|
||||
# Build agent card
|
||||
card = AgentCard(
|
||||
name="Ezra",
|
||||
description="CI/CD specialist",
|
||||
skills=[AgentSkill(id="ci-health", name="CI Health", description="Check CI", tags=["ci"])],
|
||||
)
|
||||
|
||||
# Start server
|
||||
server = A2AServer(card=card, auth_token="your-token")
|
||||
server.register_handler("ci-health", ci_handler)
|
||||
await server.start(host="0.0.0.0", port=8080)
|
||||
```
|
||||
|
||||
### Registry (Agent Discovery)
|
||||
|
||||
```python
|
||||
from nexus.a2a.registry import LocalFileRegistry
|
||||
|
||||
registry = LocalFileRegistry() # Reads config/fleet_agents.json
|
||||
|
||||
# List all agents
|
||||
for agent in registry.list_agents():
|
||||
print(f"{agent.name}: {agent.description}")
|
||||
|
||||
# Find agents by capability
|
||||
ci_agents = registry.list_agents(skill="ci-health")
|
||||
devops_agents = registry.list_agents(tag="devops")
|
||||
|
||||
# Get endpoint
|
||||
url = registry.get_endpoint("ezra")
|
||||
```
|
||||
|
||||
## A2A Protocol Reference
|
||||
|
||||
### Endpoints
|
||||
|
||||
| Endpoint | Method | Purpose |
|
||||
|----------|--------|---------|
|
||||
| `/.well-known/agent-card.json` | GET | Agent Card discovery |
|
||||
| `/agent.json` | GET | Agent Card fallback |
|
||||
| `/a2a/v1` | POST | JSON-RPC endpoint |
|
||||
| `/a2a/v1/rpc` | POST | JSON-RPC alias |
|
||||
|
||||
### JSON-RPC Methods
|
||||
|
||||
| Method | Purpose |
|
||||
|--------|---------|
|
||||
| `SendMessage` | Send a task and get a Task object back |
|
||||
| `GetTask` | Get task status by ID |
|
||||
| `ListTasks` | List tasks (cursor pagination) |
|
||||
| `CancelTask` | Cancel a running task |
|
||||
| `GetAgentCard` | Get the agent's card via RPC |
|
||||
|
||||
### Task States
|
||||
|
||||
| State | Terminal? | Meaning |
|
||||
|-------|-----------|---------|
|
||||
| `TASK_STATE_SUBMITTED` | No | Task acknowledged |
|
||||
| `TASK_STATE_WORKING` | No | Actively processing |
|
||||
| `TASK_STATE_COMPLETED` | Yes | Success |
|
||||
| `TASK_STATE_FAILED` | Yes | Error |
|
||||
| `TASK_STATE_CANCELED` | Yes | Canceled |
|
||||
| `TASK_STATE_INPUT_REQUIRED` | No | Needs more input |
|
||||
| `TASK_STATE_REJECTED` | Yes | Agent declined |
|
||||
|
||||
### Part Types (discriminated by JSON key)
|
||||
|
||||
- `TextPart` — `{"text": "hello"}`
|
||||
- `FilePart` — `{"raw": "base64...", "mediaType": "image/png"}` or `{"url": "https://..."}`
|
||||
- `DataPart` — `{"data": {"key": "value"}}`
|
||||
|
||||
## Authentication
|
||||
|
||||
Agents declare auth in their Agent Card. Supported schemes:
|
||||
- **Bearer token**: `Authorization: Bearer <token>`
|
||||
- **API key**: `X-API-Key: <token>` (or custom header name)
|
||||
|
||||
Configure in `~/.hermes/agent_card.yaml`:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
scheme: "bearer"
|
||||
token_env: "A2A_AUTH_TOKEN" # env var containing the token
|
||||
```
|
||||
|
||||
## Fleet Registry
|
||||
|
||||
The fleet registry (`config/fleet_agents.json`) lists all wizards and their capabilities. Agents can be registered via:
|
||||
|
||||
1. **Local file** — `LocalFileRegistry` reads/writes JSON directly
|
||||
2. **Gitea** — `GiteaRegistry` stores cards in a repo for distributed discovery
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
pytest tests/test_a2a.py -v
|
||||
```
|
||||
|
||||
Covers:
|
||||
- Type serialization roundtrips
|
||||
- Agent Card building from YAML
|
||||
- Registry operations (register, list, filter)
|
||||
- Server integration (SendMessage, GetTask, ListTasks, CancelTask)
|
||||
- Authentication (required, success)
|
||||
- Custom handler routing
|
||||
- Error handling
|
||||
|
||||
## Phase Status
|
||||
|
||||
- [x] Phase 1 — Agent Card & Discovery
|
||||
- [x] Phase 2 — Task Delegation
|
||||
- [x] Phase 3 — Security & Reliability
|
||||
|
||||
## Linked Issue
|
||||
|
||||
[#1122](https://forge.alexanderwhitestone.com/Timmy_Foundation/the-nexus/issues/1122)
|
||||
@@ -1,98 +0,0 @@
|
||||
"""
|
||||
A2A Protocol for Fleet-Wizard Delegation
|
||||
|
||||
Implements Google's Agent2Agent (A2A) protocol v1.0 for the Timmy
|
||||
Foundation fleet. Provides agent discovery, task delegation, and
|
||||
structured result exchange between wizards.
|
||||
|
||||
Components:
|
||||
types.py — A2A data types (Agent Card, Task, Message, Part)
|
||||
card.py — Agent Card generation from YAML config
|
||||
client.py — Async client for sending tasks to remote agents
|
||||
server.py — FastAPI server for receiving A2A tasks
|
||||
registry.py — Fleet agent discovery (local file + Gitea backends)
|
||||
"""
|
||||
|
||||
from nexus.a2a.types import (
|
||||
AgentCard,
|
||||
AgentCapabilities,
|
||||
AgentInterface,
|
||||
AgentSkill,
|
||||
Artifact,
|
||||
DataPart,
|
||||
FilePart,
|
||||
JSONRPCError,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
Task,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TextPart,
|
||||
part_from_dict,
|
||||
part_to_dict,
|
||||
)
|
||||
|
||||
from nexus.a2a.card import (
|
||||
AgentCard,
|
||||
build_card,
|
||||
get_auth_headers,
|
||||
load_agent_card,
|
||||
load_card_config,
|
||||
)
|
||||
|
||||
from nexus.a2a.registry import (
|
||||
GiteaRegistry,
|
||||
LocalFileRegistry,
|
||||
discover_agents,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"A2AClient",
|
||||
"A2AClientConfig",
|
||||
"A2AServer",
|
||||
"AgentCard",
|
||||
"AgentCapabilities",
|
||||
"AgentInterface",
|
||||
"AgentSkill",
|
||||
"Artifact",
|
||||
"DataPart",
|
||||
"FilePart",
|
||||
"GiteaRegistry",
|
||||
"JSONRPCError",
|
||||
"JSONRPCRequest",
|
||||
"JSONRPCResponse",
|
||||
"LocalFileRegistry",
|
||||
"Message",
|
||||
"Part",
|
||||
"Role",
|
||||
"Task",
|
||||
"TaskState",
|
||||
"TaskStatus",
|
||||
"TextPart",
|
||||
"build_card",
|
||||
"discover_agents",
|
||||
"echo_handler",
|
||||
"get_auth_headers",
|
||||
"load_agent_card",
|
||||
"load_card_config",
|
||||
"part_from_dict",
|
||||
"part_to_dict",
|
||||
]
|
||||
|
||||
# Lazy imports for optional deps
|
||||
def get_client(**kwargs):
|
||||
"""Get A2AClient (avoids aiohttp import at module level)."""
|
||||
from nexus.a2a.client import A2AClient, A2AClientConfig
|
||||
config = kwargs.pop("config", None)
|
||||
if config is None:
|
||||
config = A2AClientConfig(**kwargs)
|
||||
return A2AClient(config=config)
|
||||
|
||||
|
||||
def get_server(card: AgentCard, **kwargs):
|
||||
"""Get A2AServer (avoids fastapi import at module level)."""
|
||||
from nexus.a2a.server import A2AServer, echo_handler
|
||||
return A2AServer(card=card, **kwargs)
|
||||
@@ -1,167 +0,0 @@
|
||||
"""
|
||||
A2A Agent Card — generation, loading, and serving.
|
||||
|
||||
Reads from ~/.hermes/agent_card.yaml (or a passed path) and produces
|
||||
a valid A2A AgentCard that can be served at /.well-known/agent-card.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from nexus.a2a.types import (
|
||||
AgentCard,
|
||||
AgentCapabilities,
|
||||
AgentInterface,
|
||||
AgentSkill,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("nexus.a2a.card")
|
||||
|
||||
DEFAULT_CARD_PATH = Path.home() / ".hermes" / "agent_card.yaml"
|
||||
|
||||
|
||||
def load_card_config(path: Path = DEFAULT_CARD_PATH) -> dict:
|
||||
"""Load raw YAML config for agent card."""
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Agent card config not found at {path}. "
|
||||
f"Copy config/agent_card.example.yaml to {path} and customize it."
|
||||
)
|
||||
with open(path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def build_card(config: dict) -> AgentCard:
|
||||
"""
|
||||
Build an AgentCard from a config dict.
|
||||
|
||||
Expected YAML structure (see config/agent_card.example.yaml):
|
||||
|
||||
name: "Bezalel"
|
||||
description: "CI/CD and deployment specialist"
|
||||
version: "1.0.0"
|
||||
url: "https://bezalel.example.com"
|
||||
protocol_binding: "HTTP+JSON"
|
||||
skills:
|
||||
- id: "ci-health"
|
||||
name: "CI Health Check"
|
||||
description: "Run CI pipeline health checks"
|
||||
tags: ["ci", "devops"]
|
||||
- id: "deploy"
|
||||
name: "Deploy Service"
|
||||
description: "Deploy a service to production"
|
||||
tags: ["deploy", "ops"]
|
||||
default_input_modes: ["text/plain"]
|
||||
default_output_modes: ["text/plain"]
|
||||
streaming: false
|
||||
push_notifications: false
|
||||
auth:
|
||||
scheme: "bearer"
|
||||
token_env: "A2A_AUTH_TOKEN"
|
||||
"""
|
||||
name = config["name"]
|
||||
description = config["description"]
|
||||
version = config.get("version", "1.0.0")
|
||||
url = config.get("url", "http://localhost:8080")
|
||||
binding = config.get("protocol_binding", "HTTP+JSON")
|
||||
|
||||
# Build skills
|
||||
skills = []
|
||||
for s in config.get("skills", []):
|
||||
skills.append(
|
||||
AgentSkill(
|
||||
id=s["id"],
|
||||
name=s.get("name", s["id"]),
|
||||
description=s.get("description", ""),
|
||||
tags=s.get("tags", []),
|
||||
examples=s.get("examples", []),
|
||||
input_modes=s.get("inputModes", config.get("default_input_modes", ["text/plain"])),
|
||||
output_modes=s.get("outputModes", config.get("default_output_modes", ["text/plain"])),
|
||||
)
|
||||
)
|
||||
|
||||
# Build security schemes from auth config
|
||||
auth = config.get("auth", {})
|
||||
security_schemes = {}
|
||||
security_requirements = []
|
||||
|
||||
if auth.get("scheme") == "bearer":
|
||||
security_schemes["bearerAuth"] = {
|
||||
"httpAuthSecurityScheme": {
|
||||
"scheme": "Bearer",
|
||||
"bearerFormat": auth.get("bearer_format", "token"),
|
||||
}
|
||||
}
|
||||
security_requirements = [
|
||||
{"schemes": {"bearerAuth": {"list": []}}}
|
||||
]
|
||||
elif auth.get("scheme") == "api_key":
|
||||
key_name = auth.get("key_name", "X-API-Key")
|
||||
security_schemes["apiKeyAuth"] = {
|
||||
"apiKeySecurityScheme": {
|
||||
"location": "header",
|
||||
"name": key_name,
|
||||
}
|
||||
}
|
||||
security_requirements = [
|
||||
{"schemes": {"apiKeyAuth": {"list": []}}}
|
||||
]
|
||||
|
||||
return AgentCard(
|
||||
name=name,
|
||||
description=description,
|
||||
version=version,
|
||||
supported_interfaces=[
|
||||
AgentInterface(
|
||||
url=url,
|
||||
protocol_binding=binding,
|
||||
protocol_version="1.0",
|
||||
)
|
||||
],
|
||||
capabilities=AgentCapabilities(
|
||||
streaming=config.get("streaming", False),
|
||||
push_notifications=config.get("push_notifications", False),
|
||||
),
|
||||
default_input_modes=config.get("default_input_modes", ["text/plain"]),
|
||||
default_output_modes=config.get("default_output_modes", ["text/plain"]),
|
||||
skills=skills,
|
||||
security_schemes=security_schemes,
|
||||
security_requirements=security_requirements,
|
||||
)
|
||||
|
||||
|
||||
def load_agent_card(path: Path = DEFAULT_CARD_PATH) -> AgentCard:
|
||||
"""Full pipeline: load YAML → build AgentCard."""
|
||||
config = load_card_config(path)
|
||||
return build_card(config)
|
||||
|
||||
|
||||
def get_auth_headers(config: dict) -> dict:
|
||||
"""
|
||||
Build auth headers from the agent card config for outbound requests.
|
||||
|
||||
Returns dict of HTTP headers to include.
|
||||
"""
|
||||
auth = config.get("auth", {})
|
||||
headers = {"A2A-Version": "1.0"}
|
||||
|
||||
scheme = auth.get("scheme")
|
||||
if scheme == "bearer":
|
||||
token_env = auth.get("token_env", "A2A_AUTH_TOKEN")
|
||||
token = os.environ.get(token_env, "")
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
elif scheme == "api_key":
|
||||
key_env = auth.get("key_env", "A2A_API_KEY")
|
||||
key_name = auth.get("key_name", "X-API-Key")
|
||||
key = os.environ.get(key_env, "")
|
||||
if key:
|
||||
headers[key_name] = key
|
||||
|
||||
return headers
|
||||
@@ -1,392 +0,0 @@
|
||||
"""
|
||||
A2A Client — send tasks to other agents over the A2A protocol.
|
||||
|
||||
Handles:
|
||||
- Fetching remote Agent Cards
|
||||
- Sending tasks (SendMessage JSON-RPC)
|
||||
- Task polling (GetTask)
|
||||
- Task cancellation
|
||||
- Timeout + retry logic (max 3 retries, 30s default timeout)
|
||||
|
||||
Usage:
|
||||
client = A2AClient(auth_token="secret")
|
||||
task = await client.send_message("https://ezra.example.com/a2a/v1", message)
|
||||
status = await client.get_task("https://ezra.example.com/a2a/v1", task_id)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from nexus.a2a.types import (
|
||||
A2AError,
|
||||
AgentCard,
|
||||
Artifact,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
Message,
|
||||
Role,
|
||||
Task,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TextPart,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("nexus.a2a.client")
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2AClientConfig:
|
||||
"""Client configuration."""
|
||||
timeout: float = 30.0 # seconds per request
|
||||
max_retries: int = 3
|
||||
retry_delay: float = 2.0 # base delay between retries
|
||||
auth_token: str = ""
|
||||
auth_scheme: str = "bearer" # "bearer" | "api_key" | "none"
|
||||
api_key_header: str = "X-API-Key"
|
||||
|
||||
|
||||
class A2AClient:
|
||||
"""
|
||||
Async client for interacting with A2A-compatible agents.
|
||||
|
||||
Every agent endpoint is identified by its base URL (e.g.
|
||||
https://ezra.example.com/a2a/v1). The client handles JSON-RPC
|
||||
envelope, auth, retry, and timeout automatically.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[A2AClientConfig] = None, **kwargs):
|
||||
if config is None:
|
||||
config = A2AClientConfig(**kwargs)
|
||||
self.config = config
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._audit_log: list[dict] = []
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
headers=self._build_auth_headers(),
|
||||
)
|
||||
return self._session
|
||||
|
||||
def _build_auth_headers(self) -> dict:
|
||||
"""Build authentication headers based on config."""
|
||||
headers = {"A2A-Version": "1.0", "Content-Type": "application/json"}
|
||||
token = self.config.auth_token
|
||||
if not token:
|
||||
return headers
|
||||
|
||||
if self.config.auth_scheme == "bearer":
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
elif self.config.auth_scheme == "api_key":
|
||||
headers[self.config.api_key_header] = token
|
||||
|
||||
return headers
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def _rpc_call(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str,
|
||||
params: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Make a JSON-RPC call with retry logic.
|
||||
|
||||
Returns the 'result' field from the response.
|
||||
Raises on JSON-RPC errors.
|
||||
"""
|
||||
session = await self._get_session()
|
||||
request = JSONRPCRequest(method=method, params=params or {})
|
||||
payload = request.to_dict()
|
||||
|
||||
last_error = None
|
||||
for attempt in range(1, self.config.max_retries + 1):
|
||||
try:
|
||||
start = time.monotonic()
|
||||
async with session.post(endpoint, json=payload) as resp:
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
if resp.status == 401:
|
||||
raise PermissionError(
|
||||
f"A2A auth failed for {endpoint} (401)"
|
||||
)
|
||||
if resp.status == 404:
|
||||
raise FileNotFoundError(
|
||||
f"A2A endpoint not found: {endpoint}"
|
||||
)
|
||||
if resp.status >= 500:
|
||||
body = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"A2A server error {resp.status}: {body}"
|
||||
)
|
||||
|
||||
data = await resp.json()
|
||||
rpc_resp = JSONRPCResponse(
|
||||
id=str(data.get("id", "")),
|
||||
result=data.get("result"),
|
||||
error=(
|
||||
A2AError.INTERNAL
|
||||
if "error" in data
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Log for audit
|
||||
self._audit_log.append({
|
||||
"timestamp": time.time(),
|
||||
"endpoint": endpoint,
|
||||
"method": method,
|
||||
"request_id": request.id,
|
||||
"status_code": resp.status,
|
||||
"elapsed_ms": int(elapsed * 1000),
|
||||
"attempt": attempt,
|
||||
})
|
||||
|
||||
if "error" in data:
|
||||
err = data["error"]
|
||||
logger.error(
|
||||
f"A2A RPC error {err.get('code')}: "
|
||||
f"{err.get('message')}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"A2A error {err.get('code')}: "
|
||||
f"{err.get('message')}"
|
||||
)
|
||||
|
||||
return data.get("result", {})
|
||||
|
||||
except (asyncio.TimeoutError, aiohttp.ClientError) as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"A2A request to {endpoint} attempt {attempt}/"
|
||||
f"{self.config.max_retries} failed: {e}"
|
||||
)
|
||||
if attempt < self.config.max_retries:
|
||||
delay = self.config.retry_delay * attempt
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
raise ConnectionError(
|
||||
f"A2A request to {endpoint} failed after "
|
||||
f"{self.config.max_retries} retries: {last_error}"
|
||||
)
|
||||
|
||||
# --- Core A2A Methods ---
|
||||
|
||||
async def get_agent_card(self, base_url: str) -> AgentCard:
|
||||
"""
|
||||
Fetch the Agent Card from a remote agent.
|
||||
|
||||
Tries /.well-known/agent-card.json first, falls back to
|
||||
/agent.json.
|
||||
"""
|
||||
session = await self._get_session()
|
||||
card_urls = [
|
||||
f"{base_url}/.well-known/agent-card.json",
|
||||
f"{base_url}/agent.json",
|
||||
]
|
||||
|
||||
for url in card_urls:
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
card = AgentCard.from_dict(data)
|
||||
logger.info(
|
||||
f"Fetched agent card: {card.name} "
|
||||
f"({len(card.skills)} skills)"
|
||||
)
|
||||
return card
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not fetch agent card from {base_url}"
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
endpoint: str,
|
||||
message: Message,
|
||||
accepted_output_modes: Optional[list[str]] = None,
|
||||
history_length: int = 10,
|
||||
return_immediately: bool = False,
|
||||
) -> Task:
|
||||
"""
|
||||
Send a message to an agent and get a Task back.
|
||||
|
||||
This is the primary delegation method.
|
||||
"""
|
||||
params = {
|
||||
"message": message.to_dict(),
|
||||
"configuration": {
|
||||
"acceptedOutputModes": accepted_output_modes or ["text/plain"],
|
||||
"historyLength": history_length,
|
||||
"returnImmediately": return_immediately,
|
||||
},
|
||||
}
|
||||
|
||||
result = await self._rpc_call(endpoint, "SendMessage", params)
|
||||
|
||||
# Response is either a Task or Message
|
||||
if "task" in result:
|
||||
task = Task.from_dict(result["task"])
|
||||
logger.info(
|
||||
f"Task {task.id} created, state={task.status.state.value}"
|
||||
)
|
||||
return task
|
||||
elif "message" in result:
|
||||
# Wrap message response as a completed task
|
||||
msg = Message.from_dict(result["message"])
|
||||
task = Task(
|
||||
status=TaskStatus(state=TaskState.COMPLETED),
|
||||
history=[message, msg],
|
||||
artifacts=[
|
||||
Artifact(parts=msg.parts, name="response")
|
||||
],
|
||||
)
|
||||
return task
|
||||
|
||||
raise ValueError(f"Unexpected response structure: {list(result.keys())}")
|
||||
|
||||
async def get_task(self, endpoint: str, task_id: str) -> Task:
|
||||
"""Get task status by ID."""
|
||||
result = await self._rpc_call(
|
||||
endpoint,
|
||||
"GetTask",
|
||||
{"id": task_id},
|
||||
)
|
||||
return Task.from_dict(result)
|
||||
|
||||
async def list_tasks(
|
||||
self,
|
||||
endpoint: str,
|
||||
page_size: int = 20,
|
||||
page_token: str = "",
|
||||
) -> tuple[list[Task], str]:
|
||||
"""
|
||||
List tasks with cursor-based pagination.
|
||||
|
||||
Returns (tasks, next_page_token). Empty string = last page.
|
||||
"""
|
||||
result = await self._rpc_call(
|
||||
endpoint,
|
||||
"ListTasks",
|
||||
{
|
||||
"pageSize": page_size,
|
||||
"pageToken": page_token,
|
||||
},
|
||||
)
|
||||
tasks = [Task.from_dict(t) for t in result.get("tasks", [])]
|
||||
next_token = result.get("nextPageToken", "")
|
||||
return tasks, next_token
|
||||
|
||||
async def cancel_task(self, endpoint: str, task_id: str) -> Task:
|
||||
"""Cancel a running task."""
|
||||
result = await self._rpc_call(
|
||||
endpoint,
|
||||
"CancelTask",
|
||||
{"id": task_id},
|
||||
)
|
||||
return Task.from_dict(result)
|
||||
|
||||
# --- Convenience Methods ---
|
||||
|
||||
async def delegate(
|
||||
self,
|
||||
agent_url: str,
|
||||
text: str,
|
||||
skill_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> Task:
|
||||
"""
|
||||
High-level delegation: send a text message to an agent.
|
||||
|
||||
Args:
|
||||
agent_url: Full URL to agent's A2A endpoint
|
||||
(e.g. https://ezra.example.com/a2a/v1)
|
||||
text: The task description in natural language
|
||||
skill_id: Optional skill to target
|
||||
metadata: Optional metadata dict
|
||||
"""
|
||||
msg_metadata = metadata or {}
|
||||
if skill_id:
|
||||
msg_metadata["targetSkill"] = skill_id
|
||||
|
||||
message = Message(
|
||||
role=Role.USER,
|
||||
parts=[TextPart(text=text)],
|
||||
metadata=msg_metadata,
|
||||
)
|
||||
|
||||
return await self.send_message(agent_url, message)
|
||||
|
||||
async def wait_for_completion(
|
||||
self,
|
||||
endpoint: str,
|
||||
task_id: str,
|
||||
poll_interval: float = 2.0,
|
||||
max_wait: float = 300.0,
|
||||
) -> Task:
|
||||
"""
|
||||
Poll a task until it reaches a terminal state.
|
||||
|
||||
Returns the completed task.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
task = await self.get_task(endpoint, task_id)
|
||||
if task.status.state.terminal:
|
||||
return task
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= max_wait:
|
||||
raise TimeoutError(
|
||||
f"Task {task_id} did not complete within "
|
||||
f"{max_wait}s (state={task.status.state.value})"
|
||||
)
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
def get_audit_log(self) -> list[dict]:
|
||||
"""Return the audit log of all requests made by this client."""
|
||||
return list(self._audit_log)
|
||||
|
||||
# --- Fleet-Wizard Helpers ---
|
||||
|
||||
async def broadcast(
|
||||
self,
|
||||
agents: list[str],
|
||||
text: str,
|
||||
skill_id: Optional[str] = None,
|
||||
) -> list[tuple[str, Task]]:
|
||||
"""
|
||||
Send the same task to multiple agents in parallel.
|
||||
|
||||
Returns list of (agent_url, task) tuples.
|
||||
"""
|
||||
tasks = []
|
||||
for agent_url in agents:
|
||||
tasks.append(
|
||||
self.delegate(agent_url, text, skill_id=skill_id)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
paired = []
|
||||
for agent_url, result in zip(agents, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Broadcast to {agent_url} failed: {result}")
|
||||
else:
|
||||
paired.append((agent_url, result))
|
||||
return paired
|
||||
@@ -1,264 +0,0 @@
|
||||
"""
|
||||
A2A Registry — fleet-wide agent discovery.
|
||||
|
||||
Provides two registry backends:
|
||||
1. LocalFileRegistry: reads/writes agent cards to a JSON file
|
||||
(default: config/fleet_agents.json)
|
||||
2. GiteaRegistry: stores agent cards as a Gitea repo file
|
||||
(for distributed fleet discovery)
|
||||
|
||||
Usage:
|
||||
registry = LocalFileRegistry()
|
||||
registry.register(my_card)
|
||||
agents = registry.list_agents(skill="ci-health")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from nexus.a2a.types import AgentCard
|
||||
|
||||
logger = logging.getLogger("nexus.a2a.registry")
|
||||
|
||||
|
||||
class LocalFileRegistry:
|
||||
"""
|
||||
File-based agent card registry.
|
||||
|
||||
Stores all fleet agent cards in a single JSON file.
|
||||
Suitable for single-node or read-heavy workloads.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path = Path("config/fleet_agents.json")):
|
||||
self.path = path
|
||||
self._cards: dict[str, AgentCard] = {}
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
"""Load registry from disk."""
|
||||
if self.path.exists():
|
||||
try:
|
||||
with open(self.path) as f:
|
||||
data = json.load(f)
|
||||
for card_data in data.get("agents", []):
|
||||
card = AgentCard.from_dict(card_data)
|
||||
self._cards[card.name.lower()] = card
|
||||
logger.info(
|
||||
f"Loaded {len(self._cards)} agents from {self.path}"
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.error(f"Failed to load registry from {self.path}: {e}")
|
||||
|
||||
def _save(self):
|
||||
"""Persist registry to disk."""
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = {
|
||||
"version": 1,
|
||||
"agents": [card.to_dict() for card in self._cards.values()],
|
||||
}
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
logger.debug(f"Saved {len(self._cards)} agents to {self.path}")
|
||||
|
||||
def register(self, card: AgentCard) -> None:
|
||||
"""Register or update an agent card."""
|
||||
self._cards[card.name.lower()] = card
|
||||
self._save()
|
||||
logger.info(f"Registered agent: {card.name}")
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Remove an agent from the registry."""
|
||||
key = name.lower()
|
||||
if key in self._cards:
|
||||
del self._cards[key]
|
||||
self._save()
|
||||
logger.info(f"Unregistered agent: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get(self, name: str) -> Optional[AgentCard]:
|
||||
"""Get an agent card by name."""
|
||||
return self._cards.get(name.lower())
|
||||
|
||||
def list_agents(
|
||||
self,
|
||||
skill: Optional[str] = None,
|
||||
tag: Optional[str] = None,
|
||||
) -> list[AgentCard]:
|
||||
"""
|
||||
List all registered agents, optionally filtered by skill or tag.
|
||||
|
||||
Args:
|
||||
skill: Filter to agents that have this skill ID
|
||||
tag: Filter to agents that have this tag on any skill
|
||||
"""
|
||||
agents = list(self._cards.values())
|
||||
|
||||
if skill:
|
||||
agents = [
|
||||
a for a in agents
|
||||
if any(s.id == skill for s in a.skills)
|
||||
]
|
||||
|
||||
if tag:
|
||||
agents = [
|
||||
a for a in agents
|
||||
if any(tag in s.tags for s in a.skills)
|
||||
]
|
||||
|
||||
return agents
|
||||
|
||||
def get_endpoint(self, name: str) -> Optional[str]:
|
||||
"""Get the first supported interface URL for an agent."""
|
||||
card = self.get(name)
|
||||
if card and card.supported_interfaces:
|
||||
return card.supported_interfaces[0].url
|
||||
return None
|
||||
|
||||
def dump(self) -> dict:
|
||||
"""Dump full registry as a dict."""
|
||||
return {
|
||||
"version": 1,
|
||||
"agents": [card.to_dict() for card in self._cards.values()],
|
||||
}
|
||||
|
||||
|
||||
class GiteaRegistry:
|
||||
"""
|
||||
Gitea-backed agent registry.
|
||||
|
||||
Stores fleet agent cards in a Gitea repository file for
|
||||
distributed discovery across VPS nodes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gitea_url: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
file_path: str = "config/fleet_agents.json",
|
||||
):
|
||||
self.gitea_url = gitea_url.rstrip("/")
|
||||
self.repo = repo
|
||||
self.token = token
|
||||
self.file_path = file_path
|
||||
self._cards: dict[str, AgentCard] = {}
|
||||
|
||||
def _api_url(self, endpoint: str) -> str:
|
||||
return f"{self.gitea_url}/api/v1/repos/{self.repo}/{endpoint}"
|
||||
|
||||
def _headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": f"token {self.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def load(self) -> None:
|
||||
"""Fetch agent cards from Gitea."""
|
||||
try:
|
||||
import aiohttp
|
||||
url = self._api_url(f"contents/{self.file_path}")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=self._headers()) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
import base64
|
||||
content = base64.b64decode(data["content"]).decode()
|
||||
registry = json.loads(content)
|
||||
for card_data in registry.get("agents", []):
|
||||
card = AgentCard.from_dict(card_data)
|
||||
self._cards[card.name.lower()] = card
|
||||
logger.info(
|
||||
f"Loaded {len(self._cards)} agents from Gitea"
|
||||
)
|
||||
elif resp.status == 404:
|
||||
logger.info("No fleet registry file in Gitea yet")
|
||||
else:
|
||||
logger.error(
|
||||
f"Gitea fetch failed: {resp.status}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load from Gitea: {e}")
|
||||
|
||||
async def save(self, message: str = "Update fleet registry") -> None:
|
||||
"""Write agent cards to Gitea."""
|
||||
try:
|
||||
import aiohttp
|
||||
content = json.dumps(
|
||||
{"version": 1, "agents": [c.to_dict() for c in self._cards.values()]},
|
||||
indent=2,
|
||||
)
|
||||
import base64
|
||||
encoded = base64.b64encode(content.encode()).decode()
|
||||
|
||||
# Check if file exists (need SHA for update)
|
||||
url = self._api_url(f"contents/{self.file_path}")
|
||||
sha = None
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=self._headers()) as resp:
|
||||
if resp.status == 200:
|
||||
existing = await resp.json()
|
||||
sha = existing.get("sha")
|
||||
|
||||
payload = {
|
||||
"message": message,
|
||||
"content": encoded,
|
||||
}
|
||||
if sha:
|
||||
payload["sha"] = sha
|
||||
|
||||
async with session.put(
|
||||
url, headers=self._headers(), json=payload
|
||||
) as resp:
|
||||
if resp.status in (200, 201):
|
||||
logger.info("Fleet registry saved to Gitea")
|
||||
else:
|
||||
body = await resp.text()
|
||||
logger.error(
|
||||
f"Gitea save failed: {resp.status} — {body}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save to Gitea: {e}")
|
||||
|
||||
def register(self, card: AgentCard) -> None:
|
||||
"""Register an agent (local update; call save() to persist)."""
|
||||
self._cards[card.name.lower()] = card
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
key = name.lower()
|
||||
if key in self._cards:
|
||||
del self._cards[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get(self, name: str) -> Optional[AgentCard]:
|
||||
return self._cards.get(name.lower())
|
||||
|
||||
def list_agents(
|
||||
self,
|
||||
skill: Optional[str] = None,
|
||||
tag: Optional[str] = None,
|
||||
) -> list[AgentCard]:
|
||||
agents = list(self._cards.values())
|
||||
if skill:
|
||||
agents = [a for a in agents if any(s.id == skill for s in a.skills)]
|
||||
if tag:
|
||||
agents = [a for a in agents if any(tag in s.tags for s in a.skills)]
|
||||
return agents
|
||||
|
||||
|
||||
# --- Convenience ---
|
||||
|
||||
def discover_agents(
|
||||
path: Path = Path("config/fleet_agents.json"),
|
||||
skill: Optional[str] = None,
|
||||
tag: Optional[str] = None,
|
||||
) -> list[AgentCard]:
|
||||
"""One-shot discovery from local file."""
|
||||
registry = LocalFileRegistry(path)
|
||||
return registry.list_agents(skill=skill, tag=tag)
|
||||
@@ -1,386 +0,0 @@
|
||||
"""
|
||||
A2A Server — receive and process tasks from other agents.
|
||||
|
||||
Provides a FastAPI router that serves:
|
||||
- GET /.well-known/agent-card.json — Agent Card discovery
|
||||
- GET /agent.json — Agent Card fallback
|
||||
- POST /a2a/v1 — JSON-RPC endpoint (SendMessage, GetTask, etc.)
|
||||
- POST /a2a/v1/rpc — JSON-RPC endpoint (alias)
|
||||
|
||||
Task routing: registered handlers are matched by skill ID or receive
|
||||
all tasks via a default handler.
|
||||
|
||||
Usage:
|
||||
server = A2AServer(card=my_card, auth_token="secret")
|
||||
server.register_handler("ci-health", my_ci_handler)
|
||||
await server.start(host="0.0.0.0", port=8080)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Awaitable, Optional
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI, Request, Response, HTTPException, Header
|
||||
from fastapi.responses import JSONResponse
|
||||
import uvicorn
|
||||
HAS_FASTAPI = True
|
||||
except ImportError:
|
||||
HAS_FASTAPI = False
|
||||
|
||||
from nexus.a2a.types import (
|
||||
A2AError,
|
||||
AgentCard,
|
||||
Artifact,
|
||||
JSONRPCError,
|
||||
JSONRPCResponse,
|
||||
Message,
|
||||
Role,
|
||||
Task,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TextPart,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("nexus.a2a.server")
|
||||
|
||||
# Type for task handlers
|
||||
TaskHandler = Callable[[Task, AgentCard], Awaitable[Task]]
|
||||
|
||||
|
||||
class A2AServer:
|
||||
"""
|
||||
A2A protocol server for receiving agent-to-agent task delegation.
|
||||
|
||||
Supports:
|
||||
- Agent Card serving at /.well-known/agent-card.json
|
||||
- JSON-RPC task lifecycle (SendMessage, GetTask, CancelTask, ListTasks)
|
||||
- Pluggable task handlers (by skill ID or default)
|
||||
- Bearer / API key authentication
|
||||
- Audit logging
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
card: AgentCard,
|
||||
auth_token: str = "",
|
||||
auth_scheme: str = "bearer",
|
||||
):
|
||||
if not HAS_FASTAPI:
|
||||
raise ImportError(
|
||||
"fastapi and uvicorn are required for A2AServer. "
|
||||
"Install with: pip install fastapi uvicorn"
|
||||
)
|
||||
|
||||
self.card = card
|
||||
self.auth_token = auth_token
|
||||
self.auth_scheme = auth_scheme
|
||||
|
||||
# Task store (in-memory; swap for SQLite/Redis in production)
|
||||
self._tasks: dict[str, Task] = {}
|
||||
# Handlers keyed by skill ID
|
||||
self._handlers: dict[str, TaskHandler] = {}
|
||||
# Default handler for unmatched skills
|
||||
self._default_handler: Optional[TaskHandler] = None
|
||||
# Audit log
|
||||
self._audit_log: list[dict] = []
|
||||
|
||||
self.app = FastAPI(
|
||||
title=f"A2A — {card.name}",
|
||||
description=card.description,
|
||||
version=card.version,
|
||||
)
|
||||
self._register_routes()
|
||||
|
||||
def register_handler(self, skill_id: str, handler: TaskHandler):
|
||||
"""Register a handler for a specific skill ID."""
|
||||
self._handlers[skill_id] = handler
|
||||
logger.info(f"Registered handler for skill: {skill_id}")
|
||||
|
||||
def set_default_handler(self, handler: TaskHandler):
|
||||
"""Set the fallback handler for tasks without a matching skill."""
|
||||
self._default_handler = handler
|
||||
|
||||
def _verify_auth(self, authorization: Optional[str]) -> bool:
|
||||
"""Check authentication header."""
|
||||
if not self.auth_token:
|
||||
return True # No auth configured
|
||||
|
||||
if not authorization:
|
||||
return False
|
||||
|
||||
if self.auth_scheme == "bearer":
|
||||
expected = f"Bearer {self.auth_token}"
|
||||
return authorization == expected
|
||||
|
||||
return False
|
||||
|
||||
def _register_routes(self):
|
||||
"""Wire up FastAPI routes."""
|
||||
|
||||
@self.app.get("/.well-known/agent-card.json")
|
||||
async def agent_card_well_known():
|
||||
return JSONResponse(self.card.to_dict())
|
||||
|
||||
@self.app.get("/agent.json")
|
||||
async def agent_card_fallback():
|
||||
return JSONResponse(self.card.to_dict())
|
||||
|
||||
@self.app.post("/a2a/v1")
|
||||
@self.app.post("/a2a/v1/rpc")
|
||||
async def rpc_endpoint(request: Request):
|
||||
return await self._handle_rpc(request)
|
||||
|
||||
@self.app.get("/a2a/v1/tasks")
|
||||
@self.app.get("/a2a/v1/tasks/{task_id}")
|
||||
async def rest_get_task(task_id: Optional[str] = None):
|
||||
if task_id:
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return JSONRPCResponse(
|
||||
id="",
|
||||
error=A2AError.TASK_NOT_FOUND,
|
||||
).to_dict()
|
||||
return JSONResponse(task.to_dict())
|
||||
else:
|
||||
return JSONResponse(
|
||||
{"tasks": [t.to_dict() for t in self._tasks.values()]}
|
||||
)
|
||||
|
||||
async def _handle_rpc(self, request: Request) -> JSONResponse:
|
||||
"""Handle JSON-RPC requests."""
|
||||
# Auth check
|
||||
auth_header = request.headers.get("authorization")
|
||||
if not self._verify_auth(auth_header):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"error": "Unauthorized"},
|
||||
)
|
||||
|
||||
# Parse JSON-RPC
|
||||
try:
|
||||
body = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse(
|
||||
JSONRPCResponse(
|
||||
id="", error=A2AError.PARSE
|
||||
).to_dict(),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
method = body.get("method", "")
|
||||
request_id = body.get("id", str(uuid.uuid4()))
|
||||
params = body.get("params", {})
|
||||
|
||||
# Audit
|
||||
self._audit_log.append({
|
||||
"timestamp": time.time(),
|
||||
"method": method,
|
||||
"request_id": request_id,
|
||||
"source": request.client.host if request.client else "unknown",
|
||||
})
|
||||
|
||||
try:
|
||||
result = await self._dispatch_rpc(method, params, request_id)
|
||||
return JSONResponse(
|
||||
JSONRPCResponse(id=request_id, result=result).to_dict()
|
||||
)
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=JSONRPCError(-32602, str(e)),
|
||||
).to_dict(),
|
||||
status_code=400,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling {method}: {e}")
|
||||
return JSONResponse(
|
||||
JSONRPCResponse(
|
||||
id=request_id,
|
||||
error=JSONRPCError(-32603, str(e)),
|
||||
).to_dict(),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
async def _dispatch_rpc(
|
||||
self, method: str, params: dict, request_id: str
|
||||
) -> Any:
|
||||
"""Route JSON-RPC method to handler."""
|
||||
if method == "SendMessage":
|
||||
return await self._rpc_send_message(params)
|
||||
elif method == "GetTask":
|
||||
return await self._rpc_get_task(params)
|
||||
elif method == "ListTasks":
|
||||
return await self._rpc_list_tasks(params)
|
||||
elif method == "CancelTask":
|
||||
return await self._rpc_cancel_task(params)
|
||||
elif method == "GetAgentCard":
|
||||
return self.card.to_dict()
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {method}")
|
||||
|
||||
async def _rpc_send_message(self, params: dict) -> dict:
|
||||
"""Handle SendMessage — create a task and route to handler."""
|
||||
msg_data = params.get("message", {})
|
||||
message = Message.from_dict(msg_data)
|
||||
|
||||
# Determine target skill from metadata
|
||||
target_skill = message.metadata.get("targetSkill", "")
|
||||
|
||||
# Create task
|
||||
task = Task(
|
||||
context_id=message.context_id,
|
||||
status=TaskStatus(state=TaskState.SUBMITTED),
|
||||
history=[message],
|
||||
metadata={"targetSkill": target_skill} if target_skill else {},
|
||||
)
|
||||
|
||||
# Store immediately
|
||||
self._tasks[task.id] = task
|
||||
|
||||
# Dispatch to handler
|
||||
handler = self._handlers.get(target_skill) or self._default_handler
|
||||
|
||||
if handler is None:
|
||||
task.status = TaskStatus(
|
||||
state=TaskState.FAILED,
|
||||
message=Message(
|
||||
role=Role.AGENT,
|
||||
parts=[TextPart(text="No handler available for this task")],
|
||||
),
|
||||
)
|
||||
return {"task": task.to_dict()}
|
||||
|
||||
try:
|
||||
# Mark as working
|
||||
task.status = TaskStatus(state=TaskState.WORKING)
|
||||
self._tasks[task.id] = task
|
||||
|
||||
# Execute handler
|
||||
result_task = await handler(task, self.card)
|
||||
|
||||
# Store result
|
||||
self._tasks[result_task.id] = result_task
|
||||
return {"task": result_task.to_dict()}
|
||||
|
||||
except Exception as e:
|
||||
task.status = TaskStatus(
|
||||
state=TaskState.FAILED,
|
||||
message=Message(
|
||||
role=Role.AGENT,
|
||||
parts=[TextPart(text=f"Handler error: {str(e)}")],
|
||||
),
|
||||
)
|
||||
self._tasks[task.id] = task
|
||||
return {"task": task.to_dict()}
|
||||
|
||||
async def _rpc_get_task(self, params: dict) -> dict:
|
||||
"""Handle GetTask."""
|
||||
task_id = params.get("id", "")
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
raise ValueError(f"Task not found: {task_id}")
|
||||
return task.to_dict()
|
||||
|
||||
async def _rpc_list_tasks(self, params: dict) -> dict:
|
||||
"""Handle ListTasks with cursor-based pagination."""
|
||||
page_size = params.get("pageSize", 20)
|
||||
page_token = params.get("pageToken", "")
|
||||
|
||||
tasks = sorted(
|
||||
self._tasks.values(),
|
||||
key=lambda t: t.status.timestamp,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Simple cursor: find index by token
|
||||
start_idx = 0
|
||||
if page_token:
|
||||
for i, t in enumerate(tasks):
|
||||
if t.id == page_token:
|
||||
start_idx = i + 1
|
||||
break
|
||||
|
||||
page = tasks[start_idx : start_idx + page_size]
|
||||
next_token = ""
|
||||
if start_idx + page_size < len(tasks):
|
||||
next_token = tasks[start_idx + page_size - 1].id
|
||||
|
||||
return {
|
||||
"tasks": [t.to_dict() for t in page],
|
||||
"nextPageToken": next_token,
|
||||
}
|
||||
|
||||
async def _rpc_cancel_task(self, params: dict) -> dict:
|
||||
"""Handle CancelTask."""
|
||||
task_id = params.get("id", "")
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
raise ValueError(f"Task not found: {task_id}")
|
||||
|
||||
if task.status.state.terminal:
|
||||
raise ValueError(
|
||||
f"Task {task_id} is already terminal "
|
||||
f"({task.status.state.value})"
|
||||
)
|
||||
|
||||
task.status = TaskStatus(state=TaskState.CANCELED)
|
||||
self._tasks[task_id] = task
|
||||
return task.to_dict()
|
||||
|
||||
def get_audit_log(self) -> list[dict]:
|
||||
"""Return audit log of all received requests."""
|
||||
return list(self._audit_log)
|
||||
|
||||
async def start(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8080,
|
||||
):
|
||||
"""Start the A2A server with uvicorn."""
|
||||
logger.info(
|
||||
f"Starting A2A server for {self.card.name} on "
|
||||
f"{host}:{port}"
|
||||
)
|
||||
logger.info(
|
||||
f"Agent Card at "
|
||||
f"http://{host}:{port}/.well-known/agent-card.json"
|
||||
)
|
||||
config = uvicorn.Config(
|
||||
self.app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info",
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
# --- Default Handler Factory ---
|
||||
|
||||
async def echo_handler(task: Task, card: AgentCard) -> Task:
|
||||
"""
|
||||
Simple echo handler for testing.
|
||||
Returns the user's message as an artifact.
|
||||
"""
|
||||
if task.history:
|
||||
last_msg = task.history[-1]
|
||||
text_parts = [p for p in last_msg.parts if isinstance(p, TextPart)]
|
||||
if text_parts:
|
||||
response_text = f"[{card.name}] Echo: {text_parts[0].text}"
|
||||
task.artifacts.append(
|
||||
Artifact(
|
||||
parts=[TextPart(text=response_text)],
|
||||
name="echo_response",
|
||||
)
|
||||
)
|
||||
|
||||
task.status = TaskStatus(state=TaskState.COMPLETED)
|
||||
return task
|
||||
@@ -1,524 +0,0 @@
|
||||
"""
|
||||
A2A Protocol Types — Data models for Google's Agent2Agent protocol v1.0.
|
||||
|
||||
All types map directly to the A2A spec. JSON uses camelCase, enums use
|
||||
SCREAMING_SNAKE_CASE, and Part types are discriminated by member name
|
||||
(not a kind field — that was removed in v1.0).
|
||||
|
||||
See: https://github.com/google/A2A
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
# --- Enums ---
|
||||
|
||||
class TaskState(str, enum.Enum):
|
||||
"""Lifecycle states for an A2A Task."""
|
||||
SUBMITTED = "TASK_STATE_SUBMITTED"
|
||||
WORKING = "TASK_STATE_WORKING"
|
||||
COMPLETED = "TASK_STATE_COMPLETED"
|
||||
FAILED = "TASK_STATE_FAILED"
|
||||
CANCELED = "TASK_STATE_CANCELED"
|
||||
INPUT_REQUIRED = "TASK_STATE_INPUT_REQUIRED"
|
||||
REJECTED = "TASK_STATE_REJECTED"
|
||||
AUTH_REQUIRED = "TASK_STATE_AUTH_REQUIRED"
|
||||
|
||||
@property
|
||||
def terminal(self) -> bool:
|
||||
return self in (
|
||||
TaskState.COMPLETED,
|
||||
TaskState.FAILED,
|
||||
TaskState.CANCELED,
|
||||
TaskState.REJECTED,
|
||||
)
|
||||
|
||||
|
||||
class Role(str, enum.Enum):
|
||||
"""Who sent a message in an A2A conversation."""
|
||||
USER = "ROLE_USER"
|
||||
AGENT = "ROLE_AGENT"
|
||||
|
||||
|
||||
# --- Parts (discriminated by member name in JSON) ---
|
||||
|
||||
@dataclass
|
||||
class TextPart:
|
||||
"""Plain text content."""
|
||||
text: str
|
||||
media_type: str = "text/plain"
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = {"text": self.text}
|
||||
if self.media_type != "text/plain":
|
||||
d["mediaType"] = self.media_type
|
||||
if self.metadata:
|
||||
d["metadata"] = self.metadata
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilePart:
|
||||
"""Binary file content — inline or by URL reference."""
|
||||
media_type: str
|
||||
filename: Optional[str] = None
|
||||
raw: Optional[str] = None # base64-encoded bytes
|
||||
url: Optional[str] = None # URL reference
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = {"mediaType": self.media_type}
|
||||
if self.raw is not None:
|
||||
d["raw"] = self.raw
|
||||
if self.url is not None:
|
||||
d["url"] = self.url
|
||||
if self.filename:
|
||||
d["filename"] = self.filename
|
||||
if self.metadata:
|
||||
d["metadata"] = self.metadata
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataPart:
|
||||
"""Arbitrary structured JSON data."""
|
||||
data: dict
|
||||
media_type: str = "application/json"
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = {"data": self.data}
|
||||
if self.media_type != "application/json":
|
||||
d["mediaType"] = self.media_type
|
||||
if self.metadata:
|
||||
d["metadata"] = self.metadata
|
||||
return d
|
||||
|
||||
|
||||
Part = TextPart | FilePart | DataPart
|
||||
|
||||
|
||||
def part_from_dict(d: dict) -> Part:
|
||||
"""Reconstruct a Part from its JSON dict (discriminated by key name)."""
|
||||
if "text" in d:
|
||||
return TextPart(
|
||||
text=d["text"],
|
||||
media_type=d.get("mediaType", "text/plain"),
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
if "raw" in d or "url" in d:
|
||||
return FilePart(
|
||||
media_type=d["mediaType"],
|
||||
filename=d.get("filename"),
|
||||
raw=d.get("raw"),
|
||||
url=d.get("url"),
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
if "data" in d:
|
||||
return DataPart(
|
||||
data=d["data"],
|
||||
media_type=d.get("mediaType", "application/json"),
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
raise ValueError(f"Cannot determine Part type from keys: {list(d.keys())}")
|
||||
|
||||
|
||||
def part_to_dict(p: Part) -> dict:
|
||||
"""Serialize a Part to its JSON dict."""
|
||||
return p.to_dict()
|
||||
|
||||
|
||||
# --- Message ---
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""A2A Message — a turn in a conversation between user and agent."""
|
||||
role: Role
|
||||
parts: list[Part]
|
||||
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
context_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
metadata: dict = field(default_factory=dict)
|
||||
extensions: list[str] = field(default_factory=list)
|
||||
reference_task_ids: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d: dict[str, Any] = {
|
||||
"messageId": self.message_id,
|
||||
"role": self.role.value,
|
||||
"parts": [part_to_dict(p) for p in self.parts],
|
||||
}
|
||||
if self.context_id:
|
||||
d["contextId"] = self.context_id
|
||||
if self.task_id:
|
||||
d["taskId"] = self.task_id
|
||||
if self.metadata:
|
||||
d["metadata"] = self.metadata
|
||||
if self.extensions:
|
||||
d["extensions"] = self.extensions
|
||||
if self.reference_task_ids:
|
||||
d["referenceTaskIds"] = self.reference_task_ids
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Message":
|
||||
return cls(
|
||||
role=Role(d["role"]),
|
||||
parts=[part_from_dict(p) for p in d["parts"]],
|
||||
message_id=d.get("messageId", str(uuid.uuid4())),
|
||||
context_id=d.get("contextId"),
|
||||
task_id=d.get("taskId"),
|
||||
metadata=d.get("metadata", {}),
|
||||
extensions=d.get("extensions", []),
|
||||
reference_task_ids=d.get("referenceTaskIds", []),
|
||||
)
|
||||
|
||||
|
||||
# --- Artifact ---
|
||||
|
||||
@dataclass
|
||||
class Artifact:
|
||||
"""A2A Artifact — structured output from a task."""
|
||||
parts: list[Part]
|
||||
artifact_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
metadata: dict = field(default_factory=dict)
|
||||
extensions: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d: dict[str, Any] = {
|
||||
"artifactId": self.artifact_id,
|
||||
"parts": [part_to_dict(p) for p in self.parts],
|
||||
}
|
||||
if self.name:
|
||||
d["name"] = self.name
|
||||
if self.description:
|
||||
d["description"] = self.description
|
||||
if self.metadata:
|
||||
d["metadata"] = self.metadata
|
||||
if self.extensions:
|
||||
d["extensions"] = self.extensions
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Artifact":
|
||||
return cls(
|
||||
parts=[part_from_dict(p) for p in d["parts"]],
|
||||
artifact_id=d.get("artifactId", str(uuid.uuid4())),
|
||||
name=d.get("name"),
|
||||
description=d.get("description"),
|
||||
metadata=d.get("metadata", {}),
|
||||
extensions=d.get("extensions", []),
|
||||
)
|
||||
|
||||
|
||||
# --- Task ---
|
||||
|
||||
@dataclass
|
||||
class TaskStatus:
|
||||
"""Status envelope for a Task."""
|
||||
state: TaskState
|
||||
message: Optional[Message] = None
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d: dict[str, Any] = {"state": self.state.value}
|
||||
if self.message:
|
||||
d["message"] = self.message.to_dict()
|
||||
d["timestamp"] = self.timestamp
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "TaskStatus":
|
||||
msg = None
|
||||
if "message" in d:
|
||||
msg = Message.from_dict(d["message"])
|
||||
return cls(
|
||||
state=TaskState(d["state"]),
|
||||
message=msg,
|
||||
timestamp=d.get("timestamp", datetime.now(timezone.utc).isoformat()),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""A2A Task — a unit of work delegated between agents."""
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
context_id: Optional[str] = None
|
||||
status: TaskStatus = field(
|
||||
default_factory=lambda: TaskStatus(state=TaskState.SUBMITTED)
|
||||
)
|
||||
artifacts: list[Artifact] = field(default_factory=list)
|
||||
history: list[Message] = field(default_factory=list)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d: dict[str, Any] = {
|
||||
"id": self.id,
|
||||
"status": self.status.to_dict(),
|
||||
}
|
||||
if self.context_id:
|
||||
d["contextId"] = self.context_id
|
||||
if self.artifacts:
|
||||
d["artifacts"] = [a.to_dict() for a in self.artifacts]
|
||||
if self.history:
|
||||
d["history"] = [m.to_dict() for m in self.history]
|
||||
if self.metadata:
|
||||
d["metadata"] = self.metadata
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Task":
|
||||
return cls(
|
||||
id=d.get("id", str(uuid.uuid4())),
|
||||
context_id=d.get("contextId"),
|
||||
status=TaskStatus.from_dict(d["status"]) if "status" in d else TaskStatus(TaskState.SUBMITTED),
|
||||
artifacts=[Artifact.from_dict(a) for a in d.get("artifacts", [])],
|
||||
history=[Message.from_dict(m) for m in d.get("history", [])],
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
# --- Agent Card ---
|
||||
|
||||
@dataclass
|
||||
class AgentSkill:
|
||||
"""Capability declaration for an Agent Card."""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
tags: list[str] = field(default_factory=list)
|
||||
examples: list[str] = field(default_factory=list)
|
||||
input_modes: list[str] = field(default_factory=lambda: ["text/plain"])
|
||||
output_modes: list[str] = field(default_factory=lambda: ["text/plain"])
|
||||
security_requirements: list[dict] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d: dict[str, Any] = {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"tags": self.tags,
|
||||
}
|
||||
if self.examples:
|
||||
d["examples"] = self.examples
|
||||
if self.input_modes != ["text/plain"]:
|
||||
d["inputModes"] = self.input_modes
|
||||
if self.output_modes != ["text/plain"]:
|
||||
d["outputModes"] = self.output_modes
|
||||
if self.security_requirements:
|
||||
d["securityRequirements"] = self.security_requirements
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInterface:
|
||||
"""Network endpoint for an agent."""
|
||||
url: str
|
||||
protocol_binding: str = "HTTP+JSON"
|
||||
protocol_version: str = "1.0"
|
||||
tenant: str = ""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = {
|
||||
"url": self.url,
|
||||
"protocolBinding": self.protocol_binding,
|
||||
"protocolVersion": self.protocol_version,
|
||||
}
|
||||
if self.tenant:
|
||||
d["tenant"] = self.tenant
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentCapabilities:
|
||||
"""What this agent can do beyond basic request/response."""
|
||||
streaming: bool = False
|
||||
push_notifications: bool = False
|
||||
extended_agent_card: bool = False
|
||||
extensions: list[dict] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"streaming": self.streaming,
|
||||
"pushNotifications": self.push_notifications,
|
||||
"extendedAgentCard": self.extended_agent_card,
|
||||
"extensions": self.extensions,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentCard:
|
||||
"""
|
||||
A2A Agent Card — self-describing metadata published at
|
||||
/.well-known/agent-card.json
|
||||
"""
|
||||
name: str
|
||||
description: str
|
||||
version: str = "1.0.0"
|
||||
supported_interfaces: list[AgentInterface] = field(default_factory=list)
|
||||
capabilities: AgentCapabilities = field(
|
||||
default_factory=AgentCapabilities
|
||||
)
|
||||
provider: Optional[dict] = None
|
||||
documentation_url: Optional[str] = None
|
||||
icon_url: Optional[str] = None
|
||||
default_input_modes: list[str] = field(
|
||||
default_factory=lambda: ["text/plain"]
|
||||
)
|
||||
default_output_modes: list[str] = field(
|
||||
default_factory=lambda: ["text/plain"]
|
||||
)
|
||||
skills: list[AgentSkill] = field(default_factory=list)
|
||||
security_schemes: dict = field(default_factory=dict)
|
||||
security_requirements: list[dict] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"version": self.version,
|
||||
"supportedInterfaces": [i.to_dict() for i in self.supported_interfaces],
|
||||
"capabilities": self.capabilities.to_dict(),
|
||||
"defaultInputModes": self.default_input_modes,
|
||||
"defaultOutputModes": self.default_output_modes,
|
||||
"skills": [s.to_dict() for s in self.skills],
|
||||
}
|
||||
if self.provider:
|
||||
d["provider"] = self.provider
|
||||
if self.documentation_url:
|
||||
d["documentationUrl"] = self.documentation_url
|
||||
if self.icon_url:
|
||||
d["iconUrl"] = self.icon_url
|
||||
if self.security_schemes:
|
||||
d["securitySchemes"] = self.security_schemes
|
||||
if self.security_requirements:
|
||||
d["securityRequirements"] = self.security_requirements
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "AgentCard":
|
||||
return cls(
|
||||
name=d["name"],
|
||||
description=d["description"],
|
||||
version=d.get("version", "1.0.0"),
|
||||
supported_interfaces=[
|
||||
AgentInterface(
|
||||
url=i["url"],
|
||||
protocol_binding=i.get("protocolBinding", "HTTP+JSON"),
|
||||
protocol_version=i.get("protocolVersion", "1.0"),
|
||||
tenant=i.get("tenant", ""),
|
||||
)
|
||||
for i in d.get("supportedInterfaces", [])
|
||||
],
|
||||
capabilities=AgentCapabilities(
|
||||
streaming=d.get("capabilities", {}).get("streaming", False),
|
||||
push_notifications=d.get("capabilities", {}).get("pushNotifications", False),
|
||||
extended_agent_card=d.get("capabilities", {}).get("extendedAgentCard", False),
|
||||
extensions=d.get("capabilities", {}).get("extensions", []),
|
||||
),
|
||||
provider=d.get("provider"),
|
||||
documentation_url=d.get("documentationUrl"),
|
||||
icon_url=d.get("iconUrl"),
|
||||
default_input_modes=d.get("defaultInputModes", ["text/plain"]),
|
||||
default_output_modes=d.get("defaultOutputModes", ["text/plain"]),
|
||||
skills=[
|
||||
AgentSkill(
|
||||
id=s["id"],
|
||||
name=s["name"],
|
||||
description=s["description"],
|
||||
tags=s.get("tags", []),
|
||||
examples=s.get("examples", []),
|
||||
input_modes=s.get("inputModes", ["text/plain"]),
|
||||
output_modes=s.get("outputModes", ["text/plain"]),
|
||||
security_requirements=s.get("securityRequirements", []),
|
||||
)
|
||||
for s in d.get("skills", [])
|
||||
],
|
||||
security_schemes=d.get("securitySchemes", {}),
|
||||
security_requirements=d.get("securityRequirements", []),
|
||||
)
|
||||
|
||||
|
||||
# --- JSON-RPC envelope ---
|
||||
|
||||
@dataclass
|
||||
class JSONRPCRequest:
|
||||
"""JSON-RPC 2.0 request wrapping an A2A method."""
|
||||
method: str
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
params: dict = field(default_factory=dict)
|
||||
jsonrpc: str = "2.0"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"jsonrpc": self.jsonrpc,
|
||||
"id": self.id,
|
||||
"method": self.method,
|
||||
"params": self.params,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONRPCError:
|
||||
"""JSON-RPC 2.0 error object."""
|
||||
code: int
|
||||
message: str
|
||||
data: Any = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = {"code": self.code, "message": self.message}
|
||||
if self.data is not None:
|
||||
d["data"] = self.data
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONRPCResponse:
|
||||
"""JSON-RPC 2.0 response."""
|
||||
id: str
|
||||
result: Any = None
|
||||
error: Optional[JSONRPCError] = None
|
||||
jsonrpc: str = "2.0"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d: dict[str, Any] = {
|
||||
"jsonrpc": self.jsonrpc,
|
||||
"id": self.id,
|
||||
}
|
||||
if self.error:
|
||||
d["error"] = self.error.to_dict()
|
||||
else:
|
||||
d["result"] = self.result
|
||||
return d
|
||||
|
||||
|
||||
# --- Standard A2A Error codes ---
|
||||
|
||||
class A2AError:
|
||||
"""Standard A2A / JSON-RPC error factories."""
|
||||
PARSE = JSONRPCError(-32700, "Invalid JSON payload")
|
||||
INVALID_REQUEST = JSONRPCError(-32600, "Request payload validation error")
|
||||
METHOD_NOT_FOUND = JSONRPCError(-32601, "Method not found")
|
||||
INVALID_PARAMS = JSONRPCError(-32602, "Invalid parameters")
|
||||
INTERNAL = JSONRPCError(-32603, "Internal error")
|
||||
|
||||
TASK_NOT_FOUND = JSONRPCError(-32001, "Task not found")
|
||||
TASK_NOT_CANCELABLE = JSONRPCError(-32002, "Task not cancelable")
|
||||
PUSH_NOT_SUPPORTED = JSONRPCError(-32003, "Push notifications not supported")
|
||||
UNSUPPORTED_OP = JSONRPCError(-32004, "Unsupported operation")
|
||||
CONTENT_TYPE = JSONRPCError(-32005, "Content type not supported")
|
||||
INVALID_RESPONSE = JSONRPCError(-32006, "Invalid agent response")
|
||||
EXTENDED_CARD = JSONRPCError(-32007, "Extended agent card not configured")
|
||||
EXTENSION_REQUIRED = JSONRPCError(-32008, "Extension support required")
|
||||
VERSION_NOT_SUPPORTED = JSONRPCError(-32009, "Version not supported")
|
||||
@@ -1,763 +0,0 @@
|
||||
"""
|
||||
Tests for A2A Protocol implementation.
|
||||
|
||||
Covers:
|
||||
- Type serialization roundtrips (Agent Card, Task, Message, Artifact, Part)
|
||||
- JSON-RPC envelope
|
||||
- Agent Card building from YAML config
|
||||
- Registry operations (register, list, filter)
|
||||
- Client/server integration (end-to-end task delegation)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from nexus.a2a.types import (
|
||||
A2AError,
|
||||
AgentCard,
|
||||
AgentCapabilities,
|
||||
AgentInterface,
|
||||
AgentSkill,
|
||||
Artifact,
|
||||
DataPart,
|
||||
FilePart,
|
||||
JSONRPCError,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
Message,
|
||||
Role,
|
||||
Task,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TextPart,
|
||||
part_from_dict,
|
||||
part_to_dict,
|
||||
)
|
||||
from nexus.a2a.card import build_card, load_card_config
|
||||
from nexus.a2a.registry import LocalFileRegistry
|
||||
|
||||
|
||||
# === Type Serialization Roundtrips ===
|
||||
|
||||
|
||||
class TestTextPart:
|
||||
def test_roundtrip(self):
|
||||
p = TextPart(text="hello world")
|
||||
d = p.to_dict()
|
||||
assert d == {"text": "hello world"}
|
||||
p2 = part_from_dict(d)
|
||||
assert isinstance(p2, TextPart)
|
||||
assert p2.text == "hello world"
|
||||
|
||||
def test_custom_media_type(self):
|
||||
p = TextPart(text="data", media_type="text/markdown")
|
||||
d = p.to_dict()
|
||||
assert d["mediaType"] == "text/markdown"
|
||||
p2 = part_from_dict(d)
|
||||
assert p2.media_type == "text/markdown"
|
||||
|
||||
|
||||
class TestFilePart:
|
||||
def test_inline_roundtrip(self):
|
||||
p = FilePart(media_type="image/png", raw="base64data", filename="img.png")
|
||||
d = p.to_dict()
|
||||
assert d["raw"] == "base64data"
|
||||
assert d["filename"] == "img.png"
|
||||
p2 = part_from_dict(d)
|
||||
assert isinstance(p2, FilePart)
|
||||
assert p2.raw == "base64data"
|
||||
|
||||
def test_url_roundtrip(self):
|
||||
p = FilePart(media_type="application/pdf", url="https://example.com/doc.pdf")
|
||||
d = p.to_dict()
|
||||
assert d["url"] == "https://example.com/doc.pdf"
|
||||
p2 = part_from_dict(d)
|
||||
assert isinstance(p2, FilePart)
|
||||
assert p2.url == "https://example.com/doc.pdf"
|
||||
|
||||
|
||||
class TestDataPart:
|
||||
def test_roundtrip(self):
|
||||
p = DataPart(data={"key": "value", "count": 42})
|
||||
d = p.to_dict()
|
||||
assert d["data"] == {"key": "value", "count": 42}
|
||||
p2 = part_from_dict(d)
|
||||
assert isinstance(p2, DataPart)
|
||||
assert p2.data["count"] == 42
|
||||
|
||||
|
||||
class TestMessage:
|
||||
def test_roundtrip(self):
|
||||
msg = Message(
|
||||
role=Role.USER,
|
||||
parts=[TextPart(text="Hello agent")],
|
||||
metadata={"priority": "high"},
|
||||
)
|
||||
d = msg.to_dict()
|
||||
assert d["role"] == "ROLE_USER"
|
||||
assert d["parts"] == [{"text": "Hello agent"}]
|
||||
assert d["metadata"]["priority"] == "high"
|
||||
|
||||
msg2 = Message.from_dict(d)
|
||||
assert msg2.role == Role.USER
|
||||
assert isinstance(msg2.parts[0], TextPart)
|
||||
assert msg2.parts[0].text == "Hello agent"
|
||||
assert msg2.metadata["priority"] == "high"
|
||||
|
||||
def test_multi_part(self):
|
||||
msg = Message(
|
||||
role=Role.AGENT,
|
||||
parts=[
|
||||
TextPart(text="Here's the report"),
|
||||
DataPart(data={"status": "healthy"}),
|
||||
],
|
||||
)
|
||||
d = msg.to_dict()
|
||||
assert len(d["parts"]) == 2
|
||||
msg2 = Message.from_dict(d)
|
||||
assert len(msg2.parts) == 2
|
||||
assert isinstance(msg2.parts[0], TextPart)
|
||||
assert isinstance(msg2.parts[1], DataPart)
|
||||
|
||||
|
||||
class TestArtifact:
|
||||
def test_roundtrip(self):
|
||||
art = Artifact(
|
||||
parts=[TextPart(text="result data")],
|
||||
name="report",
|
||||
description="CI health report",
|
||||
)
|
||||
d = art.to_dict()
|
||||
assert d["name"] == "report"
|
||||
assert d["description"] == "CI health report"
|
||||
|
||||
art2 = Artifact.from_dict(d)
|
||||
assert art2.name == "report"
|
||||
assert isinstance(art2.parts[0], TextPart)
|
||||
assert art2.parts[0].text == "result data"
|
||||
|
||||
|
||||
class TestTask:
|
||||
def test_roundtrip(self):
|
||||
task = Task(
|
||||
id="test-123",
|
||||
status=TaskStatus(state=TaskState.WORKING),
|
||||
history=[
|
||||
Message(role=Role.USER, parts=[TextPart(text="Do X")]),
|
||||
],
|
||||
)
|
||||
d = task.to_dict()
|
||||
assert d["id"] == "test-123"
|
||||
assert d["status"]["state"] == "TASK_STATE_WORKING"
|
||||
|
||||
task2 = Task.from_dict(d)
|
||||
assert task2.id == "test-123"
|
||||
assert task2.status.state == TaskState.WORKING
|
||||
assert len(task2.history) == 1
|
||||
|
||||
def test_with_artifacts(self):
|
||||
task = Task(
|
||||
id="art-task",
|
||||
status=TaskStatus(state=TaskState.COMPLETED),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
parts=[TextPart(text="42")],
|
||||
name="answer",
|
||||
)
|
||||
],
|
||||
)
|
||||
d = task.to_dict()
|
||||
assert len(d["artifacts"]) == 1
|
||||
task2 = Task.from_dict(d)
|
||||
assert task2.artifacts[0].name == "answer"
|
||||
|
||||
def test_terminal_states(self):
|
||||
for state in [
|
||||
TaskState.COMPLETED,
|
||||
TaskState.FAILED,
|
||||
TaskState.CANCELED,
|
||||
TaskState.REJECTED,
|
||||
]:
|
||||
assert state.terminal is True
|
||||
|
||||
for state in [
|
||||
TaskState.SUBMITTED,
|
||||
TaskState.WORKING,
|
||||
TaskState.INPUT_REQUIRED,
|
||||
TaskState.AUTH_REQUIRED,
|
||||
]:
|
||||
assert state.terminal is False
|
||||
|
||||
|
||||
class TestAgentCard:
|
||||
def test_roundtrip(self):
|
||||
card = AgentCard(
|
||||
name="TestAgent",
|
||||
description="A test agent",
|
||||
version="1.0.0",
|
||||
supported_interfaces=[
|
||||
AgentInterface(url="http://localhost:8080/a2a/v1")
|
||||
],
|
||||
capabilities=AgentCapabilities(streaming=True),
|
||||
skills=[
|
||||
AgentSkill(
|
||||
id="test-skill",
|
||||
name="Test Skill",
|
||||
description="Does tests",
|
||||
tags=["test"],
|
||||
)
|
||||
],
|
||||
)
|
||||
d = card.to_dict()
|
||||
assert d["name"] == "TestAgent"
|
||||
assert d["capabilities"]["streaming"] is True
|
||||
assert len(d["skills"]) == 1
|
||||
assert d["skills"][0]["id"] == "test-skill"
|
||||
|
||||
card2 = AgentCard.from_dict(d)
|
||||
assert card2.name == "TestAgent"
|
||||
assert card2.skills[0].id == "test-skill"
|
||||
assert card2.capabilities.streaming is True
|
||||
|
||||
|
||||
class TestJSONRPC:
|
||||
def test_request_roundtrip(self):
|
||||
req = JSONRPCRequest(
|
||||
method="SendMessage",
|
||||
params={"message": {"text": "hello"}},
|
||||
)
|
||||
d = req.to_dict()
|
||||
assert d["jsonrpc"] == "2.0"
|
||||
assert d["method"] == "SendMessage"
|
||||
|
||||
def test_response_success(self):
|
||||
resp = JSONRPCResponse(
|
||||
id="req-1",
|
||||
result={"task": {"id": "t1"}},
|
||||
)
|
||||
d = resp.to_dict()
|
||||
assert "error" not in d
|
||||
assert d["result"]["task"]["id"] == "t1"
|
||||
|
||||
def test_response_error(self):
|
||||
resp = JSONRPCResponse(
|
||||
id="req-1",
|
||||
error=A2AError.TASK_NOT_FOUND,
|
||||
)
|
||||
d = resp.to_dict()
|
||||
assert "result" not in d
|
||||
assert d["error"]["code"] == -32001
|
||||
|
||||
|
||||
# === Agent Card Building ===
|
||||
|
||||
|
||||
class TestBuildCard:
|
||||
def test_basic_config(self):
|
||||
config = {
|
||||
"name": "Bezalel",
|
||||
"description": "CI/CD specialist",
|
||||
"version": "2.0.0",
|
||||
"url": "https://bezalel.example.com",
|
||||
"skills": [
|
||||
{
|
||||
"id": "ci-health",
|
||||
"name": "CI Health",
|
||||
"description": "Check CI",
|
||||
"tags": ["ci"],
|
||||
},
|
||||
{
|
||||
"id": "deploy",
|
||||
"name": "Deploy",
|
||||
"description": "Deploy services",
|
||||
"tags": ["ops"],
|
||||
},
|
||||
],
|
||||
}
|
||||
card = build_card(config)
|
||||
assert card.name == "Bezalel"
|
||||
assert card.version == "2.0.0"
|
||||
assert len(card.skills) == 2
|
||||
assert card.skills[0].id == "ci-health"
|
||||
assert card.supported_interfaces[0].url == "https://bezalel.example.com"
|
||||
|
||||
def test_bearer_auth(self):
|
||||
config = {
|
||||
"name": "Test",
|
||||
"description": "Test",
|
||||
"auth": {"scheme": "bearer", "token_env": "MY_TOKEN"},
|
||||
}
|
||||
card = build_card(config)
|
||||
assert "bearerAuth" in card.security_schemes
|
||||
assert card.security_requirements[0]["schemes"]["bearerAuth"] == {"list": []}
|
||||
|
||||
def test_api_key_auth(self):
|
||||
config = {
|
||||
"name": "Test",
|
||||
"description": "Test",
|
||||
"auth": {"scheme": "api_key", "key_name": "X-Custom-Key"},
|
||||
}
|
||||
card = build_card(config)
|
||||
assert "apiKeyAuth" in card.security_schemes
|
||||
|
||||
|
||||
# === Registry ===
|
||||
|
||||
|
||||
class TestLocalFileRegistry:
|
||||
def _make_card(self, name: str, skills: list[dict] | None = None) -> AgentCard:
|
||||
return AgentCard(
|
||||
name=name,
|
||||
description=f"Agent {name}",
|
||||
supported_interfaces=[
|
||||
AgentInterface(url=f"http://{name}:8080/a2a/v1")
|
||||
],
|
||||
skills=[
|
||||
AgentSkill(
|
||||
id=s["id"],
|
||||
name=s.get("name", s["id"]),
|
||||
description=s.get("description", ""),
|
||||
tags=s.get("tags", []),
|
||||
)
|
||||
for s in (skills or [])
|
||||
],
|
||||
)
|
||||
|
||||
def test_register_and_list(self, tmp_path):
|
||||
registry = LocalFileRegistry(tmp_path / "agents.json")
|
||||
registry.register(self._make_card("ezra"))
|
||||
registry.register(self._make_card("allegro"))
|
||||
|
||||
agents = registry.list_agents()
|
||||
assert len(agents) == 2
|
||||
names = {a.name for a in agents}
|
||||
assert names == {"ezra", "allegro"}
|
||||
|
||||
def test_filter_by_skill(self, tmp_path):
|
||||
registry = LocalFileRegistry(tmp_path / "agents.json")
|
||||
registry.register(
|
||||
self._make_card("ezra", [{"id": "ci-health", "tags": ["ci"]}])
|
||||
)
|
||||
registry.register(
|
||||
self._make_card("allegro", [{"id": "research", "tags": ["research"]}])
|
||||
)
|
||||
|
||||
ci_agents = registry.list_agents(skill="ci-health")
|
||||
assert len(ci_agents) == 1
|
||||
assert ci_agents[0].name == "ezra"
|
||||
|
||||
def test_filter_by_tag(self, tmp_path):
|
||||
registry = LocalFileRegistry(tmp_path / "agents.json")
|
||||
registry.register(
|
||||
self._make_card("ezra", [{"id": "ci", "tags": ["devops", "ci"]}])
|
||||
)
|
||||
registry.register(
|
||||
self._make_card("allegro", [{"id": "research", "tags": ["research"]}])
|
||||
)
|
||||
|
||||
devops_agents = registry.list_agents(tag="devops")
|
||||
assert len(devops_agents) == 1
|
||||
assert devops_agents[0].name == "ezra"
|
||||
|
||||
def test_persistence(self, tmp_path):
|
||||
path = tmp_path / "agents.json"
|
||||
reg1 = LocalFileRegistry(path)
|
||||
reg1.register(self._make_card("ezra"))
|
||||
|
||||
# Load fresh from disk
|
||||
reg2 = LocalFileRegistry(path)
|
||||
agents = reg2.list_agents()
|
||||
assert len(agents) == 1
|
||||
assert agents[0].name == "ezra"
|
||||
|
||||
def test_unregister(self, tmp_path):
|
||||
registry = LocalFileRegistry(tmp_path / "agents.json")
|
||||
registry.register(self._make_card("ezra"))
|
||||
assert len(registry.list_agents()) == 1
|
||||
|
||||
assert registry.unregister("ezra") is True
|
||||
assert len(registry.list_agents()) == 0
|
||||
assert registry.unregister("nonexistent") is False
|
||||
|
||||
def test_get_endpoint(self, tmp_path):
|
||||
registry = LocalFileRegistry(tmp_path / "agents.json")
|
||||
registry.register(self._make_card("ezra"))
|
||||
|
||||
url = registry.get_endpoint("ezra")
|
||||
assert url == "http://ezra:8080/a2a/v1"
|
||||
|
||||
|
||||
# === Server Integration (FastAPI required) ===
|
||||
|
||||
|
||||
try:
|
||||
from fastapi.testclient import TestClient
|
||||
HAS_TEST_CLIENT = True
|
||||
except ImportError:
|
||||
HAS_TEST_CLIENT = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_TEST_CLIENT, reason="fastapi not installed")
|
||||
class TestA2AServerIntegration:
|
||||
"""End-to-end tests using FastAPI TestClient."""
|
||||
|
||||
def _make_server(self, auth_token: str = ""):
|
||||
from nexus.a2a.server import A2AServer, echo_handler
|
||||
|
||||
card = AgentCard(
|
||||
name="TestAgent",
|
||||
description="Test agent for A2A",
|
||||
supported_interfaces=[
|
||||
AgentInterface(url="http://localhost:8080/a2a/v1")
|
||||
],
|
||||
capabilities=AgentCapabilities(streaming=False),
|
||||
skills=[
|
||||
AgentSkill(
|
||||
id="echo",
|
||||
name="Echo",
|
||||
description="Echo back messages",
|
||||
tags=["test"],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
server = A2AServer(card=card, auth_token=auth_token)
|
||||
server.register_handler("echo", echo_handler)
|
||||
server.set_default_handler(echo_handler)
|
||||
return server
|
||||
|
||||
def test_agent_card_well_known(self):
|
||||
server = self._make_server()
|
||||
client = TestClient(server.app)
|
||||
|
||||
resp = client.get("/.well-known/agent-card.json")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "TestAgent"
|
||||
assert len(data["skills"]) == 1
|
||||
|
||||
def test_agent_card_fallback(self):
|
||||
server = self._make_server()
|
||||
client = TestClient(server.app)
|
||||
|
||||
resp = client.get("/agent.json")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "TestAgent"
|
||||
|
||||
def test_send_message(self):
|
||||
server = self._make_server()
|
||||
client = TestClient(server.app)
|
||||
|
||||
rpc_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "test-1",
|
||||
"method": "SendMessage",
|
||||
"params": {
|
||||
"message": {
|
||||
"messageId": "msg-1",
|
||||
"role": "ROLE_USER",
|
||||
"parts": [{"text": "Hello from test"}],
|
||||
},
|
||||
"configuration": {
|
||||
"acceptedOutputModes": ["text/plain"],
|
||||
"historyLength": 10,
|
||||
"returnImmediately": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp = client.post("/a2a/v1", json=rpc_request)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "result" in data
|
||||
assert "task" in data["result"]
|
||||
|
||||
task = data["result"]["task"]
|
||||
assert task["status"]["state"] == "TASK_STATE_COMPLETED"
|
||||
assert len(task["artifacts"]) == 1
|
||||
assert "Echo" in task["artifacts"][0]["parts"][0]["text"]
|
||||
|
||||
def test_get_task(self):
|
||||
server = self._make_server()
|
||||
client = TestClient(server.app)
|
||||
|
||||
# Create a task first
|
||||
send_req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "s1",
|
||||
"method": "SendMessage",
|
||||
"params": {
|
||||
"message": {
|
||||
"messageId": "m1",
|
||||
"role": "ROLE_USER",
|
||||
"parts": [{"text": "get me"}],
|
||||
},
|
||||
"configuration": {},
|
||||
},
|
||||
}
|
||||
send_resp = client.post("/a2a/v1", json=send_req)
|
||||
task_id = send_resp.json()["result"]["task"]["id"]
|
||||
|
||||
# Now fetch it
|
||||
get_req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "g1",
|
||||
"method": "GetTask",
|
||||
"params": {"id": task_id},
|
||||
}
|
||||
get_resp = client.post("/a2a/v1", json=get_req)
|
||||
assert get_resp.status_code == 200
|
||||
assert get_resp.json()["result"]["id"] == task_id
|
||||
|
||||
def test_get_nonexistent_task(self):
|
||||
server = self._make_server()
|
||||
client = TestClient(server.app)
|
||||
|
||||
req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "g2",
|
||||
"method": "GetTask",
|
||||
"params": {"id": "nonexistent"},
|
||||
}
|
||||
resp = client.post("/a2a/v1", json=req)
|
||||
assert resp.status_code == 400
|
||||
data = resp.json()
|
||||
assert "error" in data
|
||||
|
||||
def test_list_tasks(self):
|
||||
server = self._make_server()
|
||||
client = TestClient(server.app)
|
||||
|
||||
# Create two tasks
|
||||
for i in range(2):
|
||||
req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"s{i}",
|
||||
"method": "SendMessage",
|
||||
"params": {
|
||||
"message": {
|
||||
"messageId": f"m{i}",
|
||||
"role": "ROLE_USER",
|
||||
"parts": [{"text": f"task {i}"}],
|
||||
},
|
||||
"configuration": {},
|
||||
},
|
||||
}
|
||||
client.post("/a2a/v1", json=req)
|
||||
|
||||
list_req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "l1",
|
||||
"method": "ListTasks",
|
||||
"params": {"pageSize": 10},
|
||||
}
|
||||
resp = client.post("/a2a/v1", json=list_req)
|
||||
assert resp.status_code == 200
|
||||
tasks = resp.json()["result"]["tasks"]
|
||||
assert len(tasks) >= 2
|
||||
|
||||
def test_cancel_task(self):
|
||||
from nexus.a2a.server import A2AServer
|
||||
|
||||
# Create a server with a slow handler so task stays WORKING
|
||||
async def slow_handler(task, card):
|
||||
import asyncio
|
||||
await asyncio.sleep(10) # never reached in test
|
||||
task.status = TaskStatus(state=TaskState.COMPLETED)
|
||||
return task
|
||||
|
||||
card = AgentCard(name="SlowAgent", description="Slow test agent")
|
||||
server = A2AServer(card=card)
|
||||
server.set_default_handler(slow_handler)
|
||||
client = TestClient(server.app)
|
||||
|
||||
# Create a task (but we need to intercept before handler runs)
|
||||
# Instead, manually insert a task and test cancel on it
|
||||
task = Task(
|
||||
id="cancel-me",
|
||||
status=TaskStatus(state=TaskState.WORKING),
|
||||
history=[
|
||||
Message(role=Role.USER, parts=[TextPart(text="cancel me")])
|
||||
],
|
||||
)
|
||||
server._tasks[task.id] = task
|
||||
|
||||
# Cancel it
|
||||
cancel_req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "c2",
|
||||
"method": "CancelTask",
|
||||
"params": {"id": "cancel-me"},
|
||||
}
|
||||
cancel_resp = client.post("/a2a/v1", json=cancel_req)
|
||||
assert cancel_resp.status_code == 200
|
||||
assert cancel_resp.json()["result"]["status"]["state"] == "TASK_STATE_CANCELED"
|
||||
|
||||
def test_auth_required(self):
|
||||
server = self._make_server(auth_token="secret123")
|
||||
client = TestClient(server.app)
|
||||
|
||||
# No auth header — should get 401
|
||||
req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "a1",
|
||||
"method": "SendMessage",
|
||||
"params": {
|
||||
"message": {
|
||||
"messageId": "am1",
|
||||
"role": "ROLE_USER",
|
||||
"parts": [{"text": "hello"}],
|
||||
},
|
||||
"configuration": {},
|
||||
},
|
||||
}
|
||||
resp = client.post("/a2a/v1", json=req)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_auth_success(self):
|
||||
server = self._make_server(auth_token="secret123")
|
||||
client = TestClient(server.app)
|
||||
|
||||
req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "a2",
|
||||
"method": "SendMessage",
|
||||
"params": {
|
||||
"message": {
|
||||
"messageId": "am2",
|
||||
"role": "ROLE_USER",
|
||||
"parts": [{"text": "authenticated"}],
|
||||
},
|
||||
"configuration": {},
|
||||
},
|
||||
}
|
||||
resp = client.post(
|
||||
"/a2a/v1",
|
||||
json=req,
|
||||
headers={"Authorization": "Bearer secret123"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["result"]["task"]["status"]["state"] == "TASK_STATE_COMPLETED"
|
||||
|
||||
def test_unknown_method(self):
|
||||
server = self._make_server()
|
||||
client = TestClient(server.app)
|
||||
|
||||
req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "u1",
|
||||
"method": "NonExistentMethod",
|
||||
"params": {},
|
||||
}
|
||||
resp = client.post("/a2a/v1", json=req)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json()["error"]["code"] == -32602
|
||||
|
||||
def test_audit_log(self):
|
||||
server = self._make_server()
|
||||
client = TestClient(server.app)
|
||||
|
||||
req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "au1",
|
||||
"method": "SendMessage",
|
||||
"params": {
|
||||
"message": {
|
||||
"messageId": "aum1",
|
||||
"role": "ROLE_USER",
|
||||
"parts": [{"text": "audit me"}],
|
||||
},
|
||||
"configuration": {},
|
||||
},
|
||||
}
|
||||
client.post("/a2a/v1", json=req)
|
||||
client.post("/a2a/v1", json=req)
|
||||
|
||||
log = server.get_audit_log()
|
||||
assert len(log) == 2
|
||||
assert all(entry["method"] == "SendMessage" for entry in log)
|
||||
|
||||
|
||||
# === Custom Handler Test ===
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_TEST_CLIENT, reason="fastapi not installed")
|
||||
class TestCustomHandlers:
|
||||
"""Test custom task handlers."""
|
||||
|
||||
def test_skill_routing(self):
|
||||
from nexus.a2a.server import A2AServer
|
||||
from nexus.a2a.types import Task, AgentCard
|
||||
|
||||
async def ci_handler(task: Task, card: AgentCard) -> Task:
|
||||
task.artifacts.append(
|
||||
Artifact(
|
||||
parts=[TextPart(text="CI pipeline healthy: 5/5 passed")],
|
||||
name="ci_report",
|
||||
)
|
||||
)
|
||||
task.status = TaskStatus(state=TaskState.COMPLETED)
|
||||
return task
|
||||
|
||||
card = AgentCard(
|
||||
name="CI Agent",
|
||||
description="CI specialist",
|
||||
skills=[AgentSkill(id="ci-health", name="CI Health", description="Check CI", tags=["ci"])],
|
||||
)
|
||||
server = A2AServer(card=card)
|
||||
server.register_handler("ci-health", ci_handler)
|
||||
|
||||
client = TestClient(server.app)
|
||||
req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "h1",
|
||||
"method": "SendMessage",
|
||||
"params": {
|
||||
"message": {
|
||||
"messageId": "hm1",
|
||||
"role": "ROLE_USER",
|
||||
"parts": [{"text": "Check CI"}],
|
||||
"metadata": {"targetSkill": "ci-health"},
|
||||
},
|
||||
"configuration": {},
|
||||
},
|
||||
}
|
||||
resp = client.post("/a2a/v1", json=req)
|
||||
task_data = resp.json()["result"]["task"]
|
||||
assert task_data["status"]["state"] == "TASK_STATE_COMPLETED"
|
||||
assert "5/5 passed" in task_data["artifacts"][0]["parts"][0]["text"]
|
||||
|
||||
def test_handler_error(self):
|
||||
from nexus.a2a.server import A2AServer
|
||||
from nexus.a2a.types import Task, AgentCard
|
||||
|
||||
async def failing_handler(task: Task, card: AgentCard) -> Task:
|
||||
raise RuntimeError("Handler blew up")
|
||||
|
||||
card = AgentCard(name="Fail Agent", description="Fails")
|
||||
server = A2AServer(card=card)
|
||||
server.set_default_handler(failing_handler)
|
||||
|
||||
client = TestClient(server.app)
|
||||
req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "f1",
|
||||
"method": "SendMessage",
|
||||
"params": {
|
||||
"message": {
|
||||
"messageId": "fm1",
|
||||
"role": "ROLE_USER",
|
||||
"parts": [{"text": "break"}],
|
||||
},
|
||||
"configuration": {},
|
||||
},
|
||||
}
|
||||
resp = client.post("/a2a/v1", json=req)
|
||||
task_data = resp.json()["result"]["task"]
|
||||
assert task_data["status"]["state"] == "TASK_STATE_FAILED"
|
||||
assert "blew up" in task_data["status"]["message"]["parts"][0]["text"].lower()
|
||||
377
tests/test_agent_memory.py
Normal file
377
tests/test_agent_memory.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Tests for agent memory — cross-session agent memory via MemPalace.
|
||||
|
||||
Tests the memory module, hooks, and session mining without requiring
|
||||
a live ChromaDB instance. Uses mocking for the MemPalace backend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.memory import (
|
||||
AgentMemory,
|
||||
MemoryContext,
|
||||
SessionTranscript,
|
||||
create_agent_memory,
|
||||
)
|
||||
from agent.memory_hooks import MemoryHooks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionTranscript tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionTranscript:
|
||||
def test_create(self):
|
||||
t = SessionTranscript(agent_name="test", wing="wing_test")
|
||||
assert t.agent_name == "test"
|
||||
assert t.wing == "wing_test"
|
||||
assert len(t.entries) == 0
|
||||
|
||||
def test_add_user_turn(self):
|
||||
t = SessionTranscript(agent_name="test", wing="wing_test")
|
||||
t.add_user_turn("Hello")
|
||||
assert len(t.entries) == 1
|
||||
assert t.entries[0]["role"] == "user"
|
||||
assert t.entries[0]["text"] == "Hello"
|
||||
|
||||
def test_add_agent_turn(self):
|
||||
t = SessionTranscript(agent_name="test", wing="wing_test")
|
||||
t.add_agent_turn("Response")
|
||||
assert t.entries[0]["role"] == "agent"
|
||||
|
||||
def test_add_tool_call(self):
|
||||
t = SessionTranscript(agent_name="test", wing="wing_test")
|
||||
t.add_tool_call("shell", "ls", "file1 file2")
|
||||
assert t.entries[0]["role"] == "tool"
|
||||
assert t.entries[0]["tool"] == "shell"
|
||||
|
||||
def test_summary_empty(self):
|
||||
t = SessionTranscript(agent_name="test", wing="wing_test")
|
||||
assert t.summary() == "Empty session."
|
||||
|
||||
def test_summary_with_entries(self):
|
||||
t = SessionTranscript(agent_name="test", wing="wing_test")
|
||||
t.add_user_turn("Do something")
|
||||
t.add_agent_turn("Done")
|
||||
t.add_tool_call("shell", "ls", "ok")
|
||||
|
||||
summary = t.summary()
|
||||
assert "USER: Do something" in summary
|
||||
assert "AGENT: Done" in summary
|
||||
assert "TOOL(shell): ok" in summary
|
||||
|
||||
def test_text_truncation(self):
|
||||
t = SessionTranscript(agent_name="test", wing="wing_test")
|
||||
long_text = "x" * 5000
|
||||
t.add_user_turn(long_text)
|
||||
assert len(t.entries[0]["text"]) == 2000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryContext tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMemoryContext:
|
||||
def test_empty_context(self):
|
||||
ctx = MemoryContext()
|
||||
assert ctx.to_prompt_block() == ""
|
||||
|
||||
def test_unloaded_context(self):
|
||||
ctx = MemoryContext()
|
||||
ctx.loaded = False
|
||||
assert ctx.to_prompt_block() == ""
|
||||
|
||||
def test_loaded_with_data(self):
|
||||
ctx = MemoryContext()
|
||||
ctx.loaded = True
|
||||
ctx.recent_diaries = [
|
||||
{"text": "Fixed PR #1386", "timestamp": "2026-04-13T10:00:00Z"}
|
||||
]
|
||||
ctx.facts = [
|
||||
{"text": "Bezalel runs on VPS Beta", "score": 0.95}
|
||||
]
|
||||
ctx.relevant_memories = [
|
||||
{"text": "Changed CI runner", "score": 0.87}
|
||||
]
|
||||
|
||||
block = ctx.to_prompt_block()
|
||||
assert "Recent Session Summaries" in block
|
||||
assert "Fixed PR #1386" in block
|
||||
assert "Known Facts" in block
|
||||
assert "Bezalel runs on VPS Beta" in block
|
||||
assert "Relevant Past Memories" in block
|
||||
|
||||
def test_loaded_empty(self):
|
||||
ctx = MemoryContext()
|
||||
ctx.loaded = True
|
||||
# No data — should return empty string
|
||||
assert ctx.to_prompt_block() == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentMemory tests (with mocked MemPalace)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAgentMemory:
|
||||
def test_create(self):
|
||||
mem = AgentMemory(agent_name="bezalel")
|
||||
assert mem.agent_name == "bezalel"
|
||||
assert mem.wing == "wing_bezalel"
|
||||
|
||||
def test_custom_wing(self):
|
||||
mem = AgentMemory(agent_name="bezalel", wing="custom_wing")
|
||||
assert mem.wing == "custom_wing"
|
||||
|
||||
def test_factory(self):
|
||||
mem = create_agent_memory("ezra")
|
||||
assert mem.agent_name == "ezra"
|
||||
assert mem.wing == "wing_ezra"
|
||||
|
||||
def test_unavailable_graceful(self):
|
||||
"""Test graceful degradation when MemPalace is unavailable."""
|
||||
mem = AgentMemory(agent_name="test")
|
||||
mem._available = False # Force unavailable
|
||||
|
||||
# Should not raise
|
||||
ctx = mem.recall_context("test query")
|
||||
assert ctx.loaded is False
|
||||
assert ctx.error == "MemPalace unavailable"
|
||||
|
||||
# remember returns None
|
||||
assert mem.remember("test") is None
|
||||
|
||||
# search returns empty
|
||||
assert mem.search("test") == []
|
||||
|
||||
def test_start_end_session(self):
|
||||
mem = AgentMemory(agent_name="test")
|
||||
mem._available = False
|
||||
|
||||
transcript = mem.start_session()
|
||||
assert isinstance(transcript, SessionTranscript)
|
||||
assert mem._transcript is not None
|
||||
|
||||
doc_id = mem.end_session()
|
||||
assert mem._transcript is None
|
||||
|
||||
def test_remember_graceful_when_unavailable(self):
|
||||
"""Test remember returns None when MemPalace is unavailable."""
|
||||
mem = AgentMemory(agent_name="test")
|
||||
mem._available = False
|
||||
|
||||
doc_id = mem.remember("some important fact")
|
||||
assert doc_id is None
|
||||
|
||||
def test_write_diary_from_transcript(self):
|
||||
mem = AgentMemory(agent_name="test")
|
||||
mem._available = False
|
||||
|
||||
transcript = mem.start_session()
|
||||
transcript.add_user_turn("Hello")
|
||||
transcript.add_agent_turn("Hi there")
|
||||
|
||||
# Write diary should handle unavailable gracefully
|
||||
doc_id = mem.write_diary()
|
||||
assert doc_id is None # MemPalace unavailable
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryHooks tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMemoryHooks:
|
||||
def test_create(self):
|
||||
hooks = MemoryHooks(agent_name="bezalel")
|
||||
assert hooks.agent_name == "bezalel"
|
||||
assert hooks.is_active is False
|
||||
|
||||
def test_session_lifecycle(self):
|
||||
hooks = MemoryHooks(agent_name="test")
|
||||
|
||||
# Force memory unavailable
|
||||
hooks._memory = AgentMemory(agent_name="test")
|
||||
hooks._memory._available = False
|
||||
|
||||
# Start session
|
||||
block = hooks.on_session_start()
|
||||
assert hooks.is_active is True
|
||||
assert block == "" # No memory available
|
||||
|
||||
# Record turns
|
||||
hooks.on_user_turn("Hello")
|
||||
hooks.on_agent_turn("Hi")
|
||||
hooks.on_tool_call("shell", "ls", "ok")
|
||||
|
||||
# Record decision
|
||||
hooks.on_important_decision("Switched to self-hosted CI")
|
||||
|
||||
# End session
|
||||
doc_id = hooks.on_session_end()
|
||||
assert hooks.is_active is False
|
||||
|
||||
def test_hooks_before_session(self):
|
||||
"""Hooks before session start should be no-ops."""
|
||||
hooks = MemoryHooks(agent_name="test")
|
||||
hooks._memory = AgentMemory(agent_name="test")
|
||||
hooks._memory._available = False
|
||||
|
||||
# Should not raise
|
||||
hooks.on_user_turn("Hello")
|
||||
hooks.on_agent_turn("Response")
|
||||
|
||||
def test_hooks_after_session_end(self):
|
||||
"""Hooks after session end should be no-ops."""
|
||||
hooks = MemoryHooks(agent_name="test")
|
||||
hooks._memory = AgentMemory(agent_name="test")
|
||||
hooks._memory._available = False
|
||||
|
||||
hooks.on_session_start()
|
||||
hooks.on_session_end()
|
||||
|
||||
# Should not raise
|
||||
hooks.on_user_turn("Late message")
|
||||
doc_id = hooks.on_session_end()
|
||||
assert doc_id is None
|
||||
|
||||
def test_search_during_session(self):
|
||||
hooks = MemoryHooks(agent_name="test")
|
||||
hooks._memory = AgentMemory(agent_name="test")
|
||||
hooks._memory._available = False
|
||||
|
||||
results = hooks.search("some query")
|
||||
assert results == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session mining tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionMining:
|
||||
def test_parse_session_file(self):
|
||||
from bin.memory_mine import parse_session_file
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write('{"role": "user", "content": "Hello"}\n')
|
||||
f.write('{"role": "assistant", "content": "Hi there"}\n')
|
||||
f.write('{"role": "tool", "name": "shell", "content": "ls output"}\n')
|
||||
f.write("\n") # blank line
|
||||
f.write("not json\n") # malformed
|
||||
path = Path(f.name)
|
||||
|
||||
turns = parse_session_file(path)
|
||||
assert len(turns) == 3
|
||||
assert turns[0]["role"] == "user"
|
||||
assert turns[1]["role"] == "assistant"
|
||||
assert turns[2]["role"] == "tool"
|
||||
path.unlink()
|
||||
|
||||
def test_summarize_session(self):
|
||||
from bin.memory_mine import summarize_session
|
||||
|
||||
turns = [
|
||||
{"role": "user", "content": "Check CI"},
|
||||
{"role": "assistant", "content": "Running CI check..."},
|
||||
{"role": "tool", "name": "shell", "content": "5 tests passed"},
|
||||
{"role": "assistant", "content": "CI is healthy"},
|
||||
]
|
||||
|
||||
summary = summarize_session(turns, "bezalel")
|
||||
assert "bezalel" in summary
|
||||
assert "Check CI" in summary
|
||||
assert "shell" in summary
|
||||
|
||||
def test_summarize_empty(self):
|
||||
from bin.memory_mine import summarize_session
|
||||
|
||||
assert summarize_session([], "test") == "Empty session."
|
||||
|
||||
def test_find_session_files(self, tmp_path):
|
||||
from bin.memory_mine import find_session_files
|
||||
|
||||
# Create some test files
|
||||
(tmp_path / "session1.jsonl").write_text("{}\n")
|
||||
(tmp_path / "session2.jsonl").write_text("{}\n")
|
||||
(tmp_path / "notes.txt").write_text("not a session")
|
||||
|
||||
files = find_session_files(tmp_path, days=365)
|
||||
assert len(files) == 2
|
||||
assert all(f.suffix == ".jsonl" for f in files)
|
||||
|
||||
def test_find_session_files_missing_dir(self):
|
||||
from bin.memory_mine import find_session_files
|
||||
|
||||
files = find_session_files(Path("/nonexistent/path"), days=7)
|
||||
assert files == []
|
||||
|
||||
def test_mine_session_dry_run(self, tmp_path):
|
||||
from bin.memory_mine import mine_session
|
||||
|
||||
session_file = tmp_path / "test.jsonl"
|
||||
session_file.write_text(
|
||||
'{"role": "user", "content": "Hello"}\n'
|
||||
'{"role": "assistant", "content": "Hi"}\n'
|
||||
)
|
||||
|
||||
result = mine_session(session_file, wing="wing_test", dry_run=True)
|
||||
assert result is None # dry run doesn't store
|
||||
|
||||
def test_mine_session_empty_file(self, tmp_path):
|
||||
from bin.memory_mine import mine_session
|
||||
|
||||
session_file = tmp_path / "empty.jsonl"
|
||||
session_file.write_text("")
|
||||
|
||||
result = mine_session(session_file, wing="wing_test")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration test — full lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFullLifecycle:
|
||||
"""Test the full session lifecycle without a real MemPalace backend."""
|
||||
|
||||
def test_full_session_flow(self):
|
||||
hooks = MemoryHooks(agent_name="bezalel")
|
||||
|
||||
# Force memory unavailable
|
||||
hooks._memory = AgentMemory(agent_name="bezalel")
|
||||
hooks._memory._available = False
|
||||
|
||||
# 1. Session start
|
||||
context_block = hooks.on_session_start("What CI issues do I have?")
|
||||
assert isinstance(context_block, str)
|
||||
|
||||
# 2. User asks question
|
||||
hooks.on_user_turn("Check CI pipeline health")
|
||||
|
||||
# 3. Agent uses tool
|
||||
hooks.on_tool_call("shell", "pytest tests/", "12 passed")
|
||||
|
||||
# 4. Agent responds
|
||||
hooks.on_agent_turn("CI pipeline is healthy. All 12 tests passing.")
|
||||
|
||||
# 5. Important decision
|
||||
hooks.on_important_decision("Decided to keep current CI runner", room="forge")
|
||||
|
||||
# 6. More interaction
|
||||
hooks.on_user_turn("Good, check memory integration next")
|
||||
hooks.on_agent_turn("Will test agent.memory module")
|
||||
|
||||
# 7. Session end
|
||||
doc_id = hooks.on_session_end()
|
||||
assert hooks.is_active is False
|
||||
Reference in New Issue
Block a user