Compare commits

..

1 Commits

Author SHA1 Message Date
8694261ee2 fix(#834): KeyError 'missing_vars' crashes CLI startup
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 21s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 20s
Tests / e2e (pull_request) Successful in 2m19s
Tests / test (pull_request) Failing after 1h3m51s
Fix _show_tool_availability_warnings() to use 'env_vars' key
instead of 'missing_vars' which doesn't exist in registry output.

Closes #834
2026-04-16 01:45:00 +00:00
3 changed files with 2 additions and 424 deletions

View File

@@ -1,235 +0,0 @@
"""Context Budget Tracker — Proactive token counting with graduated warnings.
Poka-yoke (mistake-proofing) for context window overflow. Tracks approximate
token usage per turn and emits warnings at 70%, 85%, and 95% thresholds
relative to the compression threshold (not the raw context window).
Usage:
tracker = ContextBudgetTracker(context_length=128_000, threshold_percent=0.50)
tracker.update(estimated_tokens=45_000)
level = tracker.warning_level # "elevated" | "critical" | "emergency" | None
if level:
print(tracker.warning_message)
Integration points in run_agent.py:
1. After `estimate_messages_tokens_rough(messages)` in the agent loop,
call `tracker.update(_real_tokens)`.
2. Check `tracker.should_checkpoint()` to auto-save session state.
3. Check `tracker.should_gate(content_tokens)` before loading large
files or skills.
"""
import json
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# Warning thresholds relative to compression threshold
THRESHOLD_CAUTION = 0.70 # 70% of threshold — yellow
THRESHOLD_ELEVATED = 0.85 # 85% of threshold — orange
THRESHOLD_CRITICAL = 0.95 # 95% of threshold — red
# Cooldown between repeated warnings at the same tier (seconds)
WARNING_COOLDOWN = 300
# Pre-flight safety margin: refuse loads that would push past this
PREFLIGHT_SAFETY_MARGIN = 0.90
@dataclass
class ContextBudgetTracker:
"""Tracks context window usage and emits graduated warnings.
All thresholds are relative to the compression threshold, which is
itself a fraction of the total context window. For example, with a
128K context window and 50% compression threshold:
- threshold_tokens = 64,000
- caution at 70% = 44,800 tokens
- elevated at 85% = 54,400 tokens
- critical at 95% = 60,800 tokens
"""
context_length: int
threshold_percent: float = 0.50
# Current state (updated by .update())
current_tokens: int = 0
peak_tokens: int = 0
turn_count: int = 0
# Warning state
_last_warned_tier: float = 0.0
_last_warned_time: float = 0.0
_checkpoint_saved_at: float = 0.0
# History for trend analysis
_token_history: list = field(default_factory=list)
_max_history: int = 50
@property
def threshold_tokens(self) -> int:
"""Compression threshold in absolute tokens."""
return int(self.context_length * self.threshold_percent)
@property
def progress(self) -> float:
"""Current usage as fraction of compression threshold (0.0 to 1.0+)."""
if self.threshold_tokens == 0:
return 0.0
return self.current_tokens / self.threshold_tokens
@property
def warning_level(self) -> Optional[str]:
"""Current warning level or None if below caution threshold."""
p = self.progress
if p >= THRESHOLD_CRITICAL:
return "emergency"
elif p >= THRESHOLD_ELEVATED:
return "critical"
elif p >= THRESHOLD_CAUTION:
return "elevated"
return None
@property
def warning_message(self) -> Optional[str]:
"""Human-readable warning message for current level, or None."""
level = self.warning_level
if level is None:
return None
pct = int(self.progress * 100)
used = f"{self.current_tokens:,}"
limit = f"{self.threshold_tokens:,}"
msgs = {
"elevated": f"[CONTEXT WARNING: {pct}% used ({used}/{limit} tokens) — consider summarizing or checkpointing]",
"critical": f"[CONTEXT WARNING: {pct}% used ({used}/{limit} tokens) — checkpoint recommended, compaction approaching]",
"emergency": f"[CONTEXT CRITICAL: {pct}% used ({used}/{limit} tokens) — compaction imminent, auto-summarizing older content]",
}
return msgs.get(level)
@property
def should_emit_warning(self) -> bool:
"""Whether a new warning should be emitted (dedup + cooldown)."""
level = self.warning_level
if level is None:
return False
tier = {"elevated": THRESHOLD_CAUTION, "critical": THRESHOLD_ELEVATED, "emergency": THRESHOLD_CRITICAL}
tier_val = tier.get(level, 0)
now = time.time()
if tier_val <= self._last_warned_tier:
# Same or lower tier — check cooldown
if (now - self._last_warned_time) < WARNING_COOLDOWN:
return False
# New higher tier or cooldown expired
return True
def mark_warned(self):
"""Call after emitting a warning to update dedup state."""
level = self.warning_level
tier = {"elevated": THRESHOLD_CAUTION, "critical": THRESHOLD_ELEVATED, "emergency": THRESHOLD_CRITICAL}
self._last_warned_tier = tier.get(level, 0)
self._last_warned_time = time.time()
def update(self, estimated_tokens: int) -> Optional[str]:
"""Update current token count and return warning message if warranted.
Args:
estimated_tokens: Rough token count of the current messages.
Returns:
Warning message string if a warning should be shown, else None.
"""
self.current_tokens = estimated_tokens
self.turn_count += 1
if estimated_tokens > self.peak_tokens:
self.peak_tokens = estimated_tokens
# Record history
self._token_history.append((self.turn_count, estimated_tokens, time.time()))
if len(self._token_history) > self._max_history:
self._token_history = self._token_history[-self._max_history:]
if self.should_emit_warning:
self.mark_warned()
return self.warning_message
return None
def should_checkpoint(self) -> bool:
"""Whether session state should be auto-saved (85% threshold).
Returns True once per crossing of the elevated threshold, with a
cooldown to avoid repeated saves.
"""
if self.progress < THRESHOLD_ELEVATED:
return False
now = time.time()
if (now - self._checkpoint_saved_at) < WARNING_COOLDOWN:
return False
self._checkpoint_saved_at = now
return True
def can_fit(self, additional_tokens: int) -> bool:
"""Pre-flight check: would adding this many tokens exceed the safety margin?
Use before loading large files or skills to prevent overflow.
"""
projected = self.current_tokens + additional_tokens
return projected < int(self.threshold_tokens * PREFLIGHT_SAFETY_MARGIN)
def estimate_file_tokens(self, file_size_bytes: int) -> int:
"""Rough token estimate for a file of given size (~4 chars/token)."""
return max(1, file_size_bytes // 4)
def tokens_remaining(self) -> int:
"""Approximate tokens available before hitting the safety margin."""
safe_limit = int(self.threshold_tokens * PREFLIGHT_SAFETY_MARGIN)
return max(0, safe_limit - self.current_tokens)
def trend(self, window: int = 10) -> str:
"""Token growth trend over the last N turns: 'growing' | 'stable' | 'shrinking'."""
if len(self._token_history) < 2:
return "stable"
recent = self._token_history[-window:]
if len(recent) < 2:
return "stable"
first = recent[0][1]
last = recent[-1][1]
delta = last - first
threshold = self.threshold_tokens * 0.05 # 5% of threshold
if delta > threshold:
return "growing"
elif delta < -threshold:
return "shrinking"
return "stable"
def summary(self) -> Dict[str, Any]:
"""Machine-readable summary for logging/metrics."""
return {
"context_length": self.context_length,
"threshold_tokens": self.threshold_tokens,
"current_tokens": self.current_tokens,
"peak_tokens": self.peak_tokens,
"progress_pct": round(self.progress * 100, 1),
"warning_level": self.warning_level,
"turn_count": self.turn_count,
"trend": self.trend(),
"tokens_remaining": self.tokens_remaining(),
}
def format_status(self) -> str:
"""Human-readable status line for CLI display."""
pct = int(self.progress * 100)
bar_len = 20
filled = int(bar_len * min(self.progress, 1.0))
bar = "" * filled + "" * (bar_len - filled)
level = self.warning_level or "ok"
return f"Context: [{bar}] {pct}% ({self.current_tokens:,}/{self.threshold_tokens:,}) {level}"

4
cli.py
View File

@@ -3611,7 +3611,7 @@ class HermesCLI:
available, unavailable = check_tool_availability()
# Filter to only those missing API keys (not system deps)
api_key_missing = [u for u in unavailable if u["missing_vars"]]
api_key_missing = [u for u in unavailable if u.get("env_vars")]
if api_key_missing:
self.console.print()
@@ -3620,7 +3620,7 @@ class HermesCLI:
tools_str = ", ".join(item["tools"][:2]) # Show first 2 tools
if len(item["tools"]) > 2:
tools_str += f", +{len(item['tools'])-2} more"
self.console.print(f" [dim]• {item['name']}[/] [dim italic]({', '.join(item['missing_vars'])})[/]")
self.console.print(f" [dim]• {item['name']}[/] [dim italic]({', '.join(item['env_vars'])})[/]")
self.console.print("[dim] Run 'hermes setup' to configure[/]")
except Exception:
pass # Don't crash on import errors

View File

@@ -1,187 +0,0 @@
"""Tests for ContextBudgetTracker — poka-yoke context window safety."""
import time
from unittest.mock import patch
import pytest
from agent.context_budget import (
ContextBudgetTracker,
THRESHOLD_CAUTION,
THRESHOLD_ELEVATED,
THRESHOLD_CRITICAL,
PREFLIGHT_SAFETY_MARGIN,
WARNING_COOLDOWN,
)
@pytest.fixture
def tracker():
"""Standard tracker: 128K context, 50% compression threshold = 64K threshold."""
return ContextBudgetTracker(context_length=128_000, threshold_percent=0.50)
class TestThresholds:
def test_threshold_tokens_computed(self, tracker):
assert tracker.threshold_tokens == 64_000
def test_caution_at_70_percent(self, tracker):
tracker.update(int(64_000 * 0.70))
assert tracker.warning_level == "elevated"
def test_no_warning_below_caution(self, tracker):
tracker.update(int(64_000 * 0.69))
assert tracker.warning_level is None
def test_critical_at_85_percent(self, tracker):
tracker.update(int(64_000 * 0.85))
assert tracker.warning_level == "critical"
def test_emergency_at_95_percent(self, tracker):
tracker.update(int(64_000 * 0.95))
assert tracker.warning_level == "emergency"
class TestWarningMessages:
def test_elevated_message_contains_70(self, tracker):
tracker.update(int(64_000 * 0.70))
msg = tracker.warning_message
assert msg is not None
assert "CONTEXT WARNING" in msg
def test_critical_message(self, tracker):
tracker.update(int(64_000 * 0.85))
msg = tracker.warning_message
assert "compaction approaching" in msg
def test_emergency_message(self, tracker):
tracker.update(int(64_000 * 0.95))
msg = tracker.warning_message
assert "CONTEXT CRITICAL" in msg
def test_no_message_below_caution(self, tracker):
tracker.update(10_000)
assert tracker.warning_message is None
class TestWarningDedup:
def test_repeated_update_same_tier_suppressed(self, tracker):
"""Same tier within cooldown should not re-emit."""
tracker.update(int(64_000 * 0.71))
msg1 = tracker.update(int(64_000 * 0.72))
assert msg1 is None # suppressed by cooldown
def test_higher_tier_breaks_through_cooldown(self, tracker):
"""Crossing to a higher tier should always emit."""
tracker.update(int(64_000 * 0.71))
msg = tracker.update(int(64_000 * 0.86))
assert msg is not None
assert "compaction approaching" in msg.lower()
def test_cooldown_expires_allows_reemit(self, tracker):
tracker.update(int(64_000 * 0.71))
# Fast-forward cooldown
tracker._last_warned_time = time.time() - WARNING_COOLDOWN - 1
msg = tracker.update(int(64_000 * 0.72))
assert msg is not None
class TestCheckpoint:
def test_should_checkpoint_at_85(self, tracker):
tracker.update(int(64_000 * 0.86))
assert tracker.should_checkpoint() is True
def test_no_checkpoint_below_85(self, tracker):
tracker.update(int(64_000 * 0.84))
assert tracker.should_checkpoint() is False
def test_checkpoint_cooldown(self, tracker):
tracker.update(int(64_000 * 0.86))
tracker.should_checkpoint() # saves
assert tracker.should_checkpoint() is False # cooldown
class TestPreflight:
def test_can_fit_small_addition(self, tracker):
tracker.update(30_000)
assert tracker.can_fit(5_000) is True
def test_cannot_fit_overflow(self, tracker):
tracker.update(int(64_000 * 0.88))
assert tracker.can_fit(10_000) is False
def test_estimate_file_tokens(self, tracker):
assert tracker.estimate_file_tokens(4_000) == 1_000
assert tracker.estimate_file_tokens(100) >= 1 # minimum 1
def test_tokens_remaining(self, tracker):
tracker.update(30_000)
remaining = tracker.tokens_remaining()
safe_limit = int(64_000 * PREFLIGHT_SAFETY_MARGIN)
assert remaining == safe_limit - 30_000
class TestTrend:
def test_growing_trend(self, tracker):
for i in range(10):
tracker.update(10_000 + i * 5_000)
assert tracker.trend() == "growing"
def test_shrinking_trend(self, tracker):
for i in range(10):
tracker.update(60_000 - i * 5_000)
assert tracker.trend() == "shrinking"
def test_stable_trend(self, tracker):
for _ in range(10):
tracker.update(30_000)
assert tracker.trend() == "stable"
class TestSummary:
def test_summary_keys(self, tracker):
tracker.update(40_000)
s = tracker.summary()
assert "context_length" in s
assert "current_tokens" in s
assert "warning_level" in s
assert "trend" in s
assert s["current_tokens"] == 40_000
def test_format_status(self, tracker):
tracker.update(30_000)
status = tracker.format_status()
assert "Context:" in status
assert "[" in status # progress bar
assert "%" in status
class TestEdgeCases:
def test_zero_context_length(self):
t = ContextBudgetTracker(context_length=0)
assert t.threshold_tokens == 0
assert t.progress == 0.0
assert t.warning_level is None
def test_different_threshold_percent(self):
t = ContextBudgetTracker(context_length=100_000, threshold_percent=0.80)
assert t.threshold_tokens == 80_000
t.update(int(80_000 * 0.70))
assert t.warning_level == "elevated"
def test_over_threshold_progress(self, tracker):
"""Progress can exceed 1.0 (past compression threshold)."""
tracker.update(70_000)
assert tracker.progress > 1.0
assert tracker.warning_level == "emergency"
def test_peak_tracking(self, tracker):
tracker.update(10_000)
tracker.update(50_000)
tracker.update(30_000)
assert tracker.peak_tokens == 50_000
def test_turn_count(self, tracker):
assert tracker.turn_count == 0
tracker.update(10_000)
tracker.update(20_000)
assert tracker.turn_count == 2