Files
hermes-agent/agent/token_budget.py

317 lines
11 KiB
Python

#!/usr/bin/env python3
"""
Token Budget — Poka-yoke guard against silent context overflow.
Progressive warning system with circuit breakers:
- 60%: WARNING — log + suggest summarization
- 80%: CAUTION — auto-compress, drop raw tool outputs
- 90%: CRITICAL — block verbose tool calls, force wrap-up
- 95%: STOP — graceful session termination with summary
Also provides tool output budgeting to truncate before overflow.
Usage:
from agent.token_budget import TokenBudget
budget = TokenBudget(context_length=128_000)
budget.update(8000) # from API response prompt_tokens
status = budget.check() # returns BudgetStatus with level + message
budget.should_block_tools() # True at 90%+
budget.should_terminate() # True at 95%+
# Tool output budgeting
remaining = budget.tool_output_budget()
truncated = budget.truncate_tool_output(output_text, max_chars=remaining)
"""
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
logger = logging.getLogger(__name__)
# ── Thresholds ────────────────────────────────────────────────────────
WARN_PERCENT = 0.60
CAUTION_PERCENT = 0.80
CRITICAL_PERCENT = 0.90
STOP_PERCENT = 0.95
# Reserve 5% of context for system prompt, response, and overhead
RESPONSE_RESERVE_RATIO = 0.05
# Max tool output chars at each level
TOOL_OUTPUT_BUDGETS = {
"NORMAL": 50_000,
"WARNING": 20_000,
"CAUTION": 8_000,
"CRITICAL": 2_000,
"STOP": 500,
}
class BudgetLevel(Enum):
NORMAL = "NORMAL"
WARNING = "WARNING"
CAUTION = "CAUTION"
CRITICAL = "CRITICAL"
STOP = "STOP"
@property
def percent_threshold(self) -> float:
return {
BudgetLevel.NORMAL: 0.0,
BudgetLevel.WARNING: WARN_PERCENT,
BudgetLevel.CAUTION: CAUTION_PERCENT,
BudgetLevel.CRITICAL: CRITICAL_PERCENT,
BudgetLevel.STOP: STOP_PERCENT,
}[self]
@property
def emoji(self) -> str:
return {
BudgetLevel.NORMAL: "",
BudgetLevel.WARNING: "\u26a0\ufe0f",
BudgetLevel.CAUTION: "\U0001f525",
BudgetLevel.CRITICAL: "\U0001f6d1",
BudgetLevel.STOP: "\U0001f6d1",
}[self]
@dataclass
class BudgetStatus:
"""Current token budget status."""
level: BudgetLevel
tokens_used: int
context_length: int
percent_used: float
tokens_remaining: int
message: str = ""
should_compress: bool = False
should_block_tools: bool = False
should_terminate: bool = False
def to_indicator(self) -> str:
"""Compact status indicator for CLI display."""
pct = int(self.percent_used * 100)
if self.level == BudgetLevel.NORMAL:
return f"[{pct}%]"
return f"{self.level.emoji} [{pct}%]"
def to_bar(self, width: int = 10) -> str:
"""Visual progress bar."""
filled = int(width * self.percent_used)
bar = "\u2588" * filled + "\u2591" * (width - filled)
color = self._bar_color()
return f"{color}{bar}\033[0m {int(self.percent_used * 100)}%"
def _bar_color(self) -> str:
if self.level == BudgetLevel.STOP:
return "\033[41m" # red bg
if self.level == BudgetLevel.CRITICAL:
return "\033[31m" # red
if self.level == BudgetLevel.CAUTION:
return "\033[33m" # yellow
if self.level == BudgetLevel.WARNING:
return "\033[33m" # yellow
return "\033[32m" # green
class TokenBudget:
"""
Progressive token budget tracker with poka-yoke circuit breakers.
Tracks cumulative token usage against a context length and triggers
escalating actions at each threshold.
"""
def __init__(
self,
context_length: int,
warn_percent: float = WARN_PERCENT,
caution_percent: float = CAUTION_PERCENT,
critical_percent: float = CRITICAL_PERCENT,
stop_percent: float = STOP_PERCENT,
response_reserve_ratio: float = RESPONSE_RESERVE_RATIO,
):
self.context_length = context_length
self.warn_threshold = int(context_length * warn_percent)
self.caution_threshold = int(context_length * caution_percent)
self.critical_threshold = int(context_length * critical_percent)
self.stop_threshold = int(context_length * stop_percent)
self.response_reserve = int(context_length * response_reserve_ratio)
self.tokens_used = 0
self.completions_tokens = 0
self.total_tool_output_chars = 0
self._level = BudgetLevel.NORMAL
self._history: list[int] = []
def update(self, prompt_tokens: int, completion_tokens: int = 0) -> BudgetStatus:
"""Update budget from API response usage."""
self.tokens_used = prompt_tokens
self.completions_tokens = completion_tokens
self._history.append(prompt_tokens)
return self.check()
def check(self) -> BudgetStatus:
"""Evaluate current budget level and return status."""
pct = self.tokens_used / self.context_length if self.context_length > 0 else 0
remaining = max(0, self.context_length - self.tokens_used - self.response_reserve)
# Determine level
if pct >= STOP_PERCENT:
level = BudgetLevel.STOP
elif pct >= CRITICAL_PERCENT:
level = BudgetLevel.CRITICAL
elif pct >= CAUTION_PERCENT:
level = BudgetLevel.CAUTION
elif pct >= WARN_PERCENT:
level = BudgetLevel.WARNING
else:
level = BudgetLevel.NORMAL
# Log transitions (don\'t log every check)
if level != self._level:
self._log_transition(level, pct)
self._level = level
messages = {
BudgetLevel.NORMAL: "",
BudgetLevel.WARNING: (
f"Context at {int(pct*100)}%. Consider wrapping up soon or using /compress."
),
BudgetLevel.CAUTION: (
f"Context at {int(pct*100)}%. Auto-compressing. "
f"Tool outputs will be truncated."
),
BudgetLevel.CRITICAL: (
f"Context at {int(pct*100)}%. Verbose tools blocked. "
f"Session approaching limit — please wrap up."
),
BudgetLevel.STOP: (
f"Context at {int(pct*100)}%. Session must terminate. "
f"Saving summary before shutdown."
),
}
return BudgetStatus(
level=level,
tokens_used=self.tokens_used,
context_length=self.context_length,
percent_used=pct,
tokens_remaining=remaining,
message=messages[level],
should_compress=level in (BudgetLevel.CAUTION, BudgetLevel.CRITICAL, BudgetLevel.STOP),
should_block_tools=level in (BudgetLevel.CRITICAL, BudgetLevel.STOP),
should_terminate=level == BudgetLevel.STOP,
)
def should_compress(self) -> bool:
"""True at 80%+ — auto-compression should trigger."""
return self.tokens_used >= self.caution_threshold
def should_block_tools(self) -> bool:
"""True at 90%+ — verbose tool calls should be blocked."""
return self.tokens_used >= self.critical_threshold
def should_terminate(self) -> bool:
"""True at 95%+ — session should gracefully terminate."""
return self.tokens_used >= self.stop_threshold
def tool_output_budget(self) -> int:
"""Max chars allowed for next tool output based on current level."""
status = self.check()
return TOOL_OUTPUT_BUDGETS.get(status.level.value, 50_000)
def truncate_tool_output(self, output: str, max_chars: int = None) -> str:
"""Truncate tool output to fit budget. Adds truncation notice."""
if max_chars is None:
max_chars = self.tool_output_budget()
if len(output) <= max_chars:
return output
# Preserve start and end, truncate middle
if max_chars < 200:
return output[:max_chars] + "\n[...truncated...]"
head = max_chars // 2
tail = max_chars - head - 30 # reserve for truncation notice
truncated = (
output[:head]
+ f"\n\n[...{len(output) - head - tail:,} chars truncated...]\n\n"
+ output[-tail:]
)
return truncated
def remaining_for_response(self) -> int:
"""Tokens available for the model\'s response."""
return max(0, self.context_length - self.tokens_used - self.response_reserve)
def growth_rate(self) -> Optional[float]:
"""Average token increase per turn (from history)."""
if len(self._history) < 2:
return None
diffs = [self._history[i] - self._history[i-1] for i in range(1, len(self._history))]
return sum(diffs) / len(diffs)
def turns_remaining(self) -> Optional[int]:
"""Estimated turns until context is full (based on growth rate)."""
rate = self.growth_rate()
if rate is None or rate <= 0:
return None
remaining = self.context_length - self.tokens_used
return int(remaining / rate)
def reset(self):
"""Reset budget for new session."""
self.tokens_used = 0
self.completions_tokens = 0
self.total_tool_output_chars = 0
self._level = BudgetLevel.NORMAL
self._history.clear()
def _log_transition(self, new_level: BudgetLevel, pct: float):
"""Log budget level transitions."""
msg = (
f"Token budget: {self._level.value} -> {new_level.value} "
f"({self.tokens_used}/{self.context_length} = {pct:.0%})"
)
if new_level == BudgetLevel.WARNING:
logger.warning(msg)
elif new_level == BudgetLevel.CAUTION:
logger.warning(msg)
elif new_level in (BudgetLevel.CRITICAL, BudgetLevel.STOP):
logger.error(msg)
else:
logger.info(msg)
def summary(self) -> str:
"""Human-readable budget summary."""
status = self.check()
turns = self.turns_remaining()
rate = self.growth_rate()
lines = [
f"Token Budget: {status.tokens_used:,} / {status.context_length:,} ({status.percent_used:.0%})",
f"Level: {status.level.value}",
f"Remaining: {status.tokens_remaining:,} tokens",
]
if rate is not None:
lines.append(f"Growth rate: ~{rate:,.0f} tokens/turn")
if turns is not None:
lines.append(f"Estimated turns left: ~{turns}")
if status.message:
lines.append(f"Action: {status.message}")
return "\n".join(lines)
# ── Convenience factory ───────────────────────────────────────────────
def create_budget(context_length: int, **kwargs) -> TokenBudget:
"""Create a TokenBudget with defaults."""
return TokenBudget(context_length=context_length, **kwargs)