Compare commits

..

2 Commits

Author SHA1 Message Date
4849b12338 feat: batch executor tests (#749)
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 51s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 55s
Tests / e2e (pull_request) Successful in 4m34s
Tests / test (pull_request) Failing after 56m41s
2026-04-15 03:08:11 +00:00
f00d75e364 feat: batch tool execution with parallel safety checks (#749)
Classifies tool calls as parallel-safe (reads) or sequential
(writes/destructive). Executes parallel calls concurrently,
sequential calls serialized.

Closes #749
2026-04-15 03:08:02 +00:00
4 changed files with 327 additions and 368 deletions

View File

@@ -1,272 +0,0 @@
#!/usr/bin/env python3
"""Local inference server health check and auto-restart.
Checks llama-server, Ollama, and other local inference endpoints.
Reports status, latency, and can auto-restart dead processes.
Refs: #713 — llama-server DOWN on port 8081
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
import time
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
from urllib.request import Request, urlopen
from urllib.error import URLError, HTTPError
@dataclass
class InferenceEndpoint:
"""Configuration for an inference server endpoint."""
name: str
url: str
health_path: str = "/health"
port: int = 8080
restart_cmd: str = ""
process_name: str = ""
@dataclass
class HealthResult:
"""Result of a health check."""
name: str
url: str
status: str # "ok", "down", "slow", "error"
latency_ms: float = 0.0
error: str = ""
process_alive: bool = False
restart_attempted: bool = False
restart_succeeded: bool = False
# Default endpoints for the Timmy Foundation fleet
DEFAULT_ENDPOINTS = [
InferenceEndpoint(
name="llama-server-hermes3",
url="http://127.0.0.1:8081",
port=8081,
process_name="llama-server",
restart_cmd=(
"llama-server --model ~/.ollama/models/blobs/sha256-c8985d "
"--port 8081 --host 127.0.0.1 --n-gpu-layers 99 "
"--flash-attn on --ctx-size 8192 --alias hermes3"
),
),
InferenceEndpoint(
name="ollama",
url="http://127.0.0.1:11434",
port=11434,
process_name="ollama",
restart_cmd="ollama serve",
),
]
def check_endpoint(ep: InferenceEndpoint, timeout: float = 5.0) -> HealthResult:
"""Check a single inference endpoint.
Args:
ep: Endpoint configuration.
timeout: HTTP timeout in seconds.
Returns:
HealthResult with status and latency.
"""
url = ep.url.rstrip("/") + ep.health_path
start = time.time()
# Check if process is alive
process_alive = False
if ep.process_name:
try:
result = subprocess.run(
["pgrep", "-f", ep.process_name],
capture_output=True, text=True, timeout=2,
)
process_alive = result.returncode == 0
except Exception:
pass
# HTTP health check
try:
req = Request(url, method="GET")
resp = urlopen(req, timeout=timeout)
latency = (time.time() - start) * 1000
if resp.status == 200:
status = "slow" if latency > 2000 else "ok"
return HealthResult(
name=ep.name, url=ep.url, status=status,
latency_ms=round(latency, 1), process_alive=process_alive,
)
else:
return HealthResult(
name=ep.name, url=ep.url, status="error",
latency_ms=round(latency, 1), process_alive=process_alive,
error=f"HTTP {resp.status}",
)
except URLError as e:
latency = (time.time() - start) * 1000
error_msg = str(e.reason) if hasattr(e, 'reason') else str(e)
return HealthResult(
name=ep.name, url=ep.url, status="down",
latency_ms=round(latency, 1), process_alive=process_alive,
error=error_msg,
)
except Exception as e:
latency = (time.time() - start) * 1000
return HealthResult(
name=ep.name, url=ep.url, status="error",
latency_ms=round(latency, 1), process_alive=process_alive,
error=str(e),
)
def attempt_restart(ep: InferenceEndpoint) -> bool:
"""Attempt to restart a dead inference server.
Args:
ep: Endpoint configuration with restart_cmd.
Returns:
True if restart command executed successfully.
"""
if not ep.restart_cmd:
return False
try:
# Run restart in background
subprocess.Popen(
ep.restart_cmd,
shell=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
# Wait a moment for the process to start
time.sleep(3)
return True
except Exception as e:
print(f"Restart failed for {ep.name}: {e}", file=sys.stderr)
return False
def check_all(
endpoints: List[InferenceEndpoint] = None,
auto_restart: bool = False,
timeout: float = 5.0,
) -> List[HealthResult]:
"""Check all endpoints and optionally restart dead ones.
Args:
endpoints: List of endpoints to check. Uses DEFAULT_ENDPOINTS if None.
auto_restart: If True, attempt to restart down endpoints.
timeout: HTTP timeout per endpoint.
Returns:
List of HealthResult for each endpoint.
"""
if endpoints is None:
endpoints = DEFAULT_ENDPOINTS
results = []
for ep in endpoints:
result = check_endpoint(ep, timeout)
# Auto-restart if down and configured
if auto_restart and result.status == "down" and ep.restart_cmd:
result.restart_attempted = True
result.restart_succeeded = attempt_restart(ep)
if result.restart_succeeded:
# Re-check after restart
time.sleep(2)
result2 = check_endpoint(ep, timeout)
result.status = result2.status
result.latency_ms = result2.latency_ms
result.error = result2.error
results.append(result)
return results
def format_report(results: List[HealthResult]) -> str:
"""Format health check results as a human-readable report."""
lines = [
"# Local Inference Health Check",
f"Time: {time.strftime('%Y-%m-%d %H:%M:%S')}",
"",
"| Endpoint | Status | Latency | Process | Error |",
"|----------|--------|---------|---------|-------|",
]
for r in results:
status_icon = {"ok": "", "slow": "⚠️", "down": "", "error": "💥"}.get(r.status, "?")
proc = "alive" if r.process_alive else "dead"
lat = f"{r.latency_ms}ms" if r.latency_ms > 0 else "-"
err = r.error[:40] if r.error else "-"
lines.append(f"| {r.name} | {status_icon} {r.status} | {lat} | {proc} | {err} |")
down = [r for r in results if r.status in ("down", "error")]
if down:
lines.extend(["", "## DOWN", ""])
for r in down:
lines.append(f"- **{r.name}** ({r.url}): {r.error}")
if r.restart_attempted:
status = "✅ restarted" if r.restart_succeeded else "❌ restart failed"
lines.append(f" Restart: {status}")
return "\n".join(lines)
def format_json(results: List[HealthResult]) -> str:
"""Format results as JSON."""
data = []
for r in results:
data.append({
"name": r.name,
"url": r.url,
"status": r.status,
"latency_ms": r.latency_ms,
"process_alive": r.process_alive,
"error": r.error or None,
"restart_attempted": r.restart_attempted,
"restart_succeeded": r.restart_succeeded,
})
return json.dumps({"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "endpoints": data}, indent=2)
def main():
import argparse
p = argparse.ArgumentParser(description="Local inference health check")
p.add_argument("--json", action="store_true", help="JSON output")
p.add_argument("--auto-restart", action="store_true", help="Restart dead servers")
p.add_argument("--timeout", type=float, default=5.0, help="HTTP timeout (seconds)")
p.add_argument("--port", type=int, help="Check specific port only")
a = p.parse_args()
endpoints = DEFAULT_ENDPOINTS
if a.port:
endpoints = [ep for ep in DEFAULT_ENDPOINTS if ep.port == a.port]
if not endpoints:
print(f"No endpoint configured for port {a.port}", file=sys.stderr)
sys.exit(1)
results = check_all(endpoints, auto_restart=a.auto_restart, timeout=a.timeout)
if a.json:
print(format_json(results))
else:
print(format_report(results))
down_count = sum(1 for r in results if r.status in ("down", "error"))
sys.exit(1 if down_count > 0 else 0)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,77 @@
"""Tests for batch tool execution (#749)."""
import pytest
from tools.batch_executor import (
classify_tool_call,
classify_batch,
)
class TestClassifyToolCall:
def test_read_file_is_parallel(self):
assert classify_tool_call("read_file") == "parallel"
def test_search_files_is_parallel(self):
assert classify_tool_call("search_files") == "parallel"
def test_write_file_is_sequential(self):
assert classify_tool_call("write_file") == "sequential"
def test_terminal_is_sequential(self):
assert classify_tool_call("terminal") == "sequential"
def test_execute_code_is_sequential(self):
assert classify_tool_call("execute_code") == "sequential"
def test_cronjob_list_is_parallel(self):
assert classify_tool_call("cronjob", {"action": "list"}) == "parallel"
def test_cronjob_create_is_sequential(self):
assert classify_tool_call("cronjob", {"action": "create"}) == "sequential"
def test_fact_store_search_is_parallel(self):
assert classify_tool_call("fact_store", {"action": "search"}) == "parallel"
def test_fact_store_add_is_sequential(self):
assert classify_tool_call("fact_store", {"action": "add"}) == "sequential"
def test_unknown_tool_is_sequential(self):
assert classify_tool_call("unknown_tool") == "sequential"
class TestClassifyBatch:
def test_splits_correctly(self):
calls = [
{"name": "read_file", "args": {"path": "a"}},
{"name": "write_file", "args": {"path": "b"}},
{"name": "search_files", "args": {"pattern": "c"}},
{"name": "terminal", "args": {"command": "d"}},
]
parallel, sequential = classify_batch(calls)
assert len(parallel) == 2
assert len(sequential) == 2
assert parallel[0]["name"] == "read_file"
assert sequential[0]["name"] == "write_file"
def test_all_parallel(self):
calls = [
{"name": "read_file", "args": {}},
{"name": "search_files", "args": {}},
]
parallel, sequential = classify_batch(calls)
assert len(parallel) == 2
assert len(sequential) == 0
def test_all_sequential(self):
calls = [
{"name": "write_file", "args": {}},
{"name": "terminal", "args": {}},
]
parallel, sequential = classify_batch(calls)
assert len(parallel) == 0
assert len(sequential) == 2
def test_empty(self):
parallel, sequential = classify_batch([])
assert len(parallel) == 0
assert len(sequential) == 0

View File

@@ -1,96 +0,0 @@
"""Tests for inference health check (#713)."""
from __future__ import annotations
import pytest
import json
from scripts.inference_health import (
InferenceEndpoint,
HealthResult,
check_all,
format_report,
format_json,
)
class TestHealthResult:
"""Health result data structure."""
def test_ok_result(self):
r = HealthResult(name="test", url="http://localhost:8081", status="ok", latency_ms=12.5)
assert r.status == "ok"
assert r.latency_ms == 12.5
assert not r.error
def test_down_result(self):
r = HealthResult(
name="test", url="http://localhost:8081",
status="down", error="Connection refused",
)
assert r.status == "down"
assert r.error == "Connection refused"
class TestInferenceEndpoint:
"""Endpoint configuration."""
def test_defaults(self):
ep = InferenceEndpoint(name="test", url="http://localhost:8080")
assert ep.health_path == "/health"
assert ep.port == 8080
assert ep.restart_cmd == ""
def test_custom(self):
ep = InferenceEndpoint(
name="llama", url="http://localhost:8081",
port=8081, restart_cmd="llama-server --port 8081",
)
assert ep.port == 8081
assert "llama-server" in ep.restart_cmd
class TestFormatReport:
"""Report formatting."""
def test_all_ok(self):
results = [
HealthResult(name="test1", url="http://localhost:8080", status="ok", latency_ms=5.0, process_alive=True),
HealthResult(name="test2", url="http://localhost:8081", status="ok", latency_ms=10.0, process_alive=True),
]
report = format_report(results)
assert "Health Check" in report
assert "test1" in report
assert "test2" in report
assert "DOWN" not in report
def test_with_down(self):
results = [
HealthResult(name="test1", url="http://localhost:8080", status="ok", latency_ms=5.0),
HealthResult(
name="test2", url="http://localhost:8081",
status="down", error="Connection refused", process_alive=False,
),
]
report = format_report(results)
assert "DOWN" in report
assert "Connection refused" in report
class TestFormatJson:
"""JSON output format."""
def test_valid_json(self):
results = [HealthResult(name="test", url="http://localhost:8080", status="ok", latency_ms=5.0)]
output = format_json(results)
data = json.loads(output)
assert "timestamp" in data
assert "endpoints" in data
assert len(data["endpoints"]) == 1
assert data["endpoints"][0]["name"] == "test"
def test_none_error_serializes(self):
results = [HealthResult(name="test", url="http://localhost:8080", status="ok")]
output = format_json(results)
data = json.loads(output)
assert data["endpoints"][0]["error"] is None

250
tools/batch_executor.py Normal file
View File

@@ -0,0 +1,250 @@
"""
Batch tool execution with parallel safety checks (#749).
Classifies tool calls as parallel-safe or sequential, then executes
parallel-safe calls concurrently while keeping destructive ops serialized.
Safety classification:
- PARALLEL-SAFE: read_file, search_files, browser_snapshot, session_search,
fact_store (search/probe/list), skill_view
- SEQUENTIAL: write_file, patch, terminal, execute_code, browser_click,
browser_type, browser_navigate, cronjob (create/update/delete),
memory (add/update/remove), skill_manage
"""
import asyncio
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# Tools that only read state — safe to parallelize
PARALLEL_SAFE_TOOLS = frozenset([
"read_file",
"search_files",
"browser_snapshot",
"browser_get_images",
"browser_back",
"browser_vision",
"browser_console",
"session_search",
"fact_store", # search/probe/list are read-only; add/update are not
"skill_view",
"skills_list",
"cronjob", # list is read-only; create/update/run are not (filtered below)
"clarify", # asking questions is safe
"memory", # probe/search/list are read-only
"vision_analyze",
])
# Tools that modify state — must be serialized
SEQUENTIAL_TOOLS = frozenset([
"write_file",
"patch",
"terminal",
"execute_code",
"browser_click",
"browser_type",
"browser_press",
"browser_scroll",
"browser_navigate",
"cronjob", # create/update/run/pause/resume/remove
"memory", # add/update/remove
"skill_manage",
"todo",
"text_to_speech",
"image_generate",
"delegate_task",
"clarify", # clarify with choices needs user input
"process",
])
# Cronjob sub-actions that are read-only
_CRON_READ_ONLY = frozenset(["list"])
@dataclass
class BatchResult:
"""Result of a batch tool execution."""
results: List[Dict[str, Any]] = field(default_factory=list)
parallel_count: int = 0
sequential_count: int = 0
elapsed_ms: float = 0
def classify_tool_call(tool_name: str, tool_args: Optional[Dict] = None) -> str:
"""Classify a tool call as 'parallel' or 'sequential'.
Returns 'parallel' or 'sequential'.
"""
# Special cases based on sub-action
if tool_name == "cronjob":
action = (tool_args or {}).get("action", "")
if action in _CRON_READ_ONLY:
return "parallel"
return "sequential"
if tool_name == "fact_store":
action = (tool_args or {}).get("action", "")
if action in ("search", "probe", "list", "related", "reason", "contradict"):
return "parallel"
return "sequential"
if tool_name == "memory":
action = (tool_args or {}).get("action", "")
if action in ("probe", "search", "list"):
return "parallel"
return "sequential"
# Check sequential first (more restrictive)
if tool_name in SEQUENTIAL_TOOLS:
return "sequential"
if tool_name in PARALLEL_SAFE_TOOLS:
return "parallel"
# Unknown tools default to sequential (safe)
return "sequential"
def classify_batch(tool_calls: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
"""Split a list of tool calls into parallel-safe and sequential groups.
Args:
tool_calls: List of dicts with 'name' and 'args' keys
Returns:
(parallel_calls, sequential_calls)
"""
parallel = []
sequential = []
for call in tool_calls:
name = call.get("name", "")
args = call.get("args", {})
classification = classify_tool_call(name, args)
if classification == "parallel":
parallel.append(call)
else:
sequential.append(call)
return parallel, sequential
async def execute_parallel(
tool_calls: List[Dict],
executor: Callable,
) -> List[Dict[str, Any]]:
"""Execute parallel-safe tool calls concurrently.
Args:
tool_calls: List of tool call dicts
executor: Async callable(tool_name, tool_args) -> result
Returns:
List of results in same order as input
"""
tasks = []
for call in tool_calls:
task = asyncio.create_task(
executor(call["name"], call.get("args", {})),
name=f"tool:{call['name']}"
)
tasks.append((call, task))
results = []
for call, task in tasks:
try:
result = await task
results.append({
"tool_name": call["name"],
"result": result,
"parallel": True,
"error": None,
})
except Exception as e:
logger.error("Parallel tool '%s' failed: %s", call["name"], e)
results.append({
"tool_name": call["name"],
"result": None,
"parallel": True,
"error": str(e),
})
return results
async def execute_sequential(
tool_calls: List[Dict],
executor: Callable,
) -> List[Dict[str, Any]]:
"""Execute sequential tool calls one at a time."""
results = []
for call in tool_calls:
try:
result = await executor(call["name"], call.get("args", {}))
results.append({
"tool_name": call["name"],
"result": result,
"parallel": False,
"error": None,
})
except Exception as e:
logger.error("Sequential tool '%s' failed: %s", call["name"], e)
results.append({
"tool_name": call["name"],
"result": None,
"parallel": False,
"error": str(e),
})
return results
async def execute_batch(
tool_calls: List[Dict],
executor: Callable,
) -> BatchResult:
"""Execute a batch of tool calls with parallel safety checks.
1. Classify each call as parallel-safe or sequential
2. Execute all parallel-safe calls concurrently
3. Execute sequential calls one at a time
4. Merge results in original order
Args:
tool_calls: List of dicts with 'name' and 'args' keys
executor: Async callable(tool_name, tool_args) -> result
Returns:
BatchResult with all results and timing
"""
start = time.monotonic()
parallel_calls, sequential_calls = classify_batch(tool_calls)
# Execute parallel-safe calls concurrently
parallel_results = []
if parallel_calls:
parallel_results = await execute_parallel(parallel_calls, executor)
# Execute sequential calls in order
sequential_results = []
if sequential_calls:
sequential_results = await execute_sequential(sequential_calls, executor)
# Merge results — parallel first, then sequential (order preserved within groups)
all_results = parallel_results + sequential_results
elapsed = (time.monotonic() - start) * 1000
return BatchResult(
results=all_results,
parallel_count=len(parallel_calls),
sequential_count=len(sequential_calls),
elapsed_ms=elapsed,
)