Compare commits

..

2 Commits

Author SHA1 Message Date
07c5b5b83d test: add token budget poka-yoke tests (#925)
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 44s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 45s
Tests / test (pull_request) Failing after 25m21s
Tests / e2e (pull_request) Successful in 3m18s
2026-04-21 11:41:39 +00:00
8ac26f54a5 feat: token budget with progressive poka-yoke thresholds (#925) 2026-04-21 11:40:39 +00:00
2 changed files with 507 additions and 119 deletions

View File

@@ -1,165 +1,316 @@
"""Token Budget — Poka-yoke guard against context overflow.
#!/usr/bin/env python3
"""
Token Budget — Poka-yoke guard against silent context overflow.
Progressive warning system with circuit breakers:
- 60%: Log warning, suggest summarization
- 80%: Auto-compress, drop raw tool outputs
- 90%: Block verbose tools, force wrap-up
- 95%: Graceful termination with summary
- 60%: WARNING — log + suggest summarization
- 80%: CAUTION — auto-compress, drop raw tool outputs
- 90%: CRITICAL — block verbose tool calls, force wrap-up
- 95%: STOP — graceful session termination with summary
Also provides tool output budgeting to truncate before overflow.
Usage:
from agent.token_budget import TokenBudget
budget = TokenBudget(max_tokens=128000)
budget.record_usage(prompt_tokens=500, completion_tokens=200)
status = budget.check()
# status.level: ok, warning, compress, block, terminate
budget = TokenBudget(context_length=128_000)
budget.update(8000) # from API response prompt_tokens
status = budget.check() # returns BudgetStatus with level + message
budget.should_block_tools() # True at 90%+
budget.should_terminate() # True at 95%+
# Tool output budgeting
remaining = budget.tool_output_budget()
truncated = budget.truncate_tool_output(output_text, max_chars=remaining)
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Optional
logger = logging.getLogger(__name__)
# ── Thresholds ────────────────────────────────────────────────────────
WARN_PERCENT = 0.60
CAUTION_PERCENT = 0.80
CRITICAL_PERCENT = 0.90
STOP_PERCENT = 0.95
# Reserve 5% of context for system prompt, response, and overhead
RESPONSE_RESERVE_RATIO = 0.05
# Max tool output chars at each level
TOOL_OUTPUT_BUDGETS = {
"NORMAL": 50_000,
"WARNING": 20_000,
"CAUTION": 8_000,
"CRITICAL": 2_000,
"STOP": 500,
}
class BudgetLevel(Enum):
"""Token budget alert levels."""
OK = "ok" # < 60%
WARNING = "warning" # 60-80%
COMPRESS = "compress" # 80-90%
BLOCK = "block" # 90-95%
TERMINATE = "terminate" # > 95%
NORMAL = "NORMAL"
WARNING = "WARNING"
CAUTION = "CAUTION"
CRITICAL = "CRITICAL"
STOP = "STOP"
@property
def percent_threshold(self) -> float:
return {
BudgetLevel.NORMAL: 0.0,
BudgetLevel.WARNING: WARN_PERCENT,
BudgetLevel.CAUTION: CAUTION_PERCENT,
BudgetLevel.CRITICAL: CRITICAL_PERCENT,
BudgetLevel.STOP: STOP_PERCENT,
}[self]
@property
def emoji(self) -> str:
return {
BudgetLevel.NORMAL: "",
BudgetLevel.WARNING: "\u26a0\ufe0f",
BudgetLevel.CAUTION: "\U0001f525",
BudgetLevel.CRITICAL: "\U0001f6d1",
BudgetLevel.STOP: "\U0001f6d1",
}[self]
@dataclass
class BudgetStatus:
"""Current budget status."""
"""Current token budget status."""
level: BudgetLevel
used_tokens: int
max_tokens: int
percentage: float
remaining: int
message: str
actions: List[str] = field(default_factory=list)
tokens_used: int
context_length: int
percent_used: float
tokens_remaining: int
message: str = ""
should_compress: bool = False
should_block_tools: bool = False
should_terminate: bool = False
def to_indicator(self) -> str:
"""Compact status indicator for CLI display."""
pct = int(self.percent_used * 100)
if self.level == BudgetLevel.NORMAL:
return f"[{pct}%]"
return f"{self.level.emoji} [{pct}%]"
# Default thresholds
THRESHOLDS = {
BudgetLevel.WARNING: 0.60,
BudgetLevel.COMPRESS: 0.80,
BudgetLevel.BLOCK: 0.90,
BudgetLevel.TERMINATE: 0.95,
}
def to_bar(self, width: int = 10) -> str:
"""Visual progress bar."""
filled = int(width * self.percent_used)
bar = "\u2588" * filled + "\u2591" * (width - filled)
color = self._bar_color()
return f"{color}{bar}\033[0m {int(self.percent_used * 100)}%"
def _bar_color(self) -> str:
if self.level == BudgetLevel.STOP:
return "\033[41m" # red bg
if self.level == BudgetLevel.CRITICAL:
return "\033[31m" # red
if self.level == BudgetLevel.CAUTION:
return "\033[33m" # yellow
if self.level == BudgetLevel.WARNING:
return "\033[33m" # yellow
return "\033[32m" # green
class TokenBudget:
"""Track token usage and enforce context limits."""
"""
Progressive token budget tracker with poka-yoke circuit breakers.
def __init__(self, max_tokens: int = 128000,
thresholds: Optional[Dict[BudgetLevel, float]] = None):
self._max_tokens = max_tokens
self._thresholds = thresholds or THRESHOLDS
self._prompt_tokens = 0
self._completion_tokens = 0
self._tool_output_tokens = 0
self._history: List[Dict[str, Any]] = []
Tracks cumulative token usage against a context length and triggers
escalating actions at each threshold.
"""
@property
def used_tokens(self) -> int:
return self._prompt_tokens + self._completion_tokens
def __init__(
self,
context_length: int,
warn_percent: float = WARN_PERCENT,
caution_percent: float = CAUTION_PERCENT,
critical_percent: float = CRITICAL_PERCENT,
stop_percent: float = STOP_PERCENT,
response_reserve_ratio: float = RESPONSE_RESERVE_RATIO,
):
self.context_length = context_length
self.warn_threshold = int(context_length * warn_percent)
self.caution_threshold = int(context_length * caution_percent)
self.critical_threshold = int(context_length * critical_percent)
self.stop_threshold = int(context_length * stop_percent)
self.response_reserve = int(context_length * response_reserve_ratio)
@property
def remaining(self) -> int:
return max(0, self._max_tokens - self.used_tokens)
self.tokens_used = 0
self.completions_tokens = 0
self.total_tool_output_chars = 0
self._level = BudgetLevel.NORMAL
self._history: list[int] = []
@property
def percentage(self) -> float:
if self._max_tokens == 0:
return 0
return self.used_tokens / self._max_tokens
def record_usage(self, prompt_tokens: int = 0, completion_tokens: int = 0,
tool_output_tokens: int = 0):
"""Record token usage from an API call."""
self._prompt_tokens += prompt_tokens
self._completion_tokens += completion_tokens
self._tool_output_tokens += tool_output_tokens
self._history.append({
"time": time.time(),
"prompt": prompt_tokens,
"completion": completion_tokens,
"tool_output": tool_output_tokens,
"total_used": self.used_tokens,
})
def update(self, prompt_tokens: int, completion_tokens: int = 0) -> BudgetStatus:
"""Update budget from API response usage."""
self.tokens_used = prompt_tokens
self.completions_tokens = completion_tokens
self._history.append(prompt_tokens)
return self.check()
def check(self) -> BudgetStatus:
"""Check current budget status and return appropriate actions."""
pct = self.percentage
"""Evaluate current budget level and return status."""
pct = self.tokens_used / self.context_length if self.context_length > 0 else 0
remaining = max(0, self.context_length - self.tokens_used - self.response_reserve)
if pct >= self._thresholds.get(BudgetLevel.TERMINATE, 0.95):
level = BudgetLevel.TERMINATE
msg = f"Context {pct:.0%} full. Session must terminate with summary."
actions = ["generate_summary", "terminate_session"]
elif pct >= self._thresholds.get(BudgetLevel.BLOCK, 0.90):
level = BudgetLevel.BLOCK
msg = f"Context {pct:.0%} full. Blocking verbose tool calls."
actions = ["block_verbose_tools", "force_wrap_up", "suggest_summary"]
elif pct >= self._thresholds.get(BudgetLevel.COMPRESS, 0.80):
level = BudgetLevel.COMPRESS
msg = f"Context {pct:.0%} full. Auto-compressing conversation."
actions = ["auto_compress", "drop_raw_tool_outputs", "suggest_summary"]
elif pct >= self._thresholds.get(BudgetLevel.WARNING, 0.60):
# Determine level
if pct >= STOP_PERCENT:
level = BudgetLevel.STOP
elif pct >= CRITICAL_PERCENT:
level = BudgetLevel.CRITICAL
elif pct >= CAUTION_PERCENT:
level = BudgetLevel.CAUTION
elif pct >= WARN_PERCENT:
level = BudgetLevel.WARNING
msg = f"Context {pct:.0%} used. Consider summarizing."
actions = ["suggest_summary", "log_warning"]
else:
level = BudgetLevel.OK
msg = f"Context OK: {self.used_tokens}/{self._max_tokens} tokens ({pct:.0%})"
actions = []
level = BudgetLevel.NORMAL
# Log transitions (don\'t log every check)
if level != self._level:
self._log_transition(level, pct)
self._level = level
messages = {
BudgetLevel.NORMAL: "",
BudgetLevel.WARNING: (
f"Context at {int(pct*100)}%. Consider wrapping up soon or using /compress."
),
BudgetLevel.CAUTION: (
f"Context at {int(pct*100)}%. Auto-compressing. "
f"Tool outputs will be truncated."
),
BudgetLevel.CRITICAL: (
f"Context at {int(pct*100)}%. Verbose tools blocked. "
f"Session approaching limit — please wrap up."
),
BudgetLevel.STOP: (
f"Context at {int(pct*100)}%. Session must terminate. "
f"Saving summary before shutdown."
),
}
return BudgetStatus(
level=level,
used_tokens=self.used_tokens,
max_tokens=self._max_tokens,
percentage=round(pct, 3),
remaining=self.remaining,
message=msg,
actions=actions,
tokens_used=self.tokens_used,
context_length=self.context_length,
percent_used=pct,
tokens_remaining=remaining,
message=messages[level],
should_compress=level in (BudgetLevel.CAUTION, BudgetLevel.CRITICAL, BudgetLevel.STOP),
should_block_tools=level in (BudgetLevel.CRITICAL, BudgetLevel.STOP),
should_terminate=level == BudgetLevel.STOP,
)
def should_truncate_tool_output(self, estimated_tokens: int) -> bool:
"""Check if a tool output should be truncated."""
if self.used_tokens + estimated_tokens > self._max_tokens * 0.95:
return True
return False
def should_compress(self) -> bool:
"""True at 80%+ — auto-compression should trigger."""
return self.tokens_used >= self.caution_threshold
def get_truncation_budget(self) -> int:
"""Get max tokens available for next tool output."""
budget = self.remaining - int(self._max_tokens * 0.05) # Reserve 5%
return max(0, budget)
def should_block_tools(self) -> bool:
"""True at 90%+ — verbose tool calls should be blocked."""
return self.tokens_used >= self.critical_threshold
def should_terminate(self) -> bool:
"""True at 95%+ — session should gracefully terminate."""
return self.tokens_used >= self.stop_threshold
def tool_output_budget(self) -> int:
"""Max chars allowed for next tool output based on current level."""
status = self.check()
return TOOL_OUTPUT_BUDGETS.get(status.level.value, 50_000)
def truncate_tool_output(self, output: str, max_chars: int = None) -> str:
"""Truncate tool output to fit budget. Adds truncation notice."""
if max_chars is None:
max_chars = self.tool_output_budget()
if len(output) <= max_chars:
return output
# Preserve start and end, truncate middle
if max_chars < 200:
return output[:max_chars] + "\n[...truncated...]"
head = max_chars // 2
tail = max_chars - head - 30 # reserve for truncation notice
truncated = (
output[:head]
+ f"\n\n[...{len(output) - head - tail:,} chars truncated...]\n\n"
+ output[-tail:]
)
return truncated
def remaining_for_response(self) -> int:
"""Tokens available for the model\'s response."""
return max(0, self.context_length - self.tokens_used - self.response_reserve)
def growth_rate(self) -> Optional[float]:
"""Average token increase per turn (from history)."""
if len(self._history) < 2:
return None
diffs = [self._history[i] - self._history[i-1] for i in range(1, len(self._history))]
return sum(diffs) / len(diffs)
def turns_remaining(self) -> Optional[int]:
"""Estimated turns until context is full (based on growth rate)."""
rate = self.growth_rate()
if rate is None or rate <= 0:
return None
remaining = self.context_length - self.tokens_used
return int(remaining / rate)
def reset(self):
"""Reset budget for new session."""
self._prompt_tokens = 0
self._completion_tokens = 0
self._tool_output_tokens = 0
self.tokens_used = 0
self.completions_tokens = 0
self.total_tool_output_chars = 0
self._level = BudgetLevel.NORMAL
self._history.clear()
def get_report(self) -> Dict[str, Any]:
"""Generate usage report."""
def _log_transition(self, new_level: BudgetLevel, pct: float):
"""Log budget level transitions."""
msg = (
f"Token budget: {self._level.value} -> {new_level.value} "
f"({self.tokens_used}/{self.context_length} = {pct:.0%})"
)
if new_level == BudgetLevel.WARNING:
logger.warning(msg)
elif new_level == BudgetLevel.CAUTION:
logger.warning(msg)
elif new_level in (BudgetLevel.CRITICAL, BudgetLevel.STOP):
logger.error(msg)
else:
logger.info(msg)
def summary(self) -> str:
"""Human-readable budget summary."""
status = self.check()
return {
"status": status.level.value,
"used_tokens": self.used_tokens,
"max_tokens": self._max_tokens,
"remaining": self.remaining,
"percentage": status.percentage,
"prompt_tokens": self._prompt_tokens,
"completion_tokens": self._completion_tokens,
"tool_output_tokens": self._tool_output_tokens,
"message": status.message,
"actions": status.actions,
}
turns = self.turns_remaining()
rate = self.growth_rate()
lines = [
f"Token Budget: {status.tokens_used:,} / {status.context_length:,} ({status.percent_used:.0%})",
f"Level: {status.level.value}",
f"Remaining: {status.tokens_remaining:,} tokens",
]
if rate is not None:
lines.append(f"Growth rate: ~{rate:,.0f} tokens/turn")
if turns is not None:
lines.append(f"Estimated turns left: ~{turns}")
if status.message:
lines.append(f"Action: {status.message}")
return "\n".join(lines)
# ── Convenience factory ───────────────────────────────────────────────
def create_budget(context_length: int, **kwargs) -> TokenBudget:
"""Create a TokenBudget with defaults."""
return TokenBudget(context_length=context_length, **kwargs)

237
tests/test_token_budget.py Normal file
View File

@@ -0,0 +1,237 @@
#!/usr/bin/env python3
"""
Tests for agent/token_budget.py — Poka-yoke context overflow guard.
"""
import sys
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from agent.token_budget import (
TokenBudget,
BudgetLevel,
BudgetStatus,
WARN_PERCENT,
CAUTION_PERCENT,
CRITICAL_PERCENT,
STOP_PERCENT,
)
@pytest.fixture
def budget():
"""Standard 128K context budget."""
return TokenBudget(context_length=128_000)
@pytest.fixture
def small_budget():
"""4K context for tight testing."""
return TokenBudget(context_length=4_000)
# ── Threshold Levels ──────────────────────────────────────────────────
class TestThresholds:
def test_normal_below_60(self, budget):
budget.update(50_000) # 39%
status = budget.check()
assert status.level == BudgetLevel.NORMAL
assert not status.should_compress
assert not status.should_block_tools
assert not status.should_terminate
def test_warning_at_60(self, budget):
budget.update(int(128_000 * 0.62)) # 62%
status = budget.check()
assert status.level == BudgetLevel.WARNING
assert not status.should_compress
assert not status.should_block_tools
def test_caution_at_80(self, budget):
budget.update(int(128_000 * 0.82)) # 82%
status = budget.check()
assert status.level == BudgetLevel.CAUTION
assert status.should_compress
assert not status.should_block_tools
assert not status.should_terminate
def test_critical_at_90(self, budget):
budget.update(int(128_000 * 0.91)) # 91%
status = budget.check()
assert status.level == BudgetLevel.CRITICAL
assert status.should_compress
assert status.should_block_tools
assert not status.should_terminate
def test_stop_at_95(self, budget):
budget.update(int(128_000 * 0.96)) # 96%
status = budget.check()
assert status.level == BudgetLevel.STOP
assert status.should_compress
assert status.should_block_tools
assert status.should_terminate
def test_small_context_thresholds(self, small_budget):
# 4K * 0.60 = 2400
small_budget.update(2450)
assert small_budget.check().level == BudgetLevel.WARNING
small_budget.update(3250) # 4K * 0.81
assert small_budget.check().level == BudgetLevel.CAUTION
small_budget.update(3650) # 4K * 0.91
assert small_budget.check().level == BudgetLevel.CRITICAL
small_budget.update(3850) # 4K * 0.96
assert small_budget.check().level == BudgetLevel.STOP
# ── Convenience Methods ───────────────────────────────────────────────
class TestConvenienceMethods:
def test_should_compress(self, budget):
budget.update(int(128_000 * 0.79))
assert not budget.should_compress()
budget.update(int(128_000 * 0.80))
assert budget.should_compress()
def test_should_block_tools(self, budget):
budget.update(int(128_000 * 0.89))
assert not budget.should_block_tools()
budget.update(int(128_000 * 0.90))
assert budget.should_block_tools()
def test_should_terminate(self, budget):
budget.update(int(128_000 * 0.94))
assert not budget.should_terminate()
budget.update(int(128_000 * 0.95))
assert budget.should_terminate()
# ── Tool Output Budgeting ─────────────────────────────────────────────
class TestToolOutputBudget:
def test_normal_budget(self, budget):
budget.update(int(128_000 * 0.50))
assert budget.tool_output_budget() == 50_000
def test_warning_budget(self, budget):
budget.update(int(128_000 * 0.65))
assert budget.tool_output_budget() == 20_000
def test_caution_budget(self, budget):
budget.update(int(128_000 * 0.85))
assert budget.tool_output_budget() == 8_000
def test_critical_budget(self, budget):
budget.update(int(128_000 * 0.92))
assert budget.tool_output_budget() == 2_000
def test_truncate_short_unchanged(self, budget):
result = budget.truncate_tool_output("short text", max_chars=1000)
assert result == "short text"
def test_truncate_long(self, budget):
long_text = "A" * 100_000
result = budget.truncate_tool_output(long_text, max_chars=5_000)
assert len(result) <= 5_100 # small overhead for notice
assert "truncated" in result
assert "A" in result[:2500] # head preserved
assert "A" in result[-2500:] # tail preserved
def test_truncate_very_small(self, budget):
long_text = "X" * 1000
result = budget.truncate_tool_output(long_text, max_chars=50)
assert len(result) <= 50 + 20
assert "truncated" in result
# ── Growth Tracking ───────────────────────────────────────────────────
class TestGrowthTracking:
def test_growth_rate(self, budget):
budget.update(10_000)
budget.update(15_000)
budget.update(20_000)
assert budget.growth_rate() == 5_000.0
def test_turns_remaining(self, budget):
budget.update(10_000)
budget.update(15_000)
budget.update(20_000)
# rate=5000, remaining=108000, turns=~21
turns = budget.turns_remaining()
assert turns is not None
assert 18 <= turns <= 24
def test_no_history(self, budget):
assert budget.growth_rate() is None
assert budget.turns_remaining() is None
# ── Status Indicators ─────────────────────────────────────────────────
class TestStatusIndicators:
def test_indicator_normal(self, budget):
budget.update(int(128_000 * 0.50))
status = budget.check()
indicator = status.to_indicator()
assert "50" in indicator
def test_indicator_warning(self, budget):
budget.update(int(128_000 * 0.65))
status = budget.check()
indicator = status.to_indicator()
assert "\u26a0" in indicator or "65" in indicator
def test_bar(self, budget):
budget.update(int(128_000 * 0.50))
status = budget.check()
bar = status.to_bar()
assert "50" in bar
def test_summary(self, budget):
budget.update(50_000)
summary = budget.summary()
assert "50,000" in summary
assert "128,000" in summary
assert "NORMAL" in summary
# ── Reset ─────────────────────────────────────────────────────────────
class TestReset:
def test_reset_clears_state(self, budget):
budget.update(int(128_000 * 0.90))
budget.reset()
assert budget.tokens_used == 0
assert budget.check().level == BudgetLevel.NORMAL
assert budget.growth_rate() is None
# ── Edge Cases ────────────────────────────────────────────────────────
class TestEdgeCases:
def test_exact_threshold_boundary(self, budget):
# Exactly at 60%
budget.update(int(128_000 * 0.60))
assert budget.check().level == BudgetLevel.WARNING
def test_zero_context(self):
budget = TokenBudget(context_length=0)
status = budget.check()
assert status.percent_used == 0
def test_remaining_for_response(self, budget):
budget.update(100_000)
remaining = budget.remaining_for_response()
# 128000 - 100000 - 6400 (5% reserve) = 21600
assert remaining > 0
assert remaining < 128_000
if __name__ == "__main__":
pytest.main([__file__, "-v"])