149 lines
5.3 KiB
Python
149 lines
5.3 KiB
Python
"""
|
|
Context Budget Tracker - Prevent context window overflow
|
|
|
|
Poka-yoke: Visual warnings at 70%%, 85%%, 95%% capacity.
|
|
Auto-checkpoint at 85%%. Pre-flight token estimation.
|
|
|
|
Issue: #838
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
HERMES_HOME = Path.home() / ".hermes"
|
|
CHECKPOINT_DIR = HERMES_HOME / "checkpoints"
|
|
CHARS_PER_TOKEN = 4
|
|
|
|
THRESHOLD_WARNING = 0.70
|
|
THRESHOLD_CRITICAL = 0.85
|
|
THRESHOLD_DANGER = 0.95
|
|
|
|
|
|
class ContextBudget:
|
|
def __init__(self, context_limit: int = 128000, system_tokens: int = 0,
|
|
used_tokens: int = 0, reserved_tokens: int = 2000):
|
|
self.context_limit = context_limit
|
|
self.system_tokens = system_tokens
|
|
self.used_tokens = used_tokens
|
|
self.reserved_tokens = reserved_tokens
|
|
|
|
@property
|
|
def total_used(self) -> int:
|
|
return self.system_tokens + self.used_tokens
|
|
|
|
@property
|
|
def available(self) -> int:
|
|
return max(0, self.context_limit - self.reserved_tokens)
|
|
|
|
@property
|
|
def remaining(self) -> int:
|
|
return max(0, self.available - self.total_used)
|
|
|
|
@property
|
|
def utilization(self) -> float:
|
|
return self.total_used / self.available if self.available > 0 else 1.0
|
|
|
|
|
|
def estimate_tokens(text: str) -> int:
|
|
return len(text) // CHARS_PER_TOKEN if text else 0
|
|
|
|
|
|
def estimate_messages_tokens(messages: List[Dict]) -> int:
|
|
total = 0
|
|
for msg in messages:
|
|
content = msg.get("content", "")
|
|
if isinstance(content, str):
|
|
total += estimate_tokens(content)
|
|
if msg.get("tool_calls"):
|
|
total += 100
|
|
return total
|
|
|
|
|
|
class ContextBudgetTracker:
|
|
def __init__(self, context_limit: int = 128000, session_id: str = ""):
|
|
self.budget = ContextBudget(context_limit=context_limit)
|
|
self.session_id = session_id
|
|
self._checkpointed = False
|
|
self._warnings_given = set()
|
|
|
|
def update_from_messages(self, messages: List[Dict]):
|
|
self.budget.used_tokens = estimate_messages_tokens(messages)
|
|
|
|
def can_fit(self, additional_tokens: int) -> bool:
|
|
return self.budget.remaining >= additional_tokens
|
|
|
|
def preflight_check(self, text: str) -> Tuple[bool, str]:
|
|
tokens = estimate_tokens(text)
|
|
if not self.can_fit(tokens):
|
|
return False, f"Cannot load: ~{tokens:,} tokens needed, {self.budget.remaining:,} remaining"
|
|
would_util = (self.budget.total_used + tokens) / self.budget.available if self.budget.available > 0 else 1.0
|
|
if would_util >= THRESHOLD_DANGER:
|
|
return False, f"Would reach {would_util:.0%%} capacity. Summarize or start new session."
|
|
if would_util >= THRESHOLD_CRITICAL:
|
|
return True, f"Warning: will reach {would_util:.0%%} capacity."
|
|
return True, ""
|
|
|
|
def get_warning(self) -> Optional[str]:
|
|
util = self.budget.utilization
|
|
if util >= THRESHOLD_DANGER and "danger" not in self._warnings_given:
|
|
self._warnings_given.add("danger")
|
|
return f"[CONTEXT CRITICAL: {util:.0%%} used -- {self.budget.remaining:,} tokens left. Summarize or start new session.]"
|
|
if util >= THRESHOLD_CRITICAL and "critical" not in self._warnings_given:
|
|
self._warnings_given.add("critical")
|
|
self._auto_checkpoint()
|
|
return f"[CONTEXT WARNING: {util:.0%%} used -- consider summarizing. Auto-checkpoint saved.]"
|
|
if util >= THRESHOLD_WARNING and "warning" not in self._warnings_given:
|
|
self._warnings_given.add("warning")
|
|
return f"[CONTEXT: {util:.0%%} used -- {self.budget.remaining:,} tokens remaining]"
|
|
return None
|
|
|
|
def _auto_checkpoint(self):
|
|
if self._checkpointed or not self.session_id:
|
|
return
|
|
try:
|
|
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
|
path = CHECKPOINT_DIR / f"{self.session_id}.json"
|
|
path.write_text(json.dumps({
|
|
"session_id": self.session_id,
|
|
"timestamp": time.time(),
|
|
"budget": {"utilization": round(self.budget.utilization * 100, 1)}
|
|
}, indent=2))
|
|
self._checkpointed = True
|
|
logger.info("Auto-checkpoint saved: %s", path)
|
|
except Exception as e:
|
|
logger.error("Auto-checkpoint failed: %s", e)
|
|
|
|
def get_status_line(self) -> str:
|
|
util = self.budget.utilization
|
|
remaining = self.budget.remaining
|
|
if util >= THRESHOLD_DANGER:
|
|
return f"RED {util:.0%%} used ({remaining:,} left)"
|
|
elif util >= THRESHOLD_CRITICAL:
|
|
return f"ORANGE {util:.0%%} used ({remaining:,} left)"
|
|
elif util >= THRESHOLD_WARNING:
|
|
return f"YELLOW {util:.0%%} used ({remaining:,} left)"
|
|
return f"GREEN {util:.0%%} used ({remaining:,} left)"
|
|
|
|
|
|
_tracker = None
|
|
|
|
def get_tracker(context_limit=128000, session_id=""):
|
|
global _tracker
|
|
if _tracker is None:
|
|
_tracker = ContextBudgetTracker(context_limit, session_id)
|
|
return _tracker
|
|
|
|
def check_context_budget(messages, context_limit=128000):
|
|
tracker = get_tracker(context_limit)
|
|
tracker.update_from_messages(messages)
|
|
return tracker.get_warning()
|
|
|
|
def preflight_token_check(text):
|
|
tracker = get_tracker()
|
|
return tracker.preflight_check(text)
|