317 lines
11 KiB
Python
317 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Token Budget — Poka-yoke guard against silent context overflow.
|
|
|
|
Progressive warning system with circuit breakers:
|
|
- 60%: WARNING — log + suggest summarization
|
|
- 80%: CAUTION — auto-compress, drop raw tool outputs
|
|
- 90%: CRITICAL — block verbose tool calls, force wrap-up
|
|
- 95%: STOP — graceful session termination with summary
|
|
|
|
Also provides tool output budgeting to truncate before overflow.
|
|
|
|
Usage:
|
|
from agent.token_budget import TokenBudget
|
|
|
|
budget = TokenBudget(context_length=128_000)
|
|
budget.update(8000) # from API response prompt_tokens
|
|
|
|
status = budget.check() # returns BudgetStatus with level + message
|
|
budget.should_block_tools() # True at 90%+
|
|
budget.should_terminate() # True at 95%+
|
|
|
|
# Tool output budgeting
|
|
remaining = budget.tool_output_budget()
|
|
truncated = budget.truncate_tool_output(output_text, max_chars=remaining)
|
|
"""
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ── Thresholds ────────────────────────────────────────────────────────
|
|
|
|
WARN_PERCENT = 0.60
|
|
CAUTION_PERCENT = 0.80
|
|
CRITICAL_PERCENT = 0.90
|
|
STOP_PERCENT = 0.95
|
|
|
|
# Reserve 5% of context for system prompt, response, and overhead
|
|
RESPONSE_RESERVE_RATIO = 0.05
|
|
|
|
# Max tool output chars at each level
|
|
TOOL_OUTPUT_BUDGETS = {
|
|
"NORMAL": 50_000,
|
|
"WARNING": 20_000,
|
|
"CAUTION": 8_000,
|
|
"CRITICAL": 2_000,
|
|
"STOP": 500,
|
|
}
|
|
|
|
|
|
class BudgetLevel(Enum):
|
|
NORMAL = "NORMAL"
|
|
WARNING = "WARNING"
|
|
CAUTION = "CAUTION"
|
|
CRITICAL = "CRITICAL"
|
|
STOP = "STOP"
|
|
|
|
@property
|
|
def percent_threshold(self) -> float:
|
|
return {
|
|
BudgetLevel.NORMAL: 0.0,
|
|
BudgetLevel.WARNING: WARN_PERCENT,
|
|
BudgetLevel.CAUTION: CAUTION_PERCENT,
|
|
BudgetLevel.CRITICAL: CRITICAL_PERCENT,
|
|
BudgetLevel.STOP: STOP_PERCENT,
|
|
}[self]
|
|
|
|
@property
|
|
def emoji(self) -> str:
|
|
return {
|
|
BudgetLevel.NORMAL: "",
|
|
BudgetLevel.WARNING: "\u26a0\ufe0f",
|
|
BudgetLevel.CAUTION: "\U0001f525",
|
|
BudgetLevel.CRITICAL: "\U0001f6d1",
|
|
BudgetLevel.STOP: "\U0001f6d1",
|
|
}[self]
|
|
|
|
|
|
@dataclass
|
|
class BudgetStatus:
|
|
"""Current token budget status."""
|
|
level: BudgetLevel
|
|
tokens_used: int
|
|
context_length: int
|
|
percent_used: float
|
|
tokens_remaining: int
|
|
message: str = ""
|
|
should_compress: bool = False
|
|
should_block_tools: bool = False
|
|
should_terminate: bool = False
|
|
|
|
def to_indicator(self) -> str:
|
|
"""Compact status indicator for CLI display."""
|
|
pct = int(self.percent_used * 100)
|
|
if self.level == BudgetLevel.NORMAL:
|
|
return f"[{pct}%]"
|
|
return f"{self.level.emoji} [{pct}%]"
|
|
|
|
def to_bar(self, width: int = 10) -> str:
|
|
"""Visual progress bar."""
|
|
filled = int(width * self.percent_used)
|
|
bar = "\u2588" * filled + "\u2591" * (width - filled)
|
|
color = self._bar_color()
|
|
return f"{color}{bar}\033[0m {int(self.percent_used * 100)}%"
|
|
|
|
def _bar_color(self) -> str:
|
|
if self.level == BudgetLevel.STOP:
|
|
return "\033[41m" # red bg
|
|
if self.level == BudgetLevel.CRITICAL:
|
|
return "\033[31m" # red
|
|
if self.level == BudgetLevel.CAUTION:
|
|
return "\033[33m" # yellow
|
|
if self.level == BudgetLevel.WARNING:
|
|
return "\033[33m" # yellow
|
|
return "\033[32m" # green
|
|
|
|
|
|
class TokenBudget:
|
|
"""
|
|
Progressive token budget tracker with poka-yoke circuit breakers.
|
|
|
|
Tracks cumulative token usage against a context length and triggers
|
|
escalating actions at each threshold.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
context_length: int,
|
|
warn_percent: float = WARN_PERCENT,
|
|
caution_percent: float = CAUTION_PERCENT,
|
|
critical_percent: float = CRITICAL_PERCENT,
|
|
stop_percent: float = STOP_PERCENT,
|
|
response_reserve_ratio: float = RESPONSE_RESERVE_RATIO,
|
|
):
|
|
self.context_length = context_length
|
|
self.warn_threshold = int(context_length * warn_percent)
|
|
self.caution_threshold = int(context_length * caution_percent)
|
|
self.critical_threshold = int(context_length * critical_percent)
|
|
self.stop_threshold = int(context_length * stop_percent)
|
|
self.response_reserve = int(context_length * response_reserve_ratio)
|
|
|
|
self.tokens_used = 0
|
|
self.completions_tokens = 0
|
|
self.total_tool_output_chars = 0
|
|
self._level = BudgetLevel.NORMAL
|
|
self._history: list[int] = []
|
|
|
|
def update(self, prompt_tokens: int, completion_tokens: int = 0) -> BudgetStatus:
|
|
"""Update budget from API response usage."""
|
|
self.tokens_used = prompt_tokens
|
|
self.completions_tokens = completion_tokens
|
|
self._history.append(prompt_tokens)
|
|
return self.check()
|
|
|
|
def check(self) -> BudgetStatus:
|
|
"""Evaluate current budget level and return status."""
|
|
pct = self.tokens_used / self.context_length if self.context_length > 0 else 0
|
|
remaining = max(0, self.context_length - self.tokens_used - self.response_reserve)
|
|
|
|
# Determine level
|
|
if pct >= STOP_PERCENT:
|
|
level = BudgetLevel.STOP
|
|
elif pct >= CRITICAL_PERCENT:
|
|
level = BudgetLevel.CRITICAL
|
|
elif pct >= CAUTION_PERCENT:
|
|
level = BudgetLevel.CAUTION
|
|
elif pct >= WARN_PERCENT:
|
|
level = BudgetLevel.WARNING
|
|
else:
|
|
level = BudgetLevel.NORMAL
|
|
|
|
# Log transitions (don\'t log every check)
|
|
if level != self._level:
|
|
self._log_transition(level, pct)
|
|
self._level = level
|
|
|
|
messages = {
|
|
BudgetLevel.NORMAL: "",
|
|
BudgetLevel.WARNING: (
|
|
f"Context at {int(pct*100)}%. Consider wrapping up soon or using /compress."
|
|
),
|
|
BudgetLevel.CAUTION: (
|
|
f"Context at {int(pct*100)}%. Auto-compressing. "
|
|
f"Tool outputs will be truncated."
|
|
),
|
|
BudgetLevel.CRITICAL: (
|
|
f"Context at {int(pct*100)}%. Verbose tools blocked. "
|
|
f"Session approaching limit — please wrap up."
|
|
),
|
|
BudgetLevel.STOP: (
|
|
f"Context at {int(pct*100)}%. Session must terminate. "
|
|
f"Saving summary before shutdown."
|
|
),
|
|
}
|
|
|
|
return BudgetStatus(
|
|
level=level,
|
|
tokens_used=self.tokens_used,
|
|
context_length=self.context_length,
|
|
percent_used=pct,
|
|
tokens_remaining=remaining,
|
|
message=messages[level],
|
|
should_compress=level in (BudgetLevel.CAUTION, BudgetLevel.CRITICAL, BudgetLevel.STOP),
|
|
should_block_tools=level in (BudgetLevel.CRITICAL, BudgetLevel.STOP),
|
|
should_terminate=level == BudgetLevel.STOP,
|
|
)
|
|
|
|
def should_compress(self) -> bool:
|
|
"""True at 80%+ — auto-compression should trigger."""
|
|
return self.tokens_used >= self.caution_threshold
|
|
|
|
def should_block_tools(self) -> bool:
|
|
"""True at 90%+ — verbose tool calls should be blocked."""
|
|
return self.tokens_used >= self.critical_threshold
|
|
|
|
def should_terminate(self) -> bool:
|
|
"""True at 95%+ — session should gracefully terminate."""
|
|
return self.tokens_used >= self.stop_threshold
|
|
|
|
def tool_output_budget(self) -> int:
|
|
"""Max chars allowed for next tool output based on current level."""
|
|
status = self.check()
|
|
return TOOL_OUTPUT_BUDGETS.get(status.level.value, 50_000)
|
|
|
|
def truncate_tool_output(self, output: str, max_chars: int = None) -> str:
|
|
"""Truncate tool output to fit budget. Adds truncation notice."""
|
|
if max_chars is None:
|
|
max_chars = self.tool_output_budget()
|
|
|
|
if len(output) <= max_chars:
|
|
return output
|
|
|
|
# Preserve start and end, truncate middle
|
|
if max_chars < 200:
|
|
return output[:max_chars] + "\n[...truncated...]"
|
|
|
|
head = max_chars // 2
|
|
tail = max_chars - head - 30 # reserve for truncation notice
|
|
truncated = (
|
|
output[:head]
|
|
+ f"\n\n[...{len(output) - head - tail:,} chars truncated...]\n\n"
|
|
+ output[-tail:]
|
|
)
|
|
return truncated
|
|
|
|
def remaining_for_response(self) -> int:
|
|
"""Tokens available for the model\'s response."""
|
|
return max(0, self.context_length - self.tokens_used - self.response_reserve)
|
|
|
|
def growth_rate(self) -> Optional[float]:
|
|
"""Average token increase per turn (from history)."""
|
|
if len(self._history) < 2:
|
|
return None
|
|
diffs = [self._history[i] - self._history[i-1] for i in range(1, len(self._history))]
|
|
return sum(diffs) / len(diffs)
|
|
|
|
def turns_remaining(self) -> Optional[int]:
|
|
"""Estimated turns until context is full (based on growth rate)."""
|
|
rate = self.growth_rate()
|
|
if rate is None or rate <= 0:
|
|
return None
|
|
remaining = self.context_length - self.tokens_used
|
|
return int(remaining / rate)
|
|
|
|
def reset(self):
|
|
"""Reset budget for new session."""
|
|
self.tokens_used = 0
|
|
self.completions_tokens = 0
|
|
self.total_tool_output_chars = 0
|
|
self._level = BudgetLevel.NORMAL
|
|
self._history.clear()
|
|
|
|
def _log_transition(self, new_level: BudgetLevel, pct: float):
|
|
"""Log budget level transitions."""
|
|
msg = (
|
|
f"Token budget: {self._level.value} -> {new_level.value} "
|
|
f"({self.tokens_used}/{self.context_length} = {pct:.0%})"
|
|
)
|
|
if new_level == BudgetLevel.WARNING:
|
|
logger.warning(msg)
|
|
elif new_level == BudgetLevel.CAUTION:
|
|
logger.warning(msg)
|
|
elif new_level in (BudgetLevel.CRITICAL, BudgetLevel.STOP):
|
|
logger.error(msg)
|
|
else:
|
|
logger.info(msg)
|
|
|
|
def summary(self) -> str:
|
|
"""Human-readable budget summary."""
|
|
status = self.check()
|
|
turns = self.turns_remaining()
|
|
rate = self.growth_rate()
|
|
lines = [
|
|
f"Token Budget: {status.tokens_used:,} / {status.context_length:,} ({status.percent_used:.0%})",
|
|
f"Level: {status.level.value}",
|
|
f"Remaining: {status.tokens_remaining:,} tokens",
|
|
]
|
|
if rate is not None:
|
|
lines.append(f"Growth rate: ~{rate:,.0f} tokens/turn")
|
|
if turns is not None:
|
|
lines.append(f"Estimated turns left: ~{turns}")
|
|
if status.message:
|
|
lines.append(f"Action: {status.message}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
# ── Convenience factory ───────────────────────────────────────────────
|
|
|
|
def create_budget(context_length: int, **kwargs) -> TokenBudget:
|
|
"""Create a TokenBudget with defaults."""
|
|
return TokenBudget(context_length=context_length, **kwargs)
|