Compare commits

..

1 Commits

Author SHA1 Message Date
Timmy Time
1adbf7ed1b fix: implement A2A agent card for fleet discovery (closes #819)
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 44s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 43s
Tests / e2e (pull_request) Successful in 2m33s
Tests / test (pull_request) Failing after 44m31s
2026-04-15 21:44:47 -04:00
5 changed files with 357 additions and 473 deletions

224
agent/agent_card.py Normal file
View File

@@ -0,0 +1,224 @@
"""A2A Agent Card — publish capabilities for fleet discovery.
Each fleet agent publishes an A2A-compliant agent card describing its capabilities.
Standard discovery endpoint: /.well-known/agent-card.json
Issue #819: feat: A2A agent card — publish capabilities for fleet discovery
"""
import json
import os
import socket
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional
@dataclass
class AgentSkill:
"""A single skill the agent can perform."""
id: str
name: str
description: str = ""
tags: List[str] = field(default_factory=list)
examples: List[str] = field(default_factory=list)
input_modes: List[str] = field(default_factory=lambda: ["text/plain"])
output_modes: List[str] = field(default_factory=lambda: ["text/plain"])
@dataclass
class AgentCapabilities:
"""What the agent can do."""
streaming: bool = True
push_notifications: bool = False
state_transition_history: bool = True
@dataclass
class AgentCard:
"""A2A-compliant agent card."""
name: str
description: str
url: str
version: str = "1.0.0"
capabilities: AgentCapabilities = field(default_factory=AgentCapabilities)
skills: List[AgentSkill] = field(default_factory=list)
default_input_modes: List[str] = field(default_factory=lambda: ["text/plain", "application/json"])
default_output_modes: List[str] = field(default_factory=lambda: ["text/plain", "application/json"])
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""Convert to JSON-serializable dict."""
d = asdict(self)
# Rename for A2A spec compliance
d["defaultInputModes"] = d.pop("default_input_modes")
d["defaultOutputModes"] = d.pop("default_output_modes")
return d
def to_json(self) -> str:
"""Serialize to JSON string."""
return json.dumps(self.to_dict(), indent=2)
def _load_skills_from_directory(skills_dir: Path) -> List[AgentSkill]:
"""Scan ~/.hermes/skills/ for SKILL.md frontmatter."""
skills = []
if not skills_dir.exists():
return skills
for skill_dir in skills_dir.iterdir():
if not skill_dir.is_dir():
continue
skill_md = skill_dir / "SKILL.md"
if not skill_md.exists():
continue
try:
content = skill_md.read_text(encoding="utf-8")
# Parse YAML frontmatter
if content.startswith("---"):
parts = content.split("---", 2)
if len(parts) >= 3:
import yaml
try:
metadata = yaml.safe_load(parts[1]) or {}
except Exception:
metadata = {}
name = metadata.get("name", skill_dir.name)
desc = metadata.get("description", "")
tags = metadata.get("tags", [])
skills.append(AgentSkill(
id=skill_dir.name,
name=name,
description=desc[:200] if desc else "",
tags=tags if isinstance(tags, list) else [],
))
except Exception:
continue
return skills
def validate_agent_card(card: AgentCard) -> List[str]:
"""Validate agent card against A2A schema requirements.
Returns list of validation errors (empty if valid).
"""
errors = []
if not card.name:
errors.append("name is required")
if not card.url:
errors.append("url is required")
# Validate MIME types
valid_modes = {"text/plain", "application/json", "image/png", "audio/wav"}
for mode in card.default_input_modes:
if mode not in valid_modes:
errors.append(f"invalid input mode: {mode}")
for mode in card.default_output_modes:
if mode not in valid_modes:
errors.append(f"invalid output mode: {mode}")
# Validate skills
for skill in card.skills:
if not skill.id:
errors.append(f"skill missing id: {skill.name}")
return errors
def build_agent_card(
name: Optional[str] = None,
description: Optional[str] = None,
url: Optional[str] = None,
version: Optional[str] = None,
skills: Optional[List[AgentSkill]] = None,
extra_skills: Optional[List[AgentSkill]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> AgentCard:
"""Build an A2A agent card from config and environment.
Priority: explicit params > env vars > config.yaml > defaults
"""
# Load config
config_model = ""
config_provider = ""
try:
from hermes_cli.config import load_config
cfg = load_config()
model_cfg = cfg.get("model", {})
if isinstance(model_cfg, dict):
config_model = model_cfg.get("default", "")
config_provider = model_cfg.get("provider", "")
elif isinstance(model_cfg, str):
config_model = model_cfg
except Exception:
pass
# Resolve values with priority
agent_name = name or os.environ.get("HERMES_AGENT_NAME", "") or "hermes"
agent_desc = description or os.environ.get("HERMES_AGENT_DESCRIPTION", "") or "Sovereign AI agent"
agent_url = url or os.environ.get("HERMES_AGENT_URL", "") or f"http://localhost:{os.environ.get('HERMES_API_PORT', '8642')}"
agent_version = version or os.environ.get("HERMES_AGENT_VERSION", "") or "1.0.0"
# Load skills
if skills is not None:
agent_skills = skills
else:
from hermes_constants import get_hermes_home
skills_dir = get_hermes_home() / "skills"
agent_skills = _load_skills_from_directory(skills_dir)
# Add extra skills
if extra_skills:
existing_ids = {s.id for s in agent_skills}
for skill in extra_skills:
if skill.id not in existing_ids:
agent_skills.append(skill)
# Build metadata
card_metadata = {
"model": config_model or os.environ.get("HERMES_MODEL", ""),
"provider": config_provider or os.environ.get("HERMES_PROVIDER", ""),
"hostname": socket.gethostname(),
}
if metadata:
card_metadata.update(metadata)
# Build capabilities
capabilities = AgentCapabilities(
streaming=True,
push_notifications=False,
state_transition_history=True,
)
return AgentCard(
name=agent_name,
description=agent_desc,
url=agent_url,
version=agent_version,
capabilities=capabilities,
skills=agent_skills,
metadata=card_metadata,
)
def get_agent_card_json() -> str:
"""Get agent card as JSON string (for HTTP endpoint)."""
try:
card = build_agent_card()
return card.to_json()
except Exception as e:
# Graceful fallback — return minimal card so discovery doesn't break
fallback = AgentCard(
name="hermes",
description="Sovereign AI agent",
url=f"http://localhost:{os.environ.get('HERMES_API_PORT', '8642')}",
)
return fallback.to_json()

View File

@@ -1,235 +0,0 @@
"""Context Budget Tracker — Proactive token counting with graduated warnings.
Poka-yoke (mistake-proofing) for context window overflow. Tracks approximate
token usage per turn and emits warnings at 70%, 85%, and 95% thresholds
relative to the compression threshold (not the raw context window).
Usage:
tracker = ContextBudgetTracker(context_length=128_000, threshold_percent=0.50)
tracker.update(estimated_tokens=45_000)
level = tracker.warning_level # "elevated" | "critical" | "emergency" | None
if level:
print(tracker.warning_message)
Integration points in run_agent.py:
1. After `estimate_messages_tokens_rough(messages)` in the agent loop,
call `tracker.update(_real_tokens)`.
2. Check `tracker.should_checkpoint()` to auto-save session state.
3. Check `tracker.should_gate(content_tokens)` before loading large
files or skills.
"""
import json
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# Warning thresholds relative to compression threshold
THRESHOLD_CAUTION = 0.70 # 70% of threshold — yellow
THRESHOLD_ELEVATED = 0.85 # 85% of threshold — orange
THRESHOLD_CRITICAL = 0.95 # 95% of threshold — red
# Cooldown between repeated warnings at the same tier (seconds)
WARNING_COOLDOWN = 300
# Pre-flight safety margin: refuse loads that would push past this
PREFLIGHT_SAFETY_MARGIN = 0.90
@dataclass
class ContextBudgetTracker:
"""Tracks context window usage and emits graduated warnings.
All thresholds are relative to the compression threshold, which is
itself a fraction of the total context window. For example, with a
128K context window and 50% compression threshold:
- threshold_tokens = 64,000
- caution at 70% = 44,800 tokens
- elevated at 85% = 54,400 tokens
- critical at 95% = 60,800 tokens
"""
context_length: int
threshold_percent: float = 0.50
# Current state (updated by .update())
current_tokens: int = 0
peak_tokens: int = 0
turn_count: int = 0
# Warning state
_last_warned_tier: float = 0.0
_last_warned_time: float = 0.0
_checkpoint_saved_at: float = 0.0
# History for trend analysis
_token_history: list = field(default_factory=list)
_max_history: int = 50
@property
def threshold_tokens(self) -> int:
"""Compression threshold in absolute tokens."""
return int(self.context_length * self.threshold_percent)
@property
def progress(self) -> float:
"""Current usage as fraction of compression threshold (0.0 to 1.0+)."""
if self.threshold_tokens == 0:
return 0.0
return self.current_tokens / self.threshold_tokens
@property
def warning_level(self) -> Optional[str]:
"""Current warning level or None if below caution threshold."""
p = self.progress
if p >= THRESHOLD_CRITICAL:
return "emergency"
elif p >= THRESHOLD_ELEVATED:
return "critical"
elif p >= THRESHOLD_CAUTION:
return "elevated"
return None
@property
def warning_message(self) -> Optional[str]:
"""Human-readable warning message for current level, or None."""
level = self.warning_level
if level is None:
return None
pct = int(self.progress * 100)
used = f"{self.current_tokens:,}"
limit = f"{self.threshold_tokens:,}"
msgs = {
"elevated": f"[CONTEXT WARNING: {pct}% used ({used}/{limit} tokens) — consider summarizing or checkpointing]",
"critical": f"[CONTEXT WARNING: {pct}% used ({used}/{limit} tokens) — checkpoint recommended, compaction approaching]",
"emergency": f"[CONTEXT CRITICAL: {pct}% used ({used}/{limit} tokens) — compaction imminent, auto-summarizing older content]",
}
return msgs.get(level)
@property
def should_emit_warning(self) -> bool:
"""Whether a new warning should be emitted (dedup + cooldown)."""
level = self.warning_level
if level is None:
return False
tier = {"elevated": THRESHOLD_CAUTION, "critical": THRESHOLD_ELEVATED, "emergency": THRESHOLD_CRITICAL}
tier_val = tier.get(level, 0)
now = time.time()
if tier_val <= self._last_warned_tier:
# Same or lower tier — check cooldown
if (now - self._last_warned_time) < WARNING_COOLDOWN:
return False
# New higher tier or cooldown expired
return True
def mark_warned(self):
"""Call after emitting a warning to update dedup state."""
level = self.warning_level
tier = {"elevated": THRESHOLD_CAUTION, "critical": THRESHOLD_ELEVATED, "emergency": THRESHOLD_CRITICAL}
self._last_warned_tier = tier.get(level, 0)
self._last_warned_time = time.time()
def update(self, estimated_tokens: int) -> Optional[str]:
"""Update current token count and return warning message if warranted.
Args:
estimated_tokens: Rough token count of the current messages.
Returns:
Warning message string if a warning should be shown, else None.
"""
self.current_tokens = estimated_tokens
self.turn_count += 1
if estimated_tokens > self.peak_tokens:
self.peak_tokens = estimated_tokens
# Record history
self._token_history.append((self.turn_count, estimated_tokens, time.time()))
if len(self._token_history) > self._max_history:
self._token_history = self._token_history[-self._max_history:]
if self.should_emit_warning:
self.mark_warned()
return self.warning_message
return None
def should_checkpoint(self) -> bool:
"""Whether session state should be auto-saved (85% threshold).
Returns True once per crossing of the elevated threshold, with a
cooldown to avoid repeated saves.
"""
if self.progress < THRESHOLD_ELEVATED:
return False
now = time.time()
if (now - self._checkpoint_saved_at) < WARNING_COOLDOWN:
return False
self._checkpoint_saved_at = now
return True
def can_fit(self, additional_tokens: int) -> bool:
"""Pre-flight check: would adding this many tokens exceed the safety margin?
Use before loading large files or skills to prevent overflow.
"""
projected = self.current_tokens + additional_tokens
return projected < int(self.threshold_tokens * PREFLIGHT_SAFETY_MARGIN)
def estimate_file_tokens(self, file_size_bytes: int) -> int:
"""Rough token estimate for a file of given size (~4 chars/token)."""
return max(1, file_size_bytes // 4)
def tokens_remaining(self) -> int:
"""Approximate tokens available before hitting the safety margin."""
safe_limit = int(self.threshold_tokens * PREFLIGHT_SAFETY_MARGIN)
return max(0, safe_limit - self.current_tokens)
def trend(self, window: int = 10) -> str:
"""Token growth trend over the last N turns: 'growing' | 'stable' | 'shrinking'."""
if len(self._token_history) < 2:
return "stable"
recent = self._token_history[-window:]
if len(recent) < 2:
return "stable"
first = recent[0][1]
last = recent[-1][1]
delta = last - first
threshold = self.threshold_tokens * 0.05 # 5% of threshold
if delta > threshold:
return "growing"
elif delta < -threshold:
return "shrinking"
return "stable"
def summary(self) -> Dict[str, Any]:
"""Machine-readable summary for logging/metrics."""
return {
"context_length": self.context_length,
"threshold_tokens": self.threshold_tokens,
"current_tokens": self.current_tokens,
"peak_tokens": self.peak_tokens,
"progress_pct": round(self.progress * 100, 1),
"warning_level": self.warning_level,
"turn_count": self.turn_count,
"trend": self.trend(),
"tokens_remaining": self.tokens_remaining(),
}
def format_status(self) -> str:
"""Human-readable status line for CLI display."""
pct = int(self.progress * 100)
bar_len = 20
filled = int(bar_len * min(self.progress, 1.0))
bar = "" * filled + "" * (bar_len - filled)
level = self.warning_level or "ok"
return f"Context: [{bar}] {pct}% ({self.current_tokens:,}/{self.threshold_tokens:,}) {level}"

View File

@@ -92,7 +92,6 @@ from agent.model_metadata import (
query_ollama_num_ctx,
)
from agent.context_compressor import ContextCompressor
from agent.context_budget import ContextBudgetTracker
from agent.subdirectory_hints import SubdirectoryHintTracker
from agent.prompt_caching import apply_anthropic_cache_control
from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt, build_environment_hints, load_soul_md, TOOL_USE_ENFORCEMENT_GUIDANCE, TOOL_USE_ENFORCEMENT_MODELS, DEVELOPER_ROLE_MODELS, GOOGLE_MODEL_OPERATIONAL_GUIDANCE, OPENAI_MODEL_EXECUTION_GUIDANCE
@@ -1401,12 +1400,6 @@ class AIAgent:
)
self.compression_enabled = compression_enabled
# Context budget tracker — poka-yoke for context window overflow
self._context_budget = ContextBudgetTracker(
context_length=getattr(self.context_compressor, "context_length", 0),
threshold_percent=compression_threshold,
)
# Reject models whose context window is below the minimum required
# for reliable tool-calling workflows (64K tokens).
from agent.model_metadata import MINIMUM_CONTEXT_LENGTH
@@ -7617,34 +7610,6 @@ class AIAgent:
def check_context_budget(self, additional_tokens: int = 0) -> Optional[str]:
"""Pre-flight check: verify there's room in the context budget.
Use before loading large files, skills, or other content to prevent
overflow. Returns a warning message if the addition would exceed
the safety margin, or None if there's room.
Args:
additional_tokens: Estimated tokens for the content about to be loaded.
Returns:
Warning string if budget would be exceeded, None if safe.
"""
budget = getattr(self, "_context_budget", None)
if not budget or budget.context_length == 0:
return None
if not budget.can_fit(additional_tokens):
return (
f"[CONTEXT BUDGET: {additional_tokens:,} tokens would exceed safety margin. "
f"Current: {budget.current_tokens:,}, remaining: {budget.tokens_remaining():,}. "
f"Consider summarizing or checkpointing first.]"
)
return None
def _emit_context_pressure(self, compaction_progress: float, compressor) -> None:
"""Notify the user that context is approaching the compaction threshold.
@@ -10250,7 +10215,7 @@ class AIAgent:
# compaction fires, not the raw context window.
# Does not inject into messages — just prints to CLI output
# and fires status_callback for gateway platforms.
# Tiered: 70% (yellow), 85% (orange), 95% (red/critical).
# Tiered: 85% (orange) and 95% (red/critical).
if _compressor.threshold_tokens > 0:
_compaction_progress = _real_tokens / _compressor.threshold_tokens
# Determine the warning tier for this progress level
@@ -10259,8 +10224,6 @@ class AIAgent:
_warn_tier = 0.95
elif _compaction_progress >= 0.85:
_warn_tier = 0.85
elif _compaction_progress >= 0.70:
_warn_tier = 0.70
if _warn_tier > self._context_pressure_warned_at:
# Class-level dedup: check if this session was already
# warned at this tier within the cooldown window.
@@ -10278,19 +10241,6 @@ class AIAgent:
if v[1] > _cutoff
}
# ── Auto-checkpoint at 85%+ ────────────────
# Save session state so it can be resumed after
# context compaction or window reset. Uses the
# existing session log mechanism — writes current
# messages to the session DB immediately rather
# than waiting for the next normal save point.
if _warn_tier >= 0.85 and self._session_db and self.session_id:
try:
self._save_session_log(messages)
self._safe_print(" 💾 context checkpoint saved")
except Exception:
pass # Non-critical — don't block the loop
if self.compression_enabled and _compressor.should_compress(_real_tokens):
self._safe_print(" ⟳ compacting context…")
messages, active_system_prompt = self._compress_context(

132
tests/test_agent_card.py Normal file
View File

@@ -0,0 +1,132 @@
"""Tests for A2A agent card — Issue #819."""
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from agent.agent_card import (
AgentSkill, AgentCapabilities, AgentCard,
validate_agent_card, build_agent_card, get_agent_card_json,
_load_skills_from_directory
)
class TestAgentSkill:
def test_creation(self):
skill = AgentSkill(id="code", name="Code", tags=["python"])
assert skill.id == "code"
assert "python" in skill.tags
class TestAgentCapabilities:
def test_defaults(self):
caps = AgentCapabilities()
assert caps.streaming == True
assert caps.push_notifications == False
class TestAgentCard:
def test_to_dict(self):
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
d = card.to_dict()
assert d["name"] == "timmy"
assert "defaultInputModes" in d
def test_to_json(self):
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
j = card.to_json()
parsed = json.loads(j)
assert parsed["name"] == "timmy"
class TestValidation:
def test_valid_card(self):
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
errors = validate_agent_card(card)
assert len(errors) == 0
def test_missing_name(self):
card = AgentCard(name="", description="test", url="http://localhost:8642")
errors = validate_agent_card(card)
assert any("name" in e for e in errors)
def test_missing_url(self):
card = AgentCard(name="timmy", description="test", url="")
errors = validate_agent_card(card)
assert any("url" in e for e in errors)
def test_invalid_input_mode(self):
card = AgentCard(
name="timmy", description="test", url="http://localhost:8642",
default_input_modes=["invalid/mode"]
)
errors = validate_agent_card(card)
assert any("invalid input mode" in e for e in errors)
def test_skill_missing_id(self):
card = AgentCard(
name="timmy", description="test", url="http://localhost:8642",
skills=[AgentSkill(id="", name="test")]
)
errors = validate_agent_card(card)
assert any("skill missing id" in e for e in errors)
class TestBuildAgentCard:
def test_builds_valid_card(self):
card = build_agent_card()
assert card.name
assert card.url
errors = validate_agent_card(card)
assert len(errors) == 0
def test_explicit_params_override(self):
card = build_agent_card(name="custom", description="custom desc")
assert card.name == "custom"
assert card.description == "custom desc"
def test_extra_skills(self):
extra = [AgentSkill(id="extra", name="Extra")]
card = build_agent_card(extra_skills=extra)
assert any(s.id == "extra" for s in card.skills)
class TestGetAgentCardJson:
def test_returns_valid_json(self):
j = get_agent_card_json()
parsed = json.loads(j)
assert "name" in parsed
def test_graceful_fallback(self):
# Even if something fails, should return valid JSON
j = get_agent_card_json()
assert j # Non-empty
class TestLoadSkills:
def test_empty_dir(self, tmp_path):
skills = _load_skills_from_directory(tmp_path / "nonexistent")
assert len(skills) == 0
def test_parses_skill_md(self, tmp_path):
skill_dir = tmp_path / "test-skill"
skill_dir.mkdir()
skill_md = skill_dir / "SKILL.md"
skill_md.write_text("""---
name: Test Skill
description: A test skill
tags:
- test
- example
---
Content here
""")
skills = _load_skills_from_directory(tmp_path)
assert len(skills) == 1
assert skills[0].name == "Test Skill"
assert "test" in skills[0].tags
if __name__ == "__main__":
import pytest
pytest.main([__file__, "-v"])

View File

@@ -1,187 +0,0 @@
"""Tests for ContextBudgetTracker — poka-yoke context window safety."""
import time
from unittest.mock import patch
import pytest
from agent.context_budget import (
ContextBudgetTracker,
THRESHOLD_CAUTION,
THRESHOLD_ELEVATED,
THRESHOLD_CRITICAL,
PREFLIGHT_SAFETY_MARGIN,
WARNING_COOLDOWN,
)
@pytest.fixture
def tracker():
"""Standard tracker: 128K context, 50% compression threshold = 64K threshold."""
return ContextBudgetTracker(context_length=128_000, threshold_percent=0.50)
class TestThresholds:
def test_threshold_tokens_computed(self, tracker):
assert tracker.threshold_tokens == 64_000
def test_caution_at_70_percent(self, tracker):
tracker.update(int(64_000 * 0.70))
assert tracker.warning_level == "elevated"
def test_no_warning_below_caution(self, tracker):
tracker.update(int(64_000 * 0.69))
assert tracker.warning_level is None
def test_critical_at_85_percent(self, tracker):
tracker.update(int(64_000 * 0.85))
assert tracker.warning_level == "critical"
def test_emergency_at_95_percent(self, tracker):
tracker.update(int(64_000 * 0.95))
assert tracker.warning_level == "emergency"
class TestWarningMessages:
def test_elevated_message_contains_70(self, tracker):
tracker.update(int(64_000 * 0.70))
msg = tracker.warning_message
assert msg is not None
assert "CONTEXT WARNING" in msg
def test_critical_message(self, tracker):
tracker.update(int(64_000 * 0.85))
msg = tracker.warning_message
assert "compaction approaching" in msg
def test_emergency_message(self, tracker):
tracker.update(int(64_000 * 0.95))
msg = tracker.warning_message
assert "CONTEXT CRITICAL" in msg
def test_no_message_below_caution(self, tracker):
tracker.update(10_000)
assert tracker.warning_message is None
class TestWarningDedup:
def test_repeated_update_same_tier_suppressed(self, tracker):
"""Same tier within cooldown should not re-emit."""
tracker.update(int(64_000 * 0.71))
msg1 = tracker.update(int(64_000 * 0.72))
assert msg1 is None # suppressed by cooldown
def test_higher_tier_breaks_through_cooldown(self, tracker):
"""Crossing to a higher tier should always emit."""
tracker.update(int(64_000 * 0.71))
msg = tracker.update(int(64_000 * 0.86))
assert msg is not None
assert "compaction approaching" in msg.lower()
def test_cooldown_expires_allows_reemit(self, tracker):
tracker.update(int(64_000 * 0.71))
# Fast-forward cooldown
tracker._last_warned_time = time.time() - WARNING_COOLDOWN - 1
msg = tracker.update(int(64_000 * 0.72))
assert msg is not None
class TestCheckpoint:
def test_should_checkpoint_at_85(self, tracker):
tracker.update(int(64_000 * 0.86))
assert tracker.should_checkpoint() is True
def test_no_checkpoint_below_85(self, tracker):
tracker.update(int(64_000 * 0.84))
assert tracker.should_checkpoint() is False
def test_checkpoint_cooldown(self, tracker):
tracker.update(int(64_000 * 0.86))
tracker.should_checkpoint() # saves
assert tracker.should_checkpoint() is False # cooldown
class TestPreflight:
def test_can_fit_small_addition(self, tracker):
tracker.update(30_000)
assert tracker.can_fit(5_000) is True
def test_cannot_fit_overflow(self, tracker):
tracker.update(int(64_000 * 0.88))
assert tracker.can_fit(10_000) is False
def test_estimate_file_tokens(self, tracker):
assert tracker.estimate_file_tokens(4_000) == 1_000
assert tracker.estimate_file_tokens(100) >= 1 # minimum 1
def test_tokens_remaining(self, tracker):
tracker.update(30_000)
remaining = tracker.tokens_remaining()
safe_limit = int(64_000 * PREFLIGHT_SAFETY_MARGIN)
assert remaining == safe_limit - 30_000
class TestTrend:
def test_growing_trend(self, tracker):
for i in range(10):
tracker.update(10_000 + i * 5_000)
assert tracker.trend() == "growing"
def test_shrinking_trend(self, tracker):
for i in range(10):
tracker.update(60_000 - i * 5_000)
assert tracker.trend() == "shrinking"
def test_stable_trend(self, tracker):
for _ in range(10):
tracker.update(30_000)
assert tracker.trend() == "stable"
class TestSummary:
def test_summary_keys(self, tracker):
tracker.update(40_000)
s = tracker.summary()
assert "context_length" in s
assert "current_tokens" in s
assert "warning_level" in s
assert "trend" in s
assert s["current_tokens"] == 40_000
def test_format_status(self, tracker):
tracker.update(30_000)
status = tracker.format_status()
assert "Context:" in status
assert "[" in status # progress bar
assert "%" in status
class TestEdgeCases:
def test_zero_context_length(self):
t = ContextBudgetTracker(context_length=0)
assert t.threshold_tokens == 0
assert t.progress == 0.0
assert t.warning_level is None
def test_different_threshold_percent(self):
t = ContextBudgetTracker(context_length=100_000, threshold_percent=0.80)
assert t.threshold_tokens == 80_000
t.update(int(80_000 * 0.70))
assert t.warning_level == "elevated"
def test_over_threshold_progress(self, tracker):
"""Progress can exceed 1.0 (past compression threshold)."""
tracker.update(70_000)
assert tracker.progress > 1.0
assert tracker.warning_level == "emergency"
def test_peak_tracking(self, tracker):
tracker.update(10_000)
tracker.update(50_000)
tracker.update(30_000)
assert tracker.peak_tokens == 50_000
def test_turn_count(self, tracker):
assert tracker.turn_count == 0
tracker.update(10_000)
tracker.update(20_000)
assert tracker.turn_count == 2