Compare commits

..

1 Commits

Author SHA1 Message Date
80f536e319 fix: token budget tracker for context overflow (#925)
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 42s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 39s
Tests / e2e (pull_request) Successful in 1m50s
Tests / test (pull_request) Failing after 50m49s
2026-04-21 04:43:07 +00:00
3 changed files with 165 additions and 379 deletions

165
agent/token_budget.py Normal file
View File

@@ -0,0 +1,165 @@
"""Token Budget — Poka-yoke guard against context overflow.
Progressive warning system with circuit breakers:
- 60%: Log warning, suggest summarization
- 80%: Auto-compress, drop raw tool outputs
- 90%: Block verbose tools, force wrap-up
- 95%: Graceful termination with summary
Usage:
from agent.token_budget import TokenBudget
budget = TokenBudget(max_tokens=128000)
budget.record_usage(prompt_tokens=500, completion_tokens=200)
status = budget.check()
# status.level: ok, warning, compress, block, terminate
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class BudgetLevel(Enum):
"""Token budget alert levels."""
OK = "ok" # < 60%
WARNING = "warning" # 60-80%
COMPRESS = "compress" # 80-90%
BLOCK = "block" # 90-95%
TERMINATE = "terminate" # > 95%
@dataclass
class BudgetStatus:
"""Current budget status."""
level: BudgetLevel
used_tokens: int
max_tokens: int
percentage: float
remaining: int
message: str
actions: List[str] = field(default_factory=list)
# Default thresholds
THRESHOLDS = {
BudgetLevel.WARNING: 0.60,
BudgetLevel.COMPRESS: 0.80,
BudgetLevel.BLOCK: 0.90,
BudgetLevel.TERMINATE: 0.95,
}
class TokenBudget:
"""Track token usage and enforce context limits."""
def __init__(self, max_tokens: int = 128000,
thresholds: Optional[Dict[BudgetLevel, float]] = None):
self._max_tokens = max_tokens
self._thresholds = thresholds or THRESHOLDS
self._prompt_tokens = 0
self._completion_tokens = 0
self._tool_output_tokens = 0
self._history: List[Dict[str, Any]] = []
@property
def used_tokens(self) -> int:
return self._prompt_tokens + self._completion_tokens
@property
def remaining(self) -> int:
return max(0, self._max_tokens - self.used_tokens)
@property
def percentage(self) -> float:
if self._max_tokens == 0:
return 0
return self.used_tokens / self._max_tokens
def record_usage(self, prompt_tokens: int = 0, completion_tokens: int = 0,
tool_output_tokens: int = 0):
"""Record token usage from an API call."""
self._prompt_tokens += prompt_tokens
self._completion_tokens += completion_tokens
self._tool_output_tokens += tool_output_tokens
self._history.append({
"time": time.time(),
"prompt": prompt_tokens,
"completion": completion_tokens,
"tool_output": tool_output_tokens,
"total_used": self.used_tokens,
})
def check(self) -> BudgetStatus:
"""Check current budget status and return appropriate actions."""
pct = self.percentage
if pct >= self._thresholds.get(BudgetLevel.TERMINATE, 0.95):
level = BudgetLevel.TERMINATE
msg = f"Context {pct:.0%} full. Session must terminate with summary."
actions = ["generate_summary", "terminate_session"]
elif pct >= self._thresholds.get(BudgetLevel.BLOCK, 0.90):
level = BudgetLevel.BLOCK
msg = f"Context {pct:.0%} full. Blocking verbose tool calls."
actions = ["block_verbose_tools", "force_wrap_up", "suggest_summary"]
elif pct >= self._thresholds.get(BudgetLevel.COMPRESS, 0.80):
level = BudgetLevel.COMPRESS
msg = f"Context {pct:.0%} full. Auto-compressing conversation."
actions = ["auto_compress", "drop_raw_tool_outputs", "suggest_summary"]
elif pct >= self._thresholds.get(BudgetLevel.WARNING, 0.60):
level = BudgetLevel.WARNING
msg = f"Context {pct:.0%} used. Consider summarizing."
actions = ["suggest_summary", "log_warning"]
else:
level = BudgetLevel.OK
msg = f"Context OK: {self.used_tokens}/{self._max_tokens} tokens ({pct:.0%})"
actions = []
return BudgetStatus(
level=level,
used_tokens=self.used_tokens,
max_tokens=self._max_tokens,
percentage=round(pct, 3),
remaining=self.remaining,
message=msg,
actions=actions,
)
def should_truncate_tool_output(self, estimated_tokens: int) -> bool:
"""Check if a tool output should be truncated."""
if self.used_tokens + estimated_tokens > self._max_tokens * 0.95:
return True
return False
def get_truncation_budget(self) -> int:
"""Get max tokens available for next tool output."""
budget = self.remaining - int(self._max_tokens * 0.05) # Reserve 5%
return max(0, budget)
def reset(self):
"""Reset budget for new session."""
self._prompt_tokens = 0
self._completion_tokens = 0
self._tool_output_tokens = 0
self._history.clear()
def get_report(self) -> Dict[str, Any]:
"""Generate usage report."""
status = self.check()
return {
"status": status.level.value,
"used_tokens": self.used_tokens,
"max_tokens": self._max_tokens,
"remaining": self.remaining,
"percentage": status.percentage,
"prompt_tokens": self._prompt_tokens,
"completion_tokens": self._completion_tokens,
"tool_output_tokens": self._tool_output_tokens,
"message": status.message,
"actions": status.actions,
}

View File

@@ -1,67 +0,0 @@
"""
Tests for tool hallucination detection (#922).
"""
import pytest
from tools.tool_validator import ToolHallucinationDetector, ValidationSeverity
class TestToolHallucinationDetector:
def setup_method(self):
self.detector = ToolHallucinationDetector()
self.detector.register_tool("read_file", {
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string"},
"encoding": {"type": "string"},
},
"required": ["path"]
}
})
def test_valid_tool_call(self):
result = self.detector.validate_tool_call("read_file", {"path": "/tmp/file.txt"})
assert result.valid is True
assert len(result.blocking_issues) == 0
def test_unknown_tool(self):
result = self.detector.validate_tool_call("hallucinated_tool", {})
assert result.valid is False
assert any(i.code == "UNKNOWN_TOOL" for i in result.issues)
def test_missing_required_param(self):
result = self.detector.validate_tool_call("read_file", {})
assert result.valid is False
assert any(i.code == "MISSING_REQUIRED" for i in result.issues)
def test_wrong_type(self):
result = self.detector.validate_tool_call("read_file", {"path": 123})
assert result.valid is False
assert any(i.code == "WRONG_TYPE" for i in result.issues)
def test_unknown_param_warning(self):
result = self.detector.validate_tool_call("read_file", {"path": "/tmp/file.txt", "unknown": "value"})
assert result.valid is True # Warning, not blocking
assert any(i.code == "UNKNOWN_PARAM" for i in result.issues)
def test_placeholder_detection(self):
result = self.detector.validate_tool_call("read_file", {"path": "<placeholder>"})
assert any(i.code == "PLACEHOLDER_VALUE" for i in result.issues)
def test_rejection_stats(self):
self.detector.validate_tool_call("unknown_tool", {})
self.detector.validate_tool_call("read_file", {})
stats = self.detector.get_rejection_stats()
assert stats["total"] >= 2
def test_rejection_response(self):
from tools.tool_validator import create_rejection_response
result = self.detector.validate_tool_call("unknown_tool", {})
response = create_rejection_response(result)
assert response["role"] == "tool"
assert "rejected" in response["content"].lower()
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,312 +0,0 @@
"""
Poka-Yoke: Tool Hallucination Detection — #922.
Validation firewall between LLM tool-call output and actual execution.
Detects and blocks:
1. Unknown tool names (hallucinated tools)
2. Malformed parameters (wrong types)
3. Missing required arguments
4. Extra unknown parameters
Poka-Yoke Type: Detection (catches errors at the boundary before harm)
"""
import json
import logging
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Set
logger = logging.getLogger(__name__)
class ValidationSeverity(Enum):
"""Severity of validation failure."""
BLOCK = "block" # Must block execution
WARN = "warn" # Warning, may proceed
INFO = "info" # Informational
@dataclass
class ValidationIssue:
"""A validation issue found."""
severity: ValidationSeverity
code: str
message: str
tool_name: str
parameter: Optional[str] = None
expected: Optional[str] = None
actual: Optional[Any] = None
@dataclass
class ValidationResult:
"""Result of tool call validation."""
valid: bool
tool_name: str
issues: List[ValidationIssue] = field(default_factory=list)
corrected_args: Optional[Dict[str, Any]] = None
@property
def blocking_issues(self) -> List[ValidationIssue]:
return [i for i in self.issues if i.severity == ValidationSeverity.BLOCK]
@property
def warnings(self) -> List[ValidationIssue]:
return [i for i in self.issues if i.severity == ValidationSeverity.WARN]
class ToolHallucinationDetector:
"""
Poka-yoke detector for tool hallucinations.
Validates tool calls against registered schemas before execution.
"""
def __init__(self, tool_registry: Optional[Dict] = None):
"""
Initialize detector.
Args:
tool_registry: Dict of tool_name -> tool_schema
"""
self.registry = tool_registry or {}
self._rejection_log: List[Dict] = []
def register_tool(self, name: str, schema: Dict):
"""Register a tool with its JSON Schema."""
self.registry[name] = schema
def register_tools(self, tools: Dict[str, Dict]):
"""Register multiple tools."""
self.registry.update(tools)
def validate_tool_call(
self,
tool_name: str,
arguments: Dict[str, Any],
model: str = "unknown",
) -> ValidationResult:
"""
Validate a tool call against the registry.
Args:
tool_name: Name of the tool being called
arguments: Arguments passed to the tool
model: Model that generated the call (for logging)
Returns:
ValidationResult with validation status
"""
issues = []
# 1. Check if tool exists
if tool_name not in self.registry:
issue = ValidationIssue(
severity=ValidationSeverity.BLOCK,
code="UNKNOWN_TOOL",
message=f"Tool '{tool_name}' does not exist. Available: {', '.join(sorted(self.registry.keys())[:10])}...",
tool_name=tool_name,
)
issues.append(issue)
self._log_rejection(tool_name, arguments, model, "UNKNOWN_TOOL")
return ValidationResult(valid=False, tool_name=tool_name, issues=issues)
schema = self.registry[tool_name]
params_schema = schema.get("parameters", {}).get("properties", {})
required = set(schema.get("parameters", {}).get("required", []))
# 2. Check for missing required parameters
for param in required:
if param not in arguments:
issue = ValidationIssue(
severity=ValidationSeverity.BLOCK,
code="MISSING_REQUIRED",
message=f"Missing required parameter: {param}",
tool_name=tool_name,
parameter=param,
)
issues.append(issue)
# 3. Check parameter types
for param_name, param_value in arguments.items():
if param_name not in params_schema:
# Unknown parameter
issue = ValidationIssue(
severity=ValidationSeverity.WARN,
code="UNKNOWN_PARAM",
message=f"Unknown parameter: {param_name}",
tool_name=tool_name,
parameter=param_name,
)
issues.append(issue)
continue
param_schema = params_schema[param_name]
expected_type = param_schema.get("type")
if expected_type and not self._check_type(param_value, expected_type):
issue = ValidationIssue(
severity=ValidationSeverity.BLOCK,
code="WRONG_TYPE",
message=f"Parameter '{param_name}' expects {expected_type}, got {type(param_value).__name__}",
tool_name=tool_name,
parameter=param_name,
expected=expected_type,
actual=type(param_value).__name__,
)
issues.append(issue)
# 4. Check for common hallucination patterns
hallucination_issues = self._detect_hallucination_patterns(tool_name, arguments)
issues.extend(hallucination_issues)
# Determine validity
has_blocking = any(i.severity == ValidationSeverity.BLOCK for i in issues)
if has_blocking:
self._log_rejection(tool_name, arguments, model,
"; ".join(i.code for i in issues if i.severity == ValidationSeverity.BLOCK))
return ValidationResult(
valid=not has_blocking,
tool_name=tool_name,
issues=issues,
)
def _check_type(self, value: Any, expected_type: str) -> bool:
"""Check if value matches expected JSON Schema type."""
type_map = {
"string": str,
"number": (int, float),
"integer": int,
"boolean": bool,
"array": list,
"object": dict,
}
expected = type_map.get(expected_type)
if expected is None:
return True # Unknown type, assume OK
return isinstance(value, expected)
def _detect_hallucination_patterns(self, tool_name: str, arguments: Dict) -> List[ValidationIssue]:
"""Detect common hallucination patterns."""
issues = []
# Pattern 1: Placeholder values
placeholder_patterns = [
r"^<.*>$", # <placeholder>
r"^\[.*\]$", # [placeholder]
r"^TODO$|^FIXME$", # TODO/FIXME
r"^example\.com$", # example.com
r"^127\.0\.0\.1$", # localhost
]
for param_name, param_value in arguments.items():
if isinstance(param_value, str):
for pattern in placeholder_patterns:
if re.match(pattern, param_value, re.IGNORECASE):
issues.append(ValidationIssue(
severity=ValidationSeverity.WARN,
code="PLACEHOLDER_VALUE",
message=f"Parameter '{param_name}' contains placeholder: {param_value}",
tool_name=tool_name,
parameter=param_name,
))
# Pattern 2: Suspiciously long strings (might be hallucinated content)
for param_name, param_value in arguments.items():
if isinstance(param_value, str) and len(param_value) > 10000:
issues.append(ValidationIssue(
severity=ValidationSeverity.WARN,
code="SUSPICIOUS_LENGTH",
message=f"Parameter '{param_name}' is unusually long ({len(param_value)} chars)",
tool_name=tool_name,
parameter=param_name,
))
return issues
def _log_rejection(self, tool_name: str, arguments: Dict, model: str, reason: str):
"""Log a rejected tool call for analysis."""
import time
entry = {
"timestamp": time.time(),
"tool_name": tool_name,
"arguments": {k: str(v)[:100] for k, v in arguments.items()},
"model": model,
"reason": reason,
}
self._rejection_log.append(entry)
# Keep log bounded
if len(self._rejection_log) > 1000:
self._rejection_log = self._rejection_log[-500:]
logger.warning(
"Tool hallucination blocked: tool=%s, model=%s, reason=%s",
tool_name, model, reason
)
def get_rejection_stats(self) -> Dict:
"""Get statistics on rejected tool calls."""
if not self._rejection_log:
return {"total": 0, "by_reason": {}, "by_tool": {}}
by_reason = {}
by_tool = {}
for entry in self._rejection_log:
reason = entry["reason"]
tool = entry["tool_name"]
by_reason[reason] = by_reason.get(reason, 0) + 1
by_tool[tool] = by_tool.get(tool, 0) + 1
return {
"total": len(self._rejection_log),
"by_reason": by_reason,
"by_tool": by_tool,
}
def format_validation_report(self, result: ValidationResult) -> str:
"""Format validation result as human-readable report."""
if result.valid:
return f"{result.tool_name}: valid"
lines = [f"{result.tool_name}: BLOCKED"]
for issue in result.blocking_issues:
lines.append(f" [{issue.code}] {issue.message}")
for issue in result.warnings:
lines.append(f" ⚠️ [{issue.code}] {issue.message}")
return "\n".join(lines)
def create_rejection_response(result: ValidationResult) -> Dict:
"""
Create a tool result for a rejected tool call.
This allows the agent to see the rejection and self-correct.
"""
issues_text = "\n".join(
f"- [{i.code}] {i.message}"
for i in result.blocking_issues
)
return {
"role": "tool",
"content": f"""Tool call rejected: {result.tool_name}
Issues found:
{issues_text}
Please check the tool name and parameters, then try again with valid arguments.""",
}