178 lines
6.4 KiB
Python
178 lines
6.4 KiB
Python
"""Tool Orchestrator — Robust execution and circuit breaking for agent tools.
|
|
|
|
Provides a unified execution service that wraps the tool registry.
|
|
Implements the Circuit Breaker pattern to prevent the agent from getting
|
|
stuck in failure loops when a specific tool or its underlying service
|
|
is flapping or down.
|
|
|
|
Architecture:
|
|
Discovery (tools/registry.py) -> Orchestration (agent/tool_orchestrator.py) -> Dispatch
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
import logging
|
|
import threading
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from tools.registry import registry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CircuitState:
|
|
"""States for the tool circuit breaker."""
|
|
CLOSED = "closed" # Normal operation
|
|
OPEN = "open" # Failing, execution blocked
|
|
HALF_OPEN = "half_open" # Testing if service recovered
|
|
|
|
|
|
@dataclass
|
|
class ToolStats:
|
|
"""Execution statistics for a tool."""
|
|
name: str
|
|
state: str = CircuitState.CLOSED
|
|
failures: int = 0
|
|
successes: int = 0
|
|
last_failure_time: float = 0
|
|
total_execution_time: float = 0
|
|
call_count: int = 0
|
|
|
|
|
|
class ToolOrchestrator:
|
|
"""Orchestrates tool execution with robustness patterns."""
|
|
|
|
def __init__(
|
|
self,
|
|
failure_threshold: int = 3,
|
|
reset_timeout: int = 300,
|
|
):
|
|
"""
|
|
Args:
|
|
failure_threshold: Number of failures before opening the circuit.
|
|
reset_timeout: Seconds to wait before transitioning from OPEN to HALF_OPEN.
|
|
"""
|
|
self.failure_threshold = failure_threshold
|
|
self.reset_timeout = reset_timeout
|
|
self._stats: Dict[str, ToolStats] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def _get_stats(self, name: str) -> ToolStats:
|
|
"""Get or initialize stats for a tool with thread-safe state transition."""
|
|
with self._lock:
|
|
if name not in self._stats:
|
|
self._stats[name] = ToolStats(name=name)
|
|
|
|
stats = self._stats[name]
|
|
|
|
# Transition from OPEN to HALF_OPEN if timeout expired
|
|
if stats.state == CircuitState.OPEN:
|
|
if time.time() - stats.last_failure_time > self.reset_timeout:
|
|
stats.state = CircuitState.HALF_OPEN
|
|
logger.info("Circuit breaker HALF_OPEN for tool: %s", name)
|
|
|
|
return stats
|
|
|
|
def _record_success(self, name: str, execution_time: float):
|
|
"""Record a successful tool execution and close the circuit."""
|
|
with self._lock:
|
|
stats = self._stats[name]
|
|
stats.successes += 1
|
|
stats.call_count += 1
|
|
stats.total_execution_time += execution_time
|
|
|
|
if stats.state != CircuitState.CLOSED:
|
|
logger.info("Circuit breaker CLOSED for tool: %s (recovered)", name)
|
|
|
|
stats.state = CircuitState.CLOSED
|
|
stats.failures = 0
|
|
|
|
def _record_failure(self, name: str, execution_time: float):
|
|
"""Record a failed tool execution and potentially open the circuit."""
|
|
with self._lock:
|
|
stats = self._stats[name]
|
|
stats.failures += 1
|
|
stats.call_count += 1
|
|
stats.total_execution_time += execution_time
|
|
stats.last_failure_time = time.time()
|
|
|
|
if stats.state == CircuitState.HALF_OPEN or stats.failures >= self.failure_threshold:
|
|
stats.state = CircuitState.OPEN
|
|
logger.warning(
|
|
"Circuit breaker OPEN for tool: %s (failures: %d)",
|
|
name, stats.failures
|
|
)
|
|
|
|
def dispatch(self, name: str, args: dict, **kwargs) -> str:
|
|
"""Execute a tool via the registry with circuit breaker protection."""
|
|
stats = self._get_stats(name)
|
|
|
|
if stats.state == CircuitState.OPEN:
|
|
return json.dumps({
|
|
"error": (
|
|
f"Tool '{name}' is temporarily unavailable due to repeated failures. "
|
|
f"Circuit breaker is OPEN. Please try again in a few minutes or use an alternative tool."
|
|
),
|
|
"circuit_breaker": True,
|
|
"tool_name": name
|
|
})
|
|
|
|
start_time = time.time()
|
|
try:
|
|
# Dispatch to the underlying registry
|
|
result_str = registry.dispatch(name, args, **kwargs)
|
|
execution_time = time.time() - start_time
|
|
|
|
# Inspect result for errors. registry.dispatch catches internal
|
|
# exceptions and returns a JSON error string.
|
|
is_error = False
|
|
try:
|
|
# Lightweight check for error key in JSON
|
|
if '"error":' in result_str:
|
|
res_json = json.loads(result_str)
|
|
if isinstance(res_json, dict) and "error" in res_json:
|
|
is_error = True
|
|
except (json.JSONDecodeError, TypeError):
|
|
# If it's not valid JSON, it's a malformed result (error)
|
|
is_error = True
|
|
|
|
if is_error:
|
|
self._record_failure(name, execution_time)
|
|
else:
|
|
self._record_success(name, execution_time)
|
|
|
|
return result_str
|
|
|
|
except Exception as e:
|
|
# This should rarely be hit as registry.dispatch catches most things,
|
|
# but we guard against orchestrator-level or registry-level bugs.
|
|
execution_time = time.time() - start_time
|
|
self._record_failure(name, execution_time)
|
|
|
|
error_msg = f"Tool orchestrator error during {name}: {type(e).__name__}: {e}"
|
|
logger.exception(error_msg)
|
|
return json.dumps({
|
|
"error": error_msg,
|
|
"tool_name": name,
|
|
"execution_time": execution_time
|
|
})
|
|
|
|
def get_fleet_stats(self) -> Dict[str, Any]:
|
|
"""Return execution statistics for all tools."""
|
|
with self._lock:
|
|
return {
|
|
name: {
|
|
"state": s.state,
|
|
"failures": s.failures,
|
|
"successes": s.successes,
|
|
"avg_time": s.total_execution_time / s.call_count if s.call_count > 0 else 0,
|
|
"calls": s.call_count
|
|
}
|
|
for name, s in self._stats.items()
|
|
}
|
|
|
|
|
|
# Global orchestrator instance
|
|
orchestrator = ToolOrchestrator()
|