Compare commits
1 Commits
fix/885-ci
...
fix/892
| Author | SHA1 | Date | |
|---|---|---|---|
| 517e2c571e |
@@ -1,273 +0,0 @@
|
||||
"""
|
||||
Circuit Breaker for Error Cascading — #885
|
||||
|
||||
P(error | prev was error) = 58.6% vs P(error | prev was success) = 25.2%.
|
||||
That's a 2.33x cascade factor. After 3 consecutive errors, the circuit
|
||||
opens and the agent must take corrective action.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, errors are counted
|
||||
- OPEN: Too many consecutive errors, corrective action required
|
||||
- HALF_OPEN: Testing if errors have cleared
|
||||
|
||||
Usage:
|
||||
from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker
|
||||
|
||||
cb = ToolCircuitBreaker()
|
||||
|
||||
# After each tool call
|
||||
if not cb.record_result(success=True):
|
||||
# Circuit is open — take corrective action
|
||||
cb.get_recovery_action()
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Too many errors, block execution
|
||||
HALF_OPEN = "half_open" # Testing recovery
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Generic circuit breaker with configurable thresholds.
|
||||
|
||||
Tracks consecutive errors and opens the circuit when the
|
||||
error streak exceeds the threshold.
|
||||
"""
|
||||
failure_threshold: int = 3
|
||||
recovery_timeout: float = 30.0 # seconds before trying half-open
|
||||
success_threshold: int = 2 # successes needed to close from half-open
|
||||
|
||||
state: CircuitState = field(default=CircuitState.CLOSED, init=False)
|
||||
consecutive_failures: int = field(default=0, init=False)
|
||||
consecutive_successes: int = field(default=0, init=False)
|
||||
last_failure_time: Optional[float] = field(default=None, init=False)
|
||||
total_trips: int = field(default=0, init=False)
|
||||
error_streaks: List[int] = field(default_factory=list, init=False)
|
||||
|
||||
def record_result(self, success: bool) -> bool:
|
||||
"""
|
||||
Record a tool call result. Returns True if circuit allows execution.
|
||||
|
||||
Returns:
|
||||
True if circuit is CLOSED or HALF_OPEN (execution allowed)
|
||||
False if circuit is OPEN (execution blocked)
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
if self.state == CircuitState.OPEN:
|
||||
# Check if recovery timeout has passed
|
||||
if self.last_failure_time and (now - self.last_failure_time) >= self.recovery_timeout:
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
self.consecutive_successes = 0
|
||||
return True # Allow one test execution
|
||||
return False # Still open
|
||||
|
||||
if success:
|
||||
self.consecutive_failures = 0
|
||||
self.consecutive_successes += 1
|
||||
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
if self.consecutive_successes >= self.success_threshold:
|
||||
self.state = CircuitState.CLOSED
|
||||
self.consecutive_successes = 0
|
||||
|
||||
return True
|
||||
else:
|
||||
self.consecutive_successes = 0
|
||||
self.consecutive_failures += 1
|
||||
self.last_failure_time = now
|
||||
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
# Failed during recovery — reopen immediately
|
||||
self.state = CircuitState.OPEN
|
||||
self.total_trips += 1
|
||||
return False
|
||||
|
||||
if self.consecutive_failures >= self.failure_threshold:
|
||||
self.state = CircuitState.OPEN
|
||||
self.total_trips += 1
|
||||
self.error_streaks.append(self.consecutive_failures)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def can_execute(self) -> bool:
|
||||
"""Check if execution is allowed."""
|
||||
if self.state == CircuitState.OPEN:
|
||||
if self.last_failure_time:
|
||||
now = time.time()
|
||||
if (now - self.last_failure_time) >= self.recovery_timeout:
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
self.consecutive_successes = 0
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get current circuit state."""
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"consecutive_failures": self.consecutive_failures,
|
||||
"consecutive_successes": self.consecutive_successes,
|
||||
"total_trips": self.total_trips,
|
||||
"max_streak": max(self.error_streaks) if self.error_streaks else 0,
|
||||
"can_execute": self.can_execute(),
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the circuit breaker."""
|
||||
self.state = CircuitState.CLOSED
|
||||
self.consecutive_failures = 0
|
||||
self.consecutive_successes = 0
|
||||
self.last_failure_time = None
|
||||
|
||||
|
||||
class ToolCircuitBreaker(CircuitBreaker):
|
||||
"""
|
||||
Circuit breaker specifically for tool call error cascading.
|
||||
|
||||
Provides recovery actions when the circuit opens.
|
||||
"""
|
||||
|
||||
# Tools that are most effective at recovery (from audit data)
|
||||
RECOVERY_TOOLS = [
|
||||
"terminal", # Most effective — 2300 recoveries
|
||||
"read_file", # Reset context by reading something
|
||||
"search_files", # Find what went wrong
|
||||
]
|
||||
|
||||
def get_recovery_action(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the recommended recovery action when circuit is open.
|
||||
|
||||
Returns dict with action type and details.
|
||||
"""
|
||||
streak = self.consecutive_failures
|
||||
|
||||
if streak >= 9:
|
||||
# After 9 errors: 41/46 recoveries via terminal
|
||||
return {
|
||||
"action": "terminal_only",
|
||||
"reason": f"Error streak of {streak} — terminal is the only reliable recovery",
|
||||
"suggested_tool": "terminal",
|
||||
"suggested_command": "echo 'Resetting context'",
|
||||
"severity": "critical",
|
||||
}
|
||||
elif streak >= 5:
|
||||
return {
|
||||
"action": "switch_tool_type",
|
||||
"reason": f"Error streak of {streak} — switch to a different tool category",
|
||||
"suggested_tools": ["read_file", "search_files", "terminal"],
|
||||
"severity": "high",
|
||||
}
|
||||
elif streak >= self.failure_threshold:
|
||||
return {
|
||||
"action": "ask_user",
|
||||
"reason": f"{streak} consecutive errors — ask user for guidance",
|
||||
"suggested_response": "I'm encountering repeated errors. Would you like me to try a different approach?",
|
||||
"severity": "medium",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"action": "continue",
|
||||
"reason": f"Error streak of {streak} — within tolerance",
|
||||
"severity": "low",
|
||||
}
|
||||
|
||||
def should_compress_context(self) -> bool:
|
||||
"""Determine if context compression would help recovery."""
|
||||
return self.consecutive_failures >= 5
|
||||
|
||||
def get_blocked_tool(self) -> Optional[str]:
|
||||
"""Get the tool that should be blocked (if any)."""
|
||||
if self.state == CircuitState.OPEN:
|
||||
return "last_failed_tool"
|
||||
return None
|
||||
|
||||
|
||||
class MultiToolCircuitBreaker:
|
||||
"""
|
||||
Manages per-tool circuit breakers and cross-tool cascade detection.
|
||||
|
||||
When one tool trips its breaker, related tools are also warned.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.breakers: Dict[str, ToolCircuitBreaker] = {}
|
||||
self.global_streak: int = 0
|
||||
self.last_tool: Optional[str] = None
|
||||
self.last_success: bool = True
|
||||
|
||||
def get_breaker(self, tool_name: str) -> ToolCircuitBreaker:
|
||||
"""Get or create a circuit breaker for a tool."""
|
||||
if tool_name not in self.breakers:
|
||||
self.breakers[tool_name] = ToolCircuitBreaker()
|
||||
return self.breakers[tool_name]
|
||||
|
||||
def record_result(self, tool_name: str, success: bool) -> bool:
|
||||
"""
|
||||
Record a tool call result. Returns True if execution should continue.
|
||||
"""
|
||||
breaker = self.get_breaker(tool_name)
|
||||
allowed = breaker.record_result(success)
|
||||
|
||||
# Track global streak
|
||||
if success:
|
||||
self.global_streak = 0
|
||||
self.last_success = True
|
||||
else:
|
||||
self.global_streak += 1
|
||||
self.last_success = False
|
||||
|
||||
self.last_tool = tool_name
|
||||
return allowed
|
||||
|
||||
def can_execute(self, tool_name: str) -> bool:
|
||||
"""Check if a specific tool can execute."""
|
||||
breaker = self.get_breaker(tool_name)
|
||||
return breaker.can_execute()
|
||||
|
||||
def get_global_state(self) -> Dict[str, Any]:
|
||||
"""Get overall circuit breaker state."""
|
||||
return {
|
||||
"global_streak": self.global_streak,
|
||||
"last_tool": self.last_tool,
|
||||
"last_success": self.last_success,
|
||||
"tool_states": {
|
||||
name: breaker.get_state()
|
||||
for name, breaker in self.breakers.items()
|
||||
if breaker.consecutive_failures > 0 or breaker.total_trips > 0
|
||||
},
|
||||
"any_open": any(b.state == CircuitState.OPEN for b in self.breakers.values()),
|
||||
}
|
||||
|
||||
def get_recovery_action(self) -> Dict[str, Any]:
|
||||
"""Get recovery action based on global state."""
|
||||
if self.global_streak == 0:
|
||||
return {"action": "continue", "reason": "No errors"}
|
||||
|
||||
# Find the breaker with the worst streak
|
||||
worst = max(self.breakers.values(), key=lambda b: b.consecutive_failures, default=None)
|
||||
if worst and worst.consecutive_failures > 0:
|
||||
return worst.get_recovery_action()
|
||||
|
||||
return {
|
||||
"action": "continue",
|
||||
"reason": f"Global streak: {self.global_streak}",
|
||||
"severity": "low",
|
||||
}
|
||||
|
||||
def reset_all(self):
|
||||
"""Reset all circuit breakers."""
|
||||
for breaker in self.breakers.values():
|
||||
breaker.reset()
|
||||
self.global_streak = 0
|
||||
self.last_success = True
|
||||
224
gateway/config_validator.py
Normal file
224
gateway/config_validator.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Gateway Config Validator & Fallback Fix — #892.
|
||||
|
||||
Validates gateway configuration and provides sensible defaults
|
||||
for missing keys to prevent fallback chain breaks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigIssue:
|
||||
"""A configuration issue found during validation."""
|
||||
key: str
|
||||
severity: str # error, warning, info
|
||||
message: str
|
||||
fix: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigValidation:
|
||||
"""Result of config validation."""
|
||||
valid: bool
|
||||
issues: List[ConfigIssue] = field(default_factory=list)
|
||||
warnings: int = 0
|
||||
errors: int = 0
|
||||
|
||||
|
||||
# Required keys and their defaults
|
||||
REQUIRED_KEYS = {
|
||||
"OPENROUTER_API_KEY": {
|
||||
"required": False,
|
||||
"default": "",
|
||||
"severity": "warning",
|
||||
"message": "OPENROUTER_API_KEY not set - fallback chain may break",
|
||||
"fix": "Set OPENROUTER_API_KEY in .env for OpenRouter provider",
|
||||
},
|
||||
"API_SERVER_KEY": {
|
||||
"required": False,
|
||||
"default": "",
|
||||
"severity": "warning",
|
||||
"message": "API_SERVER_KEY not configured",
|
||||
"fix": "Set API_SERVER_KEY in .env for API server auth",
|
||||
},
|
||||
"GITEA_TOKEN": {
|
||||
"required": False,
|
||||
"default": "",
|
||||
"severity": "info",
|
||||
"message": "GITEA_TOKEN not set - Gitea features disabled",
|
||||
"fix": "Set GITEA_TOKEN in .env for Gitea integration",
|
||||
},
|
||||
}
|
||||
|
||||
# Config validation rules
|
||||
VALIDATION_RULES = [
|
||||
{
|
||||
"key": "idle_minutes",
|
||||
"validate": lambda v: isinstance(v, (int, float)) and v > 0,
|
||||
"message": "Invalid idle_minutes={v} - must be > 0",
|
||||
"fix": "Set idle_minutes to positive integer (default: 30)",
|
||||
},
|
||||
{
|
||||
"key": "max_skills_discord",
|
||||
"validate": lambda v: isinstance(v, int) and v <= 100,
|
||||
"message": "Discord slash command limit reached ({v}/100) - skills not registered",
|
||||
"fix": "Reduce skills or paginate registration",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def validate_config(config: Dict[str, Any]) -> ConfigValidation:
|
||||
"""
|
||||
Validate gateway configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
ConfigValidation with issues found
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# Check required keys
|
||||
for key, spec in REQUIRED_KEYS.items():
|
||||
value = config.get(key) or os.environ.get(key) or spec["default"]
|
||||
if spec["required"] and not value:
|
||||
issues.append(ConfigIssue(
|
||||
key=key,
|
||||
severity=spec["severity"],
|
||||
message=spec["message"],
|
||||
fix=spec["fix"],
|
||||
))
|
||||
elif not value and spec["severity"] != "error":
|
||||
issues.append(ConfigIssue(
|
||||
key=key,
|
||||
severity=spec["severity"],
|
||||
message=spec["message"],
|
||||
fix=spec["fix"],
|
||||
))
|
||||
|
||||
# Check validation rules
|
||||
for rule in VALIDATION_RULES:
|
||||
value = config.get(rule["key"])
|
||||
if value is not None:
|
||||
if not rule["validate"](value):
|
||||
issues.append(ConfigIssue(
|
||||
key=rule["key"],
|
||||
severity="error",
|
||||
message=rule["message"].format(v=value),
|
||||
fix=rule["fix"],
|
||||
))
|
||||
|
||||
errors = sum(1 for i in issues if i.severity == "error")
|
||||
warnings = sum(1 for i in issues if i.severity == "warning")
|
||||
|
||||
return ConfigValidation(
|
||||
valid=errors == 0,
|
||||
issues=issues,
|
||||
warnings=warnings,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
def apply_defaults(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Apply default values for missing config keys.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Config with defaults applied
|
||||
"""
|
||||
result = dict(config)
|
||||
|
||||
for key, spec in REQUIRED_KEYS.items():
|
||||
if key not in result or not result[key]:
|
||||
default = os.environ.get(key) or spec["default"]
|
||||
if default:
|
||||
result[key] = default
|
||||
logger.debug("Applied default for %s", key)
|
||||
|
||||
# Apply validation defaults
|
||||
if "idle_minutes" not in result or not result["idle_minutes"] or result["idle_minutes"] <= 0:
|
||||
result["idle_minutes"] = 30
|
||||
logger.debug("Applied default idle_minutes=30")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fix_discord_skill_limit(skills: List[str], max_skills: int = 95) -> List[str]:
|
||||
"""
|
||||
Fix Discord slash command limit by reducing skills.
|
||||
|
||||
Args:
|
||||
skills: List of skill names
|
||||
max_skills: Maximum skills to register (default 95, leaving room for built-ins)
|
||||
|
||||
Returns:
|
||||
Reduced skill list
|
||||
"""
|
||||
if len(skills) <= max_skills:
|
||||
return skills
|
||||
|
||||
logger.warning(
|
||||
"Discord skill limit: %d skills exceeds %d limit, truncating",
|
||||
len(skills), max_skills
|
||||
)
|
||||
|
||||
# Keep first max_skills (alphabetical priority)
|
||||
return sorted(skills)[:max_skills]
|
||||
|
||||
|
||||
def validate_provider_config(provider: str, config: Dict[str, Any]) -> ConfigIssue:
|
||||
"""
|
||||
Validate provider-specific configuration.
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
config: Provider config
|
||||
|
||||
Returns:
|
||||
ConfigIssue if invalid, None if valid
|
||||
"""
|
||||
if provider == "local-llama.cpp":
|
||||
# Check if llama.cpp is configured
|
||||
if not config.get("model_path") and not config.get("base_url"):
|
||||
return ConfigIssue(
|
||||
key=f"provider.{provider}",
|
||||
severity="warning",
|
||||
message=f"{provider} provider not configured - fallback fails",
|
||||
fix=f"Configure {provider} model_path or base_url, or remove from provider list",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def format_validation_report(validation: ConfigValidation) -> str:
|
||||
"""Format validation results as a report."""
|
||||
lines = [
|
||||
"=" * 50,
|
||||
"GATEWAY CONFIG VALIDATION",
|
||||
"=" * 50,
|
||||
"",
|
||||
f"Status: {'VALID' if validation.valid else 'INVALID'}",
|
||||
f"Errors: {validation.errors}",
|
||||
f"Warnings: {validation.warnings}",
|
||||
"",
|
||||
]
|
||||
|
||||
if validation.issues:
|
||||
lines.append("Issues:")
|
||||
for issue in validation.issues:
|
||||
icon = "❌" if issue.severity == "error" else "⚠️" if issue.severity == "warning" else "ℹ️"
|
||||
lines.append(f" {icon} [{issue.key}] {issue.message}")
|
||||
lines.append(f" Fix: {issue.fix}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -1,97 +0,0 @@
|
||||
"""Tests for circuit breaker (#885)."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker, MultiToolCircuitBreaker, CircuitState
|
||||
|
||||
|
||||
def test_closed_allows_execution():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
assert cb.can_execute()
|
||||
|
||||
|
||||
def test_opens_after_threshold():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.can_execute() # Still closed at 2
|
||||
cb.record_result(False)
|
||||
assert not cb.can_execute() # Open at 3
|
||||
|
||||
|
||||
def test_closes_on_success():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
cb.record_result(False)
|
||||
cb.record_result(True)
|
||||
assert cb.consecutive_failures == 0
|
||||
|
||||
|
||||
def test_half_open_recovery():
|
||||
cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.1, success_threshold=1)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
import time
|
||||
time.sleep(0.15)
|
||||
|
||||
assert cb.can_execute() # Moved to half-open
|
||||
cb.record_result(True)
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
|
||||
def test_recovery_action_streak():
|
||||
cb = ToolCircuitBreaker(failure_threshold=3)
|
||||
for _ in range(5):
|
||||
cb.record_result(False)
|
||||
action = cb.get_recovery_action()
|
||||
assert action["action"] == "switch_tool_type"
|
||||
|
||||
|
||||
def test_recovery_action_critical():
|
||||
cb = ToolCircuitBreaker(failure_threshold=3)
|
||||
for _ in range(10):
|
||||
cb.record_result(False)
|
||||
action = cb.get_recovery_action()
|
||||
assert action["action"] == "terminal_only"
|
||||
assert action["severity"] == "critical"
|
||||
|
||||
|
||||
def test_multi_tool_breaker():
|
||||
mcb = MultiToolCircuitBreaker()
|
||||
mcb.record_result("read_file", False)
|
||||
mcb.record_result("read_file", False)
|
||||
mcb.record_result("read_file", False)
|
||||
assert not mcb.can_execute("read_file")
|
||||
assert mcb.can_execute("terminal") # Different tool unaffected
|
||||
|
||||
|
||||
def test_global_state():
|
||||
mcb = MultiToolCircuitBreaker()
|
||||
mcb.record_result("tool_a", False)
|
||||
mcb.record_result("tool_b", False)
|
||||
state = mcb.get_global_state()
|
||||
assert state["global_streak"] == 2
|
||||
|
||||
|
||||
def test_reset():
|
||||
cb = CircuitBreaker(failure_threshold=2)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.state == CircuitState.OPEN
|
||||
cb.reset()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = [test_closed_allows_execution, test_opens_after_threshold,
|
||||
test_closes_on_success, test_half_open_recovery,
|
||||
test_recovery_action_streak, test_recovery_action_critical,
|
||||
test_multi_tool_breaker, test_global_state, test_reset]
|
||||
for t in tests:
|
||||
print(f"Running {t.__name__}...")
|
||||
t()
|
||||
print(" PASS")
|
||||
print("\nAll tests passed.")
|
||||
@@ -44,34 +44,6 @@ from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_error(
|
||||
message: str,
|
||||
skill_name: str = None,
|
||||
file_path: str = None,
|
||||
suggestion: str = None,
|
||||
context: dict = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Format an error with rich context for better debugging."""
|
||||
parts = [message]
|
||||
if skill_name:
|
||||
parts.append(f"Skill: {skill_name}")
|
||||
if file_path:
|
||||
parts.append(f"File: {file_path}")
|
||||
if suggestion:
|
||||
parts.append(f"Suggestion: {suggestion}")
|
||||
if context:
|
||||
for key, value in context.items():
|
||||
parts.append(f"{key}: {value}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": " | ".join(parts),
|
||||
"skill_name": skill_name,
|
||||
"file_path": file_path,
|
||||
"suggestion": suggestion,
|
||||
}
|
||||
|
||||
|
||||
# Import security scanner — agent-created skills get the same scrutiny as
|
||||
# community hub installs.
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user