Compare commits
3 Commits
fix/format
...
fix/885-ci
| Author | SHA1 | Date | |
|---|---|---|---|
| 30509b9c7c | |||
| ccaa1cb021 | |||
| c6f2855745 |
273
agent/circuit_breaker.py
Normal file
273
agent/circuit_breaker.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
"""
|
||||||
|
Circuit Breaker for Error Cascading — #885
|
||||||
|
|
||||||
|
P(error | prev was error) = 58.6% vs P(error | prev was success) = 25.2%.
|
||||||
|
That's a 2.33x cascade factor. After 3 consecutive errors, the circuit
|
||||||
|
opens and the agent must take corrective action.
|
||||||
|
|
||||||
|
States:
|
||||||
|
- CLOSED: Normal operation, errors are counted
|
||||||
|
- OPEN: Too many consecutive errors, corrective action required
|
||||||
|
- HALF_OPEN: Testing if errors have cleared
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker
|
||||||
|
|
||||||
|
cb = ToolCircuitBreaker()
|
||||||
|
|
||||||
|
# After each tool call
|
||||||
|
if not cb.record_result(success=True):
|
||||||
|
# Circuit is open — take corrective action
|
||||||
|
cb.get_recovery_action()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitState(Enum):
|
||||||
|
CLOSED = "closed" # Normal operation
|
||||||
|
OPEN = "open" # Too many errors, block execution
|
||||||
|
HALF_OPEN = "half_open" # Testing recovery
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CircuitBreaker:
|
||||||
|
"""
|
||||||
|
Generic circuit breaker with configurable thresholds.
|
||||||
|
|
||||||
|
Tracks consecutive errors and opens the circuit when the
|
||||||
|
error streak exceeds the threshold.
|
||||||
|
"""
|
||||||
|
failure_threshold: int = 3
|
||||||
|
recovery_timeout: float = 30.0 # seconds before trying half-open
|
||||||
|
success_threshold: int = 2 # successes needed to close from half-open
|
||||||
|
|
||||||
|
state: CircuitState = field(default=CircuitState.CLOSED, init=False)
|
||||||
|
consecutive_failures: int = field(default=0, init=False)
|
||||||
|
consecutive_successes: int = field(default=0, init=False)
|
||||||
|
last_failure_time: Optional[float] = field(default=None, init=False)
|
||||||
|
total_trips: int = field(default=0, init=False)
|
||||||
|
error_streaks: List[int] = field(default_factory=list, init=False)
|
||||||
|
|
||||||
|
def record_result(self, success: bool) -> bool:
|
||||||
|
"""
|
||||||
|
Record a tool call result. Returns True if circuit allows execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if circuit is CLOSED or HALF_OPEN (execution allowed)
|
||||||
|
False if circuit is OPEN (execution blocked)
|
||||||
|
"""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
if self.state == CircuitState.OPEN:
|
||||||
|
# Check if recovery timeout has passed
|
||||||
|
if self.last_failure_time and (now - self.last_failure_time) >= self.recovery_timeout:
|
||||||
|
self.state = CircuitState.HALF_OPEN
|
||||||
|
self.consecutive_successes = 0
|
||||||
|
return True # Allow one test execution
|
||||||
|
return False # Still open
|
||||||
|
|
||||||
|
if success:
|
||||||
|
self.consecutive_failures = 0
|
||||||
|
self.consecutive_successes += 1
|
||||||
|
|
||||||
|
if self.state == CircuitState.HALF_OPEN:
|
||||||
|
if self.consecutive_successes >= self.success_threshold:
|
||||||
|
self.state = CircuitState.CLOSED
|
||||||
|
self.consecutive_successes = 0
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.consecutive_successes = 0
|
||||||
|
self.consecutive_failures += 1
|
||||||
|
self.last_failure_time = now
|
||||||
|
|
||||||
|
if self.state == CircuitState.HALF_OPEN:
|
||||||
|
# Failed during recovery — reopen immediately
|
||||||
|
self.state = CircuitState.OPEN
|
||||||
|
self.total_trips += 1
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.consecutive_failures >= self.failure_threshold:
|
||||||
|
self.state = CircuitState.OPEN
|
||||||
|
self.total_trips += 1
|
||||||
|
self.error_streaks.append(self.consecutive_failures)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def can_execute(self) -> bool:
|
||||||
|
"""Check if execution is allowed."""
|
||||||
|
if self.state == CircuitState.OPEN:
|
||||||
|
if self.last_failure_time:
|
||||||
|
now = time.time()
|
||||||
|
if (now - self.last_failure_time) >= self.recovery_timeout:
|
||||||
|
self.state = CircuitState.HALF_OPEN
|
||||||
|
self.consecutive_successes = 0
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_state(self) -> Dict[str, Any]:
|
||||||
|
"""Get current circuit state."""
|
||||||
|
return {
|
||||||
|
"state": self.state.value,
|
||||||
|
"consecutive_failures": self.consecutive_failures,
|
||||||
|
"consecutive_successes": self.consecutive_successes,
|
||||||
|
"total_trips": self.total_trips,
|
||||||
|
"max_streak": max(self.error_streaks) if self.error_streaks else 0,
|
||||||
|
"can_execute": self.can_execute(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the circuit breaker."""
|
||||||
|
self.state = CircuitState.CLOSED
|
||||||
|
self.consecutive_failures = 0
|
||||||
|
self.consecutive_successes = 0
|
||||||
|
self.last_failure_time = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCircuitBreaker(CircuitBreaker):
|
||||||
|
"""
|
||||||
|
Circuit breaker specifically for tool call error cascading.
|
||||||
|
|
||||||
|
Provides recovery actions when the circuit opens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Tools that are most effective at recovery (from audit data)
|
||||||
|
RECOVERY_TOOLS = [
|
||||||
|
"terminal", # Most effective — 2300 recoveries
|
||||||
|
"read_file", # Reset context by reading something
|
||||||
|
"search_files", # Find what went wrong
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_recovery_action(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get the recommended recovery action when circuit is open.
|
||||||
|
|
||||||
|
Returns dict with action type and details.
|
||||||
|
"""
|
||||||
|
streak = self.consecutive_failures
|
||||||
|
|
||||||
|
if streak >= 9:
|
||||||
|
# After 9 errors: 41/46 recoveries via terminal
|
||||||
|
return {
|
||||||
|
"action": "terminal_only",
|
||||||
|
"reason": f"Error streak of {streak} — terminal is the only reliable recovery",
|
||||||
|
"suggested_tool": "terminal",
|
||||||
|
"suggested_command": "echo 'Resetting context'",
|
||||||
|
"severity": "critical",
|
||||||
|
}
|
||||||
|
elif streak >= 5:
|
||||||
|
return {
|
||||||
|
"action": "switch_tool_type",
|
||||||
|
"reason": f"Error streak of {streak} — switch to a different tool category",
|
||||||
|
"suggested_tools": ["read_file", "search_files", "terminal"],
|
||||||
|
"severity": "high",
|
||||||
|
}
|
||||||
|
elif streak >= self.failure_threshold:
|
||||||
|
return {
|
||||||
|
"action": "ask_user",
|
||||||
|
"reason": f"{streak} consecutive errors — ask user for guidance",
|
||||||
|
"suggested_response": "I'm encountering repeated errors. Would you like me to try a different approach?",
|
||||||
|
"severity": "medium",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"action": "continue",
|
||||||
|
"reason": f"Error streak of {streak} — within tolerance",
|
||||||
|
"severity": "low",
|
||||||
|
}
|
||||||
|
|
||||||
|
def should_compress_context(self) -> bool:
|
||||||
|
"""Determine if context compression would help recovery."""
|
||||||
|
return self.consecutive_failures >= 5
|
||||||
|
|
||||||
|
def get_blocked_tool(self) -> Optional[str]:
|
||||||
|
"""Get the tool that should be blocked (if any)."""
|
||||||
|
if self.state == CircuitState.OPEN:
|
||||||
|
return "last_failed_tool"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MultiToolCircuitBreaker:
|
||||||
|
"""
|
||||||
|
Manages per-tool circuit breakers and cross-tool cascade detection.
|
||||||
|
|
||||||
|
When one tool trips its breaker, related tools are also warned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.breakers: Dict[str, ToolCircuitBreaker] = {}
|
||||||
|
self.global_streak: int = 0
|
||||||
|
self.last_tool: Optional[str] = None
|
||||||
|
self.last_success: bool = True
|
||||||
|
|
||||||
|
def get_breaker(self, tool_name: str) -> ToolCircuitBreaker:
|
||||||
|
"""Get or create a circuit breaker for a tool."""
|
||||||
|
if tool_name not in self.breakers:
|
||||||
|
self.breakers[tool_name] = ToolCircuitBreaker()
|
||||||
|
return self.breakers[tool_name]
|
||||||
|
|
||||||
|
def record_result(self, tool_name: str, success: bool) -> bool:
|
||||||
|
"""
|
||||||
|
Record a tool call result. Returns True if execution should continue.
|
||||||
|
"""
|
||||||
|
breaker = self.get_breaker(tool_name)
|
||||||
|
allowed = breaker.record_result(success)
|
||||||
|
|
||||||
|
# Track global streak
|
||||||
|
if success:
|
||||||
|
self.global_streak = 0
|
||||||
|
self.last_success = True
|
||||||
|
else:
|
||||||
|
self.global_streak += 1
|
||||||
|
self.last_success = False
|
||||||
|
|
||||||
|
self.last_tool = tool_name
|
||||||
|
return allowed
|
||||||
|
|
||||||
|
def can_execute(self, tool_name: str) -> bool:
|
||||||
|
"""Check if a specific tool can execute."""
|
||||||
|
breaker = self.get_breaker(tool_name)
|
||||||
|
return breaker.can_execute()
|
||||||
|
|
||||||
|
def get_global_state(self) -> Dict[str, Any]:
|
||||||
|
"""Get overall circuit breaker state."""
|
||||||
|
return {
|
||||||
|
"global_streak": self.global_streak,
|
||||||
|
"last_tool": self.last_tool,
|
||||||
|
"last_success": self.last_success,
|
||||||
|
"tool_states": {
|
||||||
|
name: breaker.get_state()
|
||||||
|
for name, breaker in self.breakers.items()
|
||||||
|
if breaker.consecutive_failures > 0 or breaker.total_trips > 0
|
||||||
|
},
|
||||||
|
"any_open": any(b.state == CircuitState.OPEN for b in self.breakers.values()),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_recovery_action(self) -> Dict[str, Any]:
|
||||||
|
"""Get recovery action based on global state."""
|
||||||
|
if self.global_streak == 0:
|
||||||
|
return {"action": "continue", "reason": "No errors"}
|
||||||
|
|
||||||
|
# Find the breaker with the worst streak
|
||||||
|
worst = max(self.breakers.values(), key=lambda b: b.consecutive_failures, default=None)
|
||||||
|
if worst and worst.consecutive_failures > 0:
|
||||||
|
return worst.get_recovery_action()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"action": "continue",
|
||||||
|
"reason": f"Global streak: {self.global_streak}",
|
||||||
|
"severity": "low",
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset_all(self):
|
||||||
|
"""Reset all circuit breakers."""
|
||||||
|
for breaker in self.breakers.values():
|
||||||
|
breaker.reset()
|
||||||
|
self.global_streak = 0
|
||||||
|
self.last_success = True
|
||||||
97
tests/test_circuit_breaker.py
Normal file
97
tests/test_circuit_breaker.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""Tests for circuit breaker (#885)."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker, MultiToolCircuitBreaker, CircuitState
|
||||||
|
|
||||||
|
|
||||||
|
def test_closed_allows_execution():
|
||||||
|
cb = CircuitBreaker(failure_threshold=3)
|
||||||
|
assert cb.can_execute()
|
||||||
|
|
||||||
|
|
||||||
|
def test_opens_after_threshold():
|
||||||
|
cb = CircuitBreaker(failure_threshold=3)
|
||||||
|
cb.record_result(False)
|
||||||
|
cb.record_result(False)
|
||||||
|
assert cb.can_execute() # Still closed at 2
|
||||||
|
cb.record_result(False)
|
||||||
|
assert not cb.can_execute() # Open at 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_closes_on_success():
|
||||||
|
cb = CircuitBreaker(failure_threshold=3)
|
||||||
|
cb.record_result(False)
|
||||||
|
cb.record_result(True)
|
||||||
|
assert cb.consecutive_failures == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_half_open_recovery():
|
||||||
|
cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.1, success_threshold=1)
|
||||||
|
cb.record_result(False)
|
||||||
|
cb.record_result(False)
|
||||||
|
assert cb.state == CircuitState.OPEN
|
||||||
|
|
||||||
|
import time
|
||||||
|
time.sleep(0.15)
|
||||||
|
|
||||||
|
assert cb.can_execute() # Moved to half-open
|
||||||
|
cb.record_result(True)
|
||||||
|
assert cb.state == CircuitState.CLOSED
|
||||||
|
|
||||||
|
|
||||||
|
def test_recovery_action_streak():
|
||||||
|
cb = ToolCircuitBreaker(failure_threshold=3)
|
||||||
|
for _ in range(5):
|
||||||
|
cb.record_result(False)
|
||||||
|
action = cb.get_recovery_action()
|
||||||
|
assert action["action"] == "switch_tool_type"
|
||||||
|
|
||||||
|
|
||||||
|
def test_recovery_action_critical():
|
||||||
|
cb = ToolCircuitBreaker(failure_threshold=3)
|
||||||
|
for _ in range(10):
|
||||||
|
cb.record_result(False)
|
||||||
|
action = cb.get_recovery_action()
|
||||||
|
assert action["action"] == "terminal_only"
|
||||||
|
assert action["severity"] == "critical"
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_tool_breaker():
|
||||||
|
mcb = MultiToolCircuitBreaker()
|
||||||
|
mcb.record_result("read_file", False)
|
||||||
|
mcb.record_result("read_file", False)
|
||||||
|
mcb.record_result("read_file", False)
|
||||||
|
assert not mcb.can_execute("read_file")
|
||||||
|
assert mcb.can_execute("terminal") # Different tool unaffected
|
||||||
|
|
||||||
|
|
||||||
|
def test_global_state():
|
||||||
|
mcb = MultiToolCircuitBreaker()
|
||||||
|
mcb.record_result("tool_a", False)
|
||||||
|
mcb.record_result("tool_b", False)
|
||||||
|
state = mcb.get_global_state()
|
||||||
|
assert state["global_streak"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset():
|
||||||
|
cb = CircuitBreaker(failure_threshold=2)
|
||||||
|
cb.record_result(False)
|
||||||
|
cb.record_result(False)
|
||||||
|
assert cb.state == CircuitState.OPEN
|
||||||
|
cb.reset()
|
||||||
|
assert cb.state == CircuitState.CLOSED
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tests = [test_closed_allows_execution, test_opens_after_threshold,
|
||||||
|
test_closes_on_success, test_half_open_recovery,
|
||||||
|
test_recovery_action_streak, test_recovery_action_critical,
|
||||||
|
test_multi_tool_breaker, test_global_state, test_reset]
|
||||||
|
for t in tests:
|
||||||
|
print(f"Running {t.__name__}...")
|
||||||
|
t()
|
||||||
|
print(" PASS")
|
||||||
|
print("\nAll tests passed.")
|
||||||
@@ -44,6 +44,34 @@ from typing import Dict, Any, Optional, Tuple
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_error(
|
||||||
|
message: str,
|
||||||
|
skill_name: str = None,
|
||||||
|
file_path: str = None,
|
||||||
|
suggestion: str = None,
|
||||||
|
context: dict = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Format an error with rich context for better debugging."""
|
||||||
|
parts = [message]
|
||||||
|
if skill_name:
|
||||||
|
parts.append(f"Skill: {skill_name}")
|
||||||
|
if file_path:
|
||||||
|
parts.append(f"File: {file_path}")
|
||||||
|
if suggestion:
|
||||||
|
parts.append(f"Suggestion: {suggestion}")
|
||||||
|
if context:
|
||||||
|
for key, value in context.items():
|
||||||
|
parts.append(f"{key}: {value}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": " | ".join(parts),
|
||||||
|
"skill_name": skill_name,
|
||||||
|
"file_path": file_path,
|
||||||
|
"suggestion": suggestion,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Import security scanner — agent-created skills get the same scrutiny as
|
# Import security scanner — agent-created skills get the same scrutiny as
|
||||||
# community hub installs.
|
# community hub installs.
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user