Compare commits
3 Commits
fix/913-sy
...
fix/885-ci
| Author | SHA1 | Date | |
|---|---|---|---|
| 30509b9c7c | |||
| ccaa1cb021 | |||
| c6f2855745 |
273
agent/circuit_breaker.py
Normal file
273
agent/circuit_breaker.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Circuit Breaker for Error Cascading — #885
|
||||
|
||||
P(error | prev was error) = 58.6% vs P(error | prev was success) = 25.2%.
|
||||
That's a 2.33x cascade factor. After 3 consecutive errors, the circuit
|
||||
opens and the agent must take corrective action.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, errors are counted
|
||||
- OPEN: Too many consecutive errors, corrective action required
|
||||
- HALF_OPEN: Testing if errors have cleared
|
||||
|
||||
Usage:
|
||||
from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker
|
||||
|
||||
cb = ToolCircuitBreaker()
|
||||
|
||||
# After each tool call
|
||||
if not cb.record_result(success=True):
|
||||
# Circuit is open — take corrective action
|
||||
cb.get_recovery_action()
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Too many errors, block execution
|
||||
HALF_OPEN = "half_open" # Testing recovery
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Generic circuit breaker with configurable thresholds.
|
||||
|
||||
Tracks consecutive errors and opens the circuit when the
|
||||
error streak exceeds the threshold.
|
||||
"""
|
||||
failure_threshold: int = 3
|
||||
recovery_timeout: float = 30.0 # seconds before trying half-open
|
||||
success_threshold: int = 2 # successes needed to close from half-open
|
||||
|
||||
state: CircuitState = field(default=CircuitState.CLOSED, init=False)
|
||||
consecutive_failures: int = field(default=0, init=False)
|
||||
consecutive_successes: int = field(default=0, init=False)
|
||||
last_failure_time: Optional[float] = field(default=None, init=False)
|
||||
total_trips: int = field(default=0, init=False)
|
||||
error_streaks: List[int] = field(default_factory=list, init=False)
|
||||
|
||||
def record_result(self, success: bool) -> bool:
|
||||
"""
|
||||
Record a tool call result. Returns True if circuit allows execution.
|
||||
|
||||
Returns:
|
||||
True if circuit is CLOSED or HALF_OPEN (execution allowed)
|
||||
False if circuit is OPEN (execution blocked)
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
if self.state == CircuitState.OPEN:
|
||||
# Check if recovery timeout has passed
|
||||
if self.last_failure_time and (now - self.last_failure_time) >= self.recovery_timeout:
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
self.consecutive_successes = 0
|
||||
return True # Allow one test execution
|
||||
return False # Still open
|
||||
|
||||
if success:
|
||||
self.consecutive_failures = 0
|
||||
self.consecutive_successes += 1
|
||||
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
if self.consecutive_successes >= self.success_threshold:
|
||||
self.state = CircuitState.CLOSED
|
||||
self.consecutive_successes = 0
|
||||
|
||||
return True
|
||||
else:
|
||||
self.consecutive_successes = 0
|
||||
self.consecutive_failures += 1
|
||||
self.last_failure_time = now
|
||||
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
# Failed during recovery — reopen immediately
|
||||
self.state = CircuitState.OPEN
|
||||
self.total_trips += 1
|
||||
return False
|
||||
|
||||
if self.consecutive_failures >= self.failure_threshold:
|
||||
self.state = CircuitState.OPEN
|
||||
self.total_trips += 1
|
||||
self.error_streaks.append(self.consecutive_failures)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def can_execute(self) -> bool:
|
||||
"""Check if execution is allowed."""
|
||||
if self.state == CircuitState.OPEN:
|
||||
if self.last_failure_time:
|
||||
now = time.time()
|
||||
if (now - self.last_failure_time) >= self.recovery_timeout:
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
self.consecutive_successes = 0
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get current circuit state."""
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"consecutive_failures": self.consecutive_failures,
|
||||
"consecutive_successes": self.consecutive_successes,
|
||||
"total_trips": self.total_trips,
|
||||
"max_streak": max(self.error_streaks) if self.error_streaks else 0,
|
||||
"can_execute": self.can_execute(),
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the circuit breaker."""
|
||||
self.state = CircuitState.CLOSED
|
||||
self.consecutive_failures = 0
|
||||
self.consecutive_successes = 0
|
||||
self.last_failure_time = None
|
||||
|
||||
|
||||
class ToolCircuitBreaker(CircuitBreaker):
|
||||
"""
|
||||
Circuit breaker specifically for tool call error cascading.
|
||||
|
||||
Provides recovery actions when the circuit opens.
|
||||
"""
|
||||
|
||||
# Tools that are most effective at recovery (from audit data)
|
||||
RECOVERY_TOOLS = [
|
||||
"terminal", # Most effective — 2300 recoveries
|
||||
"read_file", # Reset context by reading something
|
||||
"search_files", # Find what went wrong
|
||||
]
|
||||
|
||||
def get_recovery_action(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the recommended recovery action when circuit is open.
|
||||
|
||||
Returns dict with action type and details.
|
||||
"""
|
||||
streak = self.consecutive_failures
|
||||
|
||||
if streak >= 9:
|
||||
# After 9 errors: 41/46 recoveries via terminal
|
||||
return {
|
||||
"action": "terminal_only",
|
||||
"reason": f"Error streak of {streak} — terminal is the only reliable recovery",
|
||||
"suggested_tool": "terminal",
|
||||
"suggested_command": "echo 'Resetting context'",
|
||||
"severity": "critical",
|
||||
}
|
||||
elif streak >= 5:
|
||||
return {
|
||||
"action": "switch_tool_type",
|
||||
"reason": f"Error streak of {streak} — switch to a different tool category",
|
||||
"suggested_tools": ["read_file", "search_files", "terminal"],
|
||||
"severity": "high",
|
||||
}
|
||||
elif streak >= self.failure_threshold:
|
||||
return {
|
||||
"action": "ask_user",
|
||||
"reason": f"{streak} consecutive errors — ask user for guidance",
|
||||
"suggested_response": "I'm encountering repeated errors. Would you like me to try a different approach?",
|
||||
"severity": "medium",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"action": "continue",
|
||||
"reason": f"Error streak of {streak} — within tolerance",
|
||||
"severity": "low",
|
||||
}
|
||||
|
||||
def should_compress_context(self) -> bool:
|
||||
"""Determine if context compression would help recovery."""
|
||||
return self.consecutive_failures >= 5
|
||||
|
||||
def get_blocked_tool(self) -> Optional[str]:
|
||||
"""Get the tool that should be blocked (if any)."""
|
||||
if self.state == CircuitState.OPEN:
|
||||
return "last_failed_tool"
|
||||
return None
|
||||
|
||||
|
||||
class MultiToolCircuitBreaker:
|
||||
"""
|
||||
Manages per-tool circuit breakers and cross-tool cascade detection.
|
||||
|
||||
When one tool trips its breaker, related tools are also warned.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.breakers: Dict[str, ToolCircuitBreaker] = {}
|
||||
self.global_streak: int = 0
|
||||
self.last_tool: Optional[str] = None
|
||||
self.last_success: bool = True
|
||||
|
||||
def get_breaker(self, tool_name: str) -> ToolCircuitBreaker:
|
||||
"""Get or create a circuit breaker for a tool."""
|
||||
if tool_name not in self.breakers:
|
||||
self.breakers[tool_name] = ToolCircuitBreaker()
|
||||
return self.breakers[tool_name]
|
||||
|
||||
def record_result(self, tool_name: str, success: bool) -> bool:
|
||||
"""
|
||||
Record a tool call result. Returns True if execution should continue.
|
||||
"""
|
||||
breaker = self.get_breaker(tool_name)
|
||||
allowed = breaker.record_result(success)
|
||||
|
||||
# Track global streak
|
||||
if success:
|
||||
self.global_streak = 0
|
||||
self.last_success = True
|
||||
else:
|
||||
self.global_streak += 1
|
||||
self.last_success = False
|
||||
|
||||
self.last_tool = tool_name
|
||||
return allowed
|
||||
|
||||
def can_execute(self, tool_name: str) -> bool:
|
||||
"""Check if a specific tool can execute."""
|
||||
breaker = self.get_breaker(tool_name)
|
||||
return breaker.can_execute()
|
||||
|
||||
def get_global_state(self) -> Dict[str, Any]:
|
||||
"""Get overall circuit breaker state."""
|
||||
return {
|
||||
"global_streak": self.global_streak,
|
||||
"last_tool": self.last_tool,
|
||||
"last_success": self.last_success,
|
||||
"tool_states": {
|
||||
name: breaker.get_state()
|
||||
for name, breaker in self.breakers.items()
|
||||
if breaker.consecutive_failures > 0 or breaker.total_trips > 0
|
||||
},
|
||||
"any_open": any(b.state == CircuitState.OPEN for b in self.breakers.values()),
|
||||
}
|
||||
|
||||
def get_recovery_action(self) -> Dict[str, Any]:
|
||||
"""Get recovery action based on global state."""
|
||||
if self.global_streak == 0:
|
||||
return {"action": "continue", "reason": "No errors"}
|
||||
|
||||
# Find the breaker with the worst streak
|
||||
worst = max(self.breakers.values(), key=lambda b: b.consecutive_failures, default=None)
|
||||
if worst and worst.consecutive_failures > 0:
|
||||
return worst.get_recovery_action()
|
||||
|
||||
return {
|
||||
"action": "continue",
|
||||
"reason": f"Global streak: {self.global_streak}",
|
||||
"severity": "low",
|
||||
}
|
||||
|
||||
def reset_all(self):
|
||||
"""Reset all circuit breakers."""
|
||||
for breaker in self.breakers.values():
|
||||
breaker.reset()
|
||||
self.global_streak = 0
|
||||
self.last_success = True
|
||||
97
tests/test_circuit_breaker.py
Normal file
97
tests/test_circuit_breaker.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Tests for circuit breaker (#885)."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker, MultiToolCircuitBreaker, CircuitState
|
||||
|
||||
|
||||
def test_closed_allows_execution():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
assert cb.can_execute()
|
||||
|
||||
|
||||
def test_opens_after_threshold():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.can_execute() # Still closed at 2
|
||||
cb.record_result(False)
|
||||
assert not cb.can_execute() # Open at 3
|
||||
|
||||
|
||||
def test_closes_on_success():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
cb.record_result(False)
|
||||
cb.record_result(True)
|
||||
assert cb.consecutive_failures == 0
|
||||
|
||||
|
||||
def test_half_open_recovery():
|
||||
cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.1, success_threshold=1)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
import time
|
||||
time.sleep(0.15)
|
||||
|
||||
assert cb.can_execute() # Moved to half-open
|
||||
cb.record_result(True)
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
|
||||
def test_recovery_action_streak():
|
||||
cb = ToolCircuitBreaker(failure_threshold=3)
|
||||
for _ in range(5):
|
||||
cb.record_result(False)
|
||||
action = cb.get_recovery_action()
|
||||
assert action["action"] == "switch_tool_type"
|
||||
|
||||
|
||||
def test_recovery_action_critical():
|
||||
cb = ToolCircuitBreaker(failure_threshold=3)
|
||||
for _ in range(10):
|
||||
cb.record_result(False)
|
||||
action = cb.get_recovery_action()
|
||||
assert action["action"] == "terminal_only"
|
||||
assert action["severity"] == "critical"
|
||||
|
||||
|
||||
def test_multi_tool_breaker():
|
||||
mcb = MultiToolCircuitBreaker()
|
||||
mcb.record_result("read_file", False)
|
||||
mcb.record_result("read_file", False)
|
||||
mcb.record_result("read_file", False)
|
||||
assert not mcb.can_execute("read_file")
|
||||
assert mcb.can_execute("terminal") # Different tool unaffected
|
||||
|
||||
|
||||
def test_global_state():
|
||||
mcb = MultiToolCircuitBreaker()
|
||||
mcb.record_result("tool_a", False)
|
||||
mcb.record_result("tool_b", False)
|
||||
state = mcb.get_global_state()
|
||||
assert state["global_streak"] == 2
|
||||
|
||||
|
||||
def test_reset():
|
||||
cb = CircuitBreaker(failure_threshold=2)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.state == CircuitState.OPEN
|
||||
cb.reset()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = [test_closed_allows_execution, test_opens_after_threshold,
|
||||
test_closes_on_success, test_half_open_recovery,
|
||||
test_recovery_action_streak, test_recovery_action_critical,
|
||||
test_multi_tool_breaker, test_global_state, test_reset]
|
||||
for t in tests:
|
||||
print(f"Running {t.__name__}...")
|
||||
t()
|
||||
print(" PASS")
|
||||
print("\nAll tests passed.")
|
||||
@@ -44,6 +44,34 @@ from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_error(
|
||||
message: str,
|
||||
skill_name: str = None,
|
||||
file_path: str = None,
|
||||
suggestion: str = None,
|
||||
context: dict = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Format an error with rich context for better debugging."""
|
||||
parts = [message]
|
||||
if skill_name:
|
||||
parts.append(f"Skill: {skill_name}")
|
||||
if file_path:
|
||||
parts.append(f"File: {file_path}")
|
||||
if suggestion:
|
||||
parts.append(f"Suggestion: {suggestion}")
|
||||
if context:
|
||||
for key, value in context.items():
|
||||
parts.append(f"{key}: {value}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": " | ".join(parts),
|
||||
"skill_name": skill_name,
|
||||
"file_path": file_path,
|
||||
"suggestion": suggestion,
|
||||
}
|
||||
|
||||
|
||||
# Import security scanner — agent-created skills get the same scrutiny as
|
||||
# community hub installs.
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user