Compare commits
1 Commits
burn/838-c
...
fix/834
| Author | SHA1 | Date | |
|---|---|---|---|
| 8694261ee2 |
@@ -1,235 +0,0 @@
|
||||
"""Context Budget Tracker — Proactive token counting with graduated warnings.
|
||||
|
||||
Poka-yoke (mistake-proofing) for context window overflow. Tracks approximate
|
||||
token usage per turn and emits warnings at 70%, 85%, and 95% thresholds
|
||||
relative to the compression threshold (not the raw context window).
|
||||
|
||||
Usage:
|
||||
tracker = ContextBudgetTracker(context_length=128_000, threshold_percent=0.50)
|
||||
tracker.update(estimated_tokens=45_000)
|
||||
level = tracker.warning_level # "elevated" | "critical" | "emergency" | None
|
||||
if level:
|
||||
print(tracker.warning_message)
|
||||
|
||||
Integration points in run_agent.py:
|
||||
1. After `estimate_messages_tokens_rough(messages)` in the agent loop,
|
||||
call `tracker.update(_real_tokens)`.
|
||||
2. Check `tracker.should_checkpoint()` to auto-save session state.
|
||||
3. Check `tracker.should_gate(content_tokens)` before loading large
|
||||
files or skills.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Warning thresholds relative to compression threshold
|
||||
THRESHOLD_CAUTION = 0.70 # 70% of threshold — yellow
|
||||
THRESHOLD_ELEVATED = 0.85 # 85% of threshold — orange
|
||||
THRESHOLD_CRITICAL = 0.95 # 95% of threshold — red
|
||||
|
||||
# Cooldown between repeated warnings at the same tier (seconds)
|
||||
WARNING_COOLDOWN = 300
|
||||
|
||||
# Pre-flight safety margin: refuse loads that would push past this
|
||||
PREFLIGHT_SAFETY_MARGIN = 0.90
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextBudgetTracker:
|
||||
"""Tracks context window usage and emits graduated warnings.
|
||||
|
||||
All thresholds are relative to the compression threshold, which is
|
||||
itself a fraction of the total context window. For example, with a
|
||||
128K context window and 50% compression threshold:
|
||||
- threshold_tokens = 64,000
|
||||
- caution at 70% = 44,800 tokens
|
||||
- elevated at 85% = 54,400 tokens
|
||||
- critical at 95% = 60,800 tokens
|
||||
"""
|
||||
|
||||
context_length: int
|
||||
threshold_percent: float = 0.50
|
||||
|
||||
# Current state (updated by .update())
|
||||
current_tokens: int = 0
|
||||
peak_tokens: int = 0
|
||||
turn_count: int = 0
|
||||
|
||||
# Warning state
|
||||
_last_warned_tier: float = 0.0
|
||||
_last_warned_time: float = 0.0
|
||||
_checkpoint_saved_at: float = 0.0
|
||||
|
||||
# History for trend analysis
|
||||
_token_history: list = field(default_factory=list)
|
||||
_max_history: int = 50
|
||||
|
||||
@property
|
||||
def threshold_tokens(self) -> int:
|
||||
"""Compression threshold in absolute tokens."""
|
||||
return int(self.context_length * self.threshold_percent)
|
||||
|
||||
@property
|
||||
def progress(self) -> float:
|
||||
"""Current usage as fraction of compression threshold (0.0 to 1.0+)."""
|
||||
if self.threshold_tokens == 0:
|
||||
return 0.0
|
||||
return self.current_tokens / self.threshold_tokens
|
||||
|
||||
@property
|
||||
def warning_level(self) -> Optional[str]:
|
||||
"""Current warning level or None if below caution threshold."""
|
||||
p = self.progress
|
||||
if p >= THRESHOLD_CRITICAL:
|
||||
return "emergency"
|
||||
elif p >= THRESHOLD_ELEVATED:
|
||||
return "critical"
|
||||
elif p >= THRESHOLD_CAUTION:
|
||||
return "elevated"
|
||||
return None
|
||||
|
||||
@property
|
||||
def warning_message(self) -> Optional[str]:
|
||||
"""Human-readable warning message for current level, or None."""
|
||||
level = self.warning_level
|
||||
if level is None:
|
||||
return None
|
||||
pct = int(self.progress * 100)
|
||||
used = f"{self.current_tokens:,}"
|
||||
limit = f"{self.threshold_tokens:,}"
|
||||
msgs = {
|
||||
"elevated": f"[CONTEXT WARNING: {pct}% used ({used}/{limit} tokens) — consider summarizing or checkpointing]",
|
||||
"critical": f"[CONTEXT WARNING: {pct}% used ({used}/{limit} tokens) — checkpoint recommended, compaction approaching]",
|
||||
"emergency": f"[CONTEXT CRITICAL: {pct}% used ({used}/{limit} tokens) — compaction imminent, auto-summarizing older content]",
|
||||
}
|
||||
return msgs.get(level)
|
||||
|
||||
@property
|
||||
def should_emit_warning(self) -> bool:
|
||||
"""Whether a new warning should be emitted (dedup + cooldown)."""
|
||||
level = self.warning_level
|
||||
if level is None:
|
||||
return False
|
||||
|
||||
tier = {"elevated": THRESHOLD_CAUTION, "critical": THRESHOLD_ELEVATED, "emergency": THRESHOLD_CRITICAL}
|
||||
tier_val = tier.get(level, 0)
|
||||
|
||||
now = time.time()
|
||||
if tier_val <= self._last_warned_tier:
|
||||
# Same or lower tier — check cooldown
|
||||
if (now - self._last_warned_time) < WARNING_COOLDOWN:
|
||||
return False
|
||||
# New higher tier or cooldown expired
|
||||
return True
|
||||
|
||||
def mark_warned(self):
|
||||
"""Call after emitting a warning to update dedup state."""
|
||||
level = self.warning_level
|
||||
tier = {"elevated": THRESHOLD_CAUTION, "critical": THRESHOLD_ELEVATED, "emergency": THRESHOLD_CRITICAL}
|
||||
self._last_warned_tier = tier.get(level, 0)
|
||||
self._last_warned_time = time.time()
|
||||
|
||||
def update(self, estimated_tokens: int) -> Optional[str]:
|
||||
"""Update current token count and return warning message if warranted.
|
||||
|
||||
Args:
|
||||
estimated_tokens: Rough token count of the current messages.
|
||||
|
||||
Returns:
|
||||
Warning message string if a warning should be shown, else None.
|
||||
"""
|
||||
self.current_tokens = estimated_tokens
|
||||
self.turn_count += 1
|
||||
if estimated_tokens > self.peak_tokens:
|
||||
self.peak_tokens = estimated_tokens
|
||||
|
||||
# Record history
|
||||
self._token_history.append((self.turn_count, estimated_tokens, time.time()))
|
||||
if len(self._token_history) > self._max_history:
|
||||
self._token_history = self._token_history[-self._max_history:]
|
||||
|
||||
if self.should_emit_warning:
|
||||
self.mark_warned()
|
||||
return self.warning_message
|
||||
return None
|
||||
|
||||
def should_checkpoint(self) -> bool:
|
||||
"""Whether session state should be auto-saved (85% threshold).
|
||||
|
||||
Returns True once per crossing of the elevated threshold, with a
|
||||
cooldown to avoid repeated saves.
|
||||
"""
|
||||
if self.progress < THRESHOLD_ELEVATED:
|
||||
return False
|
||||
now = time.time()
|
||||
if (now - self._checkpoint_saved_at) < WARNING_COOLDOWN:
|
||||
return False
|
||||
self._checkpoint_saved_at = now
|
||||
return True
|
||||
|
||||
def can_fit(self, additional_tokens: int) -> bool:
|
||||
"""Pre-flight check: would adding this many tokens exceed the safety margin?
|
||||
|
||||
Use before loading large files or skills to prevent overflow.
|
||||
"""
|
||||
projected = self.current_tokens + additional_tokens
|
||||
return projected < int(self.threshold_tokens * PREFLIGHT_SAFETY_MARGIN)
|
||||
|
||||
def estimate_file_tokens(self, file_size_bytes: int) -> int:
|
||||
"""Rough token estimate for a file of given size (~4 chars/token)."""
|
||||
return max(1, file_size_bytes // 4)
|
||||
|
||||
def tokens_remaining(self) -> int:
|
||||
"""Approximate tokens available before hitting the safety margin."""
|
||||
safe_limit = int(self.threshold_tokens * PREFLIGHT_SAFETY_MARGIN)
|
||||
return max(0, safe_limit - self.current_tokens)
|
||||
|
||||
def trend(self, window: int = 10) -> str:
|
||||
"""Token growth trend over the last N turns: 'growing' | 'stable' | 'shrinking'."""
|
||||
if len(self._token_history) < 2:
|
||||
return "stable"
|
||||
recent = self._token_history[-window:]
|
||||
if len(recent) < 2:
|
||||
return "stable"
|
||||
first = recent[0][1]
|
||||
last = recent[-1][1]
|
||||
delta = last - first
|
||||
threshold = self.threshold_tokens * 0.05 # 5% of threshold
|
||||
if delta > threshold:
|
||||
return "growing"
|
||||
elif delta < -threshold:
|
||||
return "shrinking"
|
||||
return "stable"
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
"""Machine-readable summary for logging/metrics."""
|
||||
return {
|
||||
"context_length": self.context_length,
|
||||
"threshold_tokens": self.threshold_tokens,
|
||||
"current_tokens": self.current_tokens,
|
||||
"peak_tokens": self.peak_tokens,
|
||||
"progress_pct": round(self.progress * 100, 1),
|
||||
"warning_level": self.warning_level,
|
||||
"turn_count": self.turn_count,
|
||||
"trend": self.trend(),
|
||||
"tokens_remaining": self.tokens_remaining(),
|
||||
}
|
||||
|
||||
def format_status(self) -> str:
|
||||
"""Human-readable status line for CLI display."""
|
||||
pct = int(self.progress * 100)
|
||||
bar_len = 20
|
||||
filled = int(bar_len * min(self.progress, 1.0))
|
||||
bar = "█" * filled + "░" * (bar_len - filled)
|
||||
level = self.warning_level or "ok"
|
||||
return f"Context: [{bar}] {pct}% ({self.current_tokens:,}/{self.threshold_tokens:,}) {level}"
|
||||
4
cli.py
4
cli.py
@@ -3611,7 +3611,7 @@ class HermesCLI:
|
||||
available, unavailable = check_tool_availability()
|
||||
|
||||
# Filter to only those missing API keys (not system deps)
|
||||
api_key_missing = [u for u in unavailable if u["missing_vars"]]
|
||||
api_key_missing = [u for u in unavailable if u.get("env_vars")]
|
||||
|
||||
if api_key_missing:
|
||||
self.console.print()
|
||||
@@ -3620,7 +3620,7 @@ class HermesCLI:
|
||||
tools_str = ", ".join(item["tools"][:2]) # Show first 2 tools
|
||||
if len(item["tools"]) > 2:
|
||||
tools_str += f", +{len(item['tools'])-2} more"
|
||||
self.console.print(f" [dim]• {item['name']}[/] [dim italic]({', '.join(item['missing_vars'])})[/]")
|
||||
self.console.print(f" [dim]• {item['name']}[/] [dim italic]({', '.join(item['env_vars'])})[/]")
|
||||
self.console.print("[dim] Run 'hermes setup' to configure[/]")
|
||||
except Exception:
|
||||
pass # Don't crash on import errors
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
"""Tests for ContextBudgetTracker — poka-yoke context window safety."""
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.context_budget import (
|
||||
ContextBudgetTracker,
|
||||
THRESHOLD_CAUTION,
|
||||
THRESHOLD_ELEVATED,
|
||||
THRESHOLD_CRITICAL,
|
||||
PREFLIGHT_SAFETY_MARGIN,
|
||||
WARNING_COOLDOWN,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker():
|
||||
"""Standard tracker: 128K context, 50% compression threshold = 64K threshold."""
|
||||
return ContextBudgetTracker(context_length=128_000, threshold_percent=0.50)
|
||||
|
||||
|
||||
class TestThresholds:
|
||||
def test_threshold_tokens_computed(self, tracker):
|
||||
assert tracker.threshold_tokens == 64_000
|
||||
|
||||
def test_caution_at_70_percent(self, tracker):
|
||||
tracker.update(int(64_000 * 0.70))
|
||||
assert tracker.warning_level == "elevated"
|
||||
|
||||
def test_no_warning_below_caution(self, tracker):
|
||||
tracker.update(int(64_000 * 0.69))
|
||||
assert tracker.warning_level is None
|
||||
|
||||
def test_critical_at_85_percent(self, tracker):
|
||||
tracker.update(int(64_000 * 0.85))
|
||||
assert tracker.warning_level == "critical"
|
||||
|
||||
def test_emergency_at_95_percent(self, tracker):
|
||||
tracker.update(int(64_000 * 0.95))
|
||||
assert tracker.warning_level == "emergency"
|
||||
|
||||
|
||||
class TestWarningMessages:
|
||||
def test_elevated_message_contains_70(self, tracker):
|
||||
tracker.update(int(64_000 * 0.70))
|
||||
msg = tracker.warning_message
|
||||
assert msg is not None
|
||||
assert "CONTEXT WARNING" in msg
|
||||
|
||||
def test_critical_message(self, tracker):
|
||||
tracker.update(int(64_000 * 0.85))
|
||||
msg = tracker.warning_message
|
||||
assert "compaction approaching" in msg
|
||||
|
||||
def test_emergency_message(self, tracker):
|
||||
tracker.update(int(64_000 * 0.95))
|
||||
msg = tracker.warning_message
|
||||
assert "CONTEXT CRITICAL" in msg
|
||||
|
||||
def test_no_message_below_caution(self, tracker):
|
||||
tracker.update(10_000)
|
||||
assert tracker.warning_message is None
|
||||
|
||||
|
||||
class TestWarningDedup:
|
||||
def test_repeated_update_same_tier_suppressed(self, tracker):
|
||||
"""Same tier within cooldown should not re-emit."""
|
||||
tracker.update(int(64_000 * 0.71))
|
||||
msg1 = tracker.update(int(64_000 * 0.72))
|
||||
assert msg1 is None # suppressed by cooldown
|
||||
|
||||
def test_higher_tier_breaks_through_cooldown(self, tracker):
|
||||
"""Crossing to a higher tier should always emit."""
|
||||
tracker.update(int(64_000 * 0.71))
|
||||
msg = tracker.update(int(64_000 * 0.86))
|
||||
assert msg is not None
|
||||
assert "compaction approaching" in msg.lower()
|
||||
|
||||
def test_cooldown_expires_allows_reemit(self, tracker):
|
||||
tracker.update(int(64_000 * 0.71))
|
||||
# Fast-forward cooldown
|
||||
tracker._last_warned_time = time.time() - WARNING_COOLDOWN - 1
|
||||
msg = tracker.update(int(64_000 * 0.72))
|
||||
assert msg is not None
|
||||
|
||||
|
||||
class TestCheckpoint:
|
||||
def test_should_checkpoint_at_85(self, tracker):
|
||||
tracker.update(int(64_000 * 0.86))
|
||||
assert tracker.should_checkpoint() is True
|
||||
|
||||
def test_no_checkpoint_below_85(self, tracker):
|
||||
tracker.update(int(64_000 * 0.84))
|
||||
assert tracker.should_checkpoint() is False
|
||||
|
||||
def test_checkpoint_cooldown(self, tracker):
|
||||
tracker.update(int(64_000 * 0.86))
|
||||
tracker.should_checkpoint() # saves
|
||||
assert tracker.should_checkpoint() is False # cooldown
|
||||
|
||||
|
||||
class TestPreflight:
|
||||
def test_can_fit_small_addition(self, tracker):
|
||||
tracker.update(30_000)
|
||||
assert tracker.can_fit(5_000) is True
|
||||
|
||||
def test_cannot_fit_overflow(self, tracker):
|
||||
tracker.update(int(64_000 * 0.88))
|
||||
assert tracker.can_fit(10_000) is False
|
||||
|
||||
def test_estimate_file_tokens(self, tracker):
|
||||
assert tracker.estimate_file_tokens(4_000) == 1_000
|
||||
assert tracker.estimate_file_tokens(100) >= 1 # minimum 1
|
||||
|
||||
def test_tokens_remaining(self, tracker):
|
||||
tracker.update(30_000)
|
||||
remaining = tracker.tokens_remaining()
|
||||
safe_limit = int(64_000 * PREFLIGHT_SAFETY_MARGIN)
|
||||
assert remaining == safe_limit - 30_000
|
||||
|
||||
|
||||
class TestTrend:
|
||||
def test_growing_trend(self, tracker):
|
||||
for i in range(10):
|
||||
tracker.update(10_000 + i * 5_000)
|
||||
assert tracker.trend() == "growing"
|
||||
|
||||
def test_shrinking_trend(self, tracker):
|
||||
for i in range(10):
|
||||
tracker.update(60_000 - i * 5_000)
|
||||
assert tracker.trend() == "shrinking"
|
||||
|
||||
def test_stable_trend(self, tracker):
|
||||
for _ in range(10):
|
||||
tracker.update(30_000)
|
||||
assert tracker.trend() == "stable"
|
||||
|
||||
|
||||
class TestSummary:
|
||||
def test_summary_keys(self, tracker):
|
||||
tracker.update(40_000)
|
||||
s = tracker.summary()
|
||||
assert "context_length" in s
|
||||
assert "current_tokens" in s
|
||||
assert "warning_level" in s
|
||||
assert "trend" in s
|
||||
assert s["current_tokens"] == 40_000
|
||||
|
||||
def test_format_status(self, tracker):
|
||||
tracker.update(30_000)
|
||||
status = tracker.format_status()
|
||||
assert "Context:" in status
|
||||
assert "[" in status # progress bar
|
||||
assert "%" in status
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_zero_context_length(self):
|
||||
t = ContextBudgetTracker(context_length=0)
|
||||
assert t.threshold_tokens == 0
|
||||
assert t.progress == 0.0
|
||||
assert t.warning_level is None
|
||||
|
||||
def test_different_threshold_percent(self):
|
||||
t = ContextBudgetTracker(context_length=100_000, threshold_percent=0.80)
|
||||
assert t.threshold_tokens == 80_000
|
||||
t.update(int(80_000 * 0.70))
|
||||
assert t.warning_level == "elevated"
|
||||
|
||||
def test_over_threshold_progress(self, tracker):
|
||||
"""Progress can exceed 1.0 (past compression threshold)."""
|
||||
tracker.update(70_000)
|
||||
assert tracker.progress > 1.0
|
||||
assert tracker.warning_level == "emergency"
|
||||
|
||||
def test_peak_tracking(self, tracker):
|
||||
tracker.update(10_000)
|
||||
tracker.update(50_000)
|
||||
tracker.update(30_000)
|
||||
assert tracker.peak_tokens == 50_000
|
||||
|
||||
def test_turn_count(self, tracker):
|
||||
assert tracker.turn_count == 0
|
||||
tracker.update(10_000)
|
||||
tracker.update(20_000)
|
||||
assert tracker.turn_count == 2
|
||||
Reference in New Issue
Block a user