Compare commits
2 Commits
fix/format
...
fix/838-17
| Author | SHA1 | Date | |
|---|---|---|---|
| d4cdfdc604 | |||
| e3436e36c3 |
148
agent/context_budget.py
Normal file
148
agent/context_budget.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Context Budget Tracker - Prevent context window overflow
|
||||
|
||||
Poka-yoke: Visual warnings at 70%%, 85%%, 95%% capacity.
|
||||
Auto-checkpoint at 85%%. Pre-flight token estimation.
|
||||
|
||||
Issue: #838
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HERMES_HOME = Path.home() / ".hermes"
|
||||
CHECKPOINT_DIR = HERMES_HOME / "checkpoints"
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
THRESHOLD_WARNING = 0.70
|
||||
THRESHOLD_CRITICAL = 0.85
|
||||
THRESHOLD_DANGER = 0.95
|
||||
|
||||
|
||||
class ContextBudget:
|
||||
def __init__(self, context_limit: int = 128000, system_tokens: int = 0,
|
||||
used_tokens: int = 0, reserved_tokens: int = 2000):
|
||||
self.context_limit = context_limit
|
||||
self.system_tokens = system_tokens
|
||||
self.used_tokens = used_tokens
|
||||
self.reserved_tokens = reserved_tokens
|
||||
|
||||
@property
|
||||
def total_used(self) -> int:
|
||||
return self.system_tokens + self.used_tokens
|
||||
|
||||
@property
|
||||
def available(self) -> int:
|
||||
return max(0, self.context_limit - self.reserved_tokens)
|
||||
|
||||
@property
|
||||
def remaining(self) -> int:
|
||||
return max(0, self.available - self.total_used)
|
||||
|
||||
@property
|
||||
def utilization(self) -> float:
|
||||
return self.total_used / self.available if self.available > 0 else 1.0
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
return len(text) // CHARS_PER_TOKEN if text else 0
|
||||
|
||||
|
||||
def estimate_messages_tokens(messages: List[Dict]) -> int:
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total += estimate_tokens(content)
|
||||
if msg.get("tool_calls"):
|
||||
total += 100
|
||||
return total
|
||||
|
||||
|
||||
class ContextBudgetTracker:
|
||||
def __init__(self, context_limit: int = 128000, session_id: str = ""):
|
||||
self.budget = ContextBudget(context_limit=context_limit)
|
||||
self.session_id = session_id
|
||||
self._checkpointed = False
|
||||
self._warnings_given = set()
|
||||
|
||||
def update_from_messages(self, messages: List[Dict]):
|
||||
self.budget.used_tokens = estimate_messages_tokens(messages)
|
||||
|
||||
def can_fit(self, additional_tokens: int) -> bool:
|
||||
return self.budget.remaining >= additional_tokens
|
||||
|
||||
def preflight_check(self, text: str) -> Tuple[bool, str]:
|
||||
tokens = estimate_tokens(text)
|
||||
if not self.can_fit(tokens):
|
||||
return False, f"Cannot load: ~{tokens:,} tokens needed, {self.budget.remaining:,} remaining"
|
||||
would_util = (self.budget.total_used + tokens) / self.budget.available if self.budget.available > 0 else 1.0
|
||||
if would_util >= THRESHOLD_DANGER:
|
||||
return False, f"Would reach {would_util:.0%%} capacity. Summarize or start new session."
|
||||
if would_util >= THRESHOLD_CRITICAL:
|
||||
return True, f"Warning: will reach {would_util:.0%%} capacity."
|
||||
return True, ""
|
||||
|
||||
def get_warning(self) -> Optional[str]:
|
||||
util = self.budget.utilization
|
||||
if util >= THRESHOLD_DANGER and "danger" not in self._warnings_given:
|
||||
self._warnings_given.add("danger")
|
||||
return f"[CONTEXT CRITICAL: {util:.0%%} used -- {self.budget.remaining:,} tokens left. Summarize or start new session.]"
|
||||
if util >= THRESHOLD_CRITICAL and "critical" not in self._warnings_given:
|
||||
self._warnings_given.add("critical")
|
||||
self._auto_checkpoint()
|
||||
return f"[CONTEXT WARNING: {util:.0%%} used -- consider summarizing. Auto-checkpoint saved.]"
|
||||
if util >= THRESHOLD_WARNING and "warning" not in self._warnings_given:
|
||||
self._warnings_given.add("warning")
|
||||
return f"[CONTEXT: {util:.0%%} used -- {self.budget.remaining:,} tokens remaining]"
|
||||
return None
|
||||
|
||||
def _auto_checkpoint(self):
|
||||
if self._checkpointed or not self.session_id:
|
||||
return
|
||||
try:
|
||||
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
path = CHECKPOINT_DIR / f"{self.session_id}.json"
|
||||
path.write_text(json.dumps({
|
||||
"session_id": self.session_id,
|
||||
"timestamp": time.time(),
|
||||
"budget": {"utilization": round(self.budget.utilization * 100, 1)}
|
||||
}, indent=2))
|
||||
self._checkpointed = True
|
||||
logger.info("Auto-checkpoint saved: %s", path)
|
||||
except Exception as e:
|
||||
logger.error("Auto-checkpoint failed: %s", e)
|
||||
|
||||
def get_status_line(self) -> str:
|
||||
util = self.budget.utilization
|
||||
remaining = self.budget.remaining
|
||||
if util >= THRESHOLD_DANGER:
|
||||
return f"RED {util:.0%%} used ({remaining:,} left)"
|
||||
elif util >= THRESHOLD_CRITICAL:
|
||||
return f"ORANGE {util:.0%%} used ({remaining:,} left)"
|
||||
elif util >= THRESHOLD_WARNING:
|
||||
return f"YELLOW {util:.0%%} used ({remaining:,} left)"
|
||||
return f"GREEN {util:.0%%} used ({remaining:,} left)"
|
||||
|
||||
|
||||
_tracker = None
|
||||
|
||||
def get_tracker(context_limit=128000, session_id=""):
|
||||
global _tracker
|
||||
if _tracker is None:
|
||||
_tracker = ContextBudgetTracker(context_limit, session_id)
|
||||
return _tracker
|
||||
|
||||
def check_context_budget(messages, context_limit=128000):
|
||||
tracker = get_tracker(context_limit)
|
||||
tracker.update_from_messages(messages)
|
||||
return tracker.get_warning()
|
||||
|
||||
def preflight_token_check(text):
|
||||
tracker = get_tracker()
|
||||
return tracker.preflight_check(text)
|
||||
127
tests/test_context_budget.py
Normal file
127
tests/test_context_budget.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Tests for context budget tracker
|
||||
|
||||
Issue: #838
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent.context_budget import (
|
||||
ContextBudget,
|
||||
ContextBudgetTracker,
|
||||
estimate_tokens,
|
||||
estimate_messages_tokens,
|
||||
check_context_budget,
|
||||
preflight_token_check,
|
||||
THRESHOLD_WARNING,
|
||||
THRESHOLD_CRITICAL,
|
||||
THRESHOLD_DANGER,
|
||||
)
|
||||
|
||||
|
||||
class TestContextBudget(unittest.TestCase):
|
||||
|
||||
def test_basic_budget(self):
|
||||
b = ContextBudget(context_limit=10000)
|
||||
self.assertEqual(b.available, 8000) # 10000 - 2000 reserved
|
||||
self.assertEqual(b.remaining, 8000)
|
||||
self.assertEqual(b.utilization, 0.0)
|
||||
|
||||
def test_utilization(self):
|
||||
b = ContextBudget(context_limit=10000, used_tokens=4000)
|
||||
self.assertEqual(b.utilization, 0.5)
|
||||
self.assertEqual(b.remaining, 4000)
|
||||
|
||||
|
||||
class TestTokenEstimation(unittest.TestCase):
|
||||
|
||||
def test_estimate_tokens(self):
|
||||
self.assertEqual(estimate_tokens(""), 0)
|
||||
self.assertEqual(estimate_tokens("a" * 4), 1)
|
||||
self.assertEqual(estimate_tokens("a" * 400), 100)
|
||||
|
||||
def test_estimate_messages(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "a" * 400},
|
||||
{"role": "assistant", "content": "b" * 800},
|
||||
]
|
||||
tokens = estimate_messages_tokens(messages)
|
||||
self.assertEqual(tokens, 300) # 100 + 200
|
||||
|
||||
|
||||
class TestContextBudgetTracker(unittest.TestCase):
|
||||
|
||||
def test_warning_at_70_percent(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 5600 # 70% of 8000 available
|
||||
warning = tracker.get_warning()
|
||||
self.assertIsNotNone(warning)
|
||||
self.assertIn("70", warning)
|
||||
|
||||
def test_critical_at_85_percent(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
with patch("agent.context_budget.CHECKPOINT_DIR", Path(tmp)):
|
||||
tracker = ContextBudgetTracker(context_limit=10000, session_id="test")
|
||||
tracker.budget.used_tokens = 6800 # 85% of 8000
|
||||
warning = tracker.get_warning()
|
||||
self.assertIsNotNone(warning)
|
||||
self.assertIn("85", warning)
|
||||
|
||||
def test_danger_at_95_percent(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 7600 # 95% of 8000
|
||||
warning = tracker.get_warning()
|
||||
self.assertIsNotNone(warning)
|
||||
self.assertIn("CRITICAL", warning)
|
||||
|
||||
def test_can_fit(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 5000
|
||||
self.assertTrue(tracker.can_fit(1000))
|
||||
self.assertFalse(tracker.can_fit(5000))
|
||||
|
||||
def test_preflight_check(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 5000
|
||||
|
||||
can_fit, msg = tracker.preflight_check("a" * 400) # 100 tokens
|
||||
self.assertTrue(can_fit)
|
||||
self.assertEqual(msg, "")
|
||||
|
||||
|
||||
class TestCheckContextBudget(unittest.TestCase):
|
||||
|
||||
def test_no_warning_under_threshold(self):
|
||||
with patch("agent.context_budget._tracker", None):
|
||||
messages = [{"role": "user", "content": "short"}]
|
||||
warning = check_context_budget(messages)
|
||||
self.assertIsNone(warning)
|
||||
|
||||
def test_warning_over_threshold(self):
|
||||
with patch("agent.context_budget._tracker", None):
|
||||
# Create messages that exceed 70% of default 128k context
|
||||
messages = [{"role": "user", "content": "x" * 350000}] # ~87500 tokens
|
||||
warning = check_context_budget(messages)
|
||||
self.assertIsNotNone(warning)
|
||||
|
||||
|
||||
class TestStatusLine(unittest.TestCase):
|
||||
|
||||
def test_green_status(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
line = tracker.get_status_line()
|
||||
self.assertIn("GREEN", line)
|
||||
|
||||
def test_red_status(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 7600
|
||||
line = tracker.get_status_line()
|
||||
self.assertIn("RED", line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user