Compare commits
4 Commits
fix/749-ba
...
fix/708
| Author | SHA1 | Date | |
|---|---|---|---|
| 71d3ad7879 | |||
| d86359cbb2 | |||
| f264b55b29 | |||
| dfe23f66b1 |
177
agent/tool_orchestrator.py
Normal file
177
agent/tool_orchestrator.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Tool Orchestrator — Robust execution and circuit breaking for agent tools.
|
||||
|
||||
Provides a unified execution service that wraps the tool registry.
|
||||
Implements the Circuit Breaker pattern to prevent the agent from getting
|
||||
stuck in failure loops when a specific tool or its underlying service
|
||||
is flapping or down.
|
||||
|
||||
Architecture:
|
||||
Discovery (tools/registry.py) -> Orchestration (agent/tool_orchestrator.py) -> Dispatch
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from tools.registry import registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState:
|
||||
"""States for the tool circuit breaker."""
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, execution blocked
|
||||
HALF_OPEN = "half_open" # Testing if service recovered
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolStats:
|
||||
"""Execution statistics for a tool."""
|
||||
name: str
|
||||
state: str = CircuitState.CLOSED
|
||||
failures: int = 0
|
||||
successes: int = 0
|
||||
last_failure_time: float = 0
|
||||
total_execution_time: float = 0
|
||||
call_count: int = 0
|
||||
|
||||
|
||||
class ToolOrchestrator:
|
||||
"""Orchestrates tool execution with robustness patterns."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 3,
|
||||
reset_timeout: int = 300,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
failure_threshold: Number of failures before opening the circuit.
|
||||
reset_timeout: Seconds to wait before transitioning from OPEN to HALF_OPEN.
|
||||
"""
|
||||
self.failure_threshold = failure_threshold
|
||||
self.reset_timeout = reset_timeout
|
||||
self._stats: Dict[str, ToolStats] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _get_stats(self, name: str) -> ToolStats:
|
||||
"""Get or initialize stats for a tool with thread-safe state transition."""
|
||||
with self._lock:
|
||||
if name not in self._stats:
|
||||
self._stats[name] = ToolStats(name=name)
|
||||
|
||||
stats = self._stats[name]
|
||||
|
||||
# Transition from OPEN to HALF_OPEN if timeout expired
|
||||
if stats.state == CircuitState.OPEN:
|
||||
if time.time() - stats.last_failure_time > self.reset_timeout:
|
||||
stats.state = CircuitState.HALF_OPEN
|
||||
logger.info("Circuit breaker HALF_OPEN for tool: %s", name)
|
||||
|
||||
return stats
|
||||
|
||||
def _record_success(self, name: str, execution_time: float):
|
||||
"""Record a successful tool execution and close the circuit."""
|
||||
with self._lock:
|
||||
stats = self._stats[name]
|
||||
stats.successes += 1
|
||||
stats.call_count += 1
|
||||
stats.total_execution_time += execution_time
|
||||
|
||||
if stats.state != CircuitState.CLOSED:
|
||||
logger.info("Circuit breaker CLOSED for tool: %s (recovered)", name)
|
||||
|
||||
stats.state = CircuitState.CLOSED
|
||||
stats.failures = 0
|
||||
|
||||
def _record_failure(self, name: str, execution_time: float):
|
||||
"""Record a failed tool execution and potentially open the circuit."""
|
||||
with self._lock:
|
||||
stats = self._stats[name]
|
||||
stats.failures += 1
|
||||
stats.call_count += 1
|
||||
stats.total_execution_time += execution_time
|
||||
stats.last_failure_time = time.time()
|
||||
|
||||
if stats.state == CircuitState.HALF_OPEN or stats.failures >= self.failure_threshold:
|
||||
stats.state = CircuitState.OPEN
|
||||
logger.warning(
|
||||
"Circuit breaker OPEN for tool: %s (failures: %d)",
|
||||
name, stats.failures
|
||||
)
|
||||
|
||||
def dispatch(self, name: str, args: dict, **kwargs) -> str:
|
||||
"""Execute a tool via the registry with circuit breaker protection."""
|
||||
stats = self._get_stats(name)
|
||||
|
||||
if stats.state == CircuitState.OPEN:
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"Tool '{name}' is temporarily unavailable due to repeated failures. "
|
||||
f"Circuit breaker is OPEN. Please try again in a few minutes or use an alternative tool."
|
||||
),
|
||||
"circuit_breaker": True,
|
||||
"tool_name": name
|
||||
})
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Dispatch to the underlying registry
|
||||
result_str = registry.dispatch(name, args, **kwargs)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Inspect result for errors. registry.dispatch catches internal
|
||||
# exceptions and returns a JSON error string.
|
||||
is_error = False
|
||||
try:
|
||||
# Lightweight check for error key in JSON
|
||||
if '"error":' in result_str:
|
||||
res_json = json.loads(result_str)
|
||||
if isinstance(res_json, dict) and "error" in res_json:
|
||||
is_error = True
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If it's not valid JSON, it's a malformed result (error)
|
||||
is_error = True
|
||||
|
||||
if is_error:
|
||||
self._record_failure(name, execution_time)
|
||||
else:
|
||||
self._record_success(name, execution_time)
|
||||
|
||||
return result_str
|
||||
|
||||
except Exception as e:
|
||||
# This should rarely be hit as registry.dispatch catches most things,
|
||||
# but we guard against orchestrator-level or registry-level bugs.
|
||||
execution_time = time.time() - start_time
|
||||
self._record_failure(name, execution_time)
|
||||
|
||||
error_msg = f"Tool orchestrator error during {name}: {type(e).__name__}: {e}"
|
||||
logger.exception(error_msg)
|
||||
return json.dumps({
|
||||
"error": error_msg,
|
||||
"tool_name": name,
|
||||
"execution_time": execution_time
|
||||
})
|
||||
|
||||
def get_fleet_stats(self) -> Dict[str, Any]:
|
||||
"""Return execution statistics for all tools."""
|
||||
with self._lock:
|
||||
return {
|
||||
name: {
|
||||
"state": s.state,
|
||||
"failures": s.failures,
|
||||
"successes": s.successes,
|
||||
"avg_time": s.total_execution_time / s.call_count if s.call_count > 0 else 0,
|
||||
"calls": s.call_count
|
||||
}
|
||||
for name, s in self._stats.items()
|
||||
}
|
||||
|
||||
|
||||
# Global orchestrator instance
|
||||
orchestrator = ToolOrchestrator()
|
||||
122
evals/atlas_l40s_eval.py
Normal file
122
evals/atlas_l40s_eval.py
Normal file
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Atlas Inference Engine Evaluation on RunPod L40S."""
|
||||
|
||||
import argparse, json, os, sys, time, urllib.request, urllib.error
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
RUNPOD_API = "https://api.runpod.io/graphql"
|
||||
POD_NAME = "atlas-eval-l40s"
|
||||
ATLAS_IMAGE = "avarok/atlas-gb10:alpha-2.8"
|
||||
MODEL = "Qwen/Qwen3.5-35B-A3B-NVFP4"
|
||||
COST_LOG = Path.home() / ".hermes" / "atlas_eval_log.jsonl"
|
||||
|
||||
def load_key():
|
||||
k = os.environ.get("RUNPOD_API_KEY", "")
|
||||
if k: return k.strip()
|
||||
p = Path.home() / ".config" / "runpod" / "access_key"
|
||||
if p.exists(): return p.read_text().strip()
|
||||
print("ERROR: No RunPod key"); sys.exit(1)
|
||||
|
||||
def gql(query):
|
||||
req = urllib.request.Request(RUNPOD_API,
|
||||
data=json.dumps({"query": query}).encode(),
|
||||
headers={"Authorization": f"Bearer {load_key()}", "Content-Type": "application/json"},
|
||||
method="POST")
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as r:
|
||||
return json.loads(r.read().decode())
|
||||
except urllib.error.HTTPError as e:
|
||||
print(f"Error: {e.read().decode()[:300]}"); return None
|
||||
|
||||
def find_pod():
|
||||
r = gql("{ myself { pods { id name desiredStatus costPerHr gpuCount runtime { uptimeInSeconds } } } }")
|
||||
if r and r.get("data"):
|
||||
for p in r["data"]["myself"]["pods"]:
|
||||
if p["name"] == POD_NAME: return p
|
||||
return None
|
||||
|
||||
def deploy():
|
||||
existing = find_pod()
|
||||
if existing:
|
||||
print(f"Exists: {existing['id']} ({existing['desiredStatus']})")
|
||||
if existing["desiredStatus"] == "STOPPED":
|
||||
gql(f'mutation {{ podResume(input: {{ podId: "{existing["id"]}" }}) {{ id }} }}')
|
||||
print("Resuming...")
|
||||
return existing["id"]
|
||||
q = 'mutation { podFindAndDeployOnDemand(input: { cloudType: COMMUNITY, gpuCount: 1, gpuTypeId: "NVIDIA L40S", name: "' + POD_NAME + '", containerDiskInGb: 50, imageName: "' + ATLAS_IMAGE + '", ports: "8888/http", volumeInGb: 100, volumeMountPath: "/workspace" }) { id desiredStatus } }'
|
||||
r = gql(q)
|
||||
if r and r.get("data"):
|
||||
pod = r["data"]["podFindAndDeployOnDemand"]
|
||||
print(f"Deployed: {pod['id']} -> https://{pod['id']}-8888.proxy.runpod.net")
|
||||
return pod["id"]
|
||||
print("Deploy failed")
|
||||
|
||||
def status():
|
||||
pod = find_pod()
|
||||
if not pod: print("No pod"); return
|
||||
print(f"ID: {pod['id']}\nStatus: {pod['desiredStatus']}\nCost: ${pod['costPerHr']}/hr\nEndpoint: https://{pod['id']}-8888.proxy.runpod.net")
|
||||
u = pod.get("runtime", {}).get("uptimeInSeconds", 0)
|
||||
if u: print(f"Uptime: {u//3600}h {(u%3600)//60}m")
|
||||
|
||||
def benchmark():
|
||||
pod = find_pod()
|
||||
if not pod or pod["desiredStatus"] != "RUNNING":
|
||||
print("Pod not running"); return
|
||||
ep = f"https://{pod['id']}-8888.proxy.runpod.net/v1"
|
||||
print(f"Benchmarking: {ep}")
|
||||
prompts = [
|
||||
"Explain sovereign AI in 100 words.",
|
||||
"Write quicksort in Python.",
|
||||
"Compare transformers vs state space models.",
|
||||
"Describe MoE architecture.",
|
||||
"Write a Dockerfile for Flask+Redis.",
|
||||
]
|
||||
results = []
|
||||
for i, p in enumerate(prompts):
|
||||
print(f"\n[{i+1}/5] {p[:40]}...")
|
||||
start = time.time()
|
||||
try:
|
||||
payload = json.dumps({"model": MODEL, "messages": [{"role": "user", "content": p}], "max_tokens": 512}).encode()
|
||||
req = urllib.request.Request(f"{ep}/chat/completions", data=payload,
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer dummy"}, method="POST")
|
||||
with urllib.request.urlopen(req, timeout=120) as resp:
|
||||
r = json.loads(resp.read().decode())
|
||||
elapsed = time.time() - start
|
||||
usage = r.get("usage", {})
|
||||
tps = usage.get("completion_tokens", 0) / elapsed if elapsed > 0 else 0
|
||||
results.append({"prompt": i, "tok_per_sec": round(tps, 2), "tokens": usage.get("completion_tokens", 0)})
|
||||
print(f" {usage.get('completion_tokens', 0)} tokens / {elapsed:.1f}s = {tps:.1f} tok/s")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
results.append({"prompt": i, "error": str(e)})
|
||||
ok = [r for r in results if "tok_per_sec" in r]
|
||||
if ok:
|
||||
avg = sum(r["tok_per_sec"] for r in ok) / len(ok)
|
||||
print(f"\nAvg: {avg:.1f} tok/s | Min: {min(r['tok_per_sec'] for r in ok):.1f} | Max: {max(r['tok_per_sec'] for r in ok):.1f}")
|
||||
COST_LOG.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(COST_LOG, "a") as f:
|
||||
f.write(json.dumps({"ts": datetime.now(timezone.utc).isoformat(), "avg_tps": round(avg, 2), "results": results}) + "\n")
|
||||
|
||||
def stop():
|
||||
pod = find_pod()
|
||||
if not pod: print("No pod"); return
|
||||
gql(f'mutation {{ podStop(input: {{ podId: "{pod["id"]}" }}) {{ id }} }}')
|
||||
print(f"Stopped: {pod['id']}")
|
||||
|
||||
def terminate():
|
||||
pod = find_pod()
|
||||
if not pod: print("No pod"); return
|
||||
gql(f'mutation {{ podTerminate(input: {{ podId: "{pod["id"]}" }}) }}')
|
||||
print(f"Terminated: {pod['id']}")
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(description="Atlas L40S Eval")
|
||||
sub = p.add_subparsers(dest="cmd")
|
||||
sub.add_parser("deploy"); sub.add_parser("status"); sub.add_parser("benchmark")
|
||||
sub.add_parser("stop"); sub.add_parser("terminate")
|
||||
args = p.parse_args()
|
||||
if not args.cmd: p.print_help(); sys.exit(1)
|
||||
{"deploy": deploy, "status": status, "benchmark": benchmark, "stop": stop, "terminate": terminate}[args.cmd]()
|
||||
|
||||
if __name__ == "__main__": main()
|
||||
@@ -28,6 +28,7 @@ from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
from tools.registry import discover_builtin_tools, registry
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
from agent.tool_orchestrator import orchestrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -499,13 +500,13 @@ def handle_function_call(
|
||||
# Prefer the caller-provided list so subagents can't overwrite
|
||||
# the parent's tool set via the process-global.
|
||||
sandbox_enabled = enabled_tools if enabled_tools is not None else _last_resolved_tool_names
|
||||
result = registry.dispatch(
|
||||
result = orchestrator.dispatch(
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
enabled_tools=sandbox_enabled,
|
||||
)
|
||||
else:
|
||||
result = registry.dispatch(
|
||||
result = orchestrator.dispatch(
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
user_task=user_task,
|
||||
|
||||
Reference in New Issue
Block a user