Compare commits
1 Commits
burn/838-1
...
fix/819
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1adbf7ed1b |
224
agent/agent_card.py
Normal file
224
agent/agent_card.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""A2A Agent Card — publish capabilities for fleet discovery.
|
||||
|
||||
Each fleet agent publishes an A2A-compliant agent card describing its capabilities.
|
||||
Standard discovery endpoint: /.well-known/agent-card.json
|
||||
|
||||
Issue #819: feat: A2A agent card — publish capabilities for fleet discovery
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSkill:
|
||||
"""A single skill the agent can perform."""
|
||||
id: str
|
||||
name: str
|
||||
description: str = ""
|
||||
tags: List[str] = field(default_factory=list)
|
||||
examples: List[str] = field(default_factory=list)
|
||||
input_modes: List[str] = field(default_factory=lambda: ["text/plain"])
|
||||
output_modes: List[str] = field(default_factory=lambda: ["text/plain"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentCapabilities:
|
||||
"""What the agent can do."""
|
||||
streaming: bool = True
|
||||
push_notifications: bool = False
|
||||
state_transition_history: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentCard:
|
||||
"""A2A-compliant agent card."""
|
||||
name: str
|
||||
description: str
|
||||
url: str
|
||||
version: str = "1.0.0"
|
||||
capabilities: AgentCapabilities = field(default_factory=AgentCapabilities)
|
||||
skills: List[AgentSkill] = field(default_factory=list)
|
||||
default_input_modes: List[str] = field(default_factory=lambda: ["text/plain", "application/json"])
|
||||
default_output_modes: List[str] = field(default_factory=lambda: ["text/plain", "application/json"])
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to JSON-serializable dict."""
|
||||
d = asdict(self)
|
||||
# Rename for A2A spec compliance
|
||||
d["defaultInputModes"] = d.pop("default_input_modes")
|
||||
d["defaultOutputModes"] = d.pop("default_output_modes")
|
||||
return d
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Serialize to JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
|
||||
def _load_skills_from_directory(skills_dir: Path) -> List[AgentSkill]:
|
||||
"""Scan ~/.hermes/skills/ for SKILL.md frontmatter."""
|
||||
skills = []
|
||||
|
||||
if not skills_dir.exists():
|
||||
return skills
|
||||
|
||||
for skill_dir in skills_dir.iterdir():
|
||||
if not skill_dir.is_dir():
|
||||
continue
|
||||
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
|
||||
# Parse YAML frontmatter
|
||||
if content.startswith("---"):
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) >= 3:
|
||||
import yaml
|
||||
try:
|
||||
metadata = yaml.safe_load(parts[1]) or {}
|
||||
except Exception:
|
||||
metadata = {}
|
||||
|
||||
name = metadata.get("name", skill_dir.name)
|
||||
desc = metadata.get("description", "")
|
||||
tags = metadata.get("tags", [])
|
||||
|
||||
skills.append(AgentSkill(
|
||||
id=skill_dir.name,
|
||||
name=name,
|
||||
description=desc[:200] if desc else "",
|
||||
tags=tags if isinstance(tags, list) else [],
|
||||
))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
def validate_agent_card(card: AgentCard) -> List[str]:
|
||||
"""Validate agent card against A2A schema requirements.
|
||||
|
||||
Returns list of validation errors (empty if valid).
|
||||
"""
|
||||
errors = []
|
||||
|
||||
if not card.name:
|
||||
errors.append("name is required")
|
||||
if not card.url:
|
||||
errors.append("url is required")
|
||||
|
||||
# Validate MIME types
|
||||
valid_modes = {"text/plain", "application/json", "image/png", "audio/wav"}
|
||||
for mode in card.default_input_modes:
|
||||
if mode not in valid_modes:
|
||||
errors.append(f"invalid input mode: {mode}")
|
||||
for mode in card.default_output_modes:
|
||||
if mode not in valid_modes:
|
||||
errors.append(f"invalid output mode: {mode}")
|
||||
|
||||
# Validate skills
|
||||
for skill in card.skills:
|
||||
if not skill.id:
|
||||
errors.append(f"skill missing id: {skill.name}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def build_agent_card(
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
skills: Optional[List[AgentSkill]] = None,
|
||||
extra_skills: Optional[List[AgentSkill]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> AgentCard:
|
||||
"""Build an A2A agent card from config and environment.
|
||||
|
||||
Priority: explicit params > env vars > config.yaml > defaults
|
||||
"""
|
||||
# Load config
|
||||
config_model = ""
|
||||
config_provider = ""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
model_cfg = cfg.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
config_model = model_cfg.get("default", "")
|
||||
config_provider = model_cfg.get("provider", "")
|
||||
elif isinstance(model_cfg, str):
|
||||
config_model = model_cfg
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Resolve values with priority
|
||||
agent_name = name or os.environ.get("HERMES_AGENT_NAME", "") or "hermes"
|
||||
agent_desc = description or os.environ.get("HERMES_AGENT_DESCRIPTION", "") or "Sovereign AI agent"
|
||||
agent_url = url or os.environ.get("HERMES_AGENT_URL", "") or f"http://localhost:{os.environ.get('HERMES_API_PORT', '8642')}"
|
||||
agent_version = version or os.environ.get("HERMES_AGENT_VERSION", "") or "1.0.0"
|
||||
|
||||
# Load skills
|
||||
if skills is not None:
|
||||
agent_skills = skills
|
||||
else:
|
||||
from hermes_constants import get_hermes_home
|
||||
skills_dir = get_hermes_home() / "skills"
|
||||
agent_skills = _load_skills_from_directory(skills_dir)
|
||||
|
||||
# Add extra skills
|
||||
if extra_skills:
|
||||
existing_ids = {s.id for s in agent_skills}
|
||||
for skill in extra_skills:
|
||||
if skill.id not in existing_ids:
|
||||
agent_skills.append(skill)
|
||||
|
||||
# Build metadata
|
||||
card_metadata = {
|
||||
"model": config_model or os.environ.get("HERMES_MODEL", ""),
|
||||
"provider": config_provider or os.environ.get("HERMES_PROVIDER", ""),
|
||||
"hostname": socket.gethostname(),
|
||||
}
|
||||
if metadata:
|
||||
card_metadata.update(metadata)
|
||||
|
||||
# Build capabilities
|
||||
capabilities = AgentCapabilities(
|
||||
streaming=True,
|
||||
push_notifications=False,
|
||||
state_transition_history=True,
|
||||
)
|
||||
|
||||
return AgentCard(
|
||||
name=agent_name,
|
||||
description=agent_desc,
|
||||
url=agent_url,
|
||||
version=agent_version,
|
||||
capabilities=capabilities,
|
||||
skills=agent_skills,
|
||||
metadata=card_metadata,
|
||||
)
|
||||
|
||||
|
||||
def get_agent_card_json() -> str:
|
||||
"""Get agent card as JSON string (for HTTP endpoint)."""
|
||||
try:
|
||||
card = build_agent_card()
|
||||
return card.to_json()
|
||||
except Exception as e:
|
||||
# Graceful fallback — return minimal card so discovery doesn't break
|
||||
fallback = AgentCard(
|
||||
name="hermes",
|
||||
description="Sovereign AI agent",
|
||||
url=f"http://localhost:{os.environ.get('HERMES_API_PORT', '8642')}",
|
||||
)
|
||||
return fallback.to_json()
|
||||
@@ -1,235 +0,0 @@
|
||||
"""Context Budget Tracker — Proactive token counting with graduated warnings.
|
||||
|
||||
Poka-yoke (mistake-proofing) for context window overflow. Tracks approximate
|
||||
token usage per turn and emits warnings at 70%, 85%, and 95% thresholds
|
||||
relative to the compression threshold (not the raw context window).
|
||||
|
||||
Usage:
|
||||
tracker = ContextBudgetTracker(context_length=128_000, threshold_percent=0.50)
|
||||
tracker.update(estimated_tokens=45_000)
|
||||
level = tracker.warning_level # "elevated" | "critical" | "emergency" | None
|
||||
if level:
|
||||
print(tracker.warning_message)
|
||||
|
||||
Integration points in run_agent.py:
|
||||
1. After `estimate_messages_tokens_rough(messages)` in the agent loop,
|
||||
call `tracker.update(_real_tokens)`.
|
||||
2. Check `tracker.should_checkpoint()` to auto-save session state.
|
||||
3. Check `tracker.should_gate(content_tokens)` before loading large
|
||||
files or skills.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Warning thresholds relative to compression threshold
|
||||
THRESHOLD_CAUTION = 0.70 # 70% of threshold — yellow
|
||||
THRESHOLD_ELEVATED = 0.85 # 85% of threshold — orange
|
||||
THRESHOLD_CRITICAL = 0.95 # 95% of threshold — red
|
||||
|
||||
# Cooldown between repeated warnings at the same tier (seconds)
|
||||
WARNING_COOLDOWN = 300
|
||||
|
||||
# Pre-flight safety margin: refuse loads that would push past this
|
||||
PREFLIGHT_SAFETY_MARGIN = 0.90
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextBudgetTracker:
|
||||
"""Tracks context window usage and emits graduated warnings.
|
||||
|
||||
All thresholds are relative to the compression threshold, which is
|
||||
itself a fraction of the total context window. For example, with a
|
||||
128K context window and 50% compression threshold:
|
||||
- threshold_tokens = 64,000
|
||||
- caution at 70% = 44,800 tokens
|
||||
- elevated at 85% = 54,400 tokens
|
||||
- critical at 95% = 60,800 tokens
|
||||
"""
|
||||
|
||||
context_length: int
|
||||
threshold_percent: float = 0.50
|
||||
|
||||
# Current state (updated by .update())
|
||||
current_tokens: int = 0
|
||||
peak_tokens: int = 0
|
||||
turn_count: int = 0
|
||||
|
||||
# Warning state
|
||||
_last_warned_tier: float = 0.0
|
||||
_last_warned_time: float = 0.0
|
||||
_checkpoint_saved_at: float = 0.0
|
||||
|
||||
# History for trend analysis
|
||||
_token_history: list = field(default_factory=list)
|
||||
_max_history: int = 50
|
||||
|
||||
@property
|
||||
def threshold_tokens(self) -> int:
|
||||
"""Compression threshold in absolute tokens."""
|
||||
return int(self.context_length * self.threshold_percent)
|
||||
|
||||
@property
|
||||
def progress(self) -> float:
|
||||
"""Current usage as fraction of compression threshold (0.0 to 1.0+)."""
|
||||
if self.threshold_tokens == 0:
|
||||
return 0.0
|
||||
return self.current_tokens / self.threshold_tokens
|
||||
|
||||
@property
|
||||
def warning_level(self) -> Optional[str]:
|
||||
"""Current warning level or None if below caution threshold."""
|
||||
p = self.progress
|
||||
if p >= THRESHOLD_CRITICAL:
|
||||
return "emergency"
|
||||
elif p >= THRESHOLD_ELEVATED:
|
||||
return "critical"
|
||||
elif p >= THRESHOLD_CAUTION:
|
||||
return "elevated"
|
||||
return None
|
||||
|
||||
@property
|
||||
def warning_message(self) -> Optional[str]:
|
||||
"""Human-readable warning message for current level, or None."""
|
||||
level = self.warning_level
|
||||
if level is None:
|
||||
return None
|
||||
pct = int(self.progress * 100)
|
||||
used = f"{self.current_tokens:,}"
|
||||
limit = f"{self.threshold_tokens:,}"
|
||||
msgs = {
|
||||
"elevated": f"[CONTEXT WARNING: {pct}% used ({used}/{limit} tokens) — consider summarizing or checkpointing]",
|
||||
"critical": f"[CONTEXT WARNING: {pct}% used ({used}/{limit} tokens) — checkpoint recommended, compaction approaching]",
|
||||
"emergency": f"[CONTEXT CRITICAL: {pct}% used ({used}/{limit} tokens) — compaction imminent, auto-summarizing older content]",
|
||||
}
|
||||
return msgs.get(level)
|
||||
|
||||
@property
|
||||
def should_emit_warning(self) -> bool:
|
||||
"""Whether a new warning should be emitted (dedup + cooldown)."""
|
||||
level = self.warning_level
|
||||
if level is None:
|
||||
return False
|
||||
|
||||
tier = {"elevated": THRESHOLD_CAUTION, "critical": THRESHOLD_ELEVATED, "emergency": THRESHOLD_CRITICAL}
|
||||
tier_val = tier.get(level, 0)
|
||||
|
||||
now = time.time()
|
||||
if tier_val <= self._last_warned_tier:
|
||||
# Same or lower tier — check cooldown
|
||||
if (now - self._last_warned_time) < WARNING_COOLDOWN:
|
||||
return False
|
||||
# New higher tier or cooldown expired
|
||||
return True
|
||||
|
||||
def mark_warned(self):
|
||||
"""Call after emitting a warning to update dedup state."""
|
||||
level = self.warning_level
|
||||
tier = {"elevated": THRESHOLD_CAUTION, "critical": THRESHOLD_ELEVATED, "emergency": THRESHOLD_CRITICAL}
|
||||
self._last_warned_tier = tier.get(level, 0)
|
||||
self._last_warned_time = time.time()
|
||||
|
||||
def update(self, estimated_tokens: int) -> Optional[str]:
|
||||
"""Update current token count and return warning message if warranted.
|
||||
|
||||
Args:
|
||||
estimated_tokens: Rough token count of the current messages.
|
||||
|
||||
Returns:
|
||||
Warning message string if a warning should be shown, else None.
|
||||
"""
|
||||
self.current_tokens = estimated_tokens
|
||||
self.turn_count += 1
|
||||
if estimated_tokens > self.peak_tokens:
|
||||
self.peak_tokens = estimated_tokens
|
||||
|
||||
# Record history
|
||||
self._token_history.append((self.turn_count, estimated_tokens, time.time()))
|
||||
if len(self._token_history) > self._max_history:
|
||||
self._token_history = self._token_history[-self._max_history:]
|
||||
|
||||
if self.should_emit_warning:
|
||||
self.mark_warned()
|
||||
return self.warning_message
|
||||
return None
|
||||
|
||||
def should_checkpoint(self) -> bool:
|
||||
"""Whether session state should be auto-saved (85% threshold).
|
||||
|
||||
Returns True once per crossing of the elevated threshold, with a
|
||||
cooldown to avoid repeated saves.
|
||||
"""
|
||||
if self.progress < THRESHOLD_ELEVATED:
|
||||
return False
|
||||
now = time.time()
|
||||
if (now - self._checkpoint_saved_at) < WARNING_COOLDOWN:
|
||||
return False
|
||||
self._checkpoint_saved_at = now
|
||||
return True
|
||||
|
||||
def can_fit(self, additional_tokens: int) -> bool:
|
||||
"""Pre-flight check: would adding this many tokens exceed the safety margin?
|
||||
|
||||
Use before loading large files or skills to prevent overflow.
|
||||
"""
|
||||
projected = self.current_tokens + additional_tokens
|
||||
return projected < int(self.threshold_tokens * PREFLIGHT_SAFETY_MARGIN)
|
||||
|
||||
def estimate_file_tokens(self, file_size_bytes: int) -> int:
|
||||
"""Rough token estimate for a file of given size (~4 chars/token)."""
|
||||
return max(1, file_size_bytes // 4)
|
||||
|
||||
def tokens_remaining(self) -> int:
|
||||
"""Approximate tokens available before hitting the safety margin."""
|
||||
safe_limit = int(self.threshold_tokens * PREFLIGHT_SAFETY_MARGIN)
|
||||
return max(0, safe_limit - self.current_tokens)
|
||||
|
||||
def trend(self, window: int = 10) -> str:
|
||||
"""Token growth trend over the last N turns: 'growing' | 'stable' | 'shrinking'."""
|
||||
if len(self._token_history) < 2:
|
||||
return "stable"
|
||||
recent = self._token_history[-window:]
|
||||
if len(recent) < 2:
|
||||
return "stable"
|
||||
first = recent[0][1]
|
||||
last = recent[-1][1]
|
||||
delta = last - first
|
||||
threshold = self.threshold_tokens * 0.05 # 5% of threshold
|
||||
if delta > threshold:
|
||||
return "growing"
|
||||
elif delta < -threshold:
|
||||
return "shrinking"
|
||||
return "stable"
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
"""Machine-readable summary for logging/metrics."""
|
||||
return {
|
||||
"context_length": self.context_length,
|
||||
"threshold_tokens": self.threshold_tokens,
|
||||
"current_tokens": self.current_tokens,
|
||||
"peak_tokens": self.peak_tokens,
|
||||
"progress_pct": round(self.progress * 100, 1),
|
||||
"warning_level": self.warning_level,
|
||||
"turn_count": self.turn_count,
|
||||
"trend": self.trend(),
|
||||
"tokens_remaining": self.tokens_remaining(),
|
||||
}
|
||||
|
||||
def format_status(self) -> str:
|
||||
"""Human-readable status line for CLI display."""
|
||||
pct = int(self.progress * 100)
|
||||
bar_len = 20
|
||||
filled = int(bar_len * min(self.progress, 1.0))
|
||||
bar = "█" * filled + "░" * (bar_len - filled)
|
||||
level = self.warning_level or "ok"
|
||||
return f"Context: [{bar}] {pct}% ({self.current_tokens:,}/{self.threshold_tokens:,}) {level}"
|
||||
52
run_agent.py
52
run_agent.py
@@ -92,7 +92,6 @@ from agent.model_metadata import (
|
||||
query_ollama_num_ctx,
|
||||
)
|
||||
from agent.context_compressor import ContextCompressor
|
||||
from agent.context_budget import ContextBudgetTracker
|
||||
from agent.subdirectory_hints import SubdirectoryHintTracker
|
||||
from agent.prompt_caching import apply_anthropic_cache_control
|
||||
from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt, build_environment_hints, load_soul_md, TOOL_USE_ENFORCEMENT_GUIDANCE, TOOL_USE_ENFORCEMENT_MODELS, DEVELOPER_ROLE_MODELS, GOOGLE_MODEL_OPERATIONAL_GUIDANCE, OPENAI_MODEL_EXECUTION_GUIDANCE
|
||||
@@ -1401,12 +1400,6 @@ class AIAgent:
|
||||
)
|
||||
self.compression_enabled = compression_enabled
|
||||
|
||||
# Context budget tracker — poka-yoke for context window overflow
|
||||
self._context_budget = ContextBudgetTracker(
|
||||
context_length=getattr(self.context_compressor, "context_length", 0),
|
||||
threshold_percent=compression_threshold,
|
||||
)
|
||||
|
||||
# Reject models whose context window is below the minimum required
|
||||
# for reliable tool-calling workflows (64K tokens).
|
||||
from agent.model_metadata import MINIMUM_CONTEXT_LENGTH
|
||||
@@ -7617,34 +7610,6 @@ class AIAgent:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def check_context_budget(self, additional_tokens: int = 0) -> Optional[str]:
|
||||
"""Pre-flight check: verify there's room in the context budget.
|
||||
|
||||
Use before loading large files, skills, or other content to prevent
|
||||
overflow. Returns a warning message if the addition would exceed
|
||||
the safety margin, or None if there's room.
|
||||
|
||||
Args:
|
||||
additional_tokens: Estimated tokens for the content about to be loaded.
|
||||
|
||||
Returns:
|
||||
Warning string if budget would be exceeded, None if safe.
|
||||
"""
|
||||
budget = getattr(self, "_context_budget", None)
|
||||
if not budget or budget.context_length == 0:
|
||||
return None
|
||||
if not budget.can_fit(additional_tokens):
|
||||
return (
|
||||
f"[CONTEXT BUDGET: {additional_tokens:,} tokens would exceed safety margin. "
|
||||
f"Current: {budget.current_tokens:,}, remaining: {budget.tokens_remaining():,}. "
|
||||
f"Consider summarizing or checkpointing first.]"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _emit_context_pressure(self, compaction_progress: float, compressor) -> None:
|
||||
"""Notify the user that context is approaching the compaction threshold.
|
||||
|
||||
@@ -10250,7 +10215,7 @@ class AIAgent:
|
||||
# compaction fires, not the raw context window.
|
||||
# Does not inject into messages — just prints to CLI output
|
||||
# and fires status_callback for gateway platforms.
|
||||
# Tiered: 70% (yellow), 85% (orange), 95% (red/critical).
|
||||
# Tiered: 85% (orange) and 95% (red/critical).
|
||||
if _compressor.threshold_tokens > 0:
|
||||
_compaction_progress = _real_tokens / _compressor.threshold_tokens
|
||||
# Determine the warning tier for this progress level
|
||||
@@ -10259,8 +10224,6 @@ class AIAgent:
|
||||
_warn_tier = 0.95
|
||||
elif _compaction_progress >= 0.85:
|
||||
_warn_tier = 0.85
|
||||
elif _compaction_progress >= 0.70:
|
||||
_warn_tier = 0.70
|
||||
if _warn_tier > self._context_pressure_warned_at:
|
||||
# Class-level dedup: check if this session was already
|
||||
# warned at this tier within the cooldown window.
|
||||
@@ -10278,19 +10241,6 @@ class AIAgent:
|
||||
if v[1] > _cutoff
|
||||
}
|
||||
|
||||
# ── Auto-checkpoint at 85%+ ────────────────
|
||||
# Save session state so it can be resumed after
|
||||
# context compaction or window reset. Uses the
|
||||
# existing session log mechanism — writes current
|
||||
# messages to the session DB immediately rather
|
||||
# than waiting for the next normal save point.
|
||||
if _warn_tier >= 0.85 and self._session_db and self.session_id:
|
||||
try:
|
||||
self._save_session_log(messages)
|
||||
self._safe_print(" 💾 context checkpoint saved")
|
||||
except Exception:
|
||||
pass # Non-critical — don't block the loop
|
||||
|
||||
if self.compression_enabled and _compressor.should_compress(_real_tokens):
|
||||
self._safe_print(" ⟳ compacting context…")
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
|
||||
132
tests/test_agent_card.py
Normal file
132
tests/test_agent_card.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Tests for A2A agent card — Issue #819."""
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from agent.agent_card import (
|
||||
AgentSkill, AgentCapabilities, AgentCard,
|
||||
validate_agent_card, build_agent_card, get_agent_card_json,
|
||||
_load_skills_from_directory
|
||||
)
|
||||
|
||||
|
||||
class TestAgentSkill:
|
||||
def test_creation(self):
|
||||
skill = AgentSkill(id="code", name="Code", tags=["python"])
|
||||
assert skill.id == "code"
|
||||
assert "python" in skill.tags
|
||||
|
||||
|
||||
class TestAgentCapabilities:
|
||||
def test_defaults(self):
|
||||
caps = AgentCapabilities()
|
||||
assert caps.streaming == True
|
||||
assert caps.push_notifications == False
|
||||
|
||||
|
||||
class TestAgentCard:
|
||||
def test_to_dict(self):
|
||||
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
|
||||
d = card.to_dict()
|
||||
assert d["name"] == "timmy"
|
||||
assert "defaultInputModes" in d
|
||||
|
||||
def test_to_json(self):
|
||||
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
|
||||
j = card.to_json()
|
||||
parsed = json.loads(j)
|
||||
assert parsed["name"] == "timmy"
|
||||
|
||||
|
||||
class TestValidation:
|
||||
def test_valid_card(self):
|
||||
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
|
||||
errors = validate_agent_card(card)
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_missing_name(self):
|
||||
card = AgentCard(name="", description="test", url="http://localhost:8642")
|
||||
errors = validate_agent_card(card)
|
||||
assert any("name" in e for e in errors)
|
||||
|
||||
def test_missing_url(self):
|
||||
card = AgentCard(name="timmy", description="test", url="")
|
||||
errors = validate_agent_card(card)
|
||||
assert any("url" in e for e in errors)
|
||||
|
||||
def test_invalid_input_mode(self):
|
||||
card = AgentCard(
|
||||
name="timmy", description="test", url="http://localhost:8642",
|
||||
default_input_modes=["invalid/mode"]
|
||||
)
|
||||
errors = validate_agent_card(card)
|
||||
assert any("invalid input mode" in e for e in errors)
|
||||
|
||||
def test_skill_missing_id(self):
|
||||
card = AgentCard(
|
||||
name="timmy", description="test", url="http://localhost:8642",
|
||||
skills=[AgentSkill(id="", name="test")]
|
||||
)
|
||||
errors = validate_agent_card(card)
|
||||
assert any("skill missing id" in e for e in errors)
|
||||
|
||||
|
||||
class TestBuildAgentCard:
|
||||
def test_builds_valid_card(self):
|
||||
card = build_agent_card()
|
||||
assert card.name
|
||||
assert card.url
|
||||
errors = validate_agent_card(card)
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_explicit_params_override(self):
|
||||
card = build_agent_card(name="custom", description="custom desc")
|
||||
assert card.name == "custom"
|
||||
assert card.description == "custom desc"
|
||||
|
||||
def test_extra_skills(self):
|
||||
extra = [AgentSkill(id="extra", name="Extra")]
|
||||
card = build_agent_card(extra_skills=extra)
|
||||
assert any(s.id == "extra" for s in card.skills)
|
||||
|
||||
|
||||
class TestGetAgentCardJson:
|
||||
def test_returns_valid_json(self):
|
||||
j = get_agent_card_json()
|
||||
parsed = json.loads(j)
|
||||
assert "name" in parsed
|
||||
|
||||
def test_graceful_fallback(self):
|
||||
# Even if something fails, should return valid JSON
|
||||
j = get_agent_card_json()
|
||||
assert j # Non-empty
|
||||
|
||||
|
||||
class TestLoadSkills:
|
||||
def test_empty_dir(self, tmp_path):
|
||||
skills = _load_skills_from_directory(tmp_path / "nonexistent")
|
||||
assert len(skills) == 0
|
||||
|
||||
def test_parses_skill_md(self, tmp_path):
|
||||
skill_dir = tmp_path / "test-skill"
|
||||
skill_dir.mkdir()
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
skill_md.write_text("""---
|
||||
name: Test Skill
|
||||
description: A test skill
|
||||
tags:
|
||||
- test
|
||||
- example
|
||||
---
|
||||
Content here
|
||||
""")
|
||||
skills = _load_skills_from_directory(tmp_path)
|
||||
assert len(skills) == 1
|
||||
assert skills[0].name == "Test Skill"
|
||||
assert "test" in skills[0].tags
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,187 +0,0 @@
|
||||
"""Tests for ContextBudgetTracker — poka-yoke context window safety."""
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.context_budget import (
|
||||
ContextBudgetTracker,
|
||||
THRESHOLD_CAUTION,
|
||||
THRESHOLD_ELEVATED,
|
||||
THRESHOLD_CRITICAL,
|
||||
PREFLIGHT_SAFETY_MARGIN,
|
||||
WARNING_COOLDOWN,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker():
|
||||
"""Standard tracker: 128K context, 50% compression threshold = 64K threshold."""
|
||||
return ContextBudgetTracker(context_length=128_000, threshold_percent=0.50)
|
||||
|
||||
|
||||
class TestThresholds:
|
||||
def test_threshold_tokens_computed(self, tracker):
|
||||
assert tracker.threshold_tokens == 64_000
|
||||
|
||||
def test_caution_at_70_percent(self, tracker):
|
||||
tracker.update(int(64_000 * 0.70))
|
||||
assert tracker.warning_level == "elevated"
|
||||
|
||||
def test_no_warning_below_caution(self, tracker):
|
||||
tracker.update(int(64_000 * 0.69))
|
||||
assert tracker.warning_level is None
|
||||
|
||||
def test_critical_at_85_percent(self, tracker):
|
||||
tracker.update(int(64_000 * 0.85))
|
||||
assert tracker.warning_level == "critical"
|
||||
|
||||
def test_emergency_at_95_percent(self, tracker):
|
||||
tracker.update(int(64_000 * 0.95))
|
||||
assert tracker.warning_level == "emergency"
|
||||
|
||||
|
||||
class TestWarningMessages:
|
||||
def test_elevated_message_contains_70(self, tracker):
|
||||
tracker.update(int(64_000 * 0.70))
|
||||
msg = tracker.warning_message
|
||||
assert msg is not None
|
||||
assert "CONTEXT WARNING" in msg
|
||||
|
||||
def test_critical_message(self, tracker):
|
||||
tracker.update(int(64_000 * 0.85))
|
||||
msg = tracker.warning_message
|
||||
assert "compaction approaching" in msg
|
||||
|
||||
def test_emergency_message(self, tracker):
|
||||
tracker.update(int(64_000 * 0.95))
|
||||
msg = tracker.warning_message
|
||||
assert "CONTEXT CRITICAL" in msg
|
||||
|
||||
def test_no_message_below_caution(self, tracker):
|
||||
tracker.update(10_000)
|
||||
assert tracker.warning_message is None
|
||||
|
||||
|
||||
class TestWarningDedup:
|
||||
def test_repeated_update_same_tier_suppressed(self, tracker):
|
||||
"""Same tier within cooldown should not re-emit."""
|
||||
tracker.update(int(64_000 * 0.71))
|
||||
msg1 = tracker.update(int(64_000 * 0.72))
|
||||
assert msg1 is None # suppressed by cooldown
|
||||
|
||||
def test_higher_tier_breaks_through_cooldown(self, tracker):
|
||||
"""Crossing to a higher tier should always emit."""
|
||||
tracker.update(int(64_000 * 0.71))
|
||||
msg = tracker.update(int(64_000 * 0.86))
|
||||
assert msg is not None
|
||||
assert "compaction approaching" in msg.lower()
|
||||
|
||||
def test_cooldown_expires_allows_reemit(self, tracker):
|
||||
tracker.update(int(64_000 * 0.71))
|
||||
# Fast-forward cooldown
|
||||
tracker._last_warned_time = time.time() - WARNING_COOLDOWN - 1
|
||||
msg = tracker.update(int(64_000 * 0.72))
|
||||
assert msg is not None
|
||||
|
||||
|
||||
class TestCheckpoint:
|
||||
def test_should_checkpoint_at_85(self, tracker):
|
||||
tracker.update(int(64_000 * 0.86))
|
||||
assert tracker.should_checkpoint() is True
|
||||
|
||||
def test_no_checkpoint_below_85(self, tracker):
|
||||
tracker.update(int(64_000 * 0.84))
|
||||
assert tracker.should_checkpoint() is False
|
||||
|
||||
def test_checkpoint_cooldown(self, tracker):
|
||||
tracker.update(int(64_000 * 0.86))
|
||||
tracker.should_checkpoint() # saves
|
||||
assert tracker.should_checkpoint() is False # cooldown
|
||||
|
||||
|
||||
class TestPreflight:
|
||||
def test_can_fit_small_addition(self, tracker):
|
||||
tracker.update(30_000)
|
||||
assert tracker.can_fit(5_000) is True
|
||||
|
||||
def test_cannot_fit_overflow(self, tracker):
|
||||
tracker.update(int(64_000 * 0.88))
|
||||
assert tracker.can_fit(10_000) is False
|
||||
|
||||
def test_estimate_file_tokens(self, tracker):
|
||||
assert tracker.estimate_file_tokens(4_000) == 1_000
|
||||
assert tracker.estimate_file_tokens(100) >= 1 # minimum 1
|
||||
|
||||
def test_tokens_remaining(self, tracker):
|
||||
tracker.update(30_000)
|
||||
remaining = tracker.tokens_remaining()
|
||||
safe_limit = int(64_000 * PREFLIGHT_SAFETY_MARGIN)
|
||||
assert remaining == safe_limit - 30_000
|
||||
|
||||
|
||||
class TestTrend:
|
||||
def test_growing_trend(self, tracker):
|
||||
for i in range(10):
|
||||
tracker.update(10_000 + i * 5_000)
|
||||
assert tracker.trend() == "growing"
|
||||
|
||||
def test_shrinking_trend(self, tracker):
|
||||
for i in range(10):
|
||||
tracker.update(60_000 - i * 5_000)
|
||||
assert tracker.trend() == "shrinking"
|
||||
|
||||
def test_stable_trend(self, tracker):
|
||||
for _ in range(10):
|
||||
tracker.update(30_000)
|
||||
assert tracker.trend() == "stable"
|
||||
|
||||
|
||||
class TestSummary:
|
||||
def test_summary_keys(self, tracker):
|
||||
tracker.update(40_000)
|
||||
s = tracker.summary()
|
||||
assert "context_length" in s
|
||||
assert "current_tokens" in s
|
||||
assert "warning_level" in s
|
||||
assert "trend" in s
|
||||
assert s["current_tokens"] == 40_000
|
||||
|
||||
def test_format_status(self, tracker):
|
||||
tracker.update(30_000)
|
||||
status = tracker.format_status()
|
||||
assert "Context:" in status
|
||||
assert "[" in status # progress bar
|
||||
assert "%" in status
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_zero_context_length(self):
|
||||
t = ContextBudgetTracker(context_length=0)
|
||||
assert t.threshold_tokens == 0
|
||||
assert t.progress == 0.0
|
||||
assert t.warning_level is None
|
||||
|
||||
def test_different_threshold_percent(self):
|
||||
t = ContextBudgetTracker(context_length=100_000, threshold_percent=0.80)
|
||||
assert t.threshold_tokens == 80_000
|
||||
t.update(int(80_000 * 0.70))
|
||||
assert t.warning_level == "elevated"
|
||||
|
||||
def test_over_threshold_progress(self, tracker):
|
||||
"""Progress can exceed 1.0 (past compression threshold)."""
|
||||
tracker.update(70_000)
|
||||
assert tracker.progress > 1.0
|
||||
assert tracker.warning_level == "emergency"
|
||||
|
||||
def test_peak_tracking(self, tracker):
|
||||
tracker.update(10_000)
|
||||
tracker.update(50_000)
|
||||
tracker.update(30_000)
|
||||
assert tracker.peak_tokens == 50_000
|
||||
|
||||
def test_turn_count(self, tracker):
|
||||
assert tracker.turn_count == 0
|
||||
tracker.update(10_000)
|
||||
tracker.update(20_000)
|
||||
assert tracker.turn_count == 2
|
||||
Reference in New Issue
Block a user