Compare commits

...

2 Commits

Author SHA1 Message Date
d4cdfdc604 test: Add context budget tracker tests (#838)
Some checks failed
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 19s
Contributor Attribution Check / check-attribution (pull_request) Failing after 16s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Tests / test (pull_request) Failing after 18m30s
Tests / e2e (pull_request) Successful in 1m16s
2026-04-17 05:06:54 +00:00
e3436e36c3 feat: Add context budget tracker for overflow prevention (#838) 2026-04-17 05:06:08 +00:00
2 changed files with 275 additions and 0 deletions

148
agent/context_budget.py Normal file
View File

@@ -0,0 +1,148 @@
"""
Context Budget Tracker - Prevent context window overflow
Poka-yoke: Visual warnings at 70%%, 85%%, 95%% capacity.
Auto-checkpoint at 85%%. Pre-flight token estimation.
Issue: #838
"""
import json
import logging
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
HERMES_HOME = Path.home() / ".hermes"
CHECKPOINT_DIR = HERMES_HOME / "checkpoints"
CHARS_PER_TOKEN = 4
THRESHOLD_WARNING = 0.70
THRESHOLD_CRITICAL = 0.85
THRESHOLD_DANGER = 0.95
class ContextBudget:
def __init__(self, context_limit: int = 128000, system_tokens: int = 0,
used_tokens: int = 0, reserved_tokens: int = 2000):
self.context_limit = context_limit
self.system_tokens = system_tokens
self.used_tokens = used_tokens
self.reserved_tokens = reserved_tokens
@property
def total_used(self) -> int:
return self.system_tokens + self.used_tokens
@property
def available(self) -> int:
return max(0, self.context_limit - self.reserved_tokens)
@property
def remaining(self) -> int:
return max(0, self.available - self.total_used)
@property
def utilization(self) -> float:
return self.total_used / self.available if self.available > 0 else 1.0
def estimate_tokens(text: str) -> int:
return len(text) // CHARS_PER_TOKEN if text else 0
def estimate_messages_tokens(messages: List[Dict]) -> int:
total = 0
for msg in messages:
content = msg.get("content", "")
if isinstance(content, str):
total += estimate_tokens(content)
if msg.get("tool_calls"):
total += 100
return total
class ContextBudgetTracker:
def __init__(self, context_limit: int = 128000, session_id: str = ""):
self.budget = ContextBudget(context_limit=context_limit)
self.session_id = session_id
self._checkpointed = False
self._warnings_given = set()
def update_from_messages(self, messages: List[Dict]):
self.budget.used_tokens = estimate_messages_tokens(messages)
def can_fit(self, additional_tokens: int) -> bool:
return self.budget.remaining >= additional_tokens
def preflight_check(self, text: str) -> Tuple[bool, str]:
tokens = estimate_tokens(text)
if not self.can_fit(tokens):
return False, f"Cannot load: ~{tokens:,} tokens needed, {self.budget.remaining:,} remaining"
would_util = (self.budget.total_used + tokens) / self.budget.available if self.budget.available > 0 else 1.0
if would_util >= THRESHOLD_DANGER:
return False, f"Would reach {would_util:.0%%} capacity. Summarize or start new session."
if would_util >= THRESHOLD_CRITICAL:
return True, f"Warning: will reach {would_util:.0%%} capacity."
return True, ""
def get_warning(self) -> Optional[str]:
util = self.budget.utilization
if util >= THRESHOLD_DANGER and "danger" not in self._warnings_given:
self._warnings_given.add("danger")
return f"[CONTEXT CRITICAL: {util:.0%%} used -- {self.budget.remaining:,} tokens left. Summarize or start new session.]"
if util >= THRESHOLD_CRITICAL and "critical" not in self._warnings_given:
self._warnings_given.add("critical")
self._auto_checkpoint()
return f"[CONTEXT WARNING: {util:.0%%} used -- consider summarizing. Auto-checkpoint saved.]"
if util >= THRESHOLD_WARNING and "warning" not in self._warnings_given:
self._warnings_given.add("warning")
return f"[CONTEXT: {util:.0%%} used -- {self.budget.remaining:,} tokens remaining]"
return None
def _auto_checkpoint(self):
if self._checkpointed or not self.session_id:
return
try:
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
path = CHECKPOINT_DIR / f"{self.session_id}.json"
path.write_text(json.dumps({
"session_id": self.session_id,
"timestamp": time.time(),
"budget": {"utilization": round(self.budget.utilization * 100, 1)}
}, indent=2))
self._checkpointed = True
logger.info("Auto-checkpoint saved: %s", path)
except Exception as e:
logger.error("Auto-checkpoint failed: %s", e)
def get_status_line(self) -> str:
util = self.budget.utilization
remaining = self.budget.remaining
if util >= THRESHOLD_DANGER:
return f"RED {util:.0%%} used ({remaining:,} left)"
elif util >= THRESHOLD_CRITICAL:
return f"ORANGE {util:.0%%} used ({remaining:,} left)"
elif util >= THRESHOLD_WARNING:
return f"YELLOW {util:.0%%} used ({remaining:,} left)"
return f"GREEN {util:.0%%} used ({remaining:,} left)"
_tracker = None
def get_tracker(context_limit=128000, session_id=""):
global _tracker
if _tracker is None:
_tracker = ContextBudgetTracker(context_limit, session_id)
return _tracker
def check_context_budget(messages, context_limit=128000):
tracker = get_tracker(context_limit)
tracker.update_from_messages(messages)
return tracker.get_warning()
def preflight_token_check(text):
tracker = get_tracker()
return tracker.preflight_check(text)

View File

@@ -0,0 +1,127 @@
"""
Tests for context budget tracker
Issue: #838
"""
import json
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from agent.context_budget import (
ContextBudget,
ContextBudgetTracker,
estimate_tokens,
estimate_messages_tokens,
check_context_budget,
preflight_token_check,
THRESHOLD_WARNING,
THRESHOLD_CRITICAL,
THRESHOLD_DANGER,
)
class TestContextBudget(unittest.TestCase):
def test_basic_budget(self):
b = ContextBudget(context_limit=10000)
self.assertEqual(b.available, 8000) # 10000 - 2000 reserved
self.assertEqual(b.remaining, 8000)
self.assertEqual(b.utilization, 0.0)
def test_utilization(self):
b = ContextBudget(context_limit=10000, used_tokens=4000)
self.assertEqual(b.utilization, 0.5)
self.assertEqual(b.remaining, 4000)
class TestTokenEstimation(unittest.TestCase):
def test_estimate_tokens(self):
self.assertEqual(estimate_tokens(""), 0)
self.assertEqual(estimate_tokens("a" * 4), 1)
self.assertEqual(estimate_tokens("a" * 400), 100)
def test_estimate_messages(self):
messages = [
{"role": "user", "content": "a" * 400},
{"role": "assistant", "content": "b" * 800},
]
tokens = estimate_messages_tokens(messages)
self.assertEqual(tokens, 300) # 100 + 200
class TestContextBudgetTracker(unittest.TestCase):
def test_warning_at_70_percent(self):
tracker = ContextBudgetTracker(context_limit=10000)
tracker.budget.used_tokens = 5600 # 70% of 8000 available
warning = tracker.get_warning()
self.assertIsNotNone(warning)
self.assertIn("70", warning)
def test_critical_at_85_percent(self):
with tempfile.TemporaryDirectory() as tmp:
with patch("agent.context_budget.CHECKPOINT_DIR", Path(tmp)):
tracker = ContextBudgetTracker(context_limit=10000, session_id="test")
tracker.budget.used_tokens = 6800 # 85% of 8000
warning = tracker.get_warning()
self.assertIsNotNone(warning)
self.assertIn("85", warning)
def test_danger_at_95_percent(self):
tracker = ContextBudgetTracker(context_limit=10000)
tracker.budget.used_tokens = 7600 # 95% of 8000
warning = tracker.get_warning()
self.assertIsNotNone(warning)
self.assertIn("CRITICAL", warning)
def test_can_fit(self):
tracker = ContextBudgetTracker(context_limit=10000)
tracker.budget.used_tokens = 5000
self.assertTrue(tracker.can_fit(1000))
self.assertFalse(tracker.can_fit(5000))
def test_preflight_check(self):
tracker = ContextBudgetTracker(context_limit=10000)
tracker.budget.used_tokens = 5000
can_fit, msg = tracker.preflight_check("a" * 400) # 100 tokens
self.assertTrue(can_fit)
self.assertEqual(msg, "")
class TestCheckContextBudget(unittest.TestCase):
def test_no_warning_under_threshold(self):
with patch("agent.context_budget._tracker", None):
messages = [{"role": "user", "content": "short"}]
warning = check_context_budget(messages)
self.assertIsNone(warning)
def test_warning_over_threshold(self):
with patch("agent.context_budget._tracker", None):
# Create messages that exceed 70% of default 128k context
messages = [{"role": "user", "content": "x" * 350000}] # ~87500 tokens
warning = check_context_budget(messages)
self.assertIsNotNone(warning)
class TestStatusLine(unittest.TestCase):
def test_green_status(self):
tracker = ContextBudgetTracker(context_limit=10000)
line = tracker.get_status_line()
self.assertIn("GREEN", line)
def test_red_status(self):
tracker = ContextBudgetTracker(context_limit=10000)
tracker.budget.used_tokens = 7600
line = tracker.get_status_line()
self.assertIn("RED", line)
if __name__ == "__main__":
unittest.main()