"""Tests for circuit breaker (#885).""" import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker, MultiToolCircuitBreaker, CircuitState def test_closed_allows_execution(): cb = CircuitBreaker(failure_threshold=3) assert cb.can_execute() def test_opens_after_threshold(): cb = CircuitBreaker(failure_threshold=3) cb.record_result(False) cb.record_result(False) assert cb.can_execute() # Still closed at 2 cb.record_result(False) assert not cb.can_execute() # Open at 3 def test_closes_on_success(): cb = CircuitBreaker(failure_threshold=3) cb.record_result(False) cb.record_result(True) assert cb.consecutive_failures == 0 def test_half_open_recovery(): cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.1, success_threshold=1) cb.record_result(False) cb.record_result(False) assert cb.state == CircuitState.OPEN import time time.sleep(0.15) assert cb.can_execute() # Moved to half-open cb.record_result(True) assert cb.state == CircuitState.CLOSED def test_recovery_action_streak(): cb = ToolCircuitBreaker(failure_threshold=3) for _ in range(5): cb.record_result(False) action = cb.get_recovery_action() assert action["action"] == "switch_tool_type" def test_recovery_action_critical(): cb = ToolCircuitBreaker(failure_threshold=3) for _ in range(10): cb.record_result(False) action = cb.get_recovery_action() assert action["action"] == "terminal_only" assert action["severity"] == "critical" def test_multi_tool_breaker(): mcb = MultiToolCircuitBreaker() mcb.record_result("read_file", False) mcb.record_result("read_file", False) mcb.record_result("read_file", False) assert not mcb.can_execute("read_file") assert mcb.can_execute("terminal") # Different tool unaffected def test_global_state(): mcb = MultiToolCircuitBreaker() mcb.record_result("tool_a", False) mcb.record_result("tool_b", False) state = mcb.get_global_state() assert state["global_streak"] == 2 def test_reset(): cb = CircuitBreaker(failure_threshold=2) cb.record_result(False) cb.record_result(False) assert cb.state == CircuitState.OPEN cb.reset() assert cb.state == CircuitState.CLOSED if __name__ == "__main__": tests = [test_closed_allows_execution, test_opens_after_threshold, test_closes_on_success, test_half_open_recovery, test_recovery_action_streak, test_recovery_action_critical, test_multi_tool_breaker, test_global_state, test_reset] for t in tests: print(f"Running {t.__name__}...") t() print(" PASS") print("\nAll tests passed.")