Closes #885 2.33x error cascade factor detected. After 3 consecutive errors, circuit opens and agent must take corrective action. Recovery pattern: terminal is the safety net (2300 recoveries).
274 lines
9.7 KiB
Python
274 lines
9.7 KiB
Python
"""
|
|
Circuit Breaker for Error Cascading — #885
|
|
|
|
P(error | prev was error) = 58.6% vs P(error | prev was success) = 25.2%.
|
|
That's a 2.33x cascade factor. After 3 consecutive errors, the circuit
|
|
opens and the agent must take corrective action.
|
|
|
|
States:
|
|
- CLOSED: Normal operation, errors are counted
|
|
- OPEN: Too many consecutive errors, corrective action required
|
|
- HALF_OPEN: Testing if errors have cleared
|
|
|
|
Usage:
|
|
from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker
|
|
|
|
cb = ToolCircuitBreaker()
|
|
|
|
# After each tool call
|
|
if not cb.record_result(success=True):
|
|
# Circuit is open — take corrective action
|
|
cb.get_recovery_action()
|
|
"""
|
|
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
class CircuitState(Enum):
|
|
CLOSED = "closed" # Normal operation
|
|
OPEN = "open" # Too many errors, block execution
|
|
HALF_OPEN = "half_open" # Testing recovery
|
|
|
|
|
|
@dataclass
|
|
class CircuitBreaker:
|
|
"""
|
|
Generic circuit breaker with configurable thresholds.
|
|
|
|
Tracks consecutive errors and opens the circuit when the
|
|
error streak exceeds the threshold.
|
|
"""
|
|
failure_threshold: int = 3
|
|
recovery_timeout: float = 30.0 # seconds before trying half-open
|
|
success_threshold: int = 2 # successes needed to close from half-open
|
|
|
|
state: CircuitState = field(default=CircuitState.CLOSED, init=False)
|
|
consecutive_failures: int = field(default=0, init=False)
|
|
consecutive_successes: int = field(default=0, init=False)
|
|
last_failure_time: Optional[float] = field(default=None, init=False)
|
|
total_trips: int = field(default=0, init=False)
|
|
error_streaks: List[int] = field(default_factory=list, init=False)
|
|
|
|
def record_result(self, success: bool) -> bool:
|
|
"""
|
|
Record a tool call result. Returns True if circuit allows execution.
|
|
|
|
Returns:
|
|
True if circuit is CLOSED or HALF_OPEN (execution allowed)
|
|
False if circuit is OPEN (execution blocked)
|
|
"""
|
|
now = time.time()
|
|
|
|
if self.state == CircuitState.OPEN:
|
|
# Check if recovery timeout has passed
|
|
if self.last_failure_time and (now - self.last_failure_time) >= self.recovery_timeout:
|
|
self.state = CircuitState.HALF_OPEN
|
|
self.consecutive_successes = 0
|
|
return True # Allow one test execution
|
|
return False # Still open
|
|
|
|
if success:
|
|
self.consecutive_failures = 0
|
|
self.consecutive_successes += 1
|
|
|
|
if self.state == CircuitState.HALF_OPEN:
|
|
if self.consecutive_successes >= self.success_threshold:
|
|
self.state = CircuitState.CLOSED
|
|
self.consecutive_successes = 0
|
|
|
|
return True
|
|
else:
|
|
self.consecutive_successes = 0
|
|
self.consecutive_failures += 1
|
|
self.last_failure_time = now
|
|
|
|
if self.state == CircuitState.HALF_OPEN:
|
|
# Failed during recovery — reopen immediately
|
|
self.state = CircuitState.OPEN
|
|
self.total_trips += 1
|
|
return False
|
|
|
|
if self.consecutive_failures >= self.failure_threshold:
|
|
self.state = CircuitState.OPEN
|
|
self.total_trips += 1
|
|
self.error_streaks.append(self.consecutive_failures)
|
|
return False
|
|
|
|
return True
|
|
|
|
def can_execute(self) -> bool:
|
|
"""Check if execution is allowed."""
|
|
if self.state == CircuitState.OPEN:
|
|
if self.last_failure_time:
|
|
now = time.time()
|
|
if (now - self.last_failure_time) >= self.recovery_timeout:
|
|
self.state = CircuitState.HALF_OPEN
|
|
self.consecutive_successes = 0
|
|
return True
|
|
return False
|
|
return True
|
|
|
|
def get_state(self) -> Dict[str, Any]:
|
|
"""Get current circuit state."""
|
|
return {
|
|
"state": self.state.value,
|
|
"consecutive_failures": self.consecutive_failures,
|
|
"consecutive_successes": self.consecutive_successes,
|
|
"total_trips": self.total_trips,
|
|
"max_streak": max(self.error_streaks) if self.error_streaks else 0,
|
|
"can_execute": self.can_execute(),
|
|
}
|
|
|
|
def reset(self):
|
|
"""Reset the circuit breaker."""
|
|
self.state = CircuitState.CLOSED
|
|
self.consecutive_failures = 0
|
|
self.consecutive_successes = 0
|
|
self.last_failure_time = None
|
|
|
|
|
|
class ToolCircuitBreaker(CircuitBreaker):
|
|
"""
|
|
Circuit breaker specifically for tool call error cascading.
|
|
|
|
Provides recovery actions when the circuit opens.
|
|
"""
|
|
|
|
# Tools that are most effective at recovery (from audit data)
|
|
RECOVERY_TOOLS = [
|
|
"terminal", # Most effective — 2300 recoveries
|
|
"read_file", # Reset context by reading something
|
|
"search_files", # Find what went wrong
|
|
]
|
|
|
|
def get_recovery_action(self) -> Dict[str, Any]:
|
|
"""
|
|
Get the recommended recovery action when circuit is open.
|
|
|
|
Returns dict with action type and details.
|
|
"""
|
|
streak = self.consecutive_failures
|
|
|
|
if streak >= 9:
|
|
# After 9 errors: 41/46 recoveries via terminal
|
|
return {
|
|
"action": "terminal_only",
|
|
"reason": f"Error streak of {streak} — terminal is the only reliable recovery",
|
|
"suggested_tool": "terminal",
|
|
"suggested_command": "echo 'Resetting context'",
|
|
"severity": "critical",
|
|
}
|
|
elif streak >= 5:
|
|
return {
|
|
"action": "switch_tool_type",
|
|
"reason": f"Error streak of {streak} — switch to a different tool category",
|
|
"suggested_tools": ["read_file", "search_files", "terminal"],
|
|
"severity": "high",
|
|
}
|
|
elif streak >= self.failure_threshold:
|
|
return {
|
|
"action": "ask_user",
|
|
"reason": f"{streak} consecutive errors — ask user for guidance",
|
|
"suggested_response": "I'm encountering repeated errors. Would you like me to try a different approach?",
|
|
"severity": "medium",
|
|
}
|
|
else:
|
|
return {
|
|
"action": "continue",
|
|
"reason": f"Error streak of {streak} — within tolerance",
|
|
"severity": "low",
|
|
}
|
|
|
|
def should_compress_context(self) -> bool:
|
|
"""Determine if context compression would help recovery."""
|
|
return self.consecutive_failures >= 5
|
|
|
|
def get_blocked_tool(self) -> Optional[str]:
|
|
"""Get the tool that should be blocked (if any)."""
|
|
if self.state == CircuitState.OPEN:
|
|
return "last_failed_tool"
|
|
return None
|
|
|
|
|
|
class MultiToolCircuitBreaker:
|
|
"""
|
|
Manages per-tool circuit breakers and cross-tool cascade detection.
|
|
|
|
When one tool trips its breaker, related tools are also warned.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.breakers: Dict[str, ToolCircuitBreaker] = {}
|
|
self.global_streak: int = 0
|
|
self.last_tool: Optional[str] = None
|
|
self.last_success: bool = True
|
|
|
|
def get_breaker(self, tool_name: str) -> ToolCircuitBreaker:
|
|
"""Get or create a circuit breaker for a tool."""
|
|
if tool_name not in self.breakers:
|
|
self.breakers[tool_name] = ToolCircuitBreaker()
|
|
return self.breakers[tool_name]
|
|
|
|
def record_result(self, tool_name: str, success: bool) -> bool:
|
|
"""
|
|
Record a tool call result. Returns True if execution should continue.
|
|
"""
|
|
breaker = self.get_breaker(tool_name)
|
|
allowed = breaker.record_result(success)
|
|
|
|
# Track global streak
|
|
if success:
|
|
self.global_streak = 0
|
|
self.last_success = True
|
|
else:
|
|
self.global_streak += 1
|
|
self.last_success = False
|
|
|
|
self.last_tool = tool_name
|
|
return allowed
|
|
|
|
def can_execute(self, tool_name: str) -> bool:
|
|
"""Check if a specific tool can execute."""
|
|
breaker = self.get_breaker(tool_name)
|
|
return breaker.can_execute()
|
|
|
|
def get_global_state(self) -> Dict[str, Any]:
|
|
"""Get overall circuit breaker state."""
|
|
return {
|
|
"global_streak": self.global_streak,
|
|
"last_tool": self.last_tool,
|
|
"last_success": self.last_success,
|
|
"tool_states": {
|
|
name: breaker.get_state()
|
|
for name, breaker in self.breakers.items()
|
|
if breaker.consecutive_failures > 0 or breaker.total_trips > 0
|
|
},
|
|
"any_open": any(b.state == CircuitState.OPEN for b in self.breakers.values()),
|
|
}
|
|
|
|
def get_recovery_action(self) -> Dict[str, Any]:
|
|
"""Get recovery action based on global state."""
|
|
if self.global_streak == 0:
|
|
return {"action": "continue", "reason": "No errors"}
|
|
|
|
# Find the breaker with the worst streak
|
|
worst = max(self.breakers.values(), key=lambda b: b.consecutive_failures, default=None)
|
|
if worst and worst.consecutive_failures > 0:
|
|
return worst.get_recovery_action()
|
|
|
|
return {
|
|
"action": "continue",
|
|
"reason": f"Global streak: {self.global_streak}",
|
|
"severity": "low",
|
|
}
|
|
|
|
def reset_all(self):
|
|
"""Reset all circuit breakers."""
|
|
for breaker in self.breakers.values():
|
|
breaker.reset()
|
|
self.global_streak = 0
|
|
self.last_success = True
|